Merge pull request #1637 from touwaeriol/feat/websearch-notify-pricing

feat: web search emulation, balance/quota notify, account stats pricing, per-provider refund control, Stripe fix / Web 搜索模拟、余额配额通知、渠道统计计费、按服务商退款控制、Stripe 修复
This commit is contained in:
Wesley Liddick 2026-04-14 20:41:09 +08:00 committed by GitHub
commit d402e722cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
177 changed files with 13640 additions and 1198 deletions

View File

@ -17,6 +17,7 @@ jobs:
go-version-file: backend/go.mod go-version-file: backend/go.mod
check-latest: false check-latest: false
cache: true cache: true
cache-dependency-path: backend/go.sum
- name: Verify Go version - name: Verify Go version
run: | run: |
go version | grep -q 'go1.26.2' go version | grep -q 'go1.26.2'
@ -36,6 +37,7 @@ jobs:
go-version-file: backend/go.mod go-version-file: backend/go.mod
check-latest: false check-latest: false
cache: true cache: true
cache-dependency-path: backend/go.sum
- name: Verify Go version - name: Verify Go version
run: | run: |
go version | grep -q 'go1.26.2' go version | grep -q 'go1.26.2'

View File

@ -36,19 +36,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
// Business layer ProviderSets // Business layer ProviderSets
repository.ProviderSet, repository.ProviderSet,
service.ProviderSet, service.ProviderSet,
payment.ProviderSet,
middleware.ProviderSet, middleware.ProviderSet,
handler.ProviderSet, handler.ProviderSet,
// Server layer ProviderSet // Server layer ProviderSet
server.ProviderSet, server.ProviderSet,
// Payment providers
payment.ProvideRegistry,
payment.ProvideEncryptionKey,
payment.ProvideDefaultLoadBalancer,
service.ProvidePaymentConfigService,
service.ProvidePaymentOrderExpiryService,
// Privacy client factory for OpenAI training opt-out // Privacy client factory for OpenAI training opt-out
providePrivacyClientFactory, providePrivacyClientFactory,

View File

@ -50,8 +50,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
refreshTokenCache := repository.NewRefreshTokenCache(redisClient) refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
settingRepository := repository.NewSettingRepository(client) settingRepository := repository.NewSettingRepository(client)
groupRepository := repository.NewGroupRepository(client, db) groupRepository := repository.NewGroupRepository(client, db)
channelRepository := repository.NewChannelRepository(db) proxyRepository := repository.NewProxyRepository(client, db)
settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig) settingService := service.ProvideSettingService(settingRepository, groupRepository, proxyRepository, configConfig)
emailCache := repository.NewEmailCache(redisClient) emailCache := repository.NewEmailCache(redisClient)
emailService := service.NewEmailService(settingRepository, emailCache) emailService := service.NewEmailService(settingRepository, emailCache)
turnstileVerifier := repository.NewTurnstileVerifier() turnstileVerifier := repository.NewTurnstileVerifier()
@ -65,23 +65,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userGroupRateRepository := repository.NewUserGroupRateRepository(db) userGroupRateRepository := repository.NewUserGroupRateRepository(db)
apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
apiKeyService.SetRateLimitCacheInvalidator(billingCache)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
redeemCache := repository.NewRedeemCache(redisClient) redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
registry := payment.ProvideRegistry()
encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
if err != nil {
return nil, err
}
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
secretEncryptor, err := repository.NewAESEncryptor(configConfig) secretEncryptor, err := repository.NewAESEncryptor(configConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@ -89,10 +79,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
totpCache := repository.NewTotpCache(redisClient) totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService) totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
userHandler := handler.NewUserHandler(userService) userHandler := handler.NewUserHandler(userService, emailService, emailCache)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db) usageLogRepository := repository.NewUsageLogRepository(client, db)
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemHandler := handler.NewRedeemHandler(redeemService) redeemHandler := handler.NewRedeemHandler(redeemService)
@ -112,7 +101,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
schedulerCache := repository.ProvideSchedulerCache(redisClient, configConfig) schedulerCache := repository.ProvideSchedulerCache(redisClient, configConfig)
accountRepository := repository.NewAccountRepository(client, db, schedulerCache) accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
privacyClientFactory := providePrivacyClientFactory() privacyClientFactory := providePrivacyClientFactory()
@ -120,11 +108,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
rpmCache := repository.NewRPMCache(redisClient)
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
claudeOAuthClient := repository.NewClaudeOAuthClient() claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient() openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
openAIOAuthService.SetPrivacyClientFactory(privacyClientFactory)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
driveClient := repository.NewGeminiDriveClient() driveClient := repository.NewGeminiDriveClient()
@ -134,7 +125,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tempUnschedCache := repository.NewTempUnschedCache(redisClient) tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient) timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient) geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
oauthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache) compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator) rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
httpUpstream := repository.NewHTTPUpstream(configConfig) httpUpstream := repository.NewHTTPUpstream(configConfig)
@ -142,23 +132,20 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache() usageCache := service.NewUsageCache()
identityCache := repository.NewIdentityCache(redisClient) identityCache := repository.NewIdentityCache(redisClient)
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oauthRefreshAPI)
gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client) tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient) tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache) tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
oAuthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
rpmCache := repository.NewRPMCache(redisClient)
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
dataManagementService := service.NewDataManagementService() dataManagementService := service.NewDataManagementService()
@ -175,6 +162,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService) adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
promoHandler := admin.NewPromoHandler(promoService) promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db) opsRepository := repository.NewOpsRepository(db)
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig) pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil { if err != nil {
@ -183,17 +171,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
billingService := service.NewBillingService(configConfig, pricingService) billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache) identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oauthRefreshAPI) claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
digestSessionStore := service.NewDigestSessionStore() digestSessionStore := service.NewDigestSessionStore()
channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator) channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver) balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
opsHandler := admin.NewOpsHandler(opsService) opsHandler := admin.NewOpsHandler(opsService)
updateCache := repository.NewUpdateCache(redisClient) updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
@ -221,8 +210,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
channelHandler := admin.NewChannelHandler(channelService, billingService) channelHandler := admin.NewChannelHandler(channelService, billingService)
adminPaymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) registry := payment.ProvideRegistry()
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, adminPaymentHandler) encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
if err != nil {
return nil, err
}
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
@ -245,7 +244,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oauthRefreshAPI) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository) accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)

View File

@ -616,6 +616,7 @@ var (
{Name: "sort_order", Type: field.TypeInt, Default: 0}, {Name: "sort_order", Type: field.TypeInt, Default: 0},
{Name: "limits", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, {Name: "limits", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
{Name: "refund_enabled", Type: field.TypeBool, Default: false}, {Name: "refund_enabled", Type: field.TypeBool, Default: false},
{Name: "allow_user_refund", Type: field.TypeBool, Default: false},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
} }
@ -1078,6 +1079,11 @@ var (
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "totp_enabled", Type: field.TypeBool, Default: false}, {Name: "totp_enabled", Type: field.TypeBool, Default: false},
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true}, {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
{Name: "balance_notify_enabled", Type: field.TypeBool, Default: true},
{Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"},
{Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "balance_notify_extra_emails", Type: field.TypeString, Default: "[]", SchemaType: map[string]string{"postgres": "text"}},
{Name: "total_recharged", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
} }
// UsersTable holds the schema information for the "users" table. // UsersTable holds the schema information for the "users" table.
UsersTable = &schema.Table{ UsersTable = &schema.Table{

View File

@ -15642,25 +15642,26 @@ func (m *PaymentOrderMutation) ResetEdge(name string) error {
// PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph. // PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph.
type PaymentProviderInstanceMutation struct { type PaymentProviderInstanceMutation struct {
config config
op Op op Op
typ string typ string
id *int64 id *int64
provider_key *string provider_key *string
name *string name *string
_config *string _config *string
supported_types *string supported_types *string
enabled *bool enabled *bool
payment_mode *string payment_mode *string
sort_order *int sort_order *int
addsort_order *int addsort_order *int
limits *string limits *string
refund_enabled *bool refund_enabled *bool
created_at *time.Time allow_user_refund *bool
updated_at *time.Time created_at *time.Time
clearedFields map[string]struct{} updated_at *time.Time
done bool clearedFields map[string]struct{}
oldValue func(context.Context) (*PaymentProviderInstance, error) done bool
predicates []predicate.PaymentProviderInstance oldValue func(context.Context) (*PaymentProviderInstance, error)
predicates []predicate.PaymentProviderInstance
} }
var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil) var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil)
@ -16105,6 +16106,42 @@ func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() {
m.refund_enabled = nil m.refund_enabled = nil
} }
// SetAllowUserRefund sets the "allow_user_refund" field.
func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) {
m.allow_user_refund = &b
}
// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation.
func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) {
v := m.allow_user_refund
if v == nil {
return
}
return *v, true
}
// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity.
// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldAllowUserRefund requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err)
}
return oldValue.AllowUserRefund, nil
}
// ResetAllowUserRefund resets all changes to the "allow_user_refund" field.
func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() {
m.allow_user_refund = nil
}
// SetCreatedAt sets the "created_at" field. // SetCreatedAt sets the "created_at" field.
func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) { func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) {
m.created_at = &t m.created_at = &t
@ -16211,7 +16248,7 @@ func (m *PaymentProviderInstanceMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *PaymentProviderInstanceMutation) Fields() []string { func (m *PaymentProviderInstanceMutation) Fields() []string {
fields := make([]string, 0, 11) fields := make([]string, 0, 12)
if m.provider_key != nil { if m.provider_key != nil {
fields = append(fields, paymentproviderinstance.FieldProviderKey) fields = append(fields, paymentproviderinstance.FieldProviderKey)
} }
@ -16239,6 +16276,9 @@ func (m *PaymentProviderInstanceMutation) Fields() []string {
if m.refund_enabled != nil { if m.refund_enabled != nil {
fields = append(fields, paymentproviderinstance.FieldRefundEnabled) fields = append(fields, paymentproviderinstance.FieldRefundEnabled)
} }
if m.allow_user_refund != nil {
fields = append(fields, paymentproviderinstance.FieldAllowUserRefund)
}
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, paymentproviderinstance.FieldCreatedAt) fields = append(fields, paymentproviderinstance.FieldCreatedAt)
} }
@ -16271,6 +16311,8 @@ func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) {
return m.Limits() return m.Limits()
case paymentproviderinstance.FieldRefundEnabled: case paymentproviderinstance.FieldRefundEnabled:
return m.RefundEnabled() return m.RefundEnabled()
case paymentproviderinstance.FieldAllowUserRefund:
return m.AllowUserRefund()
case paymentproviderinstance.FieldCreatedAt: case paymentproviderinstance.FieldCreatedAt:
return m.CreatedAt() return m.CreatedAt()
case paymentproviderinstance.FieldUpdatedAt: case paymentproviderinstance.FieldUpdatedAt:
@ -16302,6 +16344,8 @@ func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name str
return m.OldLimits(ctx) return m.OldLimits(ctx)
case paymentproviderinstance.FieldRefundEnabled: case paymentproviderinstance.FieldRefundEnabled:
return m.OldRefundEnabled(ctx) return m.OldRefundEnabled(ctx)
case paymentproviderinstance.FieldAllowUserRefund:
return m.OldAllowUserRefund(ctx)
case paymentproviderinstance.FieldCreatedAt: case paymentproviderinstance.FieldCreatedAt:
return m.OldCreatedAt(ctx) return m.OldCreatedAt(ctx)
case paymentproviderinstance.FieldUpdatedAt: case paymentproviderinstance.FieldUpdatedAt:
@ -16378,6 +16422,13 @@ func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value)
} }
m.SetRefundEnabled(v) m.SetRefundEnabled(v)
return nil return nil
case paymentproviderinstance.FieldAllowUserRefund:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetAllowUserRefund(v)
return nil
case paymentproviderinstance.FieldCreatedAt: case paymentproviderinstance.FieldCreatedAt:
v, ok := value.(time.Time) v, ok := value.(time.Time)
if !ok { if !ok {
@ -16483,6 +16534,9 @@ func (m *PaymentProviderInstanceMutation) ResetField(name string) error {
case paymentproviderinstance.FieldRefundEnabled: case paymentproviderinstance.FieldRefundEnabled:
m.ResetRefundEnabled() m.ResetRefundEnabled()
return nil return nil
case paymentproviderinstance.FieldAllowUserRefund:
m.ResetAllowUserRefund()
return nil
case paymentproviderinstance.FieldCreatedAt: case paymentproviderinstance.FieldCreatedAt:
m.ResetCreatedAt() m.ResetCreatedAt()
return nil return nil
@ -28210,6 +28264,13 @@ type UserMutation struct {
totp_secret_encrypted *string totp_secret_encrypted *string
totp_enabled *bool totp_enabled *bool
totp_enabled_at *time.Time totp_enabled_at *time.Time
balance_notify_enabled *bool
balance_notify_threshold_type *string
balance_notify_threshold *float64
addbalance_notify_threshold *float64
balance_notify_extra_emails *string
total_recharged *float64
addtotal_recharged *float64
clearedFields map[string]struct{} clearedFields map[string]struct{}
api_keys map[int64]struct{} api_keys map[int64]struct{}
removedapi_keys map[int64]struct{} removedapi_keys map[int64]struct{}
@ -28927,6 +28988,240 @@ func (m *UserMutation) ResetTotpEnabledAt() {
delete(m.clearedFields, user.FieldTotpEnabledAt) delete(m.clearedFields, user.FieldTotpEnabledAt)
} }
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (m *UserMutation) SetBalanceNotifyEnabled(b bool) {
m.balance_notify_enabled = &b
}
// BalanceNotifyEnabled returns the value of the "balance_notify_enabled" field in the mutation.
func (m *UserMutation) BalanceNotifyEnabled() (r bool, exists bool) {
v := m.balance_notify_enabled
if v == nil {
return
}
return *v, true
}
// OldBalanceNotifyEnabled returns the old "balance_notify_enabled" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldBalanceNotifyEnabled(ctx context.Context) (v bool, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldBalanceNotifyEnabled is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldBalanceNotifyEnabled requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldBalanceNotifyEnabled: %w", err)
}
return oldValue.BalanceNotifyEnabled, nil
}
// ResetBalanceNotifyEnabled resets all changes to the "balance_notify_enabled" field.
func (m *UserMutation) ResetBalanceNotifyEnabled() {
m.balance_notify_enabled = nil
}
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
func (m *UserMutation) SetBalanceNotifyThresholdType(s string) {
m.balance_notify_threshold_type = &s
}
// BalanceNotifyThresholdType returns the value of the "balance_notify_threshold_type" field in the mutation.
func (m *UserMutation) BalanceNotifyThresholdType() (r string, exists bool) {
v := m.balance_notify_threshold_type
if v == nil {
return
}
return *v, true
}
// OldBalanceNotifyThresholdType returns the old "balance_notify_threshold_type" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldBalanceNotifyThresholdType(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldBalanceNotifyThresholdType is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldBalanceNotifyThresholdType requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldBalanceNotifyThresholdType: %w", err)
}
return oldValue.BalanceNotifyThresholdType, nil
}
// ResetBalanceNotifyThresholdType resets all changes to the "balance_notify_threshold_type" field.
func (m *UserMutation) ResetBalanceNotifyThresholdType() {
m.balance_notify_threshold_type = nil
}
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (m *UserMutation) SetBalanceNotifyThreshold(f float64) {
m.balance_notify_threshold = &f
m.addbalance_notify_threshold = nil
}
// BalanceNotifyThreshold returns the value of the "balance_notify_threshold" field in the mutation.
func (m *UserMutation) BalanceNotifyThreshold() (r float64, exists bool) {
v := m.balance_notify_threshold
if v == nil {
return
}
return *v, true
}
// OldBalanceNotifyThreshold returns the old "balance_notify_threshold" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldBalanceNotifyThreshold(ctx context.Context) (v *float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldBalanceNotifyThreshold is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldBalanceNotifyThreshold requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldBalanceNotifyThreshold: %w", err)
}
return oldValue.BalanceNotifyThreshold, nil
}
// AddBalanceNotifyThreshold adds f to the "balance_notify_threshold" field.
func (m *UserMutation) AddBalanceNotifyThreshold(f float64) {
if m.addbalance_notify_threshold != nil {
*m.addbalance_notify_threshold += f
} else {
m.addbalance_notify_threshold = &f
}
}
// AddedBalanceNotifyThreshold returns the value that was added to the "balance_notify_threshold" field in this mutation.
func (m *UserMutation) AddedBalanceNotifyThreshold() (r float64, exists bool) {
v := m.addbalance_notify_threshold
if v == nil {
return
}
return *v, true
}
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
func (m *UserMutation) ClearBalanceNotifyThreshold() {
m.balance_notify_threshold = nil
m.addbalance_notify_threshold = nil
m.clearedFields[user.FieldBalanceNotifyThreshold] = struct{}{}
}
// BalanceNotifyThresholdCleared returns if the "balance_notify_threshold" field was cleared in this mutation.
func (m *UserMutation) BalanceNotifyThresholdCleared() bool {
_, ok := m.clearedFields[user.FieldBalanceNotifyThreshold]
return ok
}
// ResetBalanceNotifyThreshold resets all changes to the "balance_notify_threshold" field.
func (m *UserMutation) ResetBalanceNotifyThreshold() {
m.balance_notify_threshold = nil
m.addbalance_notify_threshold = nil
delete(m.clearedFields, user.FieldBalanceNotifyThreshold)
}
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
func (m *UserMutation) SetBalanceNotifyExtraEmails(s string) {
m.balance_notify_extra_emails = &s
}
// BalanceNotifyExtraEmails returns the value of the "balance_notify_extra_emails" field in the mutation.
func (m *UserMutation) BalanceNotifyExtraEmails() (r string, exists bool) {
v := m.balance_notify_extra_emails
if v == nil {
return
}
return *v, true
}
// OldBalanceNotifyExtraEmails returns the old "balance_notify_extra_emails" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldBalanceNotifyExtraEmails(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldBalanceNotifyExtraEmails is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldBalanceNotifyExtraEmails requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldBalanceNotifyExtraEmails: %w", err)
}
return oldValue.BalanceNotifyExtraEmails, nil
}
// ResetBalanceNotifyExtraEmails resets all changes to the "balance_notify_extra_emails" field.
func (m *UserMutation) ResetBalanceNotifyExtraEmails() {
m.balance_notify_extra_emails = nil
}
// SetTotalRecharged sets the "total_recharged" field.
func (m *UserMutation) SetTotalRecharged(f float64) {
m.total_recharged = &f
m.addtotal_recharged = nil
}
// TotalRecharged returns the value of the "total_recharged" field in the mutation.
func (m *UserMutation) TotalRecharged() (r float64, exists bool) {
v := m.total_recharged
if v == nil {
return
}
return *v, true
}
// OldTotalRecharged returns the old "total_recharged" field's value of the User entity.
// If the User object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UserMutation) OldTotalRecharged(ctx context.Context) (v float64, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldTotalRecharged is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldTotalRecharged requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldTotalRecharged: %w", err)
}
return oldValue.TotalRecharged, nil
}
// AddTotalRecharged adds f to the "total_recharged" field.
func (m *UserMutation) AddTotalRecharged(f float64) {
if m.addtotal_recharged != nil {
*m.addtotal_recharged += f
} else {
m.addtotal_recharged = &f
}
}
// AddedTotalRecharged returns the value that was added to the "total_recharged" field in this mutation.
func (m *UserMutation) AddedTotalRecharged() (r float64, exists bool) {
v := m.addtotal_recharged
if v == nil {
return
}
return *v, true
}
// ResetTotalRecharged resets all changes to the "total_recharged" field.
func (m *UserMutation) ResetTotalRecharged() {
m.total_recharged = nil
m.addtotal_recharged = nil
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) { func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil { if m.api_keys == nil {
@ -29501,7 +29796,7 @@ func (m *UserMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *UserMutation) Fields() []string { func (m *UserMutation) Fields() []string {
fields := make([]string, 0, 14) fields := make([]string, 0, 19)
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, user.FieldCreatedAt) fields = append(fields, user.FieldCreatedAt)
} }
@ -29544,6 +29839,21 @@ func (m *UserMutation) Fields() []string {
if m.totp_enabled_at != nil { if m.totp_enabled_at != nil {
fields = append(fields, user.FieldTotpEnabledAt) fields = append(fields, user.FieldTotpEnabledAt)
} }
if m.balance_notify_enabled != nil {
fields = append(fields, user.FieldBalanceNotifyEnabled)
}
if m.balance_notify_threshold_type != nil {
fields = append(fields, user.FieldBalanceNotifyThresholdType)
}
if m.balance_notify_threshold != nil {
fields = append(fields, user.FieldBalanceNotifyThreshold)
}
if m.balance_notify_extra_emails != nil {
fields = append(fields, user.FieldBalanceNotifyExtraEmails)
}
if m.total_recharged != nil {
fields = append(fields, user.FieldTotalRecharged)
}
return fields return fields
} }
@ -29580,6 +29890,16 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
return m.TotpEnabled() return m.TotpEnabled()
case user.FieldTotpEnabledAt: case user.FieldTotpEnabledAt:
return m.TotpEnabledAt() return m.TotpEnabledAt()
case user.FieldBalanceNotifyEnabled:
return m.BalanceNotifyEnabled()
case user.FieldBalanceNotifyThresholdType:
return m.BalanceNotifyThresholdType()
case user.FieldBalanceNotifyThreshold:
return m.BalanceNotifyThreshold()
case user.FieldBalanceNotifyExtraEmails:
return m.BalanceNotifyExtraEmails()
case user.FieldTotalRecharged:
return m.TotalRecharged()
} }
return nil, false return nil, false
} }
@ -29617,6 +29937,16 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
return m.OldTotpEnabled(ctx) return m.OldTotpEnabled(ctx)
case user.FieldTotpEnabledAt: case user.FieldTotpEnabledAt:
return m.OldTotpEnabledAt(ctx) return m.OldTotpEnabledAt(ctx)
case user.FieldBalanceNotifyEnabled:
return m.OldBalanceNotifyEnabled(ctx)
case user.FieldBalanceNotifyThresholdType:
return m.OldBalanceNotifyThresholdType(ctx)
case user.FieldBalanceNotifyThreshold:
return m.OldBalanceNotifyThreshold(ctx)
case user.FieldBalanceNotifyExtraEmails:
return m.OldBalanceNotifyExtraEmails(ctx)
case user.FieldTotalRecharged:
return m.OldTotalRecharged(ctx)
} }
return nil, fmt.Errorf("unknown User field %s", name) return nil, fmt.Errorf("unknown User field %s", name)
} }
@ -29724,6 +30054,41 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
} }
m.SetTotpEnabledAt(v) m.SetTotpEnabledAt(v)
return nil return nil
case user.FieldBalanceNotifyEnabled:
v, ok := value.(bool)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetBalanceNotifyEnabled(v)
return nil
case user.FieldBalanceNotifyThresholdType:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetBalanceNotifyThresholdType(v)
return nil
case user.FieldBalanceNotifyThreshold:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetBalanceNotifyThreshold(v)
return nil
case user.FieldBalanceNotifyExtraEmails:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetBalanceNotifyExtraEmails(v)
return nil
case user.FieldTotalRecharged:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetTotalRecharged(v)
return nil
} }
return fmt.Errorf("unknown User field %s", name) return fmt.Errorf("unknown User field %s", name)
} }
@ -29738,6 +30103,12 @@ func (m *UserMutation) AddedFields() []string {
if m.addconcurrency != nil { if m.addconcurrency != nil {
fields = append(fields, user.FieldConcurrency) fields = append(fields, user.FieldConcurrency)
} }
if m.addbalance_notify_threshold != nil {
fields = append(fields, user.FieldBalanceNotifyThreshold)
}
if m.addtotal_recharged != nil {
fields = append(fields, user.FieldTotalRecharged)
}
return fields return fields
} }
@ -29750,6 +30121,10 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedBalance() return m.AddedBalance()
case user.FieldConcurrency: case user.FieldConcurrency:
return m.AddedConcurrency() return m.AddedConcurrency()
case user.FieldBalanceNotifyThreshold:
return m.AddedBalanceNotifyThreshold()
case user.FieldTotalRecharged:
return m.AddedTotalRecharged()
} }
return nil, false return nil, false
} }
@ -29773,6 +30148,20 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
} }
m.AddConcurrency(v) m.AddConcurrency(v)
return nil return nil
case user.FieldBalanceNotifyThreshold:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddBalanceNotifyThreshold(v)
return nil
case user.FieldTotalRecharged:
v, ok := value.(float64)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.AddTotalRecharged(v)
return nil
} }
return fmt.Errorf("unknown User numeric field %s", name) return fmt.Errorf("unknown User numeric field %s", name)
} }
@ -29790,6 +30179,9 @@ func (m *UserMutation) ClearedFields() []string {
if m.FieldCleared(user.FieldTotpEnabledAt) { if m.FieldCleared(user.FieldTotpEnabledAt) {
fields = append(fields, user.FieldTotpEnabledAt) fields = append(fields, user.FieldTotpEnabledAt)
} }
if m.FieldCleared(user.FieldBalanceNotifyThreshold) {
fields = append(fields, user.FieldBalanceNotifyThreshold)
}
return fields return fields
} }
@ -29813,6 +30205,9 @@ func (m *UserMutation) ClearField(name string) error {
case user.FieldTotpEnabledAt: case user.FieldTotpEnabledAt:
m.ClearTotpEnabledAt() m.ClearTotpEnabledAt()
return nil return nil
case user.FieldBalanceNotifyThreshold:
m.ClearBalanceNotifyThreshold()
return nil
} }
return fmt.Errorf("unknown User nullable field %s", name) return fmt.Errorf("unknown User nullable field %s", name)
} }
@ -29863,6 +30258,21 @@ func (m *UserMutation) ResetField(name string) error {
case user.FieldTotpEnabledAt: case user.FieldTotpEnabledAt:
m.ResetTotpEnabledAt() m.ResetTotpEnabledAt()
return nil return nil
case user.FieldBalanceNotifyEnabled:
m.ResetBalanceNotifyEnabled()
return nil
case user.FieldBalanceNotifyThresholdType:
m.ResetBalanceNotifyThresholdType()
return nil
case user.FieldBalanceNotifyThreshold:
m.ResetBalanceNotifyThreshold()
return nil
case user.FieldBalanceNotifyExtraEmails:
m.ResetBalanceNotifyExtraEmails()
return nil
case user.FieldTotalRecharged:
m.ResetTotalRecharged()
return nil
} }
return fmt.Errorf("unknown User field %s", name) return fmt.Errorf("unknown User field %s", name)
} }

View File

@ -35,6 +35,8 @@ type PaymentProviderInstance struct {
Limits string `json:"limits,omitempty"` Limits string `json:"limits,omitempty"`
// RefundEnabled holds the value of the "refund_enabled" field. // RefundEnabled holds the value of the "refund_enabled" field.
RefundEnabled bool `json:"refund_enabled,omitempty"` RefundEnabled bool `json:"refund_enabled,omitempty"`
// AllowUserRefund holds the value of the "allow_user_refund" field.
AllowUserRefund bool `json:"allow_user_refund,omitempty"`
// CreatedAt holds the value of the "created_at" field. // CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"` CreatedAt time.Time `json:"created_at,omitempty"`
// UpdatedAt holds the value of the "updated_at" field. // UpdatedAt holds the value of the "updated_at" field.
@ -47,7 +49,7 @@ func (*PaymentProviderInstance) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns)) values := make([]any, len(columns))
for i := range columns { for i := range columns {
switch columns[i] { switch columns[i] {
case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled: case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled, paymentproviderinstance.FieldAllowUserRefund:
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
case paymentproviderinstance.FieldID, paymentproviderinstance.FieldSortOrder: case paymentproviderinstance.FieldID, paymentproviderinstance.FieldSortOrder:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
@ -130,6 +132,12 @@ func (_m *PaymentProviderInstance) assignValues(columns []string, values []any)
} else if value.Valid { } else if value.Valid {
_m.RefundEnabled = value.Bool _m.RefundEnabled = value.Bool
} }
case paymentproviderinstance.FieldAllowUserRefund:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field allow_user_refund", values[i])
} else if value.Valid {
_m.AllowUserRefund = value.Bool
}
case paymentproviderinstance.FieldCreatedAt: case paymentproviderinstance.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok { if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i]) return fmt.Errorf("unexpected type %T for field created_at", values[i])
@ -205,6 +213,9 @@ func (_m *PaymentProviderInstance) String() string {
builder.WriteString("refund_enabled=") builder.WriteString("refund_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.RefundEnabled)) builder.WriteString(fmt.Sprintf("%v", _m.RefundEnabled))
builder.WriteString(", ") builder.WriteString(", ")
builder.WriteString("allow_user_refund=")
builder.WriteString(fmt.Sprintf("%v", _m.AllowUserRefund))
builder.WriteString(", ")
builder.WriteString("created_at=") builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ") builder.WriteString(", ")

View File

@ -31,6 +31,8 @@ const (
FieldLimits = "limits" FieldLimits = "limits"
// FieldRefundEnabled holds the string denoting the refund_enabled field in the database. // FieldRefundEnabled holds the string denoting the refund_enabled field in the database.
FieldRefundEnabled = "refund_enabled" FieldRefundEnabled = "refund_enabled"
// FieldAllowUserRefund holds the string denoting the allow_user_refund field in the database.
FieldAllowUserRefund = "allow_user_refund"
// FieldCreatedAt holds the string denoting the created_at field in the database. // FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at" FieldCreatedAt = "created_at"
// FieldUpdatedAt holds the string denoting the updated_at field in the database. // FieldUpdatedAt holds the string denoting the updated_at field in the database.
@ -51,6 +53,7 @@ var Columns = []string{
FieldSortOrder, FieldSortOrder,
FieldLimits, FieldLimits,
FieldRefundEnabled, FieldRefundEnabled,
FieldAllowUserRefund,
FieldCreatedAt, FieldCreatedAt,
FieldUpdatedAt, FieldUpdatedAt,
} }
@ -88,6 +91,8 @@ var (
DefaultLimits string DefaultLimits string
// DefaultRefundEnabled holds the default value on creation for the "refund_enabled" field. // DefaultRefundEnabled holds the default value on creation for the "refund_enabled" field.
DefaultRefundEnabled bool DefaultRefundEnabled bool
// DefaultAllowUserRefund holds the default value on creation for the "allow_user_refund" field.
DefaultAllowUserRefund bool
// DefaultCreatedAt holds the default value on creation for the "created_at" field. // DefaultCreatedAt holds the default value on creation for the "created_at" field.
DefaultCreatedAt func() time.Time DefaultCreatedAt func() time.Time
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field. // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
@ -149,6 +154,11 @@ func ByRefundEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRefundEnabled, opts...).ToFunc() return sql.OrderByField(FieldRefundEnabled, opts...).ToFunc()
} }
// ByAllowUserRefund orders the results by the allow_user_refund field.
func ByAllowUserRefund(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAllowUserRefund, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field. // ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()

View File

@ -99,6 +99,11 @@ func RefundEnabled(v bool) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldRefundEnabled, v)) return predicate.PaymentProviderInstance(sql.FieldEQ(FieldRefundEnabled, v))
} }
// AllowUserRefund applies equality check predicate on the "allow_user_refund" field. It's identical to AllowUserRefundEQ.
func AllowUserRefund(v bool) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.PaymentProviderInstance { func CreatedAt(v time.Time) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v)) return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v))
@ -559,6 +564,16 @@ func RefundEnabledNEQ(v bool) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldRefundEnabled, v)) return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldRefundEnabled, v))
} }
// AllowUserRefundEQ applies the EQ predicate on the "allow_user_refund" field.
func AllowUserRefundEQ(v bool) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v))
}
// AllowUserRefundNEQ applies the NEQ predicate on the "allow_user_refund" field.
func AllowUserRefundNEQ(v bool) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldAllowUserRefund, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field. // CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.PaymentProviderInstance { func CreatedAtEQ(v time.Time) predicate.PaymentProviderInstance {
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v)) return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v))

View File

@ -132,6 +132,20 @@ func (_c *PaymentProviderInstanceCreate) SetNillableRefundEnabled(v *bool) *Paym
return _c return _c
} }
// SetAllowUserRefund sets the "allow_user_refund" field.
func (_c *PaymentProviderInstanceCreate) SetAllowUserRefund(v bool) *PaymentProviderInstanceCreate {
_c.mutation.SetAllowUserRefund(v)
return _c
}
// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
func (_c *PaymentProviderInstanceCreate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceCreate {
if v != nil {
_c.SetAllowUserRefund(*v)
}
return _c
}
// SetCreatedAt sets the "created_at" field. // SetCreatedAt sets the "created_at" field.
func (_c *PaymentProviderInstanceCreate) SetCreatedAt(v time.Time) *PaymentProviderInstanceCreate { func (_c *PaymentProviderInstanceCreate) SetCreatedAt(v time.Time) *PaymentProviderInstanceCreate {
_c.mutation.SetCreatedAt(v) _c.mutation.SetCreatedAt(v)
@ -223,6 +237,10 @@ func (_c *PaymentProviderInstanceCreate) defaults() {
v := paymentproviderinstance.DefaultRefundEnabled v := paymentproviderinstance.DefaultRefundEnabled
_c.mutation.SetRefundEnabled(v) _c.mutation.SetRefundEnabled(v)
} }
if _, ok := _c.mutation.AllowUserRefund(); !ok {
v := paymentproviderinstance.DefaultAllowUserRefund
_c.mutation.SetAllowUserRefund(v)
}
if _, ok := _c.mutation.CreatedAt(); !ok { if _, ok := _c.mutation.CreatedAt(); !ok {
v := paymentproviderinstance.DefaultCreatedAt() v := paymentproviderinstance.DefaultCreatedAt()
_c.mutation.SetCreatedAt(v) _c.mutation.SetCreatedAt(v)
@ -282,6 +300,9 @@ func (_c *PaymentProviderInstanceCreate) check() error {
if _, ok := _c.mutation.RefundEnabled(); !ok { if _, ok := _c.mutation.RefundEnabled(); !ok {
return &ValidationError{Name: "refund_enabled", err: errors.New(`ent: missing required field "PaymentProviderInstance.refund_enabled"`)} return &ValidationError{Name: "refund_enabled", err: errors.New(`ent: missing required field "PaymentProviderInstance.refund_enabled"`)}
} }
if _, ok := _c.mutation.AllowUserRefund(); !ok {
return &ValidationError{Name: "allow_user_refund", err: errors.New(`ent: missing required field "PaymentProviderInstance.allow_user_refund"`)}
}
if _, ok := _c.mutation.CreatedAt(); !ok { if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PaymentProviderInstance.created_at"`)} return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PaymentProviderInstance.created_at"`)}
} }
@ -351,6 +372,10 @@ func (_c *PaymentProviderInstanceCreate) createSpec() (*PaymentProviderInstance,
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
_node.RefundEnabled = value _node.RefundEnabled = value
} }
if value, ok := _c.mutation.AllowUserRefund(); ok {
_spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
_node.AllowUserRefund = value
}
if value, ok := _c.mutation.CreatedAt(); ok { if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(paymentproviderinstance.FieldCreatedAt, field.TypeTime, value) _spec.SetField(paymentproviderinstance.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value _node.CreatedAt = value
@ -525,6 +550,18 @@ func (u *PaymentProviderInstanceUpsert) UpdateRefundEnabled() *PaymentProviderIn
return u return u
} }
// SetAllowUserRefund sets the "allow_user_refund" field.
func (u *PaymentProviderInstanceUpsert) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsert {
u.Set(paymentproviderinstance.FieldAllowUserRefund, v)
return u
}
// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
func (u *PaymentProviderInstanceUpsert) UpdateAllowUserRefund() *PaymentProviderInstanceUpsert {
u.SetExcluded(paymentproviderinstance.FieldAllowUserRefund)
return u
}
// SetUpdatedAt sets the "updated_at" field. // SetUpdatedAt sets the "updated_at" field.
func (u *PaymentProviderInstanceUpsert) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsert { func (u *PaymentProviderInstanceUpsert) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsert {
u.Set(paymentproviderinstance.FieldUpdatedAt, v) u.Set(paymentproviderinstance.FieldUpdatedAt, v)
@ -715,6 +752,20 @@ func (u *PaymentProviderInstanceUpsertOne) UpdateRefundEnabled() *PaymentProvide
}) })
} }
// SetAllowUserRefund sets the "allow_user_refund" field.
func (u *PaymentProviderInstanceUpsertOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertOne {
return u.Update(func(s *PaymentProviderInstanceUpsert) {
s.SetAllowUserRefund(v)
})
}
// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
func (u *PaymentProviderInstanceUpsertOne) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertOne {
return u.Update(func(s *PaymentProviderInstanceUpsert) {
s.UpdateAllowUserRefund()
})
}
// SetUpdatedAt sets the "updated_at" field. // SetUpdatedAt sets the "updated_at" field.
func (u *PaymentProviderInstanceUpsertOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertOne { func (u *PaymentProviderInstanceUpsertOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertOne {
return u.Update(func(s *PaymentProviderInstanceUpsert) { return u.Update(func(s *PaymentProviderInstanceUpsert) {
@ -1073,6 +1124,20 @@ func (u *PaymentProviderInstanceUpsertBulk) UpdateRefundEnabled() *PaymentProvid
}) })
} }
// SetAllowUserRefund sets the "allow_user_refund" field.
func (u *PaymentProviderInstanceUpsertBulk) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertBulk {
return u.Update(func(s *PaymentProviderInstanceUpsert) {
s.SetAllowUserRefund(v)
})
}
// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
func (u *PaymentProviderInstanceUpsertBulk) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertBulk {
return u.Update(func(s *PaymentProviderInstanceUpsert) {
s.UpdateAllowUserRefund()
})
}
// SetUpdatedAt sets the "updated_at" field. // SetUpdatedAt sets the "updated_at" field.
func (u *PaymentProviderInstanceUpsertBulk) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertBulk { func (u *PaymentProviderInstanceUpsertBulk) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertBulk {
return u.Update(func(s *PaymentProviderInstanceUpsert) { return u.Update(func(s *PaymentProviderInstanceUpsert) {

View File

@ -161,6 +161,20 @@ func (_u *PaymentProviderInstanceUpdate) SetNillableRefundEnabled(v *bool) *Paym
return _u return _u
} }
// SetAllowUserRefund sets the "allow_user_refund" field.
func (_u *PaymentProviderInstanceUpdate) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdate {
_u.mutation.SetAllowUserRefund(v)
return _u
}
// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
func (_u *PaymentProviderInstanceUpdate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdate {
if v != nil {
_u.SetAllowUserRefund(*v)
}
return _u
}
// SetUpdatedAt sets the "updated_at" field. // SetUpdatedAt sets the "updated_at" field.
func (_u *PaymentProviderInstanceUpdate) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdate { func (_u *PaymentProviderInstanceUpdate) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdate {
_u.mutation.SetUpdatedAt(v) _u.mutation.SetUpdatedAt(v)
@ -275,6 +289,9 @@ func (_u *PaymentProviderInstanceUpdate) sqlSave(ctx context.Context) (_node int
if value, ok := _u.mutation.RefundEnabled(); ok { if value, ok := _u.mutation.RefundEnabled(); ok {
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
} }
if value, ok := _u.mutation.AllowUserRefund(); ok {
_spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
}
if value, ok := _u.mutation.UpdatedAt(); ok { if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value) _spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value)
} }
@ -431,6 +448,20 @@ func (_u *PaymentProviderInstanceUpdateOne) SetNillableRefundEnabled(v *bool) *P
return _u return _u
} }
// SetAllowUserRefund sets the "allow_user_refund" field.
func (_u *PaymentProviderInstanceUpdateOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdateOne {
_u.mutation.SetAllowUserRefund(v)
return _u
}
// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
func (_u *PaymentProviderInstanceUpdateOne) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdateOne {
if v != nil {
_u.SetAllowUserRefund(*v)
}
return _u
}
// SetUpdatedAt sets the "updated_at" field. // SetUpdatedAt sets the "updated_at" field.
func (_u *PaymentProviderInstanceUpdateOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdateOne { func (_u *PaymentProviderInstanceUpdateOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdateOne {
_u.mutation.SetUpdatedAt(v) _u.mutation.SetUpdatedAt(v)
@ -575,6 +606,9 @@ func (_u *PaymentProviderInstanceUpdateOne) sqlSave(ctx context.Context) (_node
if value, ok := _u.mutation.RefundEnabled(); ok { if value, ok := _u.mutation.RefundEnabled(); ok {
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
} }
if value, ok := _u.mutation.AllowUserRefund(); ok {
_spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
}
if value, ok := _u.mutation.UpdatedAt(); ok { if value, ok := _u.mutation.UpdatedAt(); ok {
_spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value) _spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value)
} }

View File

@ -668,12 +668,16 @@ func init() {
paymentproviderinstanceDescRefundEnabled := paymentproviderinstanceFields[8].Descriptor() paymentproviderinstanceDescRefundEnabled := paymentproviderinstanceFields[8].Descriptor()
// paymentproviderinstance.DefaultRefundEnabled holds the default value on creation for the refund_enabled field. // paymentproviderinstance.DefaultRefundEnabled holds the default value on creation for the refund_enabled field.
paymentproviderinstance.DefaultRefundEnabled = paymentproviderinstanceDescRefundEnabled.Default.(bool) paymentproviderinstance.DefaultRefundEnabled = paymentproviderinstanceDescRefundEnabled.Default.(bool)
// paymentproviderinstanceDescAllowUserRefund is the schema descriptor for allow_user_refund field.
paymentproviderinstanceDescAllowUserRefund := paymentproviderinstanceFields[9].Descriptor()
// paymentproviderinstance.DefaultAllowUserRefund holds the default value on creation for the allow_user_refund field.
paymentproviderinstance.DefaultAllowUserRefund = paymentproviderinstanceDescAllowUserRefund.Default.(bool)
// paymentproviderinstanceDescCreatedAt is the schema descriptor for created_at field. // paymentproviderinstanceDescCreatedAt is the schema descriptor for created_at field.
paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[9].Descriptor() paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[10].Descriptor()
// paymentproviderinstance.DefaultCreatedAt holds the default value on creation for the created_at field. // paymentproviderinstance.DefaultCreatedAt holds the default value on creation for the created_at field.
paymentproviderinstance.DefaultCreatedAt = paymentproviderinstanceDescCreatedAt.Default.(func() time.Time) paymentproviderinstance.DefaultCreatedAt = paymentproviderinstanceDescCreatedAt.Default.(func() time.Time)
// paymentproviderinstanceDescUpdatedAt is the schema descriptor for updated_at field. // paymentproviderinstanceDescUpdatedAt is the schema descriptor for updated_at field.
paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[10].Descriptor() paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[11].Descriptor()
// paymentproviderinstance.DefaultUpdatedAt holds the default value on creation for the updated_at field. // paymentproviderinstance.DefaultUpdatedAt holds the default value on creation for the updated_at field.
paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time) paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time)
// paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. // paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
@ -1293,6 +1297,22 @@ func init() {
userDescTotpEnabled := userFields[9].Descriptor() userDescTotpEnabled := userFields[9].Descriptor()
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field. // user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool) user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
// userDescBalanceNotifyEnabled is the schema descriptor for balance_notify_enabled field.
userDescBalanceNotifyEnabled := userFields[11].Descriptor()
// user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field.
user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool)
// userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field.
userDescBalanceNotifyThresholdType := userFields[12].Descriptor()
// user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field.
user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string)
// userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field.
userDescBalanceNotifyExtraEmails := userFields[14].Descriptor()
// user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field.
user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string)
// userDescTotalRecharged is the schema descriptor for total_recharged field.
userDescTotalRecharged := userFields[15].Descriptor()
// user.DefaultTotalRecharged holds the default value on creation for the total_recharged field.
user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64)
userallowedgroupFields := schema.UserAllowedGroup{}.Fields() userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
_ = userallowedgroupFields _ = userallowedgroupFields
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field. // userallowedgroupDescCreatedAt is the schema descriptor for created_at field.

View File

@ -53,6 +53,8 @@ func (PaymentProviderInstance) Fields() []ent.Field {
Default(""), Default(""),
field.Bool("refund_enabled"). field.Bool("refund_enabled").
Default(false), Default(false),
field.Bool("allow_user_refund").
Default(false),
field.Time("created_at"). field.Time("created_at").
Immutable(). Immutable().
Default(time.Now). Default(time.Now).

View File

@ -72,6 +72,22 @@ func (User) Fields() []ent.Field {
field.Time("totp_enabled_at"). field.Time("totp_enabled_at").
Optional(). Optional().
Nillable(), Nillable(),
// 余额不足通知
field.Bool("balance_notify_enabled").
Default(true),
field.String("balance_notify_threshold_type").
Default("fixed"), // "fixed" | "percentage"
field.Float("balance_notify_threshold").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Optional().
Nillable(),
field.String("balance_notify_extra_emails").
SchemaType(map[string]string{dialect.Postgres: "text"}).
Default("[]"),
field.Float("total_recharged").
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
Default(0),
} }
} }

View File

@ -45,6 +45,16 @@ type User struct {
TotpEnabled bool `json:"totp_enabled,omitempty"` TotpEnabled bool `json:"totp_enabled,omitempty"`
// TotpEnabledAt holds the value of the "totp_enabled_at" field. // TotpEnabledAt holds the value of the "totp_enabled_at" field.
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"` TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
// BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field.
BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"`
// BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field.
BalanceNotifyThresholdType string `json:"balance_notify_threshold_type,omitempty"`
// BalanceNotifyThreshold holds the value of the "balance_notify_threshold" field.
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
// BalanceNotifyExtraEmails holds the value of the "balance_notify_extra_emails" field.
BalanceNotifyExtraEmails string `json:"balance_notify_extra_emails,omitempty"`
// TotalRecharged holds the value of the "total_recharged" field.
TotalRecharged float64 `json:"total_recharged,omitempty"`
// Edges holds the relations/edges for other nodes in the graph. // Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the UserQuery when eager-loading is set. // The values are being populated by the UserQuery when eager-loading is set.
Edges UserEdges `json:"edges"` Edges UserEdges `json:"edges"`
@ -184,13 +194,13 @@ func (*User) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns)) values := make([]any, len(columns))
for i := range columns { for i := range columns {
switch columns[i] { switch columns[i] {
case user.FieldTotpEnabled: case user.FieldTotpEnabled, user.FieldBalanceNotifyEnabled:
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
case user.FieldBalance: case user.FieldBalance, user.FieldBalanceNotifyThreshold, user.FieldTotalRecharged:
values[i] = new(sql.NullFloat64) values[i] = new(sql.NullFloat64)
case user.FieldID, user.FieldConcurrency: case user.FieldID, user.FieldConcurrency:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted: case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt: case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
@ -302,6 +312,37 @@ func (_m *User) assignValues(columns []string, values []any) error {
_m.TotpEnabledAt = new(time.Time) _m.TotpEnabledAt = new(time.Time)
*_m.TotpEnabledAt = value.Time *_m.TotpEnabledAt = value.Time
} }
case user.FieldBalanceNotifyEnabled:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i])
} else if value.Valid {
_m.BalanceNotifyEnabled = value.Bool
}
case user.FieldBalanceNotifyThresholdType:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field balance_notify_threshold_type", values[i])
} else if value.Valid {
_m.BalanceNotifyThresholdType = value.String
}
case user.FieldBalanceNotifyThreshold:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field balance_notify_threshold", values[i])
} else if value.Valid {
_m.BalanceNotifyThreshold = new(float64)
*_m.BalanceNotifyThreshold = value.Float64
}
case user.FieldBalanceNotifyExtraEmails:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field balance_notify_extra_emails", values[i])
} else if value.Valid {
_m.BalanceNotifyExtraEmails = value.String
}
case user.FieldTotalRecharged:
if value, ok := values[i].(*sql.NullFloat64); !ok {
return fmt.Errorf("unexpected type %T for field total_recharged", values[i])
} else if value.Valid {
_m.TotalRecharged = value.Float64
}
default: default:
_m.selectValues.Set(columns[i], values[i]) _m.selectValues.Set(columns[i], values[i])
} }
@ -440,6 +481,23 @@ func (_m *User) String() string {
builder.WriteString("totp_enabled_at=") builder.WriteString("totp_enabled_at=")
builder.WriteString(v.Format(time.ANSIC)) builder.WriteString(v.Format(time.ANSIC))
} }
builder.WriteString(", ")
builder.WriteString("balance_notify_enabled=")
builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled))
builder.WriteString(", ")
builder.WriteString("balance_notify_threshold_type=")
builder.WriteString(_m.BalanceNotifyThresholdType)
builder.WriteString(", ")
if v := _m.BalanceNotifyThreshold; v != nil {
builder.WriteString("balance_notify_threshold=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
builder.WriteString(", ")
builder.WriteString("balance_notify_extra_emails=")
builder.WriteString(_m.BalanceNotifyExtraEmails)
builder.WriteString(", ")
builder.WriteString("total_recharged=")
builder.WriteString(fmt.Sprintf("%v", _m.TotalRecharged))
builder.WriteByte(')') builder.WriteByte(')')
return builder.String() return builder.String()
} }

View File

@ -43,6 +43,16 @@ const (
FieldTotpEnabled = "totp_enabled" FieldTotpEnabled = "totp_enabled"
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database. // FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
FieldTotpEnabledAt = "totp_enabled_at" FieldTotpEnabledAt = "totp_enabled_at"
// FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database.
FieldBalanceNotifyEnabled = "balance_notify_enabled"
// FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database.
FieldBalanceNotifyThresholdType = "balance_notify_threshold_type"
// FieldBalanceNotifyThreshold holds the string denoting the balance_notify_threshold field in the database.
FieldBalanceNotifyThreshold = "balance_notify_threshold"
// FieldBalanceNotifyExtraEmails holds the string denoting the balance_notify_extra_emails field in the database.
FieldBalanceNotifyExtraEmails = "balance_notify_extra_emails"
// FieldTotalRecharged holds the string denoting the total_recharged field in the database.
FieldTotalRecharged = "total_recharged"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys" EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@ -161,6 +171,11 @@ var Columns = []string{
FieldTotpSecretEncrypted, FieldTotpSecretEncrypted,
FieldTotpEnabled, FieldTotpEnabled,
FieldTotpEnabledAt, FieldTotpEnabledAt,
FieldBalanceNotifyEnabled,
FieldBalanceNotifyThresholdType,
FieldBalanceNotifyThreshold,
FieldBalanceNotifyExtraEmails,
FieldTotalRecharged,
} }
var ( var (
@ -217,6 +232,14 @@ var (
DefaultNotes string DefaultNotes string
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field. // DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
DefaultTotpEnabled bool DefaultTotpEnabled bool
// DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field.
DefaultBalanceNotifyEnabled bool
// DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field.
DefaultBalanceNotifyThresholdType string
// DefaultBalanceNotifyExtraEmails holds the default value on creation for the "balance_notify_extra_emails" field.
DefaultBalanceNotifyExtraEmails string
// DefaultTotalRecharged holds the default value on creation for the "total_recharged" field.
DefaultTotalRecharged float64
) )
// OrderOption defines the ordering options for the User queries. // OrderOption defines the ordering options for the User queries.
@ -297,6 +320,31 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc() return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
} }
// ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field.
func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc()
}
// ByBalanceNotifyThresholdType orders the results by the balance_notify_threshold_type field.
func ByBalanceNotifyThresholdType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyThresholdType, opts...).ToFunc()
}
// ByBalanceNotifyThreshold orders the results by the balance_notify_threshold field.
func ByBalanceNotifyThreshold(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyThreshold, opts...).ToFunc()
}
// ByBalanceNotifyExtraEmails orders the results by the balance_notify_extra_emails field.
func ByBalanceNotifyExtraEmails(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBalanceNotifyExtraEmails, opts...).ToFunc()
}
// ByTotalRecharged orders the results by the total_recharged field.
func ByTotalRecharged(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldTotalRecharged, opts...).ToFunc()
}
// ByAPIKeysCount orders the results by api_keys count. // ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) { return func(s *sql.Selector) {

View File

@ -125,6 +125,31 @@ func TotpEnabledAt(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v)) return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
} }
// BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ.
func BalanceNotifyEnabled(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
}
// BalanceNotifyThresholdType applies equality check predicate on the "balance_notify_threshold_type" field. It's identical to BalanceNotifyThresholdTypeEQ.
func BalanceNotifyThresholdType(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThreshold applies equality check predicate on the "balance_notify_threshold" field. It's identical to BalanceNotifyThresholdEQ.
func BalanceNotifyThreshold(v float64) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v))
}
// BalanceNotifyExtraEmails applies equality check predicate on the "balance_notify_extra_emails" field. It's identical to BalanceNotifyExtraEmailsEQ.
func BalanceNotifyExtraEmails(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyExtraEmails, v))
}
// TotalRecharged applies equality check predicate on the "total_recharged" field. It's identical to TotalRechargedEQ.
func TotalRecharged(v float64) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotalRecharged, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field. // CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.User { func CreatedAtEQ(v time.Time) predicate.User {
return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
@ -860,6 +885,236 @@ func TotpEnabledAtNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt)) return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
} }
// BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field.
func BalanceNotifyEnabledEQ(v bool) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
}
// BalanceNotifyEnabledNEQ applies the NEQ predicate on the "balance_notify_enabled" field.
func BalanceNotifyEnabledNEQ(v bool) predicate.User {
return predicate.User(sql.FieldNEQ(FieldBalanceNotifyEnabled, v))
}
// BalanceNotifyThresholdTypeEQ applies the EQ predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeNEQ applies the NEQ predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeIn applies the In predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldBalanceNotifyThresholdType, vs...))
}
// BalanceNotifyThresholdTypeNotIn applies the NotIn predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldBalanceNotifyThresholdType, vs...))
}
// BalanceNotifyThresholdTypeGT applies the GT predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeGTE applies the GTE predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeLT applies the LT predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeLTE applies the LTE predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeContains applies the Contains predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeHasPrefix applies the HasPrefix predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeHasSuffix applies the HasSuffix predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeEqualFold applies the EqualFold predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdTypeContainsFold applies the ContainsFold predicate on the "balance_notify_threshold_type" field.
func BalanceNotifyThresholdTypeContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldBalanceNotifyThresholdType, v))
}
// BalanceNotifyThresholdEQ applies the EQ predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdEQ(v float64) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v))
}
// BalanceNotifyThresholdNEQ applies the NEQ predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdNEQ(v float64) predicate.User {
return predicate.User(sql.FieldNEQ(FieldBalanceNotifyThreshold, v))
}
// BalanceNotifyThresholdIn applies the In predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdIn(vs ...float64) predicate.User {
return predicate.User(sql.FieldIn(FieldBalanceNotifyThreshold, vs...))
}
// BalanceNotifyThresholdNotIn applies the NotIn predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdNotIn(vs ...float64) predicate.User {
return predicate.User(sql.FieldNotIn(FieldBalanceNotifyThreshold, vs...))
}
// BalanceNotifyThresholdGT applies the GT predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdGT(v float64) predicate.User {
return predicate.User(sql.FieldGT(FieldBalanceNotifyThreshold, v))
}
// BalanceNotifyThresholdGTE applies the GTE predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdGTE(v float64) predicate.User {
return predicate.User(sql.FieldGTE(FieldBalanceNotifyThreshold, v))
}
// BalanceNotifyThresholdLT applies the LT predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdLT(v float64) predicate.User {
return predicate.User(sql.FieldLT(FieldBalanceNotifyThreshold, v))
}
// BalanceNotifyThresholdLTE applies the LTE predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdLTE(v float64) predicate.User {
return predicate.User(sql.FieldLTE(FieldBalanceNotifyThreshold, v))
}
// BalanceNotifyThresholdIsNil applies the IsNil predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdIsNil() predicate.User {
return predicate.User(sql.FieldIsNull(FieldBalanceNotifyThreshold))
}
// BalanceNotifyThresholdNotNil applies the NotNil predicate on the "balance_notify_threshold" field.
func BalanceNotifyThresholdNotNil() predicate.User {
return predicate.User(sql.FieldNotNull(FieldBalanceNotifyThreshold))
}
// BalanceNotifyExtraEmailsEQ applies the EQ predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsEQ(v string) predicate.User {
return predicate.User(sql.FieldEQ(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsNEQ applies the NEQ predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsNEQ(v string) predicate.User {
return predicate.User(sql.FieldNEQ(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsIn applies the In predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsIn(vs ...string) predicate.User {
return predicate.User(sql.FieldIn(FieldBalanceNotifyExtraEmails, vs...))
}
// BalanceNotifyExtraEmailsNotIn applies the NotIn predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsNotIn(vs ...string) predicate.User {
return predicate.User(sql.FieldNotIn(FieldBalanceNotifyExtraEmails, vs...))
}
// BalanceNotifyExtraEmailsGT applies the GT predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsGT(v string) predicate.User {
return predicate.User(sql.FieldGT(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsGTE applies the GTE predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsGTE(v string) predicate.User {
return predicate.User(sql.FieldGTE(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsLT applies the LT predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsLT(v string) predicate.User {
return predicate.User(sql.FieldLT(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsLTE applies the LTE predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsLTE(v string) predicate.User {
return predicate.User(sql.FieldLTE(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsContains applies the Contains predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsContains(v string) predicate.User {
return predicate.User(sql.FieldContains(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsHasPrefix applies the HasPrefix predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsHasPrefix(v string) predicate.User {
return predicate.User(sql.FieldHasPrefix(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsHasSuffix applies the HasSuffix predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsHasSuffix(v string) predicate.User {
return predicate.User(sql.FieldHasSuffix(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsEqualFold applies the EqualFold predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsEqualFold(v string) predicate.User {
return predicate.User(sql.FieldEqualFold(FieldBalanceNotifyExtraEmails, v))
}
// BalanceNotifyExtraEmailsContainsFold applies the ContainsFold predicate on the "balance_notify_extra_emails" field.
func BalanceNotifyExtraEmailsContainsFold(v string) predicate.User {
return predicate.User(sql.FieldContainsFold(FieldBalanceNotifyExtraEmails, v))
}
// TotalRechargedEQ applies the EQ predicate on the "total_recharged" field.
func TotalRechargedEQ(v float64) predicate.User {
return predicate.User(sql.FieldEQ(FieldTotalRecharged, v))
}
// TotalRechargedNEQ applies the NEQ predicate on the "total_recharged" field.
func TotalRechargedNEQ(v float64) predicate.User {
return predicate.User(sql.FieldNEQ(FieldTotalRecharged, v))
}
// TotalRechargedIn applies the In predicate on the "total_recharged" field.
func TotalRechargedIn(vs ...float64) predicate.User {
return predicate.User(sql.FieldIn(FieldTotalRecharged, vs...))
}
// TotalRechargedNotIn applies the NotIn predicate on the "total_recharged" field.
func TotalRechargedNotIn(vs ...float64) predicate.User {
return predicate.User(sql.FieldNotIn(FieldTotalRecharged, vs...))
}
// TotalRechargedGT applies the GT predicate on the "total_recharged" field.
func TotalRechargedGT(v float64) predicate.User {
return predicate.User(sql.FieldGT(FieldTotalRecharged, v))
}
// TotalRechargedGTE applies the GTE predicate on the "total_recharged" field.
func TotalRechargedGTE(v float64) predicate.User {
return predicate.User(sql.FieldGTE(FieldTotalRecharged, v))
}
// TotalRechargedLT applies the LT predicate on the "total_recharged" field.
func TotalRechargedLT(v float64) predicate.User {
return predicate.User(sql.FieldLT(FieldTotalRecharged, v))
}
// TotalRechargedLTE applies the LTE predicate on the "total_recharged" field.
func TotalRechargedLTE(v float64) predicate.User {
return predicate.User(sql.FieldLTE(FieldTotalRecharged, v))
}
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.User { func HasAPIKeys() predicate.User {
return predicate.User(func(s *sql.Selector) { return predicate.User(func(s *sql.Selector) {

View File

@ -211,6 +211,76 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
return _c return _c
} }
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate {
_c.mutation.SetBalanceNotifyEnabled(v)
return _c
}
// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
func (_c *UserCreate) SetNillableBalanceNotifyEnabled(v *bool) *UserCreate {
if v != nil {
_c.SetBalanceNotifyEnabled(*v)
}
return _c
}
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
func (_c *UserCreate) SetBalanceNotifyThresholdType(v string) *UserCreate {
_c.mutation.SetBalanceNotifyThresholdType(v)
return _c
}
// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
func (_c *UserCreate) SetNillableBalanceNotifyThresholdType(v *string) *UserCreate {
if v != nil {
_c.SetBalanceNotifyThresholdType(*v)
}
return _c
}
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (_c *UserCreate) SetBalanceNotifyThreshold(v float64) *UserCreate {
_c.mutation.SetBalanceNotifyThreshold(v)
return _c
}
// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
func (_c *UserCreate) SetNillableBalanceNotifyThreshold(v *float64) *UserCreate {
if v != nil {
_c.SetBalanceNotifyThreshold(*v)
}
return _c
}
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
func (_c *UserCreate) SetBalanceNotifyExtraEmails(v string) *UserCreate {
_c.mutation.SetBalanceNotifyExtraEmails(v)
return _c
}
// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
func (_c *UserCreate) SetNillableBalanceNotifyExtraEmails(v *string) *UserCreate {
if v != nil {
_c.SetBalanceNotifyExtraEmails(*v)
}
return _c
}
// SetTotalRecharged sets the "total_recharged" field.
func (_c *UserCreate) SetTotalRecharged(v float64) *UserCreate {
_c.mutation.SetTotalRecharged(v)
return _c
}
// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
func (_c *UserCreate) SetNillableTotalRecharged(v *float64) *UserCreate {
if v != nil {
_c.SetTotalRecharged(*v)
}
return _c
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate { func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
_c.mutation.AddAPIKeyIDs(ids...) _c.mutation.AddAPIKeyIDs(ids...)
@ -440,6 +510,22 @@ func (_c *UserCreate) defaults() error {
v := user.DefaultTotpEnabled v := user.DefaultTotpEnabled
_c.mutation.SetTotpEnabled(v) _c.mutation.SetTotpEnabled(v)
} }
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
v := user.DefaultBalanceNotifyEnabled
_c.mutation.SetBalanceNotifyEnabled(v)
}
if _, ok := _c.mutation.BalanceNotifyThresholdType(); !ok {
v := user.DefaultBalanceNotifyThresholdType
_c.mutation.SetBalanceNotifyThresholdType(v)
}
if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok {
v := user.DefaultBalanceNotifyExtraEmails
_c.mutation.SetBalanceNotifyExtraEmails(v)
}
if _, ok := _c.mutation.TotalRecharged(); !ok {
v := user.DefaultTotalRecharged
_c.mutation.SetTotalRecharged(v)
}
return nil return nil
} }
@ -503,6 +589,18 @@ func (_c *UserCreate) check() error {
if _, ok := _c.mutation.TotpEnabled(); !ok { if _, ok := _c.mutation.TotpEnabled(); !ok {
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)} return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
} }
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)}
}
if _, ok := _c.mutation.BalanceNotifyThresholdType(); !ok {
return &ValidationError{Name: "balance_notify_threshold_type", err: errors.New(`ent: missing required field "User.balance_notify_threshold_type"`)}
}
if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok {
return &ValidationError{Name: "balance_notify_extra_emails", err: errors.New(`ent: missing required field "User.balance_notify_extra_emails"`)}
}
if _, ok := _c.mutation.TotalRecharged(); !ok {
return &ValidationError{Name: "total_recharged", err: errors.New(`ent: missing required field "User.total_recharged"`)}
}
return nil return nil
} }
@ -586,6 +684,26 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value) _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
_node.TotpEnabledAt = &value _node.TotpEnabledAt = &value
} }
if value, ok := _c.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
_node.BalanceNotifyEnabled = value
}
if value, ok := _c.mutation.BalanceNotifyThresholdType(); ok {
_spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
_node.BalanceNotifyThresholdType = value
}
if value, ok := _c.mutation.BalanceNotifyThreshold(); ok {
_spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
_node.BalanceNotifyThreshold = &value
}
if value, ok := _c.mutation.BalanceNotifyExtraEmails(); ok {
_spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
_node.BalanceNotifyExtraEmails = value
}
if value, ok := _c.mutation.TotalRecharged(); ok {
_spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
_node.TotalRecharged = value
}
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,
@ -988,6 +1106,84 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
return u return u
} }
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert {
u.Set(user.FieldBalanceNotifyEnabled, v)
return u
}
// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
func (u *UserUpsert) UpdateBalanceNotifyEnabled() *UserUpsert {
u.SetExcluded(user.FieldBalanceNotifyEnabled)
return u
}
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
func (u *UserUpsert) SetBalanceNotifyThresholdType(v string) *UserUpsert {
u.Set(user.FieldBalanceNotifyThresholdType, v)
return u
}
// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
func (u *UserUpsert) UpdateBalanceNotifyThresholdType() *UserUpsert {
u.SetExcluded(user.FieldBalanceNotifyThresholdType)
return u
}
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (u *UserUpsert) SetBalanceNotifyThreshold(v float64) *UserUpsert {
u.Set(user.FieldBalanceNotifyThreshold, v)
return u
}
// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
func (u *UserUpsert) UpdateBalanceNotifyThreshold() *UserUpsert {
u.SetExcluded(user.FieldBalanceNotifyThreshold)
return u
}
// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
func (u *UserUpsert) AddBalanceNotifyThreshold(v float64) *UserUpsert {
u.Add(user.FieldBalanceNotifyThreshold, v)
return u
}
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
func (u *UserUpsert) ClearBalanceNotifyThreshold() *UserUpsert {
u.SetNull(user.FieldBalanceNotifyThreshold)
return u
}
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
func (u *UserUpsert) SetBalanceNotifyExtraEmails(v string) *UserUpsert {
u.Set(user.FieldBalanceNotifyExtraEmails, v)
return u
}
// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
func (u *UserUpsert) UpdateBalanceNotifyExtraEmails() *UserUpsert {
u.SetExcluded(user.FieldBalanceNotifyExtraEmails)
return u
}
// SetTotalRecharged sets the "total_recharged" field.
func (u *UserUpsert) SetTotalRecharged(v float64) *UserUpsert {
u.Set(user.FieldTotalRecharged, v)
return u
}
// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
func (u *UserUpsert) UpdateTotalRecharged() *UserUpsert {
u.SetExcluded(user.FieldTotalRecharged)
return u
}
// AddTotalRecharged adds v to the "total_recharged" field.
func (u *UserUpsert) AddTotalRecharged(v float64) *UserUpsert {
u.Add(user.FieldTotalRecharged, v)
return u
}
// UpdateNewValues updates the mutable fields using the new values that were set on create. // UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using: // Using this option is equivalent to using:
// //
@ -1250,6 +1446,97 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
}) })
} }
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetBalanceNotifyEnabled(v)
})
}
// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateBalanceNotifyEnabled() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateBalanceNotifyEnabled()
})
}
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
func (u *UserUpsertOne) SetBalanceNotifyThresholdType(v string) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetBalanceNotifyThresholdType(v)
})
}
// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateBalanceNotifyThresholdType() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateBalanceNotifyThresholdType()
})
}
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (u *UserUpsertOne) SetBalanceNotifyThreshold(v float64) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetBalanceNotifyThreshold(v)
})
}
// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
func (u *UserUpsertOne) AddBalanceNotifyThreshold(v float64) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.AddBalanceNotifyThreshold(v)
})
}
// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateBalanceNotifyThreshold() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateBalanceNotifyThreshold()
})
}
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
func (u *UserUpsertOne) ClearBalanceNotifyThreshold() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.ClearBalanceNotifyThreshold()
})
}
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
func (u *UserUpsertOne) SetBalanceNotifyExtraEmails(v string) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetBalanceNotifyExtraEmails(v)
})
}
// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateBalanceNotifyExtraEmails() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateBalanceNotifyExtraEmails()
})
}
// SetTotalRecharged sets the "total_recharged" field.
func (u *UserUpsertOne) SetTotalRecharged(v float64) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.SetTotalRecharged(v)
})
}
// AddTotalRecharged adds v to the "total_recharged" field.
func (u *UserUpsertOne) AddTotalRecharged(v float64) *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.AddTotalRecharged(v)
})
}
// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
func (u *UserUpsertOne) UpdateTotalRecharged() *UserUpsertOne {
return u.Update(func(s *UserUpsert) {
s.UpdateTotalRecharged()
})
}
// Exec executes the query. // Exec executes the query.
func (u *UserUpsertOne) Exec(ctx context.Context) error { func (u *UserUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 { if len(u.create.conflict) == 0 {
@ -1678,6 +1965,97 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
}) })
} }
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetBalanceNotifyEnabled(v)
})
}
// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateBalanceNotifyEnabled() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateBalanceNotifyEnabled()
})
}
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
func (u *UserUpsertBulk) SetBalanceNotifyThresholdType(v string) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetBalanceNotifyThresholdType(v)
})
}
// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateBalanceNotifyThresholdType() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateBalanceNotifyThresholdType()
})
}
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (u *UserUpsertBulk) SetBalanceNotifyThreshold(v float64) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetBalanceNotifyThreshold(v)
})
}
// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
func (u *UserUpsertBulk) AddBalanceNotifyThreshold(v float64) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.AddBalanceNotifyThreshold(v)
})
}
// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateBalanceNotifyThreshold() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateBalanceNotifyThreshold()
})
}
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
func (u *UserUpsertBulk) ClearBalanceNotifyThreshold() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.ClearBalanceNotifyThreshold()
})
}
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
func (u *UserUpsertBulk) SetBalanceNotifyExtraEmails(v string) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetBalanceNotifyExtraEmails(v)
})
}
// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateBalanceNotifyExtraEmails() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateBalanceNotifyExtraEmails()
})
}
// SetTotalRecharged sets the "total_recharged" field.
func (u *UserUpsertBulk) SetTotalRecharged(v float64) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.SetTotalRecharged(v)
})
}
// AddTotalRecharged adds v to the "total_recharged" field.
func (u *UserUpsertBulk) AddTotalRecharged(v float64) *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.AddTotalRecharged(v)
})
}
// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
func (u *UserUpsertBulk) UpdateTotalRecharged() *UserUpsertBulk {
return u.Update(func(s *UserUpsert) {
s.UpdateTotalRecharged()
})
}
// Exec executes the query. // Exec executes the query.
func (u *UserUpsertBulk) Exec(ctx context.Context) error { func (u *UserUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil { if u.create.err != nil {

View File

@ -243,6 +243,96 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
return _u return _u
} }
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate {
_u.mutation.SetBalanceNotifyEnabled(v)
return _u
}
// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
func (_u *UserUpdate) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdate {
if v != nil {
_u.SetBalanceNotifyEnabled(*v)
}
return _u
}
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
func (_u *UserUpdate) SetBalanceNotifyThresholdType(v string) *UserUpdate {
_u.mutation.SetBalanceNotifyThresholdType(v)
return _u
}
// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
func (_u *UserUpdate) SetNillableBalanceNotifyThresholdType(v *string) *UserUpdate {
if v != nil {
_u.SetBalanceNotifyThresholdType(*v)
}
return _u
}
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (_u *UserUpdate) SetBalanceNotifyThreshold(v float64) *UserUpdate {
_u.mutation.ResetBalanceNotifyThreshold()
_u.mutation.SetBalanceNotifyThreshold(v)
return _u
}
// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
func (_u *UserUpdate) SetNillableBalanceNotifyThreshold(v *float64) *UserUpdate {
if v != nil {
_u.SetBalanceNotifyThreshold(*v)
}
return _u
}
// AddBalanceNotifyThreshold adds value to the "balance_notify_threshold" field.
func (_u *UserUpdate) AddBalanceNotifyThreshold(v float64) *UserUpdate {
_u.mutation.AddBalanceNotifyThreshold(v)
return _u
}
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
func (_u *UserUpdate) ClearBalanceNotifyThreshold() *UserUpdate {
_u.mutation.ClearBalanceNotifyThreshold()
return _u
}
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
func (_u *UserUpdate) SetBalanceNotifyExtraEmails(v string) *UserUpdate {
_u.mutation.SetBalanceNotifyExtraEmails(v)
return _u
}
// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
func (_u *UserUpdate) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpdate {
if v != nil {
_u.SetBalanceNotifyExtraEmails(*v)
}
return _u
}
// SetTotalRecharged sets the "total_recharged" field.
func (_u *UserUpdate) SetTotalRecharged(v float64) *UserUpdate {
_u.mutation.ResetTotalRecharged()
_u.mutation.SetTotalRecharged(v)
return _u
}
// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
func (_u *UserUpdate) SetNillableTotalRecharged(v *float64) *UserUpdate {
if v != nil {
_u.SetTotalRecharged(*v)
}
return _u
}
// AddTotalRecharged adds value to the "total_recharged" field.
func (_u *UserUpdate) AddTotalRecharged(v float64) *UserUpdate {
_u.mutation.AddTotalRecharged(v)
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate { func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
_u.mutation.AddAPIKeyIDs(ids...) _u.mutation.AddAPIKeyIDs(ids...)
@ -746,6 +836,30 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.TotpEnabledAtCleared() { if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
} }
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.BalanceNotifyThresholdType(); ok {
_spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
}
if value, ok := _u.mutation.BalanceNotifyThreshold(); ok {
_spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedBalanceNotifyThreshold(); ok {
_spec.AddField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
}
if _u.mutation.BalanceNotifyThresholdCleared() {
_spec.ClearField(user.FieldBalanceNotifyThreshold, field.TypeFloat64)
}
if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok {
_spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
}
if value, ok := _u.mutation.TotalRecharged(); ok {
_spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedTotalRecharged(); ok {
_spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
}
if _u.mutation.APIKeysCleared() { if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,
@ -1434,6 +1548,96 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
return _u return _u
} }
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne {
_u.mutation.SetBalanceNotifyEnabled(v)
return _u
}
// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdateOne {
if v != nil {
_u.SetBalanceNotifyEnabled(*v)
}
return _u
}
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
func (_u *UserUpdateOne) SetBalanceNotifyThresholdType(v string) *UserUpdateOne {
_u.mutation.SetBalanceNotifyThresholdType(v)
return _u
}
// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableBalanceNotifyThresholdType(v *string) *UserUpdateOne {
if v != nil {
_u.SetBalanceNotifyThresholdType(*v)
}
return _u
}
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
func (_u *UserUpdateOne) SetBalanceNotifyThreshold(v float64) *UserUpdateOne {
_u.mutation.ResetBalanceNotifyThreshold()
_u.mutation.SetBalanceNotifyThreshold(v)
return _u
}
// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableBalanceNotifyThreshold(v *float64) *UserUpdateOne {
if v != nil {
_u.SetBalanceNotifyThreshold(*v)
}
return _u
}
// AddBalanceNotifyThreshold adds value to the "balance_notify_threshold" field.
func (_u *UserUpdateOne) AddBalanceNotifyThreshold(v float64) *UserUpdateOne {
_u.mutation.AddBalanceNotifyThreshold(v)
return _u
}
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
func (_u *UserUpdateOne) ClearBalanceNotifyThreshold() *UserUpdateOne {
_u.mutation.ClearBalanceNotifyThreshold()
return _u
}
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
func (_u *UserUpdateOne) SetBalanceNotifyExtraEmails(v string) *UserUpdateOne {
_u.mutation.SetBalanceNotifyExtraEmails(v)
return _u
}
// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpdateOne {
if v != nil {
_u.SetBalanceNotifyExtraEmails(*v)
}
return _u
}
// SetTotalRecharged sets the "total_recharged" field.
func (_u *UserUpdateOne) SetTotalRecharged(v float64) *UserUpdateOne {
_u.mutation.ResetTotalRecharged()
_u.mutation.SetTotalRecharged(v)
return _u
}
// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
func (_u *UserUpdateOne) SetNillableTotalRecharged(v *float64) *UserUpdateOne {
if v != nil {
_u.SetTotalRecharged(*v)
}
return _u
}
// AddTotalRecharged adds value to the "total_recharged" field.
func (_u *UserUpdateOne) AddTotalRecharged(v float64) *UserUpdateOne {
_u.mutation.AddTotalRecharged(v)
return _u
}
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne { func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...) _u.mutation.AddAPIKeyIDs(ids...)
@ -1967,6 +2171,30 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
if _u.mutation.TotpEnabledAtCleared() { if _u.mutation.TotpEnabledAtCleared() {
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
} }
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
}
if value, ok := _u.mutation.BalanceNotifyThresholdType(); ok {
_spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
}
if value, ok := _u.mutation.BalanceNotifyThreshold(); ok {
_spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedBalanceNotifyThreshold(); ok {
_spec.AddField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
}
if _u.mutation.BalanceNotifyThresholdCleared() {
_spec.ClearField(user.FieldBalanceNotifyThreshold, field.TypeFloat64)
}
if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok {
_spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
}
if value, ok := _u.mutation.TotalRecharged(); ok {
_spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
}
if value, ok := _u.mutation.AddedTotalRecharged(); ok {
_spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
}
if _u.mutation.APIKeysCleared() { if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,

View File

@ -183,6 +183,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@ -218,6 +220,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@ -251,6 +255,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@ -280,6 +286,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
@ -312,6 +320,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=

View File

@ -28,7 +28,7 @@ const (
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support // DefaultCSPPolicy is the default Content-Security-Policy with nonce support
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware // __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com https://*.stripe.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com https://*.stripe.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// UMQ用户消息队列模式常量 // UMQ用户消息队列模式常量
const ( const (

View File

@ -233,12 +233,13 @@ func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
configPath := filepath.Join(tempDir, "config.yaml") configPath := filepath.Join(tempDir, "config.yaml")
require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644)) require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644))
require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+templatePath+"\"\n"), 0o644)) yamlSafePath := filepath.ToSlash(templatePath)
require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+yamlSafePath+"\"\n"), 0o644))
t.Setenv("DATA_DIR", tempDir) t.Setenv("DATA_DIR", tempDir)
cfg, err := Load() cfg, err := Load()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, templatePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile) require.Equal(t, yamlSafePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile)
require.Equal(t, "server-prefix\n\n{{ .ExistingInstructions }}", cfg.Gateway.ForcedCodexInstructionsTemplate) require.Equal(t, "server-prefix\n\n{{ .ExistingInstructions }}", cfg.Gateway.ForcedCodexInstructionsTemplate)
} }

View File

@ -1412,6 +1412,12 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
c.JSON(409, gin.H{ c.JSON(409, gin.H{
"error": "mixed_channel_warning", "error": "mixed_channel_warning",
"message": mixedErr.Error(), "message": mixedErr.Error(),
"details": gin.H{
"group_id": mixedErr.GroupID,
"group_name": mixedErr.GroupName,
"current_platform": mixedErr.CurrentPlatform,
"other_platform": mixedErr.OtherPlatform,
},
}) })
return return
} }

View File

@ -1,6 +1,7 @@
package admin package admin
import ( import (
"fmt"
"strconv" "strconv"
"strings" "strings"
@ -26,24 +27,32 @@ func NewChannelHandler(channelService *service.ChannelService, billingService *s
// --- Request / Response types --- // --- Request / Response types ---
type createChannelRequest struct { type createChannelRequest struct {
Name string `json:"name" binding:"required,max=100"` Name string `json:"name" binding:"required,max=100"`
Description string `json:"description"` Description string `json:"description"`
GroupIDs []int64 `json:"group_ids"` GroupIDs []int64 `json:"group_ids"`
ModelPricing []channelModelPricingRequest `json:"model_pricing"` ModelPricing []channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"` ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels bool `json:"restrict_models"` RestrictModels bool `json:"restrict_models"`
Features string `json:"features"`
FeaturesConfig map[string]any `json:"features_config"`
ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
AccountStatsPricingRules []accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
} }
type updateChannelRequest struct { type updateChannelRequest struct {
Name string `json:"name" binding:"omitempty,max=100"` Name string `json:"name" binding:"omitempty,max=100"`
Description *string `json:"description"` Description *string `json:"description"`
Status string `json:"status" binding:"omitempty,oneof=active disabled"` Status string `json:"status" binding:"omitempty,oneof=active disabled"`
GroupIDs *[]int64 `json:"group_ids"` GroupIDs *[]int64 `json:"group_ids"`
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
ModelMapping map[string]map[string]string `json:"model_mapping"` ModelMapping map[string]map[string]string `json:"model_mapping"`
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
RestrictModels *bool `json:"restrict_models"` RestrictModels *bool `json:"restrict_models"`
Features *string `json:"features"`
FeaturesConfig map[string]any `json:"features_config"`
ApplyPricingToAccountStats *bool `json:"apply_pricing_to_account_stats"`
AccountStatsPricingRules *[]accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
} }
type channelModelPricingRequest struct { type channelModelPricingRequest struct {
@ -71,18 +80,29 @@ type pricingIntervalRequest struct {
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order"`
} }
type accountStatsPricingRuleRequest struct {
Name string `json:"name"`
GroupIDs []int64 `json:"group_ids"`
AccountIDs []int64 `json:"account_ids"`
Pricing []channelModelPricingRequest `json:"pricing"`
}
type channelResponse struct { type channelResponse struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
Status string `json:"status"` Status string `json:"status"`
BillingModelSource string `json:"billing_model_source"` BillingModelSource string `json:"billing_model_source"`
RestrictModels bool `json:"restrict_models"` RestrictModels bool `json:"restrict_models"`
GroupIDs []int64 `json:"group_ids"` Features string `json:"features"`
ModelPricing []channelModelPricingResponse `json:"model_pricing"` FeaturesConfig map[string]any `json:"features_config"`
ModelMapping map[string]map[string]string `json:"model_mapping"` GroupIDs []int64 `json:"group_ids"`
CreatedAt string `json:"created_at"` ModelPricing []channelModelPricingResponse `json:"model_pricing"`
UpdatedAt string `json:"updated_at"` ModelMapping map[string]map[string]string `json:"model_mapping"`
ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
AccountStatsPricingRules []accountStatsPricingRuleResponse `json:"account_stats_pricing_rules"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
} }
type channelModelPricingResponse struct { type channelModelPricingResponse struct {
@ -112,6 +132,14 @@ type pricingIntervalResponse struct {
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order"`
} }
type accountStatsPricingRuleResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
GroupIDs []int64 `json:"group_ids"`
AccountIDs []int64 `json:"account_ids"`
Pricing []channelModelPricingResponse `json:"pricing"`
}
func channelToResponse(ch *service.Channel) *channelResponse { func channelToResponse(ch *service.Channel) *channelResponse {
if ch == nil { if ch == nil {
return nil return nil
@ -122,6 +150,8 @@ func channelToResponse(ch *service.Channel) *channelResponse {
Description: ch.Description, Description: ch.Description,
Status: ch.Status, Status: ch.Status,
RestrictModels: ch.RestrictModels, RestrictModels: ch.RestrictModels,
Features: ch.Features,
FeaturesConfig: ch.FeaturesConfig,
GroupIDs: ch.GroupIDs, GroupIDs: ch.GroupIDs,
ModelMapping: ch.ModelMapping, ModelMapping: ch.ModelMapping,
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
@ -142,6 +172,29 @@ func channelToResponse(ch *service.Channel) *channelResponse {
for _, p := range ch.ModelPricing { for _, p := range ch.ModelPricing {
resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p)) resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
} }
resp.ApplyPricingToAccountStats = ch.ApplyPricingToAccountStats
resp.AccountStatsPricingRules = make([]accountStatsPricingRuleResponse, 0, len(ch.AccountStatsPricingRules))
for _, rule := range ch.AccountStatsPricingRules {
ruleResp := accountStatsPricingRuleResponse{
ID: rule.ID,
Name: rule.Name,
GroupIDs: rule.GroupIDs,
AccountIDs: rule.AccountIDs,
Pricing: make([]channelModelPricingResponse, 0, len(rule.Pricing)),
}
if ruleResp.GroupIDs == nil {
ruleResp.GroupIDs = []int64{}
}
if ruleResp.AccountIDs == nil {
ruleResp.AccountIDs = []int64{}
}
for i := range rule.Pricing {
ruleResp.Pricing = append(ruleResp.Pricing, pricingToResponse(&rule.Pricing[i]))
}
resp.AccountStatsPricingRules = append(resp.AccountStatsPricingRules, ruleResp)
}
return resp return resp
} }
@ -200,9 +253,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
billingMode = service.BillingModeToken billingMode = service.BillingModeToken
} }
platform := r.Platform platform := r.Platform
if platform == "" {
platform = service.PlatformAnthropic
}
intervals := make([]service.PricingInterval, 0, len(r.Intervals)) intervals := make([]service.PricingInterval, 0, len(r.Intervals))
for _, iv := range r.Intervals { for _, iv := range r.Intervals {
intervals = append(intervals, service.PricingInterval{ intervals = append(intervals, service.PricingInterval{
@ -233,6 +283,15 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
return result return result
} }
func accountStatsPricingRuleRequestToService(r accountStatsPricingRuleRequest) service.AccountStatsPricingRule {
return service.AccountStatsPricingRule{
Name: r.Name,
GroupIDs: r.GroupIDs,
AccountIDs: r.AccountIDs,
Pricing: pricingRequestToService(r.Pricing),
}
}
// --- Handlers --- // --- Handlers ---
// List handles listing channels with pagination // List handles listing channels with pagination
@ -291,15 +350,42 @@ func (h *ChannelHandler) Create(c *gin.Context) {
} }
pricing := pricingRequestToService(req.ModelPricing) pricing := pricingRequestToService(req.ModelPricing)
// Main model_pricing requires a platform; default to anthropic for backward compatibility.
for i := range pricing {
if pricing[i].Platform == "" {
pricing[i].Platform = service.PlatformAnthropic
}
}
var statsRules []service.AccountStatsPricingRule
for i, r := range req.AccountStatsPricingRules {
if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE",
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
return
}
if len(r.Pricing) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
return
}
rule := accountStatsPricingRuleRequestToService(r)
rule.SortOrder = i
statsRules = append(statsRules, rule)
}
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
GroupIDs: req.GroupIDs, GroupIDs: req.GroupIDs,
ModelPricing: pricing, ModelPricing: pricing,
ModelMapping: req.ModelMapping, ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource, BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels, RestrictModels: req.RestrictModels,
Features: req.Features,
FeaturesConfig: req.FeaturesConfig,
ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
AccountStatsPricingRules: statsRules,
}) })
if err != nil { if err != nil {
response.ErrorFrom(c, err) response.ErrorFrom(c, err)
@ -325,18 +411,45 @@ func (h *ChannelHandler) Update(c *gin.Context) {
} }
input := &service.UpdateChannelInput{ input := &service.UpdateChannelInput{
Name: req.Name, Name: req.Name,
Description: req.Description, Description: req.Description,
Status: req.Status, Status: req.Status,
GroupIDs: req.GroupIDs, GroupIDs: req.GroupIDs,
ModelMapping: req.ModelMapping, ModelMapping: req.ModelMapping,
BillingModelSource: req.BillingModelSource, BillingModelSource: req.BillingModelSource,
RestrictModels: req.RestrictModels, RestrictModels: req.RestrictModels,
Features: req.Features,
FeaturesConfig: req.FeaturesConfig,
ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
} }
if req.ModelPricing != nil { if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing) pricing := pricingRequestToService(*req.ModelPricing)
for i := range pricing {
if pricing[i].Platform == "" {
pricing[i].Platform = service.PlatformAnthropic
}
}
input.ModelPricing = &pricing input.ModelPricing = &pricing
} }
if req.AccountStatsPricingRules != nil {
statsRules := make([]service.AccountStatsPricingRule, 0, len(*req.AccountStatsPricingRules))
for i, r := range *req.AccountStatsPricingRules {
if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE",
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
return
}
if len(r.Pricing) == 0 {
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
return
}
rule := accountStatsPricingRuleRequestToService(r)
rule.SortOrder = i
statsRules = append(statsRules, rule)
}
input.AccountStatsPricingRules = &statsRules
}
channel, err := h.channelService.Update(c.Request.Context(), id, input) channel, err := h.channelService.Update(c.Request.Context(), id, input)
if err != nil { if err != nil {

View File

@ -273,13 +273,13 @@ func TestPricingRequestToService_Defaults(t *testing.T) {
wantValue: string(service.BillingModeToken), wantValue: string(service.BillingModeToken),
}, },
{ {
name: "empty platform defaults to anthropic", name: "empty platform stays empty",
req: channelModelPricingRequest{ req: channelModelPricingRequest{
Models: []string{"m1"}, Models: []string{"m1"},
Platform: "", Platform: "",
}, },
wantField: "Platform", wantField: "Platform",
wantValue: "anthropic", wantValue: "",
}, },
} }

View File

@ -5,11 +5,10 @@ import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log/slog"
"net/http" "net/http"
"regexp" "regexp"
"strings" "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/handler/dto"
@ -175,6 +174,12 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
EnableFingerprintUnification: settings.EnableFingerprintUnification, EnableFingerprintUnification: settings.EnableFingerprintUnification,
EnableMetadataPassthrough: settings.EnableMetadataPassthrough, EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
EnableCCHSigning: settings.EnableCCHSigning, EnableCCHSigning: settings.EnableCCHSigning,
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails),
PaymentEnabled: paymentCfg.Enabled, PaymentEnabled: paymentCfg.Enabled,
PaymentMinAmount: paymentCfg.MinAmount, PaymentMinAmount: paymentCfg.MinAmount,
PaymentMaxAmount: paymentCfg.MaxAmount, PaymentMaxAmount: paymentCfg.MaxAmount,
@ -304,6 +309,13 @@ type UpdateSettingsRequest struct {
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"` EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
EnableCCHSigning *bool `json:"enable_cch_signing"` EnableCCHSigning *bool `json:"enable_cch_signing"`
// Balance low notification
BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"`
AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"`
AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"`
// Payment configuration (integrated into settings, full replace) // Payment configuration (integrated into settings, full replace)
PaymentEnabled *bool `json:"payment_enabled"` PaymentEnabled *bool `json:"payment_enabled"`
PaymentMinAmount *float64 `json:"payment_min_amount"` PaymentMinAmount *float64 `json:"payment_min_amount"`
@ -881,6 +893,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
} }
return previousSettings.EnableCCHSigning return previousSettings.EnableCCHSigning
}(), }(),
BalanceLowNotifyEnabled: func() bool {
if req.BalanceLowNotifyEnabled != nil {
return *req.BalanceLowNotifyEnabled
}
return previousSettings.BalanceLowNotifyEnabled
}(),
BalanceLowNotifyThreshold: func() float64 {
if req.BalanceLowNotifyThreshold != nil {
return *req.BalanceLowNotifyThreshold
}
return previousSettings.BalanceLowNotifyThreshold
}(),
BalanceLowNotifyRechargeURL: func() string {
if req.BalanceLowNotifyRechargeURL != nil {
return *req.BalanceLowNotifyRechargeURL
}
return previousSettings.BalanceLowNotifyRechargeURL
}(),
AccountQuotaNotifyEnabled: func() bool {
if req.AccountQuotaNotifyEnabled != nil {
return *req.AccountQuotaNotifyEnabled
}
return previousSettings.AccountQuotaNotifyEnabled
}(),
AccountQuotaNotifyEmails: func() []service.NotifyEmailEntry {
if req.AccountQuotaNotifyEmails != nil {
return dto.NotifyEmailEntriesToService(*req.AccountQuotaNotifyEmails)
}
return previousSettings.AccountQuotaNotifyEmails
}(),
} }
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
@ -1027,6 +1069,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification, EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough, EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
EnableCCHSigning: updatedSettings.EnableCCHSigning, EnableCCHSigning: updatedSettings.EnableCCHSigning,
BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL,
AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled,
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails),
PaymentEnabled: updatedPaymentCfg.Enabled, PaymentEnabled: updatedPaymentCfg.Enabled,
PaymentMinAmount: updatedPaymentCfg.MinAmount, PaymentMinAmount: updatedPaymentCfg.MinAmount,
PaymentMaxAmount: updatedPaymentCfg.MaxAmount, PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
@ -1073,11 +1120,11 @@ func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.Sys
subject, _ := middleware.GetAuthSubjectFromContext(c) subject, _ := middleware.GetAuthSubjectFromContext(c)
role, _ := middleware.GetUserRoleFromContext(c) role, _ := middleware.GetUserRoleFromContext(c)
log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v", slog.Info("settings updated",
time.Now().UTC().Format(time.RFC3339), "audit", true,
subject.UserID, "user_id", subject.UserID,
role, "role", role,
changed, "changed", changed,
) )
} }
@ -1092,6 +1139,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) { if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
changed = append(changed, "registration_email_suffix_whitelist") changed = append(changed, "registration_email_suffix_whitelist")
} }
if before.PromoCodeEnabled != after.PromoCodeEnabled {
changed = append(changed, "promo_code_enabled")
}
if before.InvitationCodeEnabled != after.InvitationCodeEnabled {
changed = append(changed, "invitation_code_enabled")
}
if before.PasswordResetEnabled != after.PasswordResetEnabled { if before.PasswordResetEnabled != after.PasswordResetEnabled {
changed = append(changed, "password_reset_enabled") changed = append(changed, "password_reset_enabled")
} }
@ -1302,6 +1355,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.CustomMenuItems != after.CustomMenuItems { if before.CustomMenuItems != after.CustomMenuItems {
changed = append(changed, "custom_menu_items") changed = append(changed, "custom_menu_items")
} }
if before.CustomEndpoints != after.CustomEndpoints {
changed = append(changed, "custom_endpoints")
}
if before.EnableFingerprintUnification != after.EnableFingerprintUnification { if before.EnableFingerprintUnification != after.EnableFingerprintUnification {
changed = append(changed, "enable_fingerprint_unification") changed = append(changed, "enable_fingerprint_unification")
} }
@ -1311,6 +1367,22 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EnableCCHSigning != after.EnableCCHSigning { if before.EnableCCHSigning != after.EnableCCHSigning {
changed = append(changed, "enable_cch_signing") changed = append(changed, "enable_cch_signing")
} }
// Balance & quota notification
if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled {
changed = append(changed, "balance_low_notify_enabled")
}
if before.BalanceLowNotifyThreshold != after.BalanceLowNotifyThreshold {
changed = append(changed, "balance_low_notify_threshold")
}
if before.BalanceLowNotifyRechargeURL != after.BalanceLowNotifyRechargeURL {
changed = append(changed, "balance_low_notify_recharge_url")
}
if before.AccountQuotaNotifyEnabled != after.AccountQuotaNotifyEnabled {
changed = append(changed, "account_quota_notify_enabled")
}
if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) {
changed = append(changed, "account_quota_notify_emails")
}
return changed return changed
} }
@ -1367,6 +1439,18 @@ func equalIntSlice(a, b []int) bool {
return true return true
} }
func equalNotifyEmailEntries(a, b []service.NotifyEmailEntry) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].Email != b[i].Email || a[i].Verified != b[i].Verified || a[i].Disabled != b[i].Disabled {
return false
}
}
return true
}
// TestSMTPRequest 测试SMTP连接请求 // TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct { type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host"` SMTPHost string `json:"smtp_host"`
@ -1847,3 +1931,80 @@ func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) {
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes, ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
}) })
} }
// GetWebSearchEmulationConfig 获取 Web Search 模拟配置
// GET /api/v1/admin/settings/web-search-emulation
func (h *SettingHandler) GetWebSearchEmulationConfig(c *gin.Context) {
cfg, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), cfg))
}
// UpdateWebSearchEmulationConfig 更新 Web Search 模拟配置
// PUT /api/v1/admin/settings/web-search-emulation
func (h *SettingHandler) UpdateWebSearchEmulationConfig(c *gin.Context) {
var cfg service.WebSearchEmulationConfig
if err := c.ShouldBindJSON(&cfg); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if err := h.settingService.SaveWebSearchEmulationConfig(c.Request.Context(), &cfg); err != nil {
response.ErrorFrom(c, err)
return
}
// Re-read (with sanitized api keys) to return current state
updated, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), updated))
}
// ResetWebSearchUsage 重置指定 provider 的配额用量
// POST /api/v1/admin/settings/web-search-emulation/reset-usage
func (h *SettingHandler) ResetWebSearchUsage(c *gin.Context) {
var req struct {
ProviderType string `json:"provider_type"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if req.ProviderType == "" {
response.BadRequest(c, "provider_type is required")
return
}
if err := service.ResetWebSearchUsage(c.Request.Context(), req.ProviderType); err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, nil)
}
// TestWebSearchEmulation 测试 Web Search 搜索
// POST /api/v1/admin/settings/web-search-emulation/test
func (h *SettingHandler) TestWebSearchEmulation(c *gin.Context) {
var req struct {
Query string `json:"query"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if strings.TrimSpace(req.Query) == "" {
req.Query = "搜索今年世界大事件"
}
result, err := service.TestWebSearch(c.Request.Context(), req.Query)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}

View File

@ -13,16 +13,21 @@ func UserFromServiceShallow(u *service.User) *User {
return nil return nil
} }
return &User{ return &User{
ID: u.ID, ID: u.ID,
Email: u.Email, Email: u.Email,
Username: u.Username, Username: u.Username,
Role: u.Role, Role: u.Role,
Balance: u.Balance, Balance: u.Balance,
Concurrency: u.Concurrency, Concurrency: u.Concurrency,
Status: u.Status, Status: u.Status,
AllowedGroups: u.AllowedGroups, AllowedGroups: u.AllowedGroups,
CreatedAt: u.CreatedAt, CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt, UpdatedAt: u.UpdatedAt,
BalanceNotifyEnabled: u.BalanceNotifyEnabled,
BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails),
TotalRecharged: u.TotalRecharged,
} }
} }
@ -322,6 +327,26 @@ func AccountFromServiceShallow(a *service.Account) *Account {
out.QuotaWeeklyResetAt = &v out.QuotaWeeklyResetAt = &v
} }
} }
// 配额通知配置
if enabled := a.GetQuotaNotifyDailyEnabled(); enabled {
out.QuotaNotifyDailyEnabled = &enabled
}
if threshold := a.GetQuotaNotifyDailyThreshold(); threshold > 0 {
out.QuotaNotifyDailyThreshold = &threshold
}
if enabled := a.GetQuotaNotifyWeeklyEnabled(); enabled {
out.QuotaNotifyWeeklyEnabled = &enabled
}
if threshold := a.GetQuotaNotifyWeeklyThreshold(); threshold > 0 {
out.QuotaNotifyWeeklyThreshold = &threshold
}
if enabled := a.GetQuotaNotifyTotalEnabled(); enabled {
out.QuotaNotifyTotalEnabled = &enabled
}
if threshold := a.GetQuotaNotifyTotalThreshold(); threshold > 0 {
out.QuotaNotifyTotalThreshold = &threshold
}
} }
return out return out
@ -603,6 +628,7 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
ModelMappingChain: l.ModelMappingChain, ModelMappingChain: l.ModelMappingChain,
BillingTier: l.BillingTier, BillingTier: l.BillingTier,
AccountRateMultiplier: l.AccountRateMultiplier, AccountRateMultiplier: l.AccountRateMultiplier,
AccountStatsCost: l.AccountStatsCost,
IPAddress: l.IPAddress, IPAddress: l.IPAddress,
Account: AccountSummaryFromService(l.Account), Account: AccountSummaryFromService(l.Account),
} }

View File

@ -0,0 +1,43 @@
package dto
import "github.com/Wei-Shaw/sub2api/internal/service"
// NotifyEmailEntry represents a notification email with enable/disable and verification state.
// All emails are user-managed; maximum 3 entries per user.
type NotifyEmailEntry struct {
Email string `json:"email"`
Disabled bool `json:"disabled"`
Verified bool `json:"verified"`
}
// NotifyEmailEntriesFromService converts service entries to DTO entries.
func NotifyEmailEntriesFromService(entries []service.NotifyEmailEntry) []NotifyEmailEntry {
if entries == nil {
return nil
}
result := make([]NotifyEmailEntry, len(entries))
for i, e := range entries {
result[i] = NotifyEmailEntry{
Email: e.Email,
Disabled: e.Disabled,
Verified: e.Verified,
}
}
return result
}
// NotifyEmailEntriesToService converts DTO entries to service entries.
func NotifyEmailEntriesToService(entries []NotifyEmailEntry) []service.NotifyEmailEntry {
if entries == nil {
return nil
}
result := make([]service.NotifyEmailEntry, len(entries))
for i, e := range entries {
result[i] = service.NotifyEmailEntry{
Email: e.Email,
Disabled: e.Disabled,
Verified: e.Verified,
}
}
return result
}

View File

@ -124,6 +124,9 @@ type SystemSettings struct {
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"` EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
EnableCCHSigning bool `json:"enable_cch_signing"` EnableCCHSigning bool `json:"enable_cch_signing"`
// Web Search Emulation
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
// Payment configuration // Payment configuration
PaymentEnabled bool `json:"payment_enabled"` PaymentEnabled bool `json:"payment_enabled"`
PaymentMinAmount float64 `json:"payment_min_amount"` PaymentMinAmount float64 `json:"payment_min_amount"`
@ -145,6 +148,13 @@ type SystemSettings struct {
PaymentCancelRateLimitWindow int `json:"payment_cancel_rate_limit_window"` PaymentCancelRateLimitWindow int `json:"payment_cancel_rate_limit_window"`
PaymentCancelRateLimitUnit string `json:"payment_cancel_rate_limit_unit"` PaymentCancelRateLimitUnit string `json:"payment_cancel_rate_limit_unit"`
PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"` PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"`
// Balance low notification
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"`
} }
type DefaultSubscriptionSetting struct { type DefaultSubscriptionSetting struct {
@ -183,6 +193,10 @@ type PublicSettings struct {
BackendModeEnabled bool `json:"backend_mode_enabled"` BackendModeEnabled bool `json:"backend_mode_enabled"`
PaymentEnabled bool `json:"payment_enabled"` PaymentEnabled bool `json:"payment_enabled"`
Version string `json:"version"` Version string `json:"version"`
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
} }
// OverloadCooldownSettings 529过载冷却配置 DTO // OverloadCooldownSettings 529过载冷却配置 DTO

View File

@ -18,6 +18,13 @@ type User struct {
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
// 余额不足通知
BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"`
TotalRecharged float64 `json:"total_recharged"`
APIKeys []APIKey `json:"api_keys,omitempty"` APIKeys []APIKey `json:"api_keys,omitempty"`
Subscriptions []UserSubscription `json:"subscriptions,omitempty"` Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
} }
@ -218,6 +225,14 @@ type Account struct {
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"` QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"` QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
// 配额通知配置
QuotaNotifyDailyEnabled *bool `json:"quota_notify_daily_enabled,omitempty"`
QuotaNotifyDailyThreshold *float64 `json:"quota_notify_daily_threshold,omitempty"`
QuotaNotifyWeeklyEnabled *bool `json:"quota_notify_weekly_enabled,omitempty"`
QuotaNotifyWeeklyThreshold *float64 `json:"quota_notify_weekly_threshold,omitempty"`
QuotaNotifyTotalEnabled *bool `json:"quota_notify_total_enabled,omitempty"`
QuotaNotifyTotalThreshold *float64 `json:"quota_notify_total_threshold,omitempty"`
Proxy *Proxy `json:"proxy,omitempty"` Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"`
@ -412,6 +427,8 @@ type AdminUsageLog struct {
// AccountRateMultiplier 账号计费倍率快照nil 表示按 1.0 处理) // AccountRateMultiplier 账号计费倍率快照nil 表示按 1.0 处理)
AccountRateMultiplier *float64 `json:"account_rate_multiplier"` AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
// AccountStatsCost 自定义定价规则计算的账号统计费用nil 表示使用默认公式)
AccountStatsCost *float64 `json:"account_stats_cost,omitempty"`
// IPAddress 用户请求 IP仅管理员可见 // IPAddress 用户请求 IP仅管理员可见
IPAddress *string `json:"ip_address,omitempty"` IPAddress *string `json:"ip_address,omitempty"`

View File

@ -248,6 +248,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return return
} }
// 设置请求所属分组 ID用于渠道级功能判断如 WebSearch 模拟)
parsedReq.GroupID = apiKey.GroupID
// 计算粘性会话hash // 计算粘性会话hash
parsedReq.SessionContext = &service.SessionContext{ parsedReq.SessionContext = &service.SessionContext{
ClientIP: ip.GetClientIP(c), ClientIP: ip.GetClientIP(c),
@ -470,6 +473,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) { h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
ParsedRequest: parsedReq,
APIKey: apiKey, APIKey: apiKey,
User: apiKey.User, User: apiKey.User,
Account: account, Account: account,
@ -518,7 +522,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for { for {
// 选择支持该模型的账号 // 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, int64(0)) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
if err != nil { if err != nil {
if len(fs.FailedAccountIDs) == 0 { if len(fs.FailedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@ -672,6 +676,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
} }
// 转发请求 - 根据账号平台分流 // 转发请求 - 根据账号平台分流
c.Set("parsed_request", parsedReq)
var result *service.ForwardResult var result *service.ForwardResult
requestCtx := c.Request.Context() requestCtx := c.Request.Context()
if fs.SwitchCount > 0 { if fs.SwitchCount > 0 {
@ -810,6 +815,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.submitUsageRecordTask(func(ctx context.Context) { h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result, Result: result,
ParsedRequest: parsedReq,
APIKey: currentAPIKey, APIKey: currentAPIKey,
User: currentAPIKey.User, User: currentAPIKey.User,
Account: account, Account: account,

View File

@ -168,6 +168,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // tlsFPProfileService nil, // tlsFPProfileService
nil, // channelService nil, // channelService
nil, // resolver nil, // resolver
nil, // balanceNotifyService
) )
// RunModeSimple跳过计费检查避免引入 repo/cache 依赖。 // RunModeSimple跳过计费检查避免引入 repo/cache 依赖。

View File

@ -335,6 +335,16 @@ func (h *PaymentHandler) RequestRefund(c *gin.Context) {
response.Success(c, gin.H{"message": "refund requested"}) response.Success(c, gin.H{"message": "refund requested"})
} }
// GetRefundEligibleProviders returns provider instance IDs that allow user refund.
func (h *PaymentHandler) GetRefundEligibleProviders(c *gin.Context) {
ids, err := h.configService.GetUserRefundEligibleInstanceIDs(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"provider_instance_ids": ids})
}
// VerifyOrderRequest is the request body for verifying a payment order. // VerifyOrderRequest is the request body for verifying a payment order.
type VerifyOrderRequest struct { type VerifyOrderRequest struct {
OutTradeNo string `json:"out_trade_no" binding:"required"` OutTradeNo string `json:"out_trade_no" binding:"required"`

View File

@ -61,5 +61,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
BackendModeEnabled: settings.BackendModeEnabled, BackendModeEnabled: settings.BackendModeEnabled,
PaymentEnabled: settings.PaymentEnabled, PaymentEnabled: settings.PaymentEnabled,
Version: h.version, Version: h.version,
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
}) })
} }

View File

@ -11,13 +11,17 @@ import (
// UserHandler handles user-related requests // UserHandler handles user-related requests
type UserHandler struct { type UserHandler struct {
userService *service.UserService userService *service.UserService
emailService *service.EmailService
emailCache service.EmailCache
} }
// NewUserHandler creates a new UserHandler // NewUserHandler creates a new UserHandler
func NewUserHandler(userService *service.UserService) *UserHandler { func NewUserHandler(userService *service.UserService, emailService *service.EmailService, emailCache service.EmailCache) *UserHandler {
return &UserHandler{ return &UserHandler{
userService: userService, userService: userService,
emailService: emailService,
emailCache: emailCache,
} }
} }
@ -29,7 +33,9 @@ type ChangePasswordRequest struct {
// UpdateProfileRequest represents the update profile request payload // UpdateProfileRequest represents the update profile request payload
type UpdateProfileRequest struct { type UpdateProfileRequest struct {
Username *string `json:"username"` Username *string `json:"username"`
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
} }
// GetProfile handles getting user profile // GetProfile handles getting user profile
@ -94,7 +100,9 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
} }
svcReq := service.UpdateProfileRequest{ svcReq := service.UpdateProfileRequest{
Username: req.Username, Username: req.Username,
BalanceNotifyEnabled: req.BalanceNotifyEnabled,
BalanceNotifyThreshold: req.BalanceNotifyThreshold,
} }
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq) updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
if err != nil { if err != nil {
@ -104,3 +112,141 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
response.Success(c, dto.UserFromService(updatedUser)) response.Success(c, dto.UserFromService(updatedUser))
} }
// SendNotifyEmailCodeRequest represents the request to send notify email verification code
type SendNotifyEmailCodeRequest struct {
Email string `json:"email" binding:"required,email"`
}
// SendNotifyEmailCode sends verification code to extra notification email
// POST /api/v1/user/notify-email/send-code
func (h *UserHandler) SendNotifyEmailCode(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req SendNotifyEmailCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{"message": "Verification code sent successfully"})
}
// VerifyNotifyEmailRequest represents the request to verify and add notify email
type VerifyNotifyEmailRequest struct {
Email string `json:"email" binding:"required,email"`
Code string `json:"code" binding:"required,len=6"`
}
// VerifyNotifyEmail verifies code and adds email to notification list
// POST /api/v1/user/notify-email/verify
func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req VerifyNotifyEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.userService.VerifyAndAddNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Code, h.emailCache)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Return updated user
updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromService(updatedUser))
}
// RemoveNotifyEmailRequest represents the request to remove a notify email
type RemoveNotifyEmailRequest struct {
Email string `json:"email" binding:"required,email"`
}
// RemoveNotifyEmail removes email from notification list
// DELETE /api/v1/user/notify-email
func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req RemoveNotifyEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.userService.RemoveNotifyEmail(c.Request.Context(), subject.UserID, req.Email)
if err != nil {
response.ErrorFrom(c, err)
return
}
// Return updated user
updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromService(updatedUser))
}
// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
type ToggleNotifyEmailRequest struct {
Email string `json:"email" binding:"required,email"`
Disabled bool `json:"disabled"`
}
// ToggleNotifyEmail toggles the disabled state of a notification email
// PUT /api/v1/user/notify-email/toggle
func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) {
subject, ok := middleware2.GetAuthSubjectFromContext(c)
if !ok {
response.Unauthorized(c, "User not authenticated")
return
}
var req ToggleNotifyEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
err := h.userService.ToggleNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Disabled)
if err != nil {
response.ErrorFrom(c, err)
return
}
updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.UserFromService(updatedUser))
}

View File

@ -117,7 +117,13 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
var matched []*dbent.PaymentProviderInstance var matched []*dbent.PaymentProviderInstance
for _, inst := range instances { for _, inst := range instances {
if InstanceSupportsType(inst.SupportedTypes, paymentType) { // Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay),
// not "stripe" itself. The checkout page aggregates all sub-types under "stripe".
if paymentType == TypeStripe {
if inst.ProviderKey == TypeStripe {
matched = append(matched, inst)
}
} else if InstanceSupportsType(inst.SupportedTypes, paymentType) {
matched = append(matched, inst) matched = append(matched, inst)
} }
} }

View File

@ -242,7 +242,7 @@ func TestFilterByLimits(t *testing.T) {
wantIDs: nil, wantIDs: nil,
}, },
{ {
name: "empty candidates returns empty", name: "empty candidates returns empty",
candidates: nil, candidates: nil,
paymentType: "alipay", paymentType: "alipay",
orderAmount: 10, orderAmount: 10,

View File

@ -98,9 +98,9 @@ func TestNewAlipay(t *testing.T) {
errSubstr: "privateKey", errSubstr: "privateKey",
}, },
{ {
name: "nil config map returns error for appId", name: "nil config map returns error for appId",
config: map[string]string{}, config: map[string]string{},
wantErr: true, wantErr: true,
errSubstr: "appId", errSubstr: "appId",
}, },
} }

View File

@ -18,6 +18,9 @@ const (
BlockTypeFunction BlockTypeFunction
) )
// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
type UsageMapHook func(usageMap map[string]any)
// StreamingProcessor 流式响应处理器 // StreamingProcessor 流式响应处理器
type StreamingProcessor struct { type StreamingProcessor struct {
blockType BlockType blockType BlockType
@ -30,6 +33,7 @@ type StreamingProcessor struct {
originalModel string originalModel string
webSearchQueries []string webSearchQueries []string
groundingChunks []GeminiGroundingChunk groundingChunks []GeminiGroundingChunk
usageMapHook UsageMapHook
// 累计 usage // 累计 usage
inputTokens int inputTokens int
@ -46,6 +50,28 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
} }
} }
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
p.usageMapHook = fn
}
func usageToMap(u ClaudeUsage) map[string]any {
m := map[string]any{
"input_tokens": u.InputTokens,
"output_tokens": u.OutputTokens,
}
if u.CacheCreationInputTokens > 0 {
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
}
if u.CacheReadInputTokens > 0 {
m["cache_read_input_tokens"] = u.CacheReadInputTokens
}
if u.ImageOutputTokens > 0 {
m["image_output_tokens"] = u.ImageOutputTokens
}
return m
}
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件 // ProcessLine 处理 SSE 行,返回 Claude SSE 事件
func (p *StreamingProcessor) ProcessLine(line string) []byte { func (p *StreamingProcessor) ProcessLine(line string) []byte {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
@ -172,6 +198,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
responseID = "msg_" + generateRandomID() responseID = "msg_" + generateRandomID()
} }
var usageValue any = usage
if p.usageMapHook != nil {
usageMap := usageToMap(usage)
p.usageMapHook(usageMap)
usageValue = usageMap
}
message := map[string]any{ message := map[string]any{
"id": responseID, "id": responseID,
"type": "message", "type": "message",
@ -180,7 +213,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
"model": p.originalModel, "model": p.originalModel,
"stop_reason": nil, "stop_reason": nil,
"stop_sequence": nil, "stop_sequence": nil,
"usage": usage, "usage": usageValue,
} }
event := map[string]any{ event := map[string]any{
@ -492,13 +525,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
ImageOutputTokens: p.imageOutputTokens, ImageOutputTokens: p.imageOutputTokens,
} }
var usageValue any = usage
if p.usageMapHook != nil {
usageMap := usageToMap(usage)
p.usageMapHook(usageMap)
usageValue = usageMap
}
deltaEvent := map[string]any{ deltaEvent := map[string]any{
"type": "message_delta", "type": "message_delta",
"delta": map[string]any{ "delta": map[string]any{
"stop_reason": stopReason, "stop_reason": stopReason,
"stop_sequence": nil, "stop_sequence": nil,
}, },
"usage": usage, "usage": usageValue,
} }
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent)) _, _ = result.Write(p.formatSSE("message_delta", deltaEvent))

View File

@ -27,13 +27,14 @@ func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest,
} }
out := &ResponsesRequest{ out := &ResponsesRequest{
Model: req.Model, Model: req.Model,
Input: inputJSON, Instructions: req.Instructions,
Temperature: req.Temperature, Input: inputJSON,
TopP: req.TopP, Temperature: req.Temperature,
Stream: true, // upstream always streams TopP: req.TopP,
Include: []string{"reasoning.encrypted_content"}, Stream: true, // upstream always streams
ServiceTier: req.ServiceTier, Include: []string{"reasoning.encrypted_content"},
ServiceTier: req.ServiceTier,
} }
storeFalse := false storeFalse := false

View File

@ -152,6 +152,7 @@ type AnthropicDelta struct {
// ResponsesRequest is the request body for POST /v1/responses. // ResponsesRequest is the request body for POST /v1/responses.
type ResponsesRequest struct { type ResponsesRequest struct {
Model string `json:"model"` Model string `json:"model"`
Instructions string `json:"instructions,omitempty"`
Input json.RawMessage `json:"input"` // string or []ResponsesInputItem Input json.RawMessage `json:"input"` // string or []ResponsesInputItem
MaxOutputTokens *int `json:"max_output_tokens,omitempty"` MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`
@ -337,6 +338,7 @@ type ResponsesStreamEvent struct {
type ChatCompletionsRequest struct { type ChatCompletionsRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []ChatMessage `json:"messages"` Messages []ChatMessage `json:"messages"`
Instructions string `json:"instructions,omitempty"` // OpenAI Responses API compat
MaxTokens *int `json:"max_tokens,omitempty"` MaxTokens *int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature *float64 `json:"temperature,omitempty"`

View File

@ -10,7 +10,13 @@ import (
) )
func TestInit_DualOutput(t *testing.T) { func TestInit_DualOutput(t *testing.T) {
tmpDir := t.TempDir() // Use os.MkdirTemp instead of t.TempDir to avoid cleanup failures
// when lumberjack holds file handles on Windows.
tmpDir, err := os.MkdirTemp("", "logger-test-*")
if err != nil {
t.Fatalf("create temp dir: %v", err)
}
t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
logPath := filepath.Join(tmpDir, "logs", "sub2api.log") logPath := filepath.Join(tmpDir, "logs", "sub2api.log")
origStdout := os.Stdout origStdout := os.Stdout
@ -57,7 +63,9 @@ func TestInit_DualOutput(t *testing.T) {
L().Info("dual-output-info") L().Info("dual-output-info")
L().Warn("dual-output-warn") L().Warn("dual-output-warn")
Sync()
// Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
// The log data is already in the pipe buffer; closing writers is sufficient.
_ = stdoutW.Close() _ = stdoutW.Close()
_ = stderrW.Close() _ = stderrW.Close()
@ -166,7 +174,9 @@ func TestInit_CallerShouldPointToCallsite(t *testing.T) {
} }
L().Info("caller-check") L().Info("caller-check")
Sync() // Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
os.Stdout = origStdout
os.Stderr = origStderr
_ = stdoutW.Close() _ = stdoutW.Close()
logBytes, _ := io.ReadAll(stdoutR) logBytes, _ := io.ReadAll(stdoutR)

View File

@ -77,7 +77,7 @@ func TestStdLogBridgeRoutesLevels(t *testing.T) {
log.Printf("service started") log.Printf("service started")
log.Printf("Warning: queue full") log.Printf("Warning: queue full")
log.Printf("Forward request failed: timeout") log.Printf("Forward request failed: timeout")
Sync() // Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
_ = stdoutW.Close() _ = stdoutW.Close()
_ = stderrW.Close() _ = stderrW.Close()
@ -139,7 +139,7 @@ func TestLegacyPrintfRoutesLevels(t *testing.T) {
LegacyPrintf("service.test", "request started") LegacyPrintf("service.test", "request started")
LegacyPrintf("service.test", "Warning: queue full") LegacyPrintf("service.test", "Warning: queue full")
LegacyPrintf("service.test", "forward failed: timeout") LegacyPrintf("service.test", "forward failed: timeout")
Sync() // Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
_ = stdoutW.Close() _ = stdoutW.Close()
_ = stderrW.Close() _ = stderrW.Close()

View File

@ -0,0 +1,106 @@
package websearch
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
)
const (
braveSearchEndpoint = "https://api.search.brave.com/res/v1/web/search"
braveMaxCount = 20
braveProviderName = "brave"
)
// braveSearchURL is pre-parsed at init time; url.Parse cannot fail on a constant literal.
var braveSearchURL, _ = url.Parse(braveSearchEndpoint) //nolint:errcheck
// BraveProvider implements web search via the Brave Search API.
type BraveProvider struct {
apiKey string
httpClient *http.Client
}
// NewBraveProvider creates a Brave Search provider.
// The caller is responsible for configuring the http.Client with proxy/timeouts.
func NewBraveProvider(apiKey string, httpClient *http.Client) *BraveProvider {
if httpClient == nil {
httpClient = http.DefaultClient
}
return &BraveProvider{apiKey: apiKey, httpClient: httpClient}
}
func (b *BraveProvider) Name() string { return braveProviderName }
func (b *BraveProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
count := req.MaxResults
if count <= 0 {
count = defaultMaxResults
}
if count > braveMaxCount {
count = braveMaxCount
}
u := *braveSearchURL // copy the pre-parsed URL
q := u.Query()
q.Set("q", req.Query)
q.Set("count", strconv.Itoa(count))
u.RawQuery = q.Encode()
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return nil, fmt.Errorf("brave: build request: %w", err)
}
httpReq.Header.Set("X-Subscription-Token", b.apiKey)
httpReq.Header.Set("Accept", "application/json")
resp, err := b.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("brave: request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
if err != nil {
return nil, fmt.Errorf("brave: read body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("brave: status %d: %s", resp.StatusCode, truncateBody(body))
}
var raw braveResponse
if err := json.Unmarshal(body, &raw); err != nil {
return nil, fmt.Errorf("brave: decode response: %w", err)
}
results := make([]SearchResult, 0, len(raw.Web.Results))
for _, r := range raw.Web.Results {
results = append(results, SearchResult{
URL: r.URL,
Title: r.Title,
Snippet: r.Description,
PageAge: r.Age,
})
}
return &SearchResponse{Results: results, Query: req.Query}, nil
}
// braveResponse is the minimal structure of the Brave Search API response.
type braveResponse struct {
Web struct {
Results []braveResult `json:"results"`
} `json:"web"`
}
type braveResult struct {
URL string `json:"url"`
Title string `json:"title"`
Description string `json:"description"`
Age string `json:"age"`
}

View File

@ -0,0 +1,119 @@
package websearch
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestBraveProvider_Name(t *testing.T) {
p := NewBraveProvider("key", nil)
require.Equal(t, "brave", p.Name())
}
func TestBraveProvider_Search_Success(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "test-key", r.Header.Get("X-Subscription-Token"))
require.Equal(t, "application/json", r.Header.Get("Accept"))
require.Equal(t, "golang", r.URL.Query().Get("q"))
require.Equal(t, "3", r.URL.Query().Get("count"))
resp := braveResponse{}
resp.Web.Results = []braveResult{
{URL: "https://go.dev", Title: "Go", Description: "Go lang", Age: "1 day"},
{URL: "https://pkg.go.dev", Title: "Pkg", Description: "Packages"},
{URL: "https://tour.go.dev", Title: "Tour", Description: "A Tour of Go", Age: "3 days"},
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
p := NewBraveProvider("test-key", srv.Client())
// Override the endpoint for testing
origURL := *braveSearchURL
u, _ := http.NewRequest("GET", srv.URL, nil)
*braveSearchURL = *u.URL
defer func() { *braveSearchURL = origURL }()
resp, err := p.Search(context.Background(), SearchRequest{Query: "golang", MaxResults: 3})
require.NoError(t, err)
require.Len(t, resp.Results, 3)
require.Equal(t, "https://go.dev", resp.Results[0].URL)
require.Equal(t, "Go lang", resp.Results[0].Snippet)
require.Equal(t, "1 day", resp.Results[0].PageAge)
}
func TestBraveProvider_Search_DefaultMaxResults(t *testing.T) {
var receivedCount string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedCount = r.URL.Query().Get("count")
resp := braveResponse{}
_ = json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
p := NewBraveProvider("key", srv.Client())
origURL := *braveSearchURL
u, _ := http.NewRequest("GET", srv.URL, nil)
*braveSearchURL = *u.URL
defer func() { *braveSearchURL = origURL }()
_, _ = p.Search(context.Background(), SearchRequest{Query: "test", MaxResults: 0})
require.Equal(t, "5", receivedCount)
}
func TestBraveProvider_Search_HTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(429)
_, _ = w.Write([]byte("rate limited"))
}))
defer srv.Close()
p := NewBraveProvider("key", srv.Client())
origURL := *braveSearchURL
u, _ := http.NewRequest("GET", srv.URL, nil)
*braveSearchURL = *u.URL
defer func() { *braveSearchURL = origURL }()
_, err := p.Search(context.Background(), SearchRequest{Query: "test"})
require.ErrorContains(t, err, "brave: status 429")
}
func TestBraveProvider_Search_InvalidJSON(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("not json"))
}))
defer srv.Close()
p := NewBraveProvider("key", srv.Client())
origURL := *braveSearchURL
u, _ := http.NewRequest("GET", srv.URL, nil)
*braveSearchURL = *u.URL
defer func() { *braveSearchURL = origURL }()
_, err := p.Search(context.Background(), SearchRequest{Query: "test"})
require.ErrorContains(t, err, "brave: decode response")
}
func TestBraveProvider_Search_EmptyResults(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
resp := braveResponse{}
_ = json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
p := NewBraveProvider("key", srv.Client())
origURL := *braveSearchURL
u, _ := http.NewRequest("GET", srv.URL, nil)
*braveSearchURL = *u.URL
defer func() { *braveSearchURL = origURL }()
resp, err := p.Search(context.Background(), SearchRequest{Query: "test"})
require.NoError(t, err)
require.Empty(t, resp.Results)
}

View File

@ -0,0 +1,14 @@
package websearch
const (
maxResponseSize = 1 << 20 // 1 MB
errorBodyTruncLen = 200
)
// truncateBody returns a truncated string of body for error messages.
func truncateBody(body []byte) string {
if len(body) <= errorBodyTruncLen {
return string(body)
}
return string(body[:errorBodyTruncLen]) + "...(truncated)"
}

View File

@ -0,0 +1,25 @@
package websearch
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestTruncateBody_Short(t *testing.T) {
body := []byte("short body")
require.Equal(t, "short body", truncateBody(body))
}
func TestTruncateBody_Long(t *testing.T) {
body := []byte(strings.Repeat("x", 500))
result := truncateBody(body)
require.Len(t, result, errorBodyTruncLen+len("...(truncated)"))
require.True(t, strings.HasSuffix(result, "...(truncated)"))
}
func TestTruncateBody_ExactBoundary(t *testing.T) {
body := []byte(strings.Repeat("x", errorBodyTruncLen))
require.Equal(t, string(body), truncateBody(body))
}

View File

@ -0,0 +1,528 @@
package websearch
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log/slog"
"math/rand"
"net"
"net/http"
"net/url"
"sort"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/redis/go-redis/v9"
)
// ProviderConfig holds the configuration for a single search provider.
type ProviderConfig struct {
Type string `json:"type"` // ProviderTypeBrave | ProviderTypeTavily
APIKey string `json:"api_key"` // secret
QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
SubscribedAt *int64 `json:"subscribed_at,omitempty"` // subscription start (unix seconds); quota resets monthly from this date
ProxyURL string `json:"-"` // resolved proxy URL (not persisted)
ProxyID int64 `json:"-"` // resolved proxy ID for unavailability tracking
ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds)
}
// Manager selects providers by quota-weighted load balancing and tracks quota via Redis.
type Manager struct {
configs []ProviderConfig
redis *redis.Client
clientMu sync.Mutex
clientCache map[string]*http.Client
}
// Timeout constants for proxy and search operations.
const (
proxyDialTimeout = 3 * time.Second // proxy TCP connection timeout
proxyTLSTimeout = 3 * time.Second // TLS handshake timeout
searchDataTimeout = 60 * time.Second // response data transfer timeout
searchRequestTimeout = searchDataTimeout + proxyDialTimeout
quotaKeyPrefix = "websearch:quota:"
proxyUnavailableKey = "websearch:proxy_unavailable:%d"
proxyUnavailableTTL = 5 * time.Minute
quotaTTLBuffer = 24 * time.Hour
defaultQuotaTTL = 31*24*time.Hour + quotaTTLBuffer // fallback when no subscription date
maxCachedClients = 100
)
// ErrProxyUnavailable indicates the search failed due to a proxy connectivity issue.
// Callers may use this to trigger account switching instead of direct fallback.
var ErrProxyUnavailable = errors.New("websearch: proxy unavailable")
// quotaIncrScript atomically increments the counter and sets TTL on first creation.
var quotaIncrScript = redis.NewScript(`
local val = redis.call('INCR', KEYS[1])
if val == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[1])
else
local ttl = redis.call('TTL', KEYS[1])
if ttl == -1 then
redis.call('EXPIRE', KEYS[1], ARGV[1])
end
end
return val
`)
// NewManager creates a Manager with the given provider configs and Redis client.
// Provider order is preserved as-is; selectByQuotaWeight handles load balancing.
func NewManager(configs []ProviderConfig, redisClient *redis.Client) *Manager {
copied := make([]ProviderConfig, len(configs))
copy(copied, configs)
return &Manager{
configs: copied,
redis: redisClient,
clientCache: make(map[string]*http.Client),
}
}
// SearchWithBestProvider selects a provider using quota-weighted load balancing,
// reserves quota, executes the search, and rolls back quota on failure.
// If the search fails due to a proxy error, the proxy is marked unavailable for 5 minutes.
func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
if strings.TrimSpace(req.Query) == "" {
return nil, "", fmt.Errorf("websearch: empty search query")
}
candidates := m.filterAvailableProviders(ctx, req.ProxyURL)
if len(candidates) == 0 {
return nil, "", fmt.Errorf("websearch: no available provider (all exhausted, expired, or proxy unavailable)")
}
selected := m.selectByQuotaWeight(ctx, candidates)
for _, cfg := range selected {
allowed, incremented := m.tryReserveQuota(ctx, cfg)
if !allowed {
continue
}
resp, err := m.executeSearch(ctx, cfg, req)
if err != nil {
if incremented {
m.rollbackQuota(ctx, cfg)
}
if isProxyError(err) {
m.markProxyUnavailable(ctx, cfg, req.ProxyURL)
if req.ProxyURL != "" {
// Account-level proxy is shared by all providers — no point
// trying others with the same broken proxy; signal account switch.
slog.Warn("websearch: account proxy error, aborting failover",
"provider", cfg.Type, "error", err)
return nil, "", fmt.Errorf("%w: %s", ErrProxyUnavailable, err.Error())
}
// Provider-specific proxy failed — try the next provider which
// may use a different (or no) proxy.
slog.Warn("websearch: provider proxy error, trying next provider",
"provider", cfg.Type, "error", err)
continue
}
slog.Warn("websearch: provider search failed",
"provider", cfg.Type, "error", err)
continue
}
return resp, cfg.Type, nil
}
return nil, "", fmt.Errorf("websearch: no available provider (all exhausted or failed)")
}
// filterAvailableProviders returns providers that have API keys, are not expired,
// and whose proxies are not marked unavailable.
func (m *Manager) filterAvailableProviders(ctx context.Context, accountProxyURL string) []ProviderConfig {
var out []ProviderConfig
for _, cfg := range m.configs {
if !m.isProviderAvailable(cfg) {
continue
}
proxyID := resolveProxyID(cfg, accountProxyURL)
if proxyID > 0 && !m.isProxyAvailable(ctx, proxyID) {
slog.Debug("websearch: proxy marked unavailable, skipping",
"provider", cfg.Type, "proxy_id", proxyID)
continue
}
out = append(out, cfg)
}
return out
}
// weighted is a provider candidate with computed quota weight.
type weighted struct {
cfg ProviderConfig
weight int64
}
// selectByQuotaWeight orders candidates by remaining quota weight.
// Providers with quota_limit=0 (no limit set) get weight 0 and are placed last.
// Among providers with quota, higher remaining quota = higher priority.
func (m *Manager) selectByQuotaWeight(ctx context.Context, candidates []ProviderConfig) []ProviderConfig {
items := m.computeWeights(ctx, candidates)
withQuota, withoutQuota := partitionByQuota(items)
sortByStableRandomWeight(withQuota)
return mergeWeightedResults(withQuota, withoutQuota, len(candidates))
}
func (m *Manager) computeWeights(ctx context.Context, candidates []ProviderConfig) []weighted {
items := make([]weighted, 0, len(candidates))
for _, cfg := range candidates {
w := int64(0)
if cfg.QuotaLimit > 0 {
used, _ := m.GetUsage(ctx, cfg.Type)
if remaining := cfg.QuotaLimit - used; remaining > 0 {
w = remaining
}
}
items = append(items, weighted{cfg: cfg, weight: w})
}
return items
}
func partitionByQuota(items []weighted) (withQuota, withoutQuota []weighted) {
for _, item := range items {
if item.weight > 0 {
withQuota = append(withQuota, item)
} else {
withoutQuota = append(withoutQuota, item)
}
}
return
}
// sortByStableRandomWeight assigns a fixed random factor to each item before sorting,
// ensuring deterministic sort behavior (transitivity) within a single call.
func sortByStableRandomWeight(items []weighted) {
if len(items) <= 1 {
return
}
type entry struct {
item weighted
factor float64
}
entries := make([]entry, len(items))
for i, item := range items {
entries[i] = entry{item: item, factor: float64(item.weight) * (0.5 + rand.Float64())}
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].factor > entries[j].factor
})
for i, e := range entries {
items[i] = e.item
}
}
func mergeWeightedResults(withQuota, withoutQuota []weighted, capacity int) []ProviderConfig {
result := make([]ProviderConfig, 0, capacity)
for _, item := range withQuota {
result = append(result, item.cfg)
}
for _, item := range withoutQuota {
result = append(result, item.cfg)
}
return result
}
func (m *Manager) isProviderAvailable(cfg ProviderConfig) bool {
if cfg.APIKey == "" {
return false
}
if cfg.ExpiresAt != nil && time.Now().Unix() > *cfg.ExpiresAt {
slog.Info("websearch: provider expired, skipping",
"provider", cfg.Type, "expires_at", *cfg.ExpiresAt)
return false
}
return true
}
// --- Proxy availability tracking ---
// markProxyUnavailable marks the effective proxy as unavailable for proxyUnavailableTTL.
func (m *Manager) markProxyUnavailable(ctx context.Context, cfg ProviderConfig, accountProxyURL string) {
proxyID := resolveProxyID(cfg, accountProxyURL)
if proxyID <= 0 || m.redis == nil {
return
}
key := fmt.Sprintf(proxyUnavailableKey, proxyID)
if err := m.redis.Set(ctx, key, "1", proxyUnavailableTTL).Err(); err != nil {
slog.Warn("websearch: failed to mark proxy unavailable",
"proxy_id", proxyID, "error", err)
}
}
// isProxyAvailable checks whether a proxy is currently marked as unavailable.
func (m *Manager) isProxyAvailable(ctx context.Context, proxyID int64) bool {
if m.redis == nil || proxyID <= 0 {
return true
}
key := fmt.Sprintf(proxyUnavailableKey, proxyID)
val, err := m.redis.Get(ctx, key).Result()
if err != nil {
return true // Redis error → assume available
}
return val == ""
}
// resolveProxyID determines the effective proxy ID for a provider+account combination.
func resolveProxyID(cfg ProviderConfig, accountProxyURL string) int64 {
if accountProxyURL != "" {
return 0 // account proxy has no ID in provider config
}
return cfg.ProxyID
}
// isProxyError checks whether the error is likely caused by proxy or network connectivity
// (as opposed to an API-level error from the search provider).
func isProxyError(err error) bool {
if err == nil {
return false
}
// Network-level errors (timeout, connection refused, DNS failure)
var netErr net.Error
if errors.As(err, &netErr) {
return true
}
var opErr *net.OpError
if errors.As(err, &opErr) {
return true
}
// TLS handshake failures (often caused by proxy intercepting/blocking)
var tlsErr *tls.RecordHeaderError
if errors.As(err, &tlsErr) {
return true
}
// String-based detection for wrapped errors
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "proxy") ||
strings.Contains(msg, "socks") ||
strings.Contains(msg, "connection refused") ||
strings.Contains(msg, "no such host") ||
strings.Contains(msg, "i/o timeout") ||
strings.Contains(msg, "tls handshake") ||
strings.Contains(msg, "certificate")
}
// --- Quota management ---
func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool, bool) {
if cfg.QuotaLimit <= 0 {
return true, false
}
if m.redis == nil {
slog.Warn("websearch: Redis unavailable, quota check skipped", "provider", cfg.Type)
return true, false
}
key := quotaRedisKey(cfg.Type)
ttlSec := int(quotaTTLFromSubscription(cfg.SubscribedAt).Seconds())
newVal, err := quotaIncrScript.Run(ctx, m.redis, []string{key}, ttlSec).Int64()
if err != nil {
slog.Warn("websearch: quota Lua INCR failed, allowing request",
"provider", cfg.Type, "error", err)
return true, false
}
if newVal > cfg.QuotaLimit {
if decrErr := m.redis.Decr(ctx, key).Err(); decrErr != nil {
slog.Warn("websearch: quota over-limit DECR failed",
"provider", cfg.Type, "error", decrErr)
}
slog.Info("websearch: provider quota exhausted",
"provider", cfg.Type, "used", newVal, "limit", cfg.QuotaLimit)
return false, false
}
return true, true
}
func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) {
if cfg.QuotaLimit <= 0 || m.redis == nil {
return
}
key := quotaRedisKey(cfg.Type)
if err := m.redis.Decr(ctx, key).Err(); err != nil {
slog.Warn("websearch: quota rollback DECR failed",
"provider", cfg.Type, "error", err)
}
}
// --- Search execution ---
// TestSearch executes a search using the first available provider without reserving quota.
// Intended for admin test functionality only.
func (m *Manager) TestSearch(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
if strings.TrimSpace(req.Query) == "" {
return nil, "", fmt.Errorf("websearch: empty search query")
}
for _, cfg := range m.configs {
if !m.isProviderAvailable(cfg) {
continue
}
resp, err := m.executeSearch(ctx, cfg, req)
if err != nil {
continue
}
return resp, cfg.Type, nil
}
return nil, "", fmt.Errorf("websearch: no available provider")
}
func (m *Manager) executeSearch(ctx context.Context, cfg ProviderConfig, req SearchRequest) (*SearchResponse, error) {
proxyURL := cfg.ProxyURL
if req.ProxyURL != "" {
proxyURL = req.ProxyURL
}
client, err := m.getOrCreateHTTPClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("websearch: %w", err)
}
provider := m.buildProvider(cfg, client)
return provider.Search(ctx, req)
}
// --- HTTP client cache ---
func (m *Manager) getOrCreateHTTPClient(proxyURL string) (*http.Client, error) {
m.clientMu.Lock()
defer m.clientMu.Unlock()
if c, ok := m.clientCache[proxyURL]; ok {
return c, nil
}
if len(m.clientCache) >= maxCachedClients {
m.clientCache = make(map[string]*http.Client)
}
c, err := newHTTPClient(proxyURL)
if err != nil {
return nil, err
}
m.clientCache[proxyURL] = c
return c, nil
}
// newHTTPClient creates an HTTP client with proper timeout settings.
// Uses proxyutil.ConfigureTransportProxy for unified proxy protocol support
// (HTTP/HTTPS/SOCKS5/SOCKS5H).
// Returns error if proxyURL is invalid — never falls back to direct connection.
func newHTTPClient(proxyURL string) (*http.Client, error) {
transport := &http.Transport{
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
DialContext: (&net.Dialer{Timeout: proxyDialTimeout}).DialContext,
TLSHandshakeTimeout: proxyTLSTimeout,
ResponseHeaderTimeout: searchDataTimeout,
}
if proxyURL != "" {
parsed, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("invalid proxy URL %q: %w", proxyURL, err)
}
if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
return nil, fmt.Errorf("configure proxy: %w", err)
}
}
return &http.Client{Transport: transport, Timeout: searchRequestTimeout}, nil
}
// GetUsage returns the current usage count for the given provider.
func (m *Manager) GetUsage(ctx context.Context, providerType string) (int64, error) {
if m.redis == nil {
return 0, nil
}
key := quotaRedisKey(providerType)
val, err := m.redis.Get(ctx, key).Int64()
if err == redis.Nil {
return 0, nil
}
return val, err
}
// GetAllUsage returns usage for every configured provider.
func (m *Manager) GetAllUsage(ctx context.Context) map[string]int64 {
result := make(map[string]int64, len(m.configs))
for _, cfg := range m.configs {
used, _ := m.GetUsage(ctx, cfg.Type)
result[cfg.Type] = used
}
return result
}
// ResetUsage deletes the Redis quota key for the given provider, resetting usage to 0.
func (m *Manager) ResetUsage(ctx context.Context, providerType string) error {
if m.redis == nil {
return nil
}
key := quotaRedisKey(providerType)
return m.redis.Del(ctx, key).Err()
}
// --- Provider factory ---
func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provider {
switch cfg.Type {
case braveProviderName:
return NewBraveProvider(cfg.APIKey, client)
case tavilyProviderName:
return NewTavilyProvider(cfg.APIKey, client)
default:
slog.Warn("websearch: unknown provider type, falling back to brave",
"type", cfg.Type)
return NewBraveProvider(cfg.APIKey, client)
}
}
// --- Redis key helpers ---
func quotaRedisKey(providerType string) string {
return quotaKeyPrefix + providerType
}
// quotaTTLFromSubscription calculates the TTL for the quota counter based on
// the provider's subscription start date. Quota resets monthly from that date.
// When the Redis key expires naturally, the next INCR creates a fresh counter (lazy refresh).
func quotaTTLFromSubscription(subscribedAt *int64) time.Duration {
if subscribedAt == nil || *subscribedAt == 0 {
return defaultQuotaTTL
}
next := nextMonthlyReset(time.Unix(*subscribedAt, 0).UTC())
ttl := time.Until(next) + quotaTTLBuffer
if ttl <= quotaTTLBuffer {
// Already past the reset — next cycle
ttl = defaultQuotaTTL
}
return ttl
}
// nextMonthlyReset returns the next monthly reset time based on the subscription start date.
// E.g., subscribed on Jan 15 → resets on Feb 15, Mar 15, etc.
// Handles day-of-month overflow: Jan 31 → Feb 28 (not Mar 3).
func nextMonthlyReset(subscribedAt time.Time) time.Time {
now := time.Now().UTC()
if subscribedAt.IsZero() {
return now.AddDate(0, 1, 0)
}
months := (now.Year()-subscribedAt.Year())*12 + int(now.Month()-subscribedAt.Month())
if months < 0 {
months = 0
}
candidate := addMonthsClamped(subscribedAt, months)
if candidate.After(now) {
return candidate
}
return addMonthsClamped(subscribedAt, months+1)
}
// addMonthsClamped adds N months to a date, clamping the day to the last day of the target month.
// E.g., Jan 31 + 1 month = Feb 28 (not Mar 3).
func addMonthsClamped(t time.Time, months int) time.Time {
y, m, d := t.Date()
targetMonth := time.Month(int(m) + months)
targetYear := y + int(targetMonth-1)/12
targetMonth = (targetMonth-1)%12 + 1
// Last day of the target month
lastDay := time.Date(targetYear, targetMonth+1, 0, 0, 0, 0, 0, time.UTC).Day()
if d > lastDay {
d = lastDay
}
return time.Date(targetYear, targetMonth, d, 0, 0, 0, 0, time.UTC)
}

View File

@ -0,0 +1,323 @@
package websearch
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestNewManager_PreservesOrder(t *testing.T) {
configs := []ProviderConfig{
{Type: "brave", APIKey: "k3"},
{Type: "tavily", APIKey: "k1"},
}
m := NewManager(configs, nil)
require.Equal(t, "brave", m.configs[0].Type)
require.Equal(t, "tavily", m.configs[1].Type)
}
func TestManager_SearchWithBestProvider_EmptyQuery(t *testing.T) {
m := NewManager([]ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: ""})
require.ErrorContains(t, err, "empty search query")
_, _, err = m.SearchWithBestProvider(context.Background(), SearchRequest{Query: " "})
require.ErrorContains(t, err, "empty search query")
}
func TestManager_SearchWithBestProvider_SkipEmptyAPIKey(t *testing.T) {
m := NewManager([]ProviderConfig{{Type: "brave", APIKey: ""}}, nil)
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
require.ErrorContains(t, err, "no available provider")
}
func TestManager_SearchWithBestProvider_SkipExpired(t *testing.T) {
past := time.Now().Add(-1 * time.Hour).Unix()
m := NewManager([]ProviderConfig{
{Type: "brave", APIKey: "k", ExpiresAt: &past},
}, nil)
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
require.ErrorContains(t, err, "no available provider")
}
func TestManager_SearchWithBestProvider_UsesFirstAvailable(t *testing.T) {
srvBrave := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
resp := braveResponse{}
resp.Web.Results = []braveResult{{URL: "https://brave.com", Title: "Brave", Description: "from brave"}}
_ = json.NewEncoder(w).Encode(resp)
}))
defer srvBrave.Close()
origURL := *braveSearchURL
u, _ := http.NewRequest("GET", srvBrave.URL, nil)
*braveSearchURL = *u.URL
defer func() { *braveSearchURL = origURL }()
m := NewManager([]ProviderConfig{
{Type: "brave", APIKey: "k1"},
{Type: "tavily", APIKey: "k2"},
}, nil)
m.clientCache[srvBrave.URL] = srvBrave.Client()
m.clientCache[""] = srvBrave.Client()
resp, providerName, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
require.NoError(t, err)
require.Equal(t, "brave", providerName)
require.Len(t, resp.Results, 1)
require.Equal(t, "from brave", resp.Results[0].Snippet)
}
func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
resp := braveResponse{}
resp.Web.Results = []braveResult{{URL: "https://test.com", Title: "Test", Description: "result"}}
_ = json.NewEncoder(w).Encode(resp)
}))
defer srv.Close()
origURL := *braveSearchURL
u, _ := http.NewRequest("GET", srv.URL, nil)
*braveSearchURL = *u.URL
defer func() { *braveSearchURL = origURL }()
m := NewManager([]ProviderConfig{
{Type: "brave", APIKey: "k", QuotaLimit: 100},
}, nil)
m.clientCache[""] = srv.Client()
resp, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
require.NoError(t, err)
require.Len(t, resp.Results, 1)
}
func TestManager_GetUsage_NilRedis(t *testing.T) {
m := NewManager(nil, nil)
used, err := m.GetUsage(context.Background(), "brave")
require.NoError(t, err)
require.Equal(t, int64(0), used)
}
func TestManager_GetAllUsage_NilRedis(t *testing.T) {
m := NewManager([]ProviderConfig{
{Type: "brave"},
}, nil)
usage := m.GetAllUsage(context.Background())
require.Equal(t, int64(0), usage["brave"])
}
// --- Quota TTL from subscription ---
func TestQuotaTTLFromSubscription_NilSubscription(t *testing.T) {
ttl := quotaTTLFromSubscription(nil)
require.Equal(t, defaultQuotaTTL, ttl)
}
func TestQuotaTTLFromSubscription_ZeroSubscription(t *testing.T) {
zero := int64(0)
ttl := quotaTTLFromSubscription(&zero)
require.Equal(t, defaultQuotaTTL, ttl)
}
func TestQuotaTTLFromSubscription_ValidSubscription(t *testing.T) {
// Subscribed 10 days ago — next reset in ~20 days
sub := time.Now().Add(-10 * 24 * time.Hour).Unix()
ttl := quotaTTLFromSubscription(&sub)
require.Greater(t, ttl, 15*24*time.Hour) // at least 15 days
require.Less(t, ttl, 25*24*time.Hour+quotaTTLBuffer)
}
func TestNextMonthlyReset_SubscribedRecentPast(t *testing.T) {
// Subscribed on the 10th of this month (always valid day)
now := time.Now().UTC()
sub := time.Date(now.Year(), now.Month(), 10, 0, 0, 0, 0, time.UTC)
next := nextMonthlyReset(sub)
require.True(t, next.After(now) || next.Equal(now), "next reset should be in the future or now")
require.True(t, next.Before(now.AddDate(0, 1, 1)))
}
func TestNextMonthlyReset_SubscribedLongAgo(t *testing.T) {
// Subscribed 6 months ago on the 1st
sub := time.Now().UTC().AddDate(0, -6, 0)
sub = time.Date(sub.Year(), sub.Month(), 1, 0, 0, 0, 0, time.UTC)
next := nextMonthlyReset(sub)
require.True(t, next.After(time.Now().UTC()))
// Should be within the next 31 days
require.True(t, next.Before(time.Now().UTC().AddDate(0, 1, 1)))
}
func TestNextMonthlyReset_FutureSubscription(t *testing.T) {
sub := time.Now().UTC().AddDate(0, 0, 5)
next := nextMonthlyReset(sub)
require.True(t, next.After(time.Now().UTC()))
}
func TestAddMonthsClamped_Jan31ToFeb(t *testing.T) {
sub := time.Date(2026, 1, 31, 0, 0, 0, 0, time.UTC)
next := addMonthsClamped(sub, 1)
require.Equal(t, time.Month(2), next.Month())
require.Equal(t, 28, next.Day()) // Feb 28 (2026 is not a leap year)
}
func TestAddMonthsClamped_Jan31ToFebLeapYear(t *testing.T) {
sub := time.Date(2028, 1, 31, 0, 0, 0, 0, time.UTC)
next := addMonthsClamped(sub, 1)
require.Equal(t, time.Month(2), next.Month())
require.Equal(t, 29, next.Day()) // Feb 29 (2028 is a leap year)
}
func TestAddMonthsClamped_Mar31ToApr(t *testing.T) {
sub := time.Date(2026, 3, 31, 0, 0, 0, 0, time.UTC)
next := addMonthsClamped(sub, 1)
require.Equal(t, time.Month(4), next.Month())
require.Equal(t, 30, next.Day()) // Apr has 30 days
}
func TestAddMonthsClamped_NormalDay(t *testing.T) {
sub := time.Date(2026, 1, 15, 0, 0, 0, 0, time.UTC)
next := addMonthsClamped(sub, 1)
require.Equal(t, time.Month(2), next.Month())
require.Equal(t, 15, next.Day()) // no clamping needed
}
// --- Redis key ---
func TestQuotaRedisKey_Format(t *testing.T) {
key := quotaRedisKey("brave")
require.Equal(t, "websearch:quota:brave", key)
}
// --- isProviderAvailable ---
func TestIsProviderAvailable_EmptyAPIKey(t *testing.T) {
m := NewManager(nil, nil)
require.False(t, m.isProviderAvailable(ProviderConfig{APIKey: ""}))
}
func TestIsProviderAvailable_Expired(t *testing.T) {
m := NewManager(nil, nil)
past := time.Now().Add(-1 * time.Hour).Unix()
require.False(t, m.isProviderAvailable(ProviderConfig{APIKey: "k", ExpiresAt: &past}))
}
func TestIsProviderAvailable_Valid(t *testing.T) {
m := NewManager(nil, nil)
future := time.Now().Add(1 * time.Hour).Unix()
require.True(t, m.isProviderAvailable(ProviderConfig{APIKey: "k", ExpiresAt: &future}))
require.True(t, m.isProviderAvailable(ProviderConfig{APIKey: "k"})) // no expiry
}
// --- resolveProxyID ---
func TestResolveProxyID_AccountProxyOverrides(t *testing.T) {
cfg := ProviderConfig{ProxyID: 42}
require.Equal(t, int64(0), resolveProxyID(cfg, "http://account-proxy:8080"))
require.Equal(t, int64(42), resolveProxyID(cfg, ""))
}
// --- isProxyError ---
func TestIsProxyError_Nil(t *testing.T) {
require.False(t, isProxyError(nil))
}
func TestIsProxyError_ConnectionRefused(t *testing.T) {
require.True(t, isProxyError(fmt.Errorf("dial tcp: connection refused")))
}
func TestIsProxyError_Timeout(t *testing.T) {
require.True(t, isProxyError(fmt.Errorf("i/o timeout while connecting to proxy")))
}
func TestIsProxyError_SOCKS(t *testing.T) {
require.True(t, isProxyError(fmt.Errorf("socks connect failed")))
}
func TestIsProxyError_TLSHandshake(t *testing.T) {
require.True(t, isProxyError(fmt.Errorf("tls handshake timeout")))
}
func TestIsProxyError_APIError_NotProxy(t *testing.T) {
require.False(t, isProxyError(fmt.Errorf("API rate limit exceeded")))
}
// --- isProxyAvailable (nil Redis) ---
func TestIsProxyAvailable_NilRedis(t *testing.T) {
m := NewManager(nil, nil)
require.True(t, m.isProxyAvailable(context.Background(), 42))
}
func TestIsProxyAvailable_ZeroID(t *testing.T) {
m := NewManager(nil, nil)
require.True(t, m.isProxyAvailable(context.Background(), 0))
}
// --- selectByQuotaWeight ---
func TestSelectByQuotaWeight_NoQuotaLast(t *testing.T) {
m := NewManager(nil, nil)
candidates := []ProviderConfig{
{Type: "brave", APIKey: "k1", QuotaLimit: 0},
{Type: "tavily", APIKey: "k2", QuotaLimit: 100},
}
result := m.selectByQuotaWeight(context.Background(), candidates)
require.Len(t, result, 2)
require.Equal(t, "tavily", result[0].Type)
require.Equal(t, "brave", result[1].Type)
}
func TestSelectByQuotaWeight_AllNoQuota(t *testing.T) {
m := NewManager(nil, nil)
candidates := []ProviderConfig{
{Type: "brave", APIKey: "k1", QuotaLimit: 0},
{Type: "tavily", APIKey: "k2", QuotaLimit: 0},
}
result := m.selectByQuotaWeight(context.Background(), candidates)
require.Len(t, result, 2)
}
func TestSelectByQuotaWeight_Empty(t *testing.T) {
m := NewManager(nil, nil)
result := m.selectByQuotaWeight(context.Background(), nil)
require.Empty(t, result)
}
// --- newHTTPClient ---
func TestNewHTTPClient_NoProxy(t *testing.T) {
c, err := newHTTPClient("")
require.NoError(t, err)
require.NotNil(t, c)
}
func TestNewHTTPClient_InvalidProxy(t *testing.T) {
_, err := newHTTPClient("://bad-url")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid proxy URL")
}
func TestNewHTTPClient_ValidHTTPProxy(t *testing.T) {
c, err := newHTTPClient("http://proxy.example.com:8080")
require.NoError(t, err)
require.NotNil(t, c)
}
func TestNewHTTPClient_ValidSOCKS5Proxy(t *testing.T) {
c, err := newHTTPClient("socks5://proxy.example.com:1080")
require.NoError(t, err)
require.NotNil(t, c)
}
// --- ResetUsage ---
func TestManager_ResetUsage_NilRedis(t *testing.T) {
m := NewManager(nil, nil)
err := m.ResetUsage(context.Background(), "brave")
require.NoError(t, err)
}

View File

@ -0,0 +1,11 @@
package websearch
import "context"
// Provider is the interface every search backend must implement.
type Provider interface {
// Name returns the provider identifier ("brave" or "tavily").
Name() string
// Search executes a web search and returns results.
Search(ctx context.Context, req SearchRequest) (*SearchResponse, error)
}

View File

@ -0,0 +1,107 @@
package websearch
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
)
const (
tavilySearchEndpoint = "https://api.tavily.com/search"
tavilyProviderName = "tavily"
tavilySearchDepthBasic = "basic"
)
// TavilyProvider implements web search via the Tavily Search API.
type TavilyProvider struct {
apiKey string
httpClient *http.Client
}
// NewTavilyProvider creates a Tavily Search provider.
// The caller is responsible for configuring the http.Client with proxy/timeouts.
func NewTavilyProvider(apiKey string, httpClient *http.Client) *TavilyProvider {
if httpClient == nil {
httpClient = http.DefaultClient
}
return &TavilyProvider{apiKey: apiKey, httpClient: httpClient}
}
func (t *TavilyProvider) Name() string { return tavilyProviderName }
func (t *TavilyProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
maxResults := req.MaxResults
if maxResults <= 0 {
maxResults = defaultMaxResults
}
payload := tavilyRequest{
APIKey: t.apiKey,
Query: req.Query,
MaxResults: maxResults,
SearchDepth: tavilySearchDepthBasic,
}
bodyBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("tavily: encode request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tavilySearchEndpoint, bytes.NewReader(bodyBytes))
if err != nil {
return nil, fmt.Errorf("tavily: build request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := t.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("tavily: request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
if err != nil {
return nil, fmt.Errorf("tavily: read body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("tavily: status %d: %s", resp.StatusCode, truncateBody(body))
}
var raw tavilyResponse
if err := json.Unmarshal(body, &raw); err != nil {
return nil, fmt.Errorf("tavily: decode response: %w", err)
}
results := make([]SearchResult, 0, len(raw.Results))
for _, r := range raw.Results {
results = append(results, SearchResult{
URL: r.URL,
Title: r.Title,
Snippet: r.Content,
})
}
return &SearchResponse{Results: results, Query: req.Query}, nil
}
type tavilyRequest struct {
APIKey string `json:"api_key"`
Query string `json:"query"`
MaxResults int `json:"max_results"`
SearchDepth string `json:"search_depth"`
}
type tavilyResponse struct {
Results []tavilyResult `json:"results"`
}
type tavilyResult struct {
URL string `json:"url"`
Title string `json:"title"`
Content string `json:"content"`
Score float64 `json:"score"`
}

View File

@ -0,0 +1,63 @@
package websearch
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestTavilyProvider_Name(t *testing.T) {
p := NewTavilyProvider("key", nil)
require.Equal(t, "tavily", p.Name())
}
func TestTavilyProvider_Search_RequestConstruction(t *testing.T) {
// Verify tavilyRequest struct fields map correctly
req := tavilyRequest{
APIKey: "test-key",
Query: "golang",
MaxResults: 3,
SearchDepth: tavilySearchDepthBasic,
}
data, err := json.Marshal(req)
require.NoError(t, err)
var parsed map[string]any
require.NoError(t, json.Unmarshal(data, &parsed))
require.Equal(t, "test-key", parsed["api_key"])
require.Equal(t, "golang", parsed["query"])
require.Equal(t, float64(3), parsed["max_results"])
require.Equal(t, "basic", parsed["search_depth"])
}
func TestTavilyProvider_Search_ResponseParsing(t *testing.T) {
rawResp := `{"results":[{"url":"https://go.dev","title":"Go","content":"Go programming language","score":0.95}]}`
var resp tavilyResponse
require.NoError(t, json.Unmarshal([]byte(rawResp), &resp))
require.Len(t, resp.Results, 1)
require.Equal(t, "https://go.dev", resp.Results[0].URL)
require.Equal(t, "Go programming language", resp.Results[0].Content)
require.InDelta(t, 0.95, resp.Results[0].Score, 0.001)
// Verify mapping to SearchResult
results := make([]SearchResult, 0, len(resp.Results))
for _, r := range resp.Results {
results = append(results, SearchResult{
URL: r.URL, Title: r.Title, Snippet: r.Content,
})
}
require.Equal(t, "Go programming language", results[0].Snippet)
require.Equal(t, "", results[0].PageAge)
}
func TestTavilyProvider_Search_EmptyResults(t *testing.T) {
var resp tavilyResponse
require.NoError(t, json.Unmarshal([]byte(`{"results":[]}`), &resp))
require.Empty(t, resp.Results)
}
func TestTavilyProvider_Search_InvalidJSON(t *testing.T) {
var resp tavilyResponse
require.Error(t, json.Unmarshal([]byte("not json"), &resp))
}

View File

@ -0,0 +1,30 @@
package websearch
// SearchResult represents a single web search result.
type SearchResult struct {
URL string `json:"url"`
Title string `json:"title"`
Snippet string `json:"snippet"`
PageAge string `json:"page_age,omitempty"`
}
// SearchRequest describes a web search to perform.
type SearchRequest struct {
Query string
MaxResults int // defaults to defaultMaxResults if <= 0
ProxyURL string // optional HTTP proxy URL
}
// SearchResponse holds the results of a web search.
type SearchResponse struct {
Results []SearchResult
Query string // the query that was actually executed
}
const defaultMaxResults = 5
// Provider type identifiers.
const (
ProviderTypeBrave = "brave"
ProviderTypeTavily = "tavily"
)

View File

@ -138,10 +138,17 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
WithUser(func(q *dbent.UserQuery) { WithUser(func(q *dbent.UserQuery) {
q.Select( q.Select(
user.FieldID, user.FieldID,
user.FieldEmail,
user.FieldUsername,
user.FieldStatus, user.FieldStatus,
user.FieldRole, user.FieldRole,
user.FieldBalance, user.FieldBalance,
user.FieldConcurrency, user.FieldConcurrency,
user.FieldBalanceNotifyEnabled,
user.FieldBalanceNotifyThresholdType,
user.FieldBalanceNotifyThreshold,
user.FieldBalanceNotifyExtraEmails,
user.FieldTotalRecharged,
) )
}). }).
WithGroup(func(q *dbent.GroupQuery) { WithGroup(func(q *dbent.GroupQuery) {
@ -639,22 +646,31 @@ func userEntityToService(u *dbent.User) *service.User {
if u == nil { if u == nil {
return nil return nil
} }
return &service.User{ out := &service.User{
ID: u.ID, ID: u.ID,
Email: u.Email, Email: u.Email,
Username: u.Username, Username: u.Username,
Notes: u.Notes, Notes: u.Notes,
PasswordHash: u.PasswordHash, PasswordHash: u.PasswordHash,
Role: u.Role, Role: u.Role,
Balance: u.Balance, Balance: u.Balance,
Concurrency: u.Concurrency, Concurrency: u.Concurrency,
Status: u.Status, Status: u.Status,
TotpSecretEncrypted: u.TotpSecretEncrypted, TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled, TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt, TotpEnabledAt: u.TotpEnabledAt,
CreatedAt: u.CreatedAt, BalanceNotifyEnabled: u.BalanceNotifyEnabled,
UpdatedAt: u.UpdatedAt, BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
TotalRecharged: u.TotalRecharged,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
} }
// Parse extra emails JSON (supports both old []string and new []NotifyEmailEntry format)
if u.BalanceNotifyExtraEmails != "" && u.BalanceNotifyExtraEmails != "[]" {
out.BalanceNotifyExtraEmails = service.ParseNotifyEmails(u.BalanceNotifyExtraEmails)
}
return out
} }
func groupEntityToService(g *dbent.Group) *service.Group { func groupEntityToService(g *dbent.Group) *service.Group {

View File

@ -41,10 +41,14 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
if err != nil { if err != nil {
return err return err
} }
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
if err != nil {
return err
}
err = tx.QueryRowContext(ctx, err = tx.QueryRowContext(ctx,
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models) VALUES ($1, $2, $3, $4, $5, $6) `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING id, created_at, updated_at`, RETURNING id, created_at, updated_at`,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats,
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt) ).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
if err != nil { if err != nil {
if isUniqueViolation(err) { if isUniqueViolation(err) {
@ -67,17 +71,24 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
} }
} }
// 设置账号统计定价规则
if len(channel.AccountStatsPricingRules) > 0 {
if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
return err
}
}
return nil return nil
}) })
} }
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) { func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
ch := &service.Channel{} ch := &service.Channel{}
var modelMappingJSON []byte var modelMappingJSON, featuresConfigJSON []byte
err := r.db.QueryRowContext(ctx, err := r.db.QueryRowContext(ctx,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at
FROM channels WHERE id = $1`, id, FROM channels WHERE id = $1`, id,
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt) ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, service.ErrChannelNotFound return nil, service.ErrChannelNotFound
} }
@ -85,6 +96,7 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
return nil, fmt.Errorf("get channel: %w", err) return nil, fmt.Errorf("get channel: %w", err)
} }
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
groupIDs, err := r.GetGroupIDs(ctx, id) groupIDs, err := r.GetGroupIDs(ctx, id)
if err != nil { if err != nil {
@ -98,6 +110,12 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
} }
ch.ModelPricing = pricing ch.ModelPricing = pricing
statsPricingRules, err := r.loadAccountStatsPricingRules(ctx, id)
if err != nil {
return nil, err
}
ch.AccountStatsPricingRules = statsPricingRules
return ch, nil return ch, nil
} }
@ -107,10 +125,14 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
if err != nil { if err != nil {
return err return err
} }
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
if err != nil {
return err
}
result, err := tx.ExecContext(ctx, result, err := tx.ExecContext(ctx,
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, updated_at = NOW() `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, apply_pricing_to_account_stats = $9, updated_at = NOW()
WHERE id = $7`, WHERE id = $10`,
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.ID, channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats, channel.ID,
) )
if err != nil { if err != nil {
if isUniqueViolation(err) { if isUniqueViolation(err) {
@ -137,6 +159,13 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
} }
} }
// 更新账号统计定价规则
if channel.AccountStatsPricingRules != nil {
if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
return err
}
}
return nil return nil
}) })
} }
@ -187,7 +216,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
// 查询 channel 列表 // 查询 channel 列表
dataQuery := fmt.Sprintf( dataQuery := fmt.Sprintf(
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.apply_pricing_to_account_stats, c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`, FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
whereClause, channelListOrderBy(params), argIdx, argIdx+1, whereClause, channelListOrderBy(params), argIdx, argIdx+1,
) )
@ -203,11 +232,12 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
var channelIDs []int64 var channelIDs []int64
for rows.Next() { for rows.Next() {
var ch service.Channel var ch service.Channel
var modelMappingJSON []byte var modelMappingJSON, featuresConfigJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil { if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, nil, fmt.Errorf("scan channel: %w", err) return nil, nil, fmt.Errorf("scan channel: %w", err)
} }
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
channels = append(channels, ch) channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID) channelIDs = append(channelIDs, ch.ID)
} }
@ -225,9 +255,14 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
if err != nil {
return nil, nil, err
}
for i := range channels { for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID] channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID] channels[i].ModelPricing = pricingMap[channels[i].ID]
channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
} }
} }
@ -273,7 +308,7 @@ func channelListOrderBy(params pagination.PaginationParams) string {
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) { func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx, rows, err := r.db.QueryContext(ctx,
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`, `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("query all channels: %w", err) return nil, fmt.Errorf("query all channels: %w", err)
@ -284,11 +319,12 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
var channelIDs []int64 var channelIDs []int64
for rows.Next() { for rows.Next() {
var ch service.Channel var ch service.Channel
var modelMappingJSON []byte var modelMappingJSON, featuresConfigJSON []byte
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil { if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan channel: %w", err) return nil, fmt.Errorf("scan channel: %w", err)
} }
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
channels = append(channels, ch) channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID) channelIDs = append(channelIDs, ch.ID)
} }
@ -312,9 +348,16 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
return nil, err return nil, err
} }
// 批量加载账号统计定价规则
statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
if err != nil {
return nil, err
}
for i := range channels { for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID] channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID] channels[i].ModelPricing = pricingMap[channels[i].ID]
channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
} }
return channels, nil return channels, nil
@ -456,6 +499,28 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string {
return m return m
} }
func marshalFeaturesConfig(m map[string]any) ([]byte, error) {
if len(m) == 0 {
return []byte("{}"), nil
}
data, err := json.Marshal(m)
if err != nil {
return nil, fmt.Errorf("marshal features_config: %w", err)
}
return data, nil
}
func unmarshalFeaturesConfig(data []byte) map[string]any {
if len(data) == 0 {
return nil
}
var m map[string]any
if err := json.Unmarshal(data, &m); err != nil {
return nil
}
return m
}
// GetGroupPlatforms 批量查询分组 ID 对应的平台 // GetGroupPlatforms 批量查询分组 ID 对应的平台
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) { func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
if len(groupIDs) == 0 { if len(groupIDs) == 0 {

View File

@ -0,0 +1,244 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
// --- 账号统计定价规则 ---
// batchLoadAccountStatsPricingRules 批量加载多个渠道的账号统计定价规则(含模型定价)
func (r *channelRepository) batchLoadAccountStatsPricingRules(ctx context.Context, channelIDs []int64) (map[int64][]service.AccountStatsPricingRule, error) {
// 1. 查询规则
rows, err := r.db.QueryContext(ctx,
`SELECT id, channel_id, name, group_ids, account_ids, sort_order, created_at, updated_at
FROM channel_account_stats_pricing_rules WHERE channel_id = ANY($1) ORDER BY channel_id, sort_order, id`,
pq.Array(channelIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load account stats pricing rules: %w", err)
}
defer func() { _ = rows.Close() }()
var allRules []service.AccountStatsPricingRule
var ruleIDs []int64
for rows.Next() {
var rule service.AccountStatsPricingRule
if err := rows.Scan(
&rule.ID, &rule.ChannelID, &rule.Name,
pq.Array(&rule.GroupIDs), pq.Array(&rule.AccountIDs),
&rule.SortOrder, &rule.CreatedAt, &rule.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan account stats pricing rule: %w", err)
}
ruleIDs = append(ruleIDs, rule.ID)
allRules = append(allRules, rule)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate account stats pricing rules: %w", err)
}
// 2. 批量加载规则的模型定价
pricingMap, err := r.batchLoadAccountStatsModelPricing(ctx, ruleIDs)
if err != nil {
return nil, err
}
// 3. 按 channelID 分组并关联定价
result := make(map[int64][]service.AccountStatsPricingRule, len(channelIDs))
for i := range allRules {
allRules[i].Pricing = pricingMap[allRules[i].ID]
result[allRules[i].ChannelID] = append(result[allRules[i].ChannelID], allRules[i])
}
return result, nil
}
// batchLoadAccountStatsModelPricing 批量加载规则的模型定价
func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Context, ruleIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
if len(ruleIDs) == 0 {
return make(map[int64][]service.ChannelModelPricing), nil
}
rows, err := r.db.QueryContext(ctx,
`SELECT id, rule_id, platform, models, billing_mode, input_price, output_price,
cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
FROM channel_account_stats_model_pricing WHERE rule_id = ANY($1) ORDER BY rule_id, id`,
pq.Array(ruleIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load account stats model pricing: %w", err)
}
defer func() { _ = rows.Close() }()
pricingMap := make(map[int64][]service.ChannelModelPricing, len(ruleIDs))
for rows.Next() {
var p service.ChannelModelPricing
var ruleID int64
var modelsJSON []byte
if err := rows.Scan(
&p.ID, &ruleID, &p.Platform, &modelsJSON, &p.BillingMode,
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
&p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan account stats model pricing: %w", err)
}
if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
p.Models = []string{}
}
pricingMap[ruleID] = append(pricingMap[ruleID], p)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate account stats model pricing: %w", err)
}
// Load intervals for all pricing entries.
var allPricingIDs []int64
for _, pricings := range pricingMap {
for _, p := range pricings {
allPricingIDs = append(allPricingIDs, p.ID)
}
}
if len(allPricingIDs) > 0 {
intervalsMap, err := r.batchLoadAccountStatsIntervals(ctx, allPricingIDs)
if err != nil {
return nil, err
}
for ruleID, pricings := range pricingMap {
for i := range pricings {
pricings[i].Intervals = intervalsMap[pricings[i].ID]
}
pricingMap[ruleID] = pricings
}
}
return pricingMap, nil
}
// loadAccountStatsPricingRules 加载单个渠道的账号统计定价规则(供 GetByID 使用)
func (r *channelRepository) loadAccountStatsPricingRules(ctx context.Context, channelID int64) ([]service.AccountStatsPricingRule, error) {
result, err := r.batchLoadAccountStatsPricingRules(ctx, []int64{channelID})
if err != nil {
return nil, err
}
return result[channelID], nil
}
// replaceAccountStatsPricingRulesTx 在事务中替换渠道的账号统计定价规则(删除旧的 + 插入新的)
func replaceAccountStatsPricingRulesTx(ctx context.Context, tx *sql.Tx, channelID int64, rules []service.AccountStatsPricingRule) error {
// CASCADE 会自动删除关联的 model_pricing
if _, err := tx.ExecContext(ctx,
`DELETE FROM channel_account_stats_pricing_rules WHERE channel_id = $1`, channelID,
); err != nil {
return fmt.Errorf("delete old account stats pricing rules: %w", err)
}
for i := range rules {
rules[i].ChannelID = channelID
if err := createAccountStatsPricingRuleTx(ctx, tx, &rules[i]); err != nil {
return fmt.Errorf("insert account stats pricing rule: %w", err)
}
}
return nil
}
// createAccountStatsPricingRuleTx 在事务中创建单条账号统计定价规则及其模型定价
func createAccountStatsPricingRuleTx(ctx context.Context, tx *sql.Tx, rule *service.AccountStatsPricingRule) error {
err := tx.QueryRowContext(ctx,
`INSERT INTO channel_account_stats_pricing_rules (channel_id, name, group_ids, account_ids, sort_order)
VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at`,
rule.ChannelID, rule.Name, pq.Array(rule.GroupIDs), pq.Array(rule.AccountIDs), rule.SortOrder,
).Scan(&rule.ID, &rule.CreatedAt, &rule.UpdatedAt)
if err != nil {
return fmt.Errorf("insert account stats pricing rule: %w", err)
}
for j := range rule.Pricing {
if err := createAccountStatsModelPricingTx(ctx, tx, rule.ID, &rule.Pricing[j]); err != nil {
return err
}
}
return nil
}
// createAccountStatsModelPricingTx 在事务中创建单条账号统计模型定价
func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID int64, pricing *service.ChannelModelPricing) error {
modelsJSON, err := json.Marshal(pricing.Models)
if err != nil {
return fmt.Errorf("marshal models: %w", err)
}
billingMode := pricing.BillingMode
if billingMode == "" {
billingMode = service.BillingModeToken
}
platform := pricing.Platform
err = tx.QueryRowContext(ctx,
`INSERT INTO channel_account_stats_model_pricing (rule_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
ruleID, platform, modelsJSON, billingMode,
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
pricing.ImageOutputPrice, pricing.PerRequestPrice,
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
if err != nil {
return fmt.Errorf("insert account stats model pricing: %w", err)
}
// Persist intervals (mirrors channel_pricing_intervals logic).
for i := range pricing.Intervals {
iv := &pricing.Intervals[i]
iv.PricingID = pricing.ID
if err := createAccountStatsIntervalTx(ctx, tx, iv); err != nil {
return err
}
}
return nil
}
// createAccountStatsIntervalTx inserts a single interval for an account stats pricing entry.
func createAccountStatsIntervalTx(ctx context.Context, tx *sql.Tx, iv *service.PricingInterval) error {
return tx.QueryRowContext(ctx,
`INSERT INTO channel_account_stats_pricing_intervals
(pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel,
iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice,
iv.PerRequestPrice, iv.SortOrder,
).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt)
}
// batchLoadAccountStatsIntervals loads intervals for account stats pricing entries.
func (r *channelRepository) batchLoadAccountStatsIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) {
if len(pricingIDs) == 0 {
return nil, nil
}
rows, err := r.db.QueryContext(ctx,
`SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
input_price, output_price, cache_write_price, cache_read_price,
per_request_price, sort_order, created_at, updated_at
FROM channel_account_stats_pricing_intervals
WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`,
pq.Array(pricingIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load account stats pricing intervals: %w", err)
}
defer func() { _ = rows.Close() }()
result := make(map[int64][]service.PricingInterval)
for rows.Next() {
var iv service.PricingInterval
if err := rows.Scan(
&iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel,
&iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice,
&iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan account stats pricing interval: %w", err)
}
result[iv.PricingID] = append(result[iv.PricingID], iv)
}
return result, rows.Err()
}

View File

@ -3,6 +3,8 @@ package repository
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"strings"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@ -11,23 +13,33 @@ import (
const ( const (
verifyCodeKeyPrefix = "verify_code:" verifyCodeKeyPrefix = "verify_code:"
notifyVerifyKeyPrefix = "notify_verify:"
passwordResetKeyPrefix = "password_reset:" passwordResetKeyPrefix = "password_reset:"
passwordResetSentAtKeyPrefix = "password_reset_sent:" passwordResetSentAtKeyPrefix = "password_reset_sent:"
notifyCodeUserRateKeyPrefix = "notify_code_user_rate:"
) )
// verifyCodeKey generates the Redis key for email verification code. // verifyCodeKey generates the Redis key for email verification code.
// Email is lowercased for case-insensitive consistency.
func verifyCodeKey(email string) string { func verifyCodeKey(email string) string {
return verifyCodeKeyPrefix + email return verifyCodeKeyPrefix + strings.ToLower(email)
}
// notifyVerifyKey generates the Redis key for notify email verification code.
// Email is lowercased to prevent case-sensitive key mismatch (the business layer
// uses strings.EqualFold for comparison).
func notifyVerifyKey(email string) string {
return notifyVerifyKeyPrefix + strings.ToLower(email)
} }
// passwordResetKey generates the Redis key for password reset token. // passwordResetKey generates the Redis key for password reset token.
func passwordResetKey(email string) string { func passwordResetKey(email string) string {
return passwordResetKeyPrefix + email return passwordResetKeyPrefix + strings.ToLower(email)
} }
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp. // passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
func passwordResetSentAtKey(email string) string { func passwordResetSentAtKey(email string) string {
return passwordResetSentAtKeyPrefix + email return passwordResetSentAtKeyPrefix + strings.ToLower(email)
} }
type emailCache struct { type emailCache struct {
@ -106,3 +118,60 @@ func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email st
key := passwordResetSentAtKey(email) key := passwordResetSentAtKey(email)
return c.rdb.Set(ctx, key, "1", ttl).Err() return c.rdb.Set(ctx, key, "1", ttl).Err()
} }
// Notify email verification code methods
func (c *emailCache) GetNotifyVerifyCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
key := notifyVerifyKey(email)
val, err := c.rdb.Get(ctx, key).Result()
if err != nil {
return nil, err
}
var data service.VerificationCodeData
if err := json.Unmarshal([]byte(val), &data); err != nil {
return nil, err
}
return &data, nil
}
func (c *emailCache) SetNotifyVerifyCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
key := notifyVerifyKey(email)
val, err := json.Marshal(data)
if err != nil {
return err
}
return c.rdb.Set(ctx, key, val, ttl).Err()
}
func (c *emailCache) DeleteNotifyVerifyCode(ctx context.Context, email string) error {
key := notifyVerifyKey(email)
return c.rdb.Del(ctx, key).Err()
}
// User-level rate limiting for notify email verification codes
func notifyCodeUserRateKey(userID int64) string {
return notifyCodeUserRateKeyPrefix + fmt.Sprintf("%d", userID)
}
func (c *emailCache) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) {
key := notifyCodeUserRateKey(userID)
count, err := c.rdb.Incr(ctx, key).Result()
if err != nil {
return 0, err
}
// Always set TTL (idempotent) to avoid orphan keys if process crashes between INCR and EXPIRE.
if err := c.rdb.Expire(ctx, key, window).Err(); err != nil {
return count, fmt.Errorf("expire notify code rate key: %w", err)
}
return count, nil
}
func (c *emailCache) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) {
key := notifyCodeUserRateKey(userID)
count, err := c.rdb.Get(ctx, key).Int64()
if err != nil {
return 0, err
}
return count, nil
}

View File

@ -113,9 +113,11 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
} }
if cmd.BalanceCost > 0 { if cmd.BalanceCost > 0 {
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil { newBalance, err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost)
if err != nil {
return err return err
} }
result.NewBalance = &newBalance
} }
if cmd.APIKeyQuotaCost > 0 { if cmd.APIKeyQuotaCost > 0 {
@ -133,9 +135,11 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
} }
if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) { if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) {
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil { quotaState, err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost)
if err != nil {
return err return err
} }
result.QuotaState = quotaState
} }
return nil return nil
@ -169,24 +173,22 @@ func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscrip
return service.ErrSubscriptionNotFound return service.ErrSubscriptionNotFound
} }
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error { func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) (float64, error) {
res, err := tx.ExecContext(ctx, ` var newBalance float64
err := tx.QueryRowContext(ctx, `
UPDATE users UPDATE users
SET balance = balance - $1, SET balance = balance - $1,
updated_at = NOW() updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL WHERE id = $2 AND deleted_at IS NULL
`, amount, userID) RETURNING balance
`, amount, userID).Scan(&newBalance)
if errors.Is(err, sql.ErrNoRows) {
return 0, service.ErrUserNotFound
}
if err != nil { if err != nil {
return err return 0, err
} }
affected, err := res.RowsAffected() return newBalance, nil
if err != nil {
return err
}
if affected > 0 {
return nil
}
return service.ErrUserNotFound
} }
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) { func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
@ -240,7 +242,7 @@ func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKe
return nil return nil
} }
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error { func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) (*service.AccountQuotaState, error) {
rows, err := tx.QueryContext(ctx, rows, err := tx.QueryContext(ctx,
`UPDATE accounts SET extra = ( `UPDATE accounts SET extra = (
COALESCE(extra, '{}'::jsonb) COALESCE(extra, '{}'::jsonb)
@ -248,61 +250,71 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN || CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
jsonb_build_object( jsonb_build_object(
'quota_daily_used', 'quota_daily_used',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) CASE WHEN `+dailyExpiredExpr+`
+ '24 hours'::interval <= NOW()
THEN $1 THEN $1
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END, ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
'quota_daily_start', 'quota_daily_start',
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) CASE WHEN `+dailyExpiredExpr+`
+ '24 hours'::interval <= NOW()
THEN `+nowUTC+` THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
) )
|| CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
ELSE '{}'::jsonb END
ELSE '{}'::jsonb END ELSE '{}'::jsonb END
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN || CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
jsonb_build_object( jsonb_build_object(
'quota_weekly_used', 'quota_weekly_used',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) CASE WHEN `+weeklyExpiredExpr+`
+ '168 hours'::interval <= NOW()
THEN $1 THEN $1
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END, ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
'quota_weekly_start', 'quota_weekly_start',
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) CASE WHEN `+weeklyExpiredExpr+`
+ '168 hours'::interval <= NOW()
THEN `+nowUTC+` THEN `+nowUTC+`
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
) )
|| CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
ELSE '{}'::jsonb END
ELSE '{}'::jsonb END ELSE '{}'::jsonb END
), updated_at = NOW() ), updated_at = NOW()
WHERE id = $2 AND deleted_at IS NULL WHERE id = $2 AND deleted_at IS NULL
RETURNING RETURNING
COALESCE((extra->>'quota_used')::numeric, 0), COALESCE((extra->>'quota_used')::numeric, 0),
COALESCE((extra->>'quota_limit')::numeric, 0)`, COALESCE((extra->>'quota_limit')::numeric, 0),
COALESCE((extra->>'quota_daily_used')::numeric, 0),
COALESCE((extra->>'quota_daily_limit')::numeric, 0),
COALESCE((extra->>'quota_weekly_used')::numeric, 0),
COALESCE((extra->>'quota_weekly_limit')::numeric, 0)`,
amount, accountID) amount, accountID)
if err != nil { if err != nil {
return err return nil, err
} }
defer func() { _ = rows.Close() }() defer func() { _ = rows.Close() }()
var newUsed, limit float64 var state service.AccountQuotaState
if rows.Next() { if rows.Next() {
if err := rows.Scan(&newUsed, &limit); err != nil { if err := rows.Scan(
return err &state.TotalUsed, &state.TotalLimit,
&state.DailyUsed, &state.DailyLimit,
&state.WeeklyUsed, &state.WeeklyLimit,
); err != nil {
return nil, err
} }
} else { } else {
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return err return nil, err
} }
return service.ErrAccountNotFound return nil, service.ErrAccountNotFound
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return err return nil, err
} }
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit { if state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit {
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil { if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err) logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
return err return nil, err
} }
} }
return nil return &state, nil
} }

View File

@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
) )
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at" const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at"
// usageLogInsertArgTypes must stay in the same order as: // usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args // 1. prepareUsageLogInsert().args
@ -82,6 +82,7 @@ var usageLogInsertArgTypes = [...]string{
"text", // model_mapping_chain "text", // model_mapping_chain
"text", // billing_tier "text", // billing_tier
"text", // billing_mode "text", // billing_mode
"numeric", // account_stats_cost
"timestamptz", // created_at "timestamptz", // created_at
} }
@ -360,6 +361,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
model_mapping_chain, model_mapping_chain,
billing_tier, billing_tier,
billing_mode, billing_mode,
account_stats_cost,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6, $7, $1, $2, $3, $4, $5, $6, $7,
@ -367,7 +369,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$10, $11, $12, $13, $10, $11, $12, $13,
$14, $15, $16, $17, $14, $15, $16, $17,
$18, $19, $20, $21, $22, $23, $18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45 $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at RETURNING id, created_at
@ -797,6 +799,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain, model_mapping_chain,
billing_tier, billing_tier,
billing_mode, billing_mode,
account_stats_cost,
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
@ -873,6 +876,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain, model_mapping_chain,
billing_tier, billing_tier,
billing_mode, billing_mode,
account_stats_cost,
created_at created_at
) )
SELECT SELECT
@ -920,6 +924,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
model_mapping_chain, model_mapping_chain,
billing_tier, billing_tier,
billing_mode, billing_mode,
account_stats_cost,
created_at created_at
FROM input FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
@ -1007,10 +1012,11 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain, model_mapping_chain,
billing_tier, billing_tier,
billing_mode, billing_mode,
account_stats_cost,
created_at created_at
) AS (VALUES `) ) AS (VALUES `)
args := make([]any, 0, len(preparedList)*45) args := make([]any, 0, len(preparedList)*46)
argPos := 1 argPos := 1
for idx, prepared := range preparedList { for idx, prepared := range preparedList {
if idx > 0 { if idx > 0 {
@ -1080,6 +1086,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain, model_mapping_chain,
billing_tier, billing_tier,
billing_mode, billing_mode,
account_stats_cost,
created_at created_at
) )
SELECT SELECT
@ -1127,6 +1134,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
model_mapping_chain, model_mapping_chain,
billing_tier, billing_tier,
billing_mode, billing_mode,
account_stats_cost,
created_at created_at
FROM input FROM input
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
@ -1182,6 +1190,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
model_mapping_chain, model_mapping_chain,
billing_tier, billing_tier,
billing_mode, billing_mode,
account_stats_cost,
created_at created_at
) VALUES ( ) VALUES (
$1, $2, $3, $4, $5, $6, $7, $1, $2, $3, $4, $5, $6, $7,
@ -1189,7 +1198,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$10, $11, $12, $13, $10, $11, $12, $13,
$14, $15, $16, $17, $14, $15, $16, $17,
$18, $19, $20, $21, $22, $23, $18, $19, $20, $21, $22, $23,
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45 $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
) )
ON CONFLICT (request_id, api_key_id) DO NOTHING ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...) `, prepared.args...)
@ -1285,6 +1294,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
modelMappingChain, modelMappingChain,
billingTier, billingTier,
billingMode, billingMode,
log.AccountStatsCost, // account_stats_cost
createdAt, createdAt,
}, },
} }
@ -1959,7 +1969,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
SELECT SELECT
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost, COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost, COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs FROM usage_logs
@ -1989,7 +1999,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
SELECT SELECT
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost, COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost, COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs FROM usage_logs
@ -2026,7 +2036,7 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc
account_id, account_id,
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost, COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost, COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs FROM usage_logs
@ -2990,7 +3000,7 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时实际费用使用账号倍率total_cost * account_rate_multiplier // 当仅按 account_id 聚合时实际费用使用账号倍率total_cost * account_rate_multiplier
if accountID > 0 && userID == 0 && apiKeyID == 0 { if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
} }
modelExpr := resolveModelDimensionExpression(source) modelExpr := resolveModelDimensionExpression(source)
@ -3358,7 +3368,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost, COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost, COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost, COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs FROM usage_logs
%s %s
@ -3433,7 +3443,7 @@ type EndpointStat = usagestats.EndpointStat
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) { func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 { if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
} }
query := fmt.Sprintf(` query := fmt.Sprintf(`
@ -3500,7 +3510,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) { func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
if accountID > 0 && userID == 0 && apiKeyID == 0 { if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
} }
query := fmt.Sprintf(` query := fmt.Sprintf(`
@ -3591,7 +3601,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
COUNT(*) as requests, COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost, COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost, COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(actual_cost), 0) as user_cost COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
@ -4069,6 +4079,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
modelMappingChain sql.NullString modelMappingChain sql.NullString
billingTier sql.NullString billingTier sql.NullString
billingMode sql.NullString billingMode sql.NullString
accountStatsCost sql.NullFloat64
createdAt time.Time createdAt time.Time
) )
@ -4118,6 +4129,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&modelMappingChain, &modelMappingChain,
&billingTier, &billingTier,
&billingMode, &billingMode,
&accountStatsCost,
&createdAt, &createdAt,
); err != nil { ); err != nil {
return nil, err return nil, err
@ -4214,6 +4226,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if billingMode.Valid { if billingMode.Valid {
log.BillingMode = &billingMode.String log.BillingMode = &billingMode.String
} }
if accountStatsCost.Valid {
log.AccountStatsCost = &accountStatsCost.Float64
}
return log, nil return log, nil
} }

View File

@ -85,6 +85,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // model_mapping_chain sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode sqlmock.AnyArg(), // billing_mode
sqlmock.AnyArg(), // account_stats_cost
createdAt, createdAt,
). ).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt)) WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
@ -163,6 +164,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(), // model_mapping_chain sqlmock.AnyArg(), // model_mapping_chain
sqlmock.AnyArg(), // billing_tier sqlmock.AnyArg(), // billing_tier
sqlmock.AnyArg(), // billing_mode sqlmock.AnyArg(), // billing_mode
sqlmock.AnyArg(), // account_stats_cost
createdAt, createdAt,
). ).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt)) WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
@ -483,10 +485,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
sql.NullString{}, sql.NullString{},
false, false,
sql.NullInt64{}, // channel_id sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode sql.NullString{}, // billing_mode
sql.NullFloat64{}, // account_stats_cost
now, now,
}}) }})
require.NoError(t, err) require.NoError(t, err)
@ -530,10 +533,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
sql.NullString{}, sql.NullString{},
false, false,
sql.NullInt64{}, // channel_id sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode sql.NullString{}, // billing_mode
sql.NullFloat64{}, // account_stats_cost
now, now,
}}) }})
require.NoError(t, err) require.NoError(t, err)
@ -577,10 +581,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{}, sql.NullString{},
sql.NullString{}, sql.NullString{},
false, false,
sql.NullInt64{}, // channel_id sql.NullInt64{}, // channel_id
sql.NullString{}, // model_mapping_chain sql.NullString{}, // model_mapping_chain
sql.NullString{}, // billing_tier sql.NullString{}, // billing_tier
sql.NullString{}, // billing_mode sql.NullString{}, // billing_mode
sql.NullFloat64{}, // account_stats_cost
now, now,
}}) }})
require.NoError(t, err) require.NoError(t, err)

View File

@ -100,7 +100,7 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
query := ` query := `
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
FROM user_group_rate_multipliers ugr FROM user_group_rate_multipliers ugr
JOIN users u ON u.id = ugr.user_id JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
WHERE ugr.group_id = $1 WHERE ugr.group_id = $1
ORDER BY ugr.user_id ORDER BY ugr.user_id
` `

View File

@ -137,7 +137,7 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
txClient = r.client txClient = r.client
} }
updated, err := txClient.User.UpdateOneID(userIn.ID). updateOp := txClient.User.UpdateOneID(userIn.ID).
SetEmail(userIn.Email). SetEmail(userIn.Email).
SetUsername(userIn.Username). SetUsername(userIn.Username).
SetNotes(userIn.Notes). SetNotes(userIn.Notes).
@ -146,7 +146,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance). SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency). SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status). SetStatus(userIn.Status).
Save(ctx) SetBalanceNotifyEnabled(userIn.BalanceNotifyEnabled).
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
SetTotalRecharged(userIn.TotalRecharged)
if userIn.BalanceNotifyThreshold == nil {
updateOp = updateOp.ClearBalanceNotifyThreshold()
}
updated, err := updateOp.Save(ctx)
if err != nil { if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists) return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
} }
@ -382,7 +390,12 @@ func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
client := clientFromContext(ctx, r.client) client := clientFromContext(ctx, r.client)
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx) update := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount)
// Track cumulative recharge amount for percentage-based notifications
if amount > 0 {
update = update.AddTotalRecharged(amount)
}
n, err := update.Save(ctx)
if err != nil { if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, nil) return translatePersistenceError(err, service.ErrUserNotFound, nil)
} }
@ -549,6 +562,11 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
dst.UpdatedAt = src.UpdatedAt dst.UpdatedAt = src.UpdatedAt
} }
// marshalExtraEmails serializes notify email entries to JSON for storage.
func marshalExtraEmails(entries []service.NotifyEmailEntry) string {
return service.MarshalNotifyEmails(entries)
}
// UpdateTotpSecret 更新用户的 TOTP 加密密钥 // UpdateTotpSecret 更新用户的 TOTP 加密密钥
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
client := clientFromContext(ctx, r.client) client := clientFromContext(ctx, r.client)

View File

@ -58,6 +58,11 @@ func TestAPIContracts(t *testing.T) {
"allowed_groups": null, "allowed_groups": null,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z",
"balance_notify_enabled": false,
"balance_notify_threshold_type": "",
"balance_notify_threshold": null,
"balance_notify_extra_emails": null,
"total_recharged": 0,
"run_mode": "standard" "run_mode": "standard"
} }
}`, }`,
@ -204,11 +209,10 @@ func TestAPIContracts(t *testing.T) {
"image_price_1k": null, "image_price_1k": null,
"image_price_2k": null, "image_price_2k": null,
"image_price_4k": null, "image_price_4k": null,
"claude_code_only": false, "claude_code_only": false,
"allow_messages_dispatch": false, "allow_messages_dispatch": false,
"fallback_group_id": null, "fallback_group_id": null,
"fallback_group_id_on_invalid_request": null, "fallback_group_id_on_invalid_request": null,
"allow_messages_dispatch": false,
"require_oauth_only": false, "require_oauth_only": false,
"require_privacy_set": false, "require_privacy_set": false,
"created_at": "2025-01-02T03:04:05Z", "created_at": "2025-01-02T03:04:05Z",
@ -587,26 +591,32 @@ func TestAPIContracts(t *testing.T) {
"enable_cch_signing": false, "enable_cch_signing": false,
"enable_fingerprint_unification": true, "enable_fingerprint_unification": true,
"enable_metadata_passthrough": false, "enable_metadata_passthrough": false,
"web_search_emulation_enabled": false,
"custom_menu_items": [],
"custom_endpoints": [],
"payment_enabled": false, "payment_enabled": false,
"payment_min_amount": 0, "payment_min_amount": 0,
"payment_max_amount": 0, "payment_max_amount": 0,
"payment_daily_limit": 0, "payment_daily_limit": 0,
"payment_order_timeout_minutes": 0, "payment_order_timeout_minutes": 0,
"payment_max_pending_orders": 0, "payment_max_pending_orders": 0,
"payment_enabled_types": null,
"payment_balance_disabled": false, "payment_balance_disabled": false,
"payment_load_balance_strategy": "", "payment_load_balance_strategy": "",
"payment_product_name_prefix": "", "payment_product_name_prefix": "",
"payment_product_name_suffix": "", "payment_product_name_suffix": "",
"payment_help_image_url": "", "payment_help_image_url": "",
"payment_help_text": "", "payment_help_text": "",
"payment_enabled_types": null,
"payment_cancel_rate_limit_enabled": false, "payment_cancel_rate_limit_enabled": false,
"payment_cancel_rate_limit_max": 0, "payment_cancel_rate_limit_max": 0,
"payment_cancel_rate_limit_window": 0, "payment_cancel_rate_limit_window": 0,
"payment_cancel_rate_limit_unit": "", "payment_cancel_rate_limit_unit": "",
"payment_cancel_rate_limit_window_mode": "", "payment_cancel_rate_limit_window_mode": "",
"custom_menu_items": [], "balance_low_notify_enabled": false,
"custom_endpoints": [] "account_quota_notify_enabled": false,
"balance_low_notify_threshold": 0,
"balance_low_notify_recharge_url": "",
"account_quota_notify_emails": []
} }
}`, }`,
}, },
@ -699,7 +709,7 @@ func newContractDeps(t *testing.T) *contractDeps {
RunMode: config.RunModeStandard, RunMode: config.RunModeStandard,
} }
userService := service.NewUserService(userRepo, nil, nil) userService := service.NewUserService(userRepo, nil, nil, nil)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo() usageRepo := newStubUsageLogRepo()

View File

@ -2,12 +2,15 @@
package server package server
import ( import (
"context"
"log" "log"
"log/slog"
"net/http" "net/http"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/handler"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@ -56,6 +59,42 @@ func ProvideRouter(
} }
} }
// Wire up websearch Manager builder so it initializes on startup and rebuilds on config save.
settingService.SetWebSearchManagerBuilder(context.Background(), func(cfg *service.WebSearchEmulationConfig, proxyURLs map[int64]string) {
if cfg == nil || !cfg.Enabled || len(cfg.Providers) == 0 {
service.SetWebSearchManager(nil)
return
}
configs := make([]websearch.ProviderConfig, 0, len(cfg.Providers))
for _, p := range cfg.Providers {
if p.APIKey == "" {
continue
}
pc := websearch.ProviderConfig{
Type: p.Type,
APIKey: p.APIKey,
QuotaLimit: derefInt64(p.QuotaLimit),
ExpiresAt: p.ExpiresAt,
}
if p.SubscribedAt != nil {
pc.SubscribedAt = p.SubscribedAt
}
if p.ProxyID != nil {
pc.ProxyID = *p.ProxyID
if u, ok := proxyURLs[*p.ProxyID]; ok {
pc.ProxyURL = u
} else {
// Proxy configured but not found — skip this provider to prevent direct connection.
slog.Warn("websearch: proxy not found for provider, skipping",
"provider", p.Type, "proxy_id", *p.ProxyID)
continue
}
}
configs = append(configs, pc)
}
service.SetWebSearchManager(websearch.NewManager(configs, redisClient))
})
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
} }
@ -102,3 +141,10 @@ func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
// 不设置 ReadTimeout因为大请求体可能需要较长时间读取 // 不设置 ReadTimeout因为大请求体可能需要较长时间读取
} }
} }
func derefInt64(p *int64) int64 {
if p == nil {
return 0
}
return *p
}

View File

@ -39,7 +39,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
return &clone, nil return &clone, nil
}, },
} }
userService := service.NewUserService(userRepo, nil, nil) userService := service.NewUserService(userRepo, nil, nil, nil)
router := gin.New() router := gin.New()
router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil))) router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil)))

View File

@ -41,7 +41,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
userRepo := &stubJWTUserRepo{users: users} userRepo := &stubJWTUserRepo{users: users}
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil) userSvc := service.NewUserService(userRepo, nil, nil, nil)
mw := NewJWTAuthMiddleware(authSvc, userSvc) mw := NewJWTAuthMiddleware(authSvc, userSvc)
r := gin.New() r := gin.New()

View File

@ -18,6 +18,8 @@ const (
NonceTemplate = "__CSP_NONCE__" NonceTemplate = "__CSP_NONCE__"
// CloudflareInsightsDomain is the domain for Cloudflare Web Analytics // CloudflareInsightsDomain is the domain for Cloudflare Web Analytics
CloudflareInsightsDomain = "https://static.cloudflareinsights.com" CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
// StripeDomain is the domain for Stripe.js SDK
StripeDomain = "https://*.stripe.com"
) )
// GenerateNonce generates a cryptographically secure random nonce. // GenerateNonce generates a cryptographically secure random nonce.
@ -97,8 +99,9 @@ func isAPIRoutePath(c *gin.Context) bool {
strings.HasPrefix(path, "/responses") strings.HasPrefix(path, "/responses")
} }
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain. // enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights,
// This allows the application to work correctly even if the config file has an older CSP policy. // and Stripe.js domains. This allows the application to work correctly even if the
// config file has an older CSP policy.
func enhanceCSPPolicy(policy string) string { func enhanceCSPPolicy(policy string) string {
// Add nonce placeholder to script-src if not present // Add nonce placeholder to script-src if not present
if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") { if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") {
@ -110,6 +113,12 @@ func enhanceCSPPolicy(policy string) string {
policy = addToDirective(policy, "script-src", CloudflareInsightsDomain) policy = addToDirective(policy, "script-src", CloudflareInsightsDomain)
} }
// Add Stripe.js domain to script-src and frame-src if not present
if !strings.Contains(policy, "stripe.com") {
policy = addToDirective(policy, "script-src", StripeDomain)
policy = addToDirective(policy, "frame-src", StripeDomain)
}
return policy return policy
} }

View File

@ -407,6 +407,11 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Beta 策略配置 // Beta 策略配置
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings) adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings) adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
// Web Search 模拟配置
adminSettings.GET("/web-search-emulation", h.Admin.Setting.GetWebSearchEmulationConfig)
adminSettings.PUT("/web-search-emulation", h.Admin.Setting.UpdateWebSearchEmulationConfig)
adminSettings.POST("/web-search-emulation/test", h.Admin.Setting.TestWebSearchEmulation)
adminSettings.POST("/web-search-emulation/reset-usage", h.Admin.Setting.ResetWebSearchUsage)
} }
} }

View File

@ -39,6 +39,7 @@ func RegisterPaymentRoutes(
orders.GET("/:id", paymentHandler.GetOrder) orders.GET("/:id", paymentHandler.GetOrder)
orders.POST("/:id/cancel", paymentHandler.CancelOrder) orders.POST("/:id/cancel", paymentHandler.CancelOrder)
orders.POST("/:id/refund-request", paymentHandler.RequestRefund) orders.POST("/:id/refund-request", paymentHandler.RequestRefund)
orders.GET("/refund-eligible-providers", paymentHandler.GetRefundEligibleProviders)
} }
} }

View File

@ -26,6 +26,15 @@ func RegisterUserRoutes(
user.PUT("/password", h.User.ChangePassword) user.PUT("/password", h.User.ChangePassword)
user.PUT("", h.User.UpdateProfile) user.PUT("", h.User.UpdateProfile)
// 通知邮箱管理
notifyEmail := user.Group("/notify-email")
{
notifyEmail.POST("/send-code", h.User.SendNotifyEmailCode)
notifyEmail.POST("/verify", h.User.VerifyNotifyEmail)
notifyEmail.PUT("/toggle", h.User.ToggleNotifyEmail)
notifyEmail.DELETE("", h.User.RemoveNotifyEmail)
}
// TOTP 双因素认证 // TOTP 双因素认证
totp := user.Group("/totp") totp := user.Group("/totp")
{ {

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"hash/fnv" "hash/fnv"
"log/slog"
"reflect" "reflect"
"sort" "sort"
"strconv" "strconv"
@ -969,7 +970,7 @@ func (a *Account) IsOveragesEnabled() bool {
return false return false
} }
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)” // IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用"自动透传(仅替换认证)"
// //
// 新字段accounts.extra.openai_passthrough。 // 新字段accounts.extra.openai_passthrough。
// 兼容字段accounts.extra.openai_oauth_passthrough历史 OAuth 开关)。 // 兼容字段accounts.extra.openai_oauth_passthrough历史 OAuth 开关)。
@ -1133,7 +1134,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
return resolvedDefault return resolvedDefault
} }
// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。 // IsOpenAIWSForceHTTPEnabled 返回账号级"强制 HTTP"开关。
// 字段accounts.extra.openai_ws_force_http。 // 字段accounts.extra.openai_ws_force_http。
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool { func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
if a == nil || !a.IsOpenAI() || a.Extra == nil { if a == nil || !a.IsOpenAI() || a.Extra == nil {
@ -1158,7 +1159,7 @@ func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled() return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
} }
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用“自动透传(仅替换认证)” // IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用"自动透传(仅替换认证)"
// 字段accounts.extra.anthropic_passthrough。 // 字段accounts.extra.anthropic_passthrough。
// 字段缺失或类型不正确时,按 false关闭处理。 // 字段缺失或类型不正确时,按 false关闭处理。
func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool { func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
@ -1169,7 +1170,42 @@ func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
return ok && enabled return ok && enabled
} }
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。 // WebSearch 模拟三态常量
const (
WebSearchModeDefault = "default" // 跟随渠道配置
WebSearchModeEnabled = "enabled" // 强制开启
WebSearchModeDisabled = "disabled" // 强制关闭
)
// GetWebSearchEmulationMode 返回账号的 WebSearch 模拟模式。
// 三态default跟随渠道/ enabled强制开启/ disabled强制关闭
// 兼容旧 bool 值true→enabled, false→default并记录 debug 日志)。
func (a *Account) GetWebSearchEmulationMode() string {
if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil {
return WebSearchModeDefault
}
raw := a.Extra[featureKeyWebSearchEmulation]
// Tolerant: legacy bool values (pre-migration or stale writes)
if b, ok := raw.(bool); ok {
slog.Debug("legacy bool web_search_emulation value", "account_id", a.ID, "value", b)
if b {
return WebSearchModeEnabled
}
return WebSearchModeDefault
}
mode, ok := raw.(string)
if !ok {
return WebSearchModeDefault
}
switch mode {
case WebSearchModeEnabled, WebSearchModeDisabled:
return mode
default:
return WebSearchModeDefault
}
}
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用"仅允许 Codex 官方客户端"。
// 字段accounts.extra.codex_cli_only。 // 字段accounts.extra.codex_cli_only。
// 字段缺失或类型不正确时,按 false关闭处理。 // 字段缺失或类型不正确时,按 false关闭处理。
func (a *Account) IsCodexCLIOnlyEnabled() bool { func (a *Account) IsCodexCLIOnlyEnabled() bool {
@ -1395,6 +1431,19 @@ func (a *Account) getExtraTime(key string) time.Time {
return time.Time{} return time.Time{}
} }
// getExtraBool 从 Extra 中读取指定 key 的 bool 值
func (a *Account) getExtraBool(key string) bool {
if a.Extra == nil {
return false
}
if v, ok := a.Extra[key]; ok {
if b, ok := v.(bool); ok {
return b
}
}
return false
}
// getExtraString 从 Extra 中读取指定 key 的字符串值 // getExtraString 从 Extra 中读取指定 key 的字符串值
func (a *Account) getExtraString(key string) string { func (a *Account) getExtraString(key string) string {
if a.Extra == nil { if a.Extra == nil {
@ -1408,6 +1457,14 @@ func (a *Account) getExtraString(key string) string {
return "" return ""
} }
// getExtraStringDefault 从 Extra 中读取指定 key 的字符串值,不存在时返回 defaultVal
func (a *Account) getExtraStringDefault(key, defaultVal string) string {
if v := a.getExtraString(key); v != "" {
return v
}
return defaultVal
}
// getExtraInt 从 Extra 中读取指定 key 的 int 值 // getExtraInt 从 Extra 中读取指定 key 的 int 值
func (a *Account) getExtraInt(key string) int { func (a *Account) getExtraInt(key string) int {
if a.Extra == nil { if a.Extra == nil {
@ -1464,6 +1521,62 @@ func (a *Account) GetQuotaResetTimezone() string {
return "UTC" return "UTC"
} }
// --- Quota Notification Getters ---
// QuotaNotifyConfig returns the notify configuration for a given quota dimension.
// dim must be one of quotaDimDaily, quotaDimWeekly, quotaDimTotal.
func (a *Account) QuotaNotifyConfig(dim string) (enabled bool, threshold float64, thresholdType string) {
enabled = a.getExtraBool("quota_notify_" + dim + "_enabled")
threshold = a.getExtraFloat64("quota_notify_" + dim + "_threshold")
thresholdType = a.getExtraStringDefault("quota_notify_"+dim+"_threshold_type", thresholdTypeFixed)
return
}
func (a *Account) GetQuotaNotifyDailyEnabled() bool {
e, _, _ := a.QuotaNotifyConfig(quotaDimDaily)
return e
}
func (a *Account) GetQuotaNotifyDailyThreshold() float64 {
_, t, _ := a.QuotaNotifyConfig(quotaDimDaily)
return t
}
func (a *Account) GetQuotaNotifyDailyThresholdType() string {
_, _, tt := a.QuotaNotifyConfig(quotaDimDaily)
return tt
}
func (a *Account) GetQuotaNotifyWeeklyEnabled() bool {
e, _, _ := a.QuotaNotifyConfig(quotaDimWeekly)
return e
}
func (a *Account) GetQuotaNotifyWeeklyThreshold() float64 {
_, t, _ := a.QuotaNotifyConfig(quotaDimWeekly)
return t
}
func (a *Account) GetQuotaNotifyWeeklyThresholdType() string {
_, _, tt := a.QuotaNotifyConfig(quotaDimWeekly)
return tt
}
func (a *Account) GetQuotaNotifyTotalEnabled() bool {
e, _, _ := a.QuotaNotifyConfig(quotaDimTotal)
return e
}
func (a *Account) GetQuotaNotifyTotalThreshold() float64 {
_, t, _ := a.QuotaNotifyConfig(quotaDimTotal)
return t
}
func (a *Account) GetQuotaNotifyTotalThresholdType() string {
_, _, tt := a.QuotaNotifyConfig(quotaDimTotal)
return tt
}
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点 // nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time { func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time {
t := after.In(tz) t := after.In(tz)

View File

@ -0,0 +1,236 @@
package service
import (
"context"
"strings"
)
// resolveAccountStatsCost 计算账号统计定价费用。
// 返回 nil 表示不覆盖使用默认公式total_cost × account_rate_multiplier
//
// 优先级(先命中为准):
// 1. 自定义规则(始终尝试,不依赖 ApplyPricingToAccountStats 开关)
// 2. ApplyPricingToAccountStats 启用时,直接使用本次请求的客户计费(倍率前的 totalCost
// 3. 模型定价文件LiteLLM中上游模型的默认价格
// 4. nil → 走默认公式total_cost × account_rate_multiplier
//
// upstreamModel 是最终发往上游的模型 ID。
// totalCost 是本次请求的客户计费(倍率前),用于优先级 2。
func resolveAccountStatsCost(
ctx context.Context,
channelService *ChannelService,
billingService *BillingService,
accountID int64,
groupID int64,
upstreamModel string,
tokens UsageTokens,
requestCount int,
totalCost float64,
) *float64 {
if channelService == nil || upstreamModel == "" {
return nil
}
channel, err := channelService.GetChannelForGroup(ctx, groupID)
if err != nil || channel == nil {
return nil
}
platform := channelService.GetGroupPlatform(ctx, groupID)
// 优先级 1自定义规则始终尝试
if cost := tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount); cost != nil {
return cost
}
// 优先级 2渠道开启"应用模型定价到账号统计"时,直接使用客户计费(倍率前)
if channel.ApplyPricingToAccountStats {
cost := totalCost
if cost <= 0 {
return nil
}
return &cost
}
// 优先级 3模型定价文件LiteLLM默认价格
if billingService != nil {
return tryModelFilePricing(billingService, upstreamModel, tokens)
}
return nil
}
// tryModelFilePricing 使用模型定价文件LiteLLM/fallback中的标准价格计算费用。
func tryModelFilePricing(billingService *BillingService, model string, tokens UsageTokens) *float64 {
pricing, err := billingService.GetModelPricing(model)
if err != nil || pricing == nil {
return nil
}
cost := float64(tokens.InputTokens)*pricing.InputPricePerToken +
float64(tokens.OutputTokens)*pricing.OutputPricePerToken +
float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken +
float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken +
float64(tokens.ImageOutputTokens)*pricing.ImageOutputPricePerToken
if cost <= 0 {
return nil
}
return &cost
}
// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
func tryCustomRules(
channel *Channel, accountID, groupID int64,
platform, model string, tokens UsageTokens, requestCount int,
) *float64 {
modelLower := strings.ToLower(model)
for _, rule := range channel.AccountStatsPricingRules {
if !matchAccountStatsRule(&rule, accountID, groupID) {
continue
}
pricing := findPricingForModel(rule.Pricing, platform, modelLower)
if pricing == nil {
continue // 规则匹配但模型不在规则定价中,继续下一条
}
return calculateStatsCost(pricing, tokens, requestCount)
}
return nil
}
// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。
// 匹配条件accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。
// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。
func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int64) bool {
if len(rule.AccountIDs) == 0 && len(rule.GroupIDs) == 0 {
return false
}
for _, id := range rule.AccountIDs {
if id == accountID {
return true
}
}
for _, id := range rule.GroupIDs {
if id == groupID {
return true
}
}
return false
}
// findPricingForModel 在定价列表中查找匹配的模型定价。
// 先精确匹配,再通配符匹配(按配置顺序,先匹配先使用)。
func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
// 精确匹配优先
for i := range pricingList {
p := &pricingList[i]
if !isPlatformMatch(platform, p.Platform) {
continue
}
for _, m := range p.Models {
if strings.ToLower(m) == modelLower {
return p
}
}
}
// 通配符匹配:按配置顺序,先匹配先使用
for i := range pricingList {
p := &pricingList[i]
if !isPlatformMatch(platform, p.Platform) {
continue
}
for _, m := range p.Models {
ml := strings.ToLower(m)
if !strings.HasSuffix(ml, "*") {
continue
}
prefix := strings.TrimSuffix(ml, "*")
if strings.HasPrefix(modelLower, prefix) {
return p
}
}
}
return nil
}
// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
func isPlatformMatch(queryPlatform, pricingPlatform string) bool {
if queryPlatform == "" || pricingPlatform == "" {
return true
}
return queryPlatform == pricingPlatform
}
// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。
func calculateStatsCost(pricing *ChannelModelPricing, tokens UsageTokens, requestCount int) *float64 {
if pricing == nil {
return nil
}
switch pricing.BillingMode {
case BillingModePerRequest, BillingModeImage:
return calculatePerRequestStatsCost(pricing, requestCount)
default:
return calculateTokenStatsCost(pricing, tokens)
}
}
// calculatePerRequestStatsCost 按次/图片计费。
func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int) *float64 {
if pricing.PerRequestPrice == nil || *pricing.PerRequestPrice <= 0 {
return nil
}
cost := *pricing.PerRequestPrice * float64(requestCount)
return &cost
}
// calculateTokenStatsCost Token 计费。
// If the pricing has intervals, find the matching interval by total token count
// and use its prices instead of the flat pricing fields.
func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
p := pricing
if len(pricing.Intervals) > 0 {
totalTokens := tokens.InputTokens + tokens.OutputTokens + tokens.CacheCreationTokens + tokens.CacheReadTokens
if iv := FindMatchingInterval(pricing.Intervals, totalTokens); iv != nil {
p = &ChannelModelPricing{
InputPrice: iv.InputPrice,
OutputPrice: iv.OutputPrice,
CacheWritePrice: iv.CacheWritePrice,
CacheReadPrice: iv.CacheReadPrice,
PerRequestPrice: iv.PerRequestPrice,
}
}
}
deref := func(ptr *float64) float64 {
if ptr == nil {
return 0
}
return *ptr
}
cost := float64(tokens.InputTokens)*deref(p.InputPrice) +
float64(tokens.OutputTokens)*deref(p.OutputPrice) +
float64(tokens.CacheCreationTokens)*deref(p.CacheWritePrice) +
float64(tokens.CacheReadTokens)*deref(p.CacheReadPrice) +
float64(tokens.ImageOutputTokens)*deref(p.ImageOutputPrice)
if cost <= 0 {
return nil
}
return &cost
}
// applyAccountStatsCost resolves the account stats cost for a usage log entry.
// It resolves the upstream model (falling back to the requested model) and calls
// the 4-level priority chain via resolveAccountStatsCost.
func applyAccountStatsCost(
ctx context.Context,
usageLog *UsageLog,
cs *ChannelService, bs *BillingService,
accountID int64, groupID int64,
upstreamModel, requestedModel string,
tokens UsageTokens,
totalCost float64,
) {
model := upstreamModel
if model == "" {
model = requestedModel
}
usageLog.AccountStatsCost = resolveAccountStatsCost(
ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost,
)
}

View File

@ -0,0 +1,771 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// matchAccountStatsRule
// ---------------------------------------------------------------------------
func TestMatchAccountStatsRule_BothEmpty_NoMatch(t *testing.T) {
rule := &AccountStatsPricingRule{}
require.False(t, matchAccountStatsRule(rule, 1, 10))
}
func TestMatchAccountStatsRule_AccountIDMatch(t *testing.T) {
rule := &AccountStatsPricingRule{AccountIDs: []int64{1, 2, 3}}
require.True(t, matchAccountStatsRule(rule, 2, 999))
}
func TestMatchAccountStatsRule_GroupIDMatch(t *testing.T) {
rule := &AccountStatsPricingRule{GroupIDs: []int64{10, 20}}
require.True(t, matchAccountStatsRule(rule, 999, 20))
}
func TestMatchAccountStatsRule_BothConfigured_AccountMatch(t *testing.T) {
rule := &AccountStatsPricingRule{
AccountIDs: []int64{1, 2},
GroupIDs: []int64{10, 20},
}
require.True(t, matchAccountStatsRule(rule, 2, 999))
}
func TestMatchAccountStatsRule_BothConfigured_GroupMatch(t *testing.T) {
rule := &AccountStatsPricingRule{
AccountIDs: []int64{1, 2},
GroupIDs: []int64{10, 20},
}
require.True(t, matchAccountStatsRule(rule, 999, 10))
}
func TestMatchAccountStatsRule_BothConfigured_NeitherMatch(t *testing.T) {
rule := &AccountStatsPricingRule{
AccountIDs: []int64{1, 2},
GroupIDs: []int64{10, 20},
}
require.False(t, matchAccountStatsRule(rule, 999, 999))
}
// ---------------------------------------------------------------------------
// findPricingForModel
// ---------------------------------------------------------------------------
func TestFindPricingForModel(t *testing.T) {
exactPricing := ChannelModelPricing{
ID: 1,
Models: []string{"claude-opus-4"},
}
wildcardPricing := ChannelModelPricing{
ID: 2,
Models: []string{"claude-*"},
}
platformPricing := ChannelModelPricing{
ID: 3,
Platform: "openai",
Models: []string{"gpt-4o"},
}
emptyPlatformPricing := ChannelModelPricing{
ID: 4,
Models: []string{"gemini-2.5-pro"},
}
tests := []struct {
name string
list []ChannelModelPricing
platform string
model string
wantID int64
wantNil bool
}{
{
name: "exact match",
list: []ChannelModelPricing{exactPricing},
platform: "anthropic",
model: "claude-opus-4",
wantID: 1,
},
{
name: "exact match case insensitive",
list: []ChannelModelPricing{{ID: 5, Models: []string{"Claude-Opus-4"}}},
platform: "",
model: "claude-opus-4",
wantID: 5,
},
{
name: "wildcard match",
list: []ChannelModelPricing{wildcardPricing},
platform: "anthropic",
model: "claude-opus-4",
wantID: 2,
},
{
name: "exact match takes priority over wildcard",
list: []ChannelModelPricing{wildcardPricing, exactPricing},
platform: "anthropic",
model: "claude-opus-4",
wantID: 1,
},
{
name: "platform mismatch skipped",
list: []ChannelModelPricing{platformPricing},
platform: "anthropic",
model: "gpt-4o",
wantNil: true,
},
{
name: "empty platform in pricing matches any",
list: []ChannelModelPricing{emptyPlatformPricing},
platform: "gemini",
model: "gemini-2.5-pro",
wantID: 4,
},
{
name: "empty platform in query matches any pricing platform",
list: []ChannelModelPricing{platformPricing},
platform: "",
model: "gpt-4o",
wantID: 3,
},
{
name: "no match at all",
list: []ChannelModelPricing{exactPricing, wildcardPricing},
platform: "anthropic",
model: "gpt-4o",
wantNil: true,
},
{
name: "empty list returns nil",
list: nil,
model: "claude-opus-4",
wantNil: true,
},
{
name: "wildcard matches by config order (first match wins)",
list: []ChannelModelPricing{
{ID: 10, Models: []string{"claude-*"}},
{ID: 11, Models: []string{"claude-opus-*"}},
},
platform: "",
model: "claude-opus-4",
wantID: 10, // config order: "claude-*" is first and matches, so it wins
},
{
name: "shorter wildcard used when longer does not match",
list: []ChannelModelPricing{
{ID: 10, Models: []string{"claude-*"}},
{ID: 11, Models: []string{"claude-opus-*"}},
},
platform: "",
model: "claude-sonnet-4",
wantID: 10, // only "claude-*" matches
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := findPricingForModel(tt.list, tt.platform, tt.model)
if tt.wantNil {
require.Nil(t, result)
return
}
require.NotNil(t, result)
require.Equal(t, tt.wantID, result.ID)
})
}
}
// ---------------------------------------------------------------------------
// calculateStatsCost
// ---------------------------------------------------------------------------
func TestCalculateStatsCost_NilPricing(t *testing.T) {
result := calculateStatsCost(nil, UsageTokens{}, 1)
require.Nil(t, result)
}
func TestCalculateStatsCost_TokenBilling(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(0.001),
OutputPrice: testPtrFloat64(0.002),
}
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 50,
}
result := calculateStatsCost(pricing, tokens, 1)
require.NotNil(t, result)
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
require.InDelta(t, 0.2, *result, 1e-12)
}
func TestCalculateStatsCost_TokenBilling_WithCache(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(0.001),
OutputPrice: testPtrFloat64(0.002),
CacheWritePrice: testPtrFloat64(0.003),
CacheReadPrice: testPtrFloat64(0.0005),
}
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 50,
CacheCreationTokens: 200,
CacheReadTokens: 300,
}
result := calculateStatsCost(pricing, tokens, 1)
require.NotNil(t, result)
// 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
// = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
require.InDelta(t, 0.95, *result, 1e-12)
}
func TestCalculateStatsCost_TokenBilling_WithImageOutput(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(0.001),
OutputPrice: testPtrFloat64(0.002),
ImageOutputPrice: testPtrFloat64(0.01),
}
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 50,
ImageOutputTokens: 10,
}
result := calculateStatsCost(pricing, tokens, 1)
require.NotNil(t, result)
// 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
require.InDelta(t, 0.3, *result, 1e-12)
}
func TestCalculateStatsCost_TokenBilling_PartialPricesNil(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(0.001),
// OutputPrice, CacheWritePrice, etc. are all nil → treated as 0
}
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 50,
CacheCreationTokens: 200,
}
result := calculateStatsCost(pricing, tokens, 1)
require.NotNil(t, result)
// Only input contributes: 100*0.001 = 0.1
require.InDelta(t, 0.1, *result, 1e-12)
}
func TestCalculateStatsCost_TokenBilling_AllTokensZero(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(0.001),
OutputPrice: testPtrFloat64(0.002),
}
tokens := UsageTokens{} // all zeros
result := calculateStatsCost(pricing, tokens, 1)
// totalCost == 0 → returns nil (does not override, falls back to default formula)
require.Nil(t, result)
}
func TestCalculateStatsCost_PerRequestBilling(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0.05),
}
tokens := UsageTokens{InputTokens: 999, OutputTokens: 999}
result := calculateStatsCost(pricing, tokens, 3)
require.NotNil(t, result)
// 0.05 * 3 = 0.15
require.InDelta(t, 0.15, *result, 1e-12)
}
func TestCalculateStatsCost_PerRequestBilling_PriceNil(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModePerRequest,
// PerRequestPrice is nil
}
result := calculateStatsCost(pricing, UsageTokens{}, 1)
require.Nil(t, result)
}
func TestCalculateStatsCost_PerRequestBilling_PriceZero(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0),
}
result := calculateStatsCost(pricing, UsageTokens{}, 1)
// price == 0 → condition *pricing.PerRequestPrice > 0 is false → returns nil
require.Nil(t, result)
}
func TestCalculateStatsCost_ImageBilling(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModeImage,
PerRequestPrice: testPtrFloat64(0.10),
}
result := calculateStatsCost(pricing, UsageTokens{}, 2)
require.NotNil(t, result)
// 0.10 * 2 = 0.20
require.InDelta(t, 0.20, *result, 1e-12)
}
func TestCalculateStatsCost_ImageBilling_PriceNil(t *testing.T) {
pricing := &ChannelModelPricing{
BillingMode: BillingModeImage,
// PerRequestPrice is nil
}
result := calculateStatsCost(pricing, UsageTokens{}, 1)
require.Nil(t, result)
}
func TestCalculateStatsCost_DefaultBillingMode_FallsToToken(t *testing.T) {
// BillingMode is empty string (default) → falls into token billing
pricing := &ChannelModelPricing{
InputPrice: testPtrFloat64(0.001),
OutputPrice: testPtrFloat64(0.002),
}
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 50,
}
result := calculateStatsCost(pricing, tokens, 1)
require.NotNil(t, result)
require.InDelta(t, 0.2, *result, 1e-12)
}
// ---------------------------------------------------------------------------
// tryCustomRules — 多规则顺序测试
// ---------------------------------------------------------------------------
func TestTryCustomRules_FirstMatchWins(t *testing.T) {
channel := &Channel{
AccountStatsPricingRules: []AccountStatsPricingRule{
{
GroupIDs: []int64{1},
Pricing: []ChannelModelPricing{
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01), OutputPrice: testPtrFloat64(0.02)},
},
},
{
GroupIDs: []int64{1},
Pricing: []ChannelModelPricing{
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99), OutputPrice: testPtrFloat64(0.99)},
},
},
},
}
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
require.NotNil(t, result)
// 应使用第一条规则的价格100*0.01 + 50*0.02 = 2.0
require.InDelta(t, 2.0, *result, 1e-12)
}
func TestTryCustomRules_SkipsNonMatchingRules(t *testing.T) {
channel := &Channel{
AccountStatsPricingRules: []AccountStatsPricingRule{
{
AccountIDs: []int64{888}, // 不匹配
Pricing: []ChannelModelPricing{
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99)},
},
},
{
GroupIDs: []int64{1}, // 匹配
Pricing: []ChannelModelPricing{
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)},
},
},
},
}
tokens := UsageTokens{InputTokens: 100}
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
require.NotNil(t, result)
// 跳过规则1账号不匹配使用规则2100*0.05 = 5.0
require.InDelta(t, 5.0, *result, 1e-12)
}
func TestTryCustomRules_NoMatch_ReturnsNil(t *testing.T) {
channel := &Channel{
AccountStatsPricingRules: []AccountStatsPricingRule{
{
AccountIDs: []int64{888},
Pricing: []ChannelModelPricing{
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01)},
},
},
},
}
tokens := UsageTokens{InputTokens: 100}
result := tryCustomRules(channel, 999, 2, "", "claude-opus-4", tokens, 1)
require.Nil(t, result) // 账号和分组都不匹配
}
func TestTryCustomRules_RuleMatchesButModelNot_ContinuesToNext(t *testing.T) {
channel := &Channel{
AccountStatsPricingRules: []AccountStatsPricingRule{
{
GroupIDs: []int64{1},
Pricing: []ChannelModelPricing{
{ID: 100, Models: []string{"gpt-4o"}, InputPrice: testPtrFloat64(0.01)}, // 模型不匹配
},
},
{
GroupIDs: []int64{1},
Pricing: []ChannelModelPricing{
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)}, // 模型匹配
},
},
},
}
tokens := UsageTokens{InputTokens: 100}
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
require.NotNil(t, result)
require.InDelta(t, 5.0, *result, 1e-12) // 使用规则2
}
// ---------------------------------------------------------------------------
// tryModelFilePricing
// ---------------------------------------------------------------------------
// newTestBillingServiceWithPrices creates a BillingService with pre-populated
// fallback prices for testing. No config or pricing service is needed.
// The key must match what getFallbackPricing resolves to for a given model name.
// E.g., model "claude-sonnet-4" resolves to key "claude-sonnet-4".
func newTestBillingServiceWithPrices(prices map[string]*ModelPricing) *BillingService {
return &BillingService{
fallbackPrices: prices,
}
}
func TestTryModelFilePricing_Success(t *testing.T) {
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
"claude-sonnet-4": {
InputPricePerToken: 0.001,
OutputPricePerToken: 0.002,
},
})
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
require.NotNil(t, result)
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
require.InDelta(t, 0.2, *result, 1e-12)
}
func TestTryModelFilePricing_PricingNotFound(t *testing.T) {
// "nonexistent-model" does not match any fallback pattern
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{})
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
result := tryModelFilePricing(bs, "nonexistent-model", tokens)
require.Nil(t, result)
}
func TestTryModelFilePricing_NilFallback(t *testing.T) {
// getFallbackPricing returns nil when key maps to nil
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
"claude-sonnet-4": nil,
})
tokens := UsageTokens{InputTokens: 100}
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
require.Nil(t, result)
}
func TestTryModelFilePricing_ZeroCost(t *testing.T) {
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
"claude-sonnet-4": {
InputPricePerToken: 0.001,
OutputPricePerToken: 0.002,
},
})
tokens := UsageTokens{} // all zero tokens → cost = 0 → nil
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
require.Nil(t, result)
}
func TestTryModelFilePricing_WithImageOutput(t *testing.T) {
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
"claude-sonnet-4": {
InputPricePerToken: 0.001,
OutputPricePerToken: 0.002,
ImageOutputPricePerToken: 0.01,
},
})
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 50,
ImageOutputTokens: 10,
}
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
require.NotNil(t, result)
// 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
require.InDelta(t, 0.3, *result, 1e-12)
}
func TestTryModelFilePricing_WithCacheTokens(t *testing.T) {
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
"claude-sonnet-4": {
InputPricePerToken: 0.001,
OutputPricePerToken: 0.002,
CacheCreationPricePerToken: 0.003,
CacheReadPricePerToken: 0.0005,
},
})
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 50,
CacheCreationTokens: 200,
CacheReadTokens: 300,
}
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
require.NotNil(t, result)
// 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
// = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
require.InDelta(t, 0.95, *result, 1e-12)
}
// ---------------------------------------------------------------------------
// resolveAccountStatsCost — integration tests covering the 4-level priority chain
// ---------------------------------------------------------------------------
func TestResolveAccountStatsCost_NilChannelService(t *testing.T) {
result := resolveAccountStatsCost(
context.Background(),
nil, // channelService is nil
newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
1, 1, "claude-sonnet-4",
UsageTokens{InputTokens: 100}, 1, 0.5,
)
require.Nil(t, result)
}
func TestResolveAccountStatsCost_EmptyUpstreamModel(t *testing.T) {
cs := newTestChannelServiceForStats(t, &Channel{
ID: 1,
Status: StatusActive,
}, 1, "")
result := resolveAccountStatsCost(
context.Background(),
cs,
newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
1, 1, "", // empty upstream model
UsageTokens{InputTokens: 100}, 1, 0.5,
)
require.Nil(t, result)
}
func TestResolveAccountStatsCost_GetChannelForGroupReturnsNil(t *testing.T) {
// Group 99 is NOT in the cache, so GetChannelForGroup returns nil
cs := newTestChannelServiceForStats(t, &Channel{
ID: 1,
Status: StatusActive,
}, 1, "")
result := resolveAccountStatsCost(
context.Background(),
cs,
newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
1, 99, "claude-sonnet-4", // groupID 99 has no channel
UsageTokens{InputTokens: 100}, 1, 0.5,
)
require.Nil(t, result)
}
func TestResolveAccountStatsCost_HitsCustomRule(t *testing.T) {
channel := &Channel{
ID: 1,
Status: StatusActive,
AccountStatsPricingRules: []AccountStatsPricingRule{
{
GroupIDs: []int64{10},
Pricing: []ChannelModelPricing{
{
ID: 100,
Models: []string{"claude-sonnet-4"},
InputPrice: testPtrFloat64(0.01),
OutputPrice: testPtrFloat64(0.02),
},
},
},
},
}
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
result := resolveAccountStatsCost(
context.Background(),
cs, nil, // billingService not needed when custom rule hits
1, 10, "claude-sonnet-4",
tokens, 1, 999.0, // totalCost ignored because custom rule hits
)
require.NotNil(t, result)
// 100*0.01 + 50*0.02 = 1.0 + 1.0 = 2.0
require.InDelta(t, 2.0, *result, 1e-12)
}
func TestResolveAccountStatsCost_ApplyPricingToAccountStats_UsesTotalCost(t *testing.T) {
channel := &Channel{
ID: 1,
Status: StatusActive,
ApplyPricingToAccountStats: true,
// No custom rules
}
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
result := resolveAccountStatsCost(
context.Background(),
cs, nil,
1, 10, "claude-sonnet-4",
tokens, 1, 0.75, // totalCost = 0.75
)
require.NotNil(t, result)
require.InDelta(t, 0.75, *result, 1e-12)
}
func TestResolveAccountStatsCost_ApplyPricingToAccountStats_ZeroTotalCost_ReturnsNil(t *testing.T) {
channel := &Channel{
ID: 1,
Status: StatusActive,
ApplyPricingToAccountStats: true,
}
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
result := resolveAccountStatsCost(
context.Background(),
cs, nil,
1, 10, "claude-sonnet-4",
UsageTokens{}, 1, 0.0, // totalCost = 0
)
require.Nil(t, result)
}
func TestResolveAccountStatsCost_FallsBackToLiteLLM(t *testing.T) {
channel := &Channel{
ID: 1,
Status: StatusActive,
ApplyPricingToAccountStats: false, // not enabled
// No custom rules
}
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
"claude-sonnet-4": {
InputPricePerToken: 0.001,
OutputPricePerToken: 0.002,
},
})
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
result := resolveAccountStatsCost(
context.Background(),
cs, bs,
1, 10, "claude-sonnet-4",
tokens, 1, 999.0, // totalCost ignored
)
require.NotNil(t, result)
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
require.InDelta(t, 0.2, *result, 1e-12)
}
func TestResolveAccountStatsCost_AllMiss_ReturnsNil(t *testing.T) {
channel := &Channel{
ID: 1,
Status: StatusActive,
ApplyPricingToAccountStats: false,
// No custom rules
}
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
// BillingService with no pricing for the model
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{})
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
result := resolveAccountStatsCost(
context.Background(),
cs, bs,
1, 10, "totally-unknown-model",
tokens, 1, 0.0,
)
require.Nil(t, result)
}
func TestResolveAccountStatsCost_NilBillingService_SkipsLiteLLM(t *testing.T) {
channel := &Channel{
ID: 1,
Status: StatusActive,
ApplyPricingToAccountStats: false,
}
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
result := resolveAccountStatsCost(
context.Background(),
cs, nil, // billingService is nil
1, 10, "claude-sonnet-4",
UsageTokens{InputTokens: 100}, 1, 0.0,
)
require.Nil(t, result)
}
func TestResolveAccountStatsCost_CustomRulePriorityOverApplyPricing(t *testing.T) {
// Both custom rule and ApplyPricingToAccountStats are configured;
// custom rule should take precedence.
channel := &Channel{
ID: 1,
Status: StatusActive,
ApplyPricingToAccountStats: true,
AccountStatsPricingRules: []AccountStatsPricingRule{
{
GroupIDs: []int64{10},
Pricing: []ChannelModelPricing{
{
ID: 100,
Models: []string{"claude-sonnet-4"},
InputPrice: testPtrFloat64(0.05),
},
},
},
},
}
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
tokens := UsageTokens{InputTokens: 100}
result := resolveAccountStatsCost(
context.Background(),
cs, nil,
1, 10, "claude-sonnet-4",
tokens, 1, 99.0, // totalCost = 99.0 (would be used if ApplyPricing wins)
)
require.NotNil(t, result)
// Custom rule: 100*0.05 = 5.0 (NOT 99.0 from totalCost)
require.InDelta(t, 5.0, *result, 1e-12)
}
// ---------------------------------------------------------------------------
// helpers for resolveAccountStatsCost tests
// ---------------------------------------------------------------------------
// newTestChannelServiceForStats creates a ChannelService with a single channel
// mapped to the given groupID, suitable for resolveAccountStatsCost tests.
func newTestChannelServiceForStats(t *testing.T, channel *Channel, groupID int64, platform string) *ChannelService {
t.Helper()
cache := newEmptyChannelCache()
cache.channelByGroupID[groupID] = channel
cache.groupPlatform[groupID] = platform
cs := &ChannelService{}
cache.loadedAt = time.Now()
cs.cache.Store(cache)
return cs
}

View File

@ -0,0 +1,105 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGetWebSearchEmulationMode_Enabled(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
}
require.Equal(t, WebSearchModeEnabled, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_Disabled(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: "disabled"},
}
require.Equal(t, WebSearchModeDisabled, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_Default(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: "default"},
}
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_UnknownString(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: "unknown"},
}
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_OldBoolTrue(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: true},
}
// bool true → tolerant fallback → enabled (not default)
require.Equal(t, WebSearchModeEnabled, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_OldBoolFalse(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: false},
}
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_NilAccount(t *testing.T) {
var a *Account
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_NilExtra(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: nil,
}
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_MissingField(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{},
}
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_NonAnthropicPlatform(t *testing.T) {
a := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
}
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}
func TestGetWebSearchEmulationMode_NonAPIKeyType(t *testing.T) {
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
}
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
}

View File

@ -65,14 +65,14 @@ func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (boo
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
panic("unexpected") panic("unexpected")
} }
func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error {
panic("unexpected") panic("unexpected")
} }
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
panic("unexpected")
}
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests. // apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
type apiKeyRepoStubForGroupUpdate struct { type apiKeyRepoStubForGroupUpdate struct {
@ -131,9 +131,6 @@ func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, str
func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
panic("unexpected") panic("unexpected")
} }
func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) { func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) {
panic("unexpected") panic("unexpected")
} }
@ -158,6 +155,9 @@ func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, in
func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) { func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
panic("unexpected") panic("unexpected")
} }
func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) {
panic("unexpected")
}
// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests. // groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
type groupRepoStubForGroupUpdate struct { type groupRepoStubForGroupUpdate struct {

View File

@ -12,12 +12,12 @@ import (
type accountRepoStubForClearAccountError struct { type accountRepoStubForClearAccountError struct {
mockAccountRepoForGemini mockAccountRepoForGemini
account *Account account *Account
clearErrorCalls int clearErrorCalls int
clearRateLimitCalls int clearRateLimitCalls int
clearAntigravityCalls int clearAntigravityCalls int
clearModelRateLimitCalls int clearModelRateLimitCalls int
clearTempUnschedCalls int clearTempUnschedCalls int
} }
func (r *accountRepoStubForClearAccountError) GetByID(ctx context.Context, id int64) (*Account, error) { func (r *accountRepoStubForClearAccountError) GetByID(ctx context.Context, id int64) (*Account, error) {
@ -60,13 +60,13 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
resetAt := time.Now().Add(5 * time.Minute) resetAt := time.Now().Add(5 * time.Minute)
repo := &accountRepoStubForClearAccountError{ repo := &accountRepoStubForClearAccountError{
account: &Account{ account: &Account{
ID: 31, ID: 31,
Platform: PlatformOpenAI, Platform: PlatformOpenAI,
Type: AccountTypeOAuth, Type: AccountTypeOAuth,
Status: StatusError, Status: StatusError,
ErrorMessage: "refresh failed", ErrorMessage: "refresh failed",
RateLimitResetAt: &resetAt, RateLimitResetAt: &resetAt,
TempUnschedulableUntil: &until, TempUnschedulableUntil: &until,
TempUnschedulableReason: "missing refresh token", TempUnschedulableReason: "missing refresh token",
}, },
} }

View File

@ -34,6 +34,15 @@ type APIKeyAuthUserSnapshot struct {
Role string `json:"role"` Role string `json:"role"`
Balance float64 `json:"balance"` Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"` Concurrency int `json:"concurrency"`
// Balance notification fields (required for CheckBalanceAfterDeduction)
Email string `json:"email"`
Username string `json:"username"`
BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
TotalRecharged float64 `json:"total_recharged"`
} }
// APIKeyAuthGroupSnapshot 分组快照 // APIKeyAuthGroupSnapshot 分组快照

View File

@ -6,6 +6,7 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"math/rand/v2" "math/rand/v2"
"time" "time"
@ -13,7 +14,7 @@ import (
"github.com/dgraph-io/ristretto" "github.com/dgraph-io/ristretto"
) )
const apiKeyAuthSnapshotVersion = 3 const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold
type apiKeyAuthCacheConfig struct { type apiKeyAuthCacheConfig struct {
l1Size int l1Size int
@ -99,7 +100,7 @@ func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context
s.authCacheL1.Del(cacheKey) s.authCacheL1.Del(cacheKey)
}); err != nil { }); err != nil {
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation // Log but don't fail - L1 cache will still work, just without cross-instance invalidation
println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error()) slog.Warn("failed to start auth cache invalidation subscriber", "error", err)
} }
} }
@ -219,11 +220,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
RateLimit1d: apiKey.RateLimit1d, RateLimit1d: apiKey.RateLimit1d,
RateLimit7d: apiKey.RateLimit7d, RateLimit7d: apiKey.RateLimit7d,
User: APIKeyAuthUserSnapshot{ User: APIKeyAuthUserSnapshot{
ID: apiKey.User.ID, ID: apiKey.User.ID,
Status: apiKey.User.Status, Status: apiKey.User.Status,
Role: apiKey.User.Role, Role: apiKey.User.Role,
Balance: apiKey.User.Balance, Balance: apiKey.User.Balance,
Concurrency: apiKey.User.Concurrency, Concurrency: apiKey.User.Concurrency,
Email: apiKey.User.Email,
Username: apiKey.User.Username,
BalanceNotifyEnabled: apiKey.User.BalanceNotifyEnabled,
BalanceNotifyThresholdType: apiKey.User.BalanceNotifyThresholdType,
BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
TotalRecharged: apiKey.User.TotalRecharged,
}, },
} }
if apiKey.Group != nil { if apiKey.Group != nil {
@ -274,11 +282,18 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
RateLimit1d: snapshot.RateLimit1d, RateLimit1d: snapshot.RateLimit1d,
RateLimit7d: snapshot.RateLimit7d, RateLimit7d: snapshot.RateLimit7d,
User: &User{ User: &User{
ID: snapshot.User.ID, ID: snapshot.User.ID,
Status: snapshot.User.Status, Status: snapshot.User.Status,
Role: snapshot.User.Role, Role: snapshot.User.Role,
Balance: snapshot.User.Balance, Balance: snapshot.User.Balance,
Concurrency: snapshot.User.Concurrency, Concurrency: snapshot.User.Concurrency,
Email: snapshot.User.Email,
Username: snapshot.User.Username,
BalanceNotifyEnabled: snapshot.User.BalanceNotifyEnabled,
BalanceNotifyThresholdType: snapshot.User.BalanceNotifyThresholdType,
BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
TotalRecharged: snapshot.User.TotalRecharged,
}, },
} }
if snapshot.Group != nil { if snapshot.Group != nil {

View File

@ -87,6 +87,18 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
return nil return nil
} }
func (s *emailCacheStub) GetNotifyVerifyCode(ctx context.Context, email string) (*VerificationCodeData, error) {
return nil, nil
}
func (s *emailCacheStub) SetNotifyVerifyCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error {
return nil
}
func (s *emailCacheStub) DeleteNotifyVerifyCode(ctx context.Context, email string) error {
return nil
}
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) { func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
return nil, nil return nil, nil
} }
@ -107,6 +119,14 @@ func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, emai
return nil return nil
} }
func (s *emailCacheStub) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) {
return 0, nil
}
func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) {
return 0, nil
}
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService { func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
cfg := &config.Config{ cfg := &config.Config{
JWT: config.JWTConfig{ JWT: config.JWTConfig{

View File

@ -0,0 +1,404 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
// newBalanceNotifyServiceForTest constructs a BalanceNotifyService with an
// in-memory settings repo and a non-nil emailService so that the guard-clause
// nil-checks pass. The emailService is intentionally minimal — tests must
// avoid crossing scenarios that would actually dispatch emails.
func newBalanceNotifyServiceForTest() (*BalanceNotifyService, *mockSettingRepo) {
repo := newMockSettingRepo()
// EmailService is a concrete type; construct with the same repo so that
// any accidental fallback reads still succeed. Tests should not trigger a
// crossing that reaches SendEmail.
email := NewEmailService(repo, nil)
return NewBalanceNotifyService(email, repo, nil), repo
}
// ---------- guard clauses ----------
func TestCheckBalanceAfterDeduction_NilUser(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
// Should not panic.
s.CheckBalanceAfterDeduction(context.Background(), nil, 100, 50)
}
func TestCheckBalanceAfterDeduction_UserNotifyDisabled(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
repo.data[SettingKeyBalanceLowNotifyThreshold] = "10"
u := &User{ID: 1, BalanceNotifyEnabled: false}
// Even with a crossing, disabled flag short-circuits.
s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
}
func TestCheckBalanceAfterDeduction_GlobalDisabled(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyBalanceLowNotifyEnabled] = "false"
u := &User{ID: 1, BalanceNotifyEnabled: true}
s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
}
func TestCheckBalanceAfterDeduction_ThresholdZero(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
repo.data[SettingKeyBalanceLowNotifyThreshold] = "0"
u := &User{ID: 1, BalanceNotifyEnabled: true}
s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
}
func TestCheckBalanceAfterDeduction_UserThresholdOverride(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
repo.data[SettingKeyBalanceLowNotifyThreshold] = "100" // global default
customThreshold := 5.0
u := &User{
ID: 1,
BalanceNotifyEnabled: true,
BalanceNotifyThreshold: &customThreshold,
}
// User's 5.0 threshold takes precedence over global 100. 20 -> 15 does not
// cross 5, so nothing fires (verified by absence of panic).
s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
}
func TestCheckBalanceAfterDeduction_NoCrossingNotFired(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
repo.data[SettingKeyBalanceLowNotifyThreshold] = "10"
u := &User{ID: 1, BalanceNotifyEnabled: true}
// 100 -> 95, both remain above threshold=10, no crossing.
s.CheckBalanceAfterDeduction(context.Background(), u, 100, 5)
// 5 -> 3, both already below threshold, no crossing (only fires on first
// cross from above-to-below).
s.CheckBalanceAfterDeduction(context.Background(), u, 5, 2)
}
// ---------- nil-service guards on CheckAccountQuotaAfterIncrement ----------
func TestCheckAccountQuotaAfterIncrement_NilAccount(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
// Should not panic.
s.CheckAccountQuotaAfterIncrement(context.Background(), nil, 10, nil)
}
func TestCheckAccountQuotaAfterIncrement_ZeroCost(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
a := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
s.CheckAccountQuotaAfterIncrement(context.Background(), a, 0, nil)
}
func TestCheckAccountQuotaAfterIncrement_NegativeCost(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
a := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
s.CheckAccountQuotaAfterIncrement(context.Background(), a, -5, nil)
}
func TestCheckAccountQuotaAfterIncrement_GlobalDisabled(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyAccountQuotaNotifyEnabled] = "false"
a := &Account{
ID: 1,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"quota_notify_daily_enabled": true,
"quota_notify_daily_threshold": 100.0,
"quota_daily_limit": 1000.0,
"quota_daily_used": 950.0,
},
}
// Global disabled → no processing even if a dim would cross.
s.CheckAccountQuotaAfterIncrement(context.Background(), a, 100, nil)
}
// ---------- sanity: internal helpers still work ----------
func TestGetBalanceNotifyConfig_AllFields(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
repo.data[SettingKeyBalanceLowNotifyThreshold] = "12.5"
repo.data[SettingKeyBalanceLowNotifyRechargeURL] = "https://example.com/pay"
enabled, threshold, url := s.getBalanceNotifyConfig(context.Background())
require.True(t, enabled)
require.Equal(t, 12.5, threshold)
require.Equal(t, "https://example.com/pay", url)
}
func TestGetBalanceNotifyConfig_Disabled(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyBalanceLowNotifyEnabled] = "false"
enabled, _, _ := s.getBalanceNotifyConfig(context.Background())
require.False(t, enabled)
}
func TestGetBalanceNotifyConfig_InvalidThreshold(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
repo.data[SettingKeyBalanceLowNotifyThreshold] = "not-a-number"
enabled, threshold, _ := s.getBalanceNotifyConfig(context.Background())
require.True(t, enabled)
require.Equal(t, 0.0, threshold)
}
func TestIsAccountQuotaNotifyEnabled(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
// Missing key → false
require.False(t, s.isAccountQuotaNotifyEnabled(context.Background()))
// Explicit "false"
repo.data[SettingKeyAccountQuotaNotifyEnabled] = "false"
require.False(t, s.isAccountQuotaNotifyEnabled(context.Background()))
// Explicit "true"
repo.data[SettingKeyAccountQuotaNotifyEnabled] = "true"
require.True(t, s.isAccountQuotaNotifyEnabled(context.Background()))
}
func TestGetSiteName_FallsBackToDefault(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
name := s.getSiteName(context.Background())
require.Equal(t, defaultSiteName, name)
}
func TestGetSiteName_Configured(t *testing.T) {
s, repo := newBalanceNotifyServiceForTest()
repo.data[SettingKeySiteName] = "My Site"
require.Equal(t, "My Site", s.getSiteName(context.Background()))
}
// ---------- crossedDownward ----------
func TestCrossedDownward_CrossesBelow(t *testing.T) {
// oldBalance > threshold, newBalance < threshold → true
require.True(t, crossedDownward(100, 5, 10))
}
func TestCrossedDownward_ExactlyAtThreshold(t *testing.T) {
// oldBalance > threshold, newBalance == threshold → false (not below)
require.False(t, crossedDownward(100, 10, 10))
}
func TestCrossedDownward_OldExactlyAtThreshold_NewBelow(t *testing.T) {
// oldBalance == threshold, newBalance < threshold → true
// (at-or-above → below counts as a crossing)
require.True(t, crossedDownward(10, 5, 10))
}
func TestCrossedDownward_AlreadyBelow(t *testing.T) {
// oldBalance < threshold → false (already below, no new crossing)
require.False(t, crossedDownward(5, 3, 10))
}
func TestCrossedDownward_BothAbove(t *testing.T) {
// oldBalance > threshold, newBalance > threshold → false (no crossing)
require.False(t, crossedDownward(100, 50, 10))
}
func TestCrossedDownward_ZeroThreshold(t *testing.T) {
// threshold == 0 → oldV >= 0 is always true, but newV < 0 only for negatives
// Typical case: positive balances should not fire when threshold is 0.
require.False(t, crossedDownward(10, 5, 0))
require.False(t, crossedDownward(0, 0, 0))
}
func TestCrossedDownward_ZeroThreshold_NegativeNew(t *testing.T) {
// Edge case: newBalance goes negative with threshold=0.
require.True(t, crossedDownward(5, -1, 0))
}
func TestCrossedDownward_NegativeValues(t *testing.T) {
// Both already negative, threshold is positive → no crossing (already below).
require.False(t, crossedDownward(-5, -10, 10))
}
func TestCrossedDownward_LargeDecrement(t *testing.T) {
// A single large deduction crosses the threshold.
require.True(t, crossedDownward(1000, 0.5, 100))
}
func TestCrossedDownward_SmallDecrement_NoCrossing(t *testing.T) {
// A tiny deduction stays above threshold.
require.False(t, crossedDownward(100, 99.99, 10))
}
// ---------- checkQuotaDimCrossings ----------
func TestCheckQuotaDimCrossings_NoDimensions(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
// Empty dims → no crossing, no panic.
s.checkQuotaDimCrossings(account, nil, 10, []string{"admin@example.com"}, "TestSite")
s.checkQuotaDimCrossings(account, []quotaDim{}, 10, []string{"admin@example.com"}, "TestSite")
}
func TestCheckQuotaDimCrossings_DisabledDimension(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
dims := []quotaDim{
{
name: quotaDimDaily,
enabled: false, // disabled
threshold: 100,
thresholdType: thresholdTypeFixed,
currentUsed: 950,
limit: 1000,
},
}
// Disabled dimension should be skipped even if crossing would occur.
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
}
func TestCheckQuotaDimCrossings_ZeroThresholdSkipped(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
dims := []quotaDim{
{
name: quotaDimDaily,
enabled: true,
threshold: 0, // zero threshold
thresholdType: thresholdTypeFixed,
currentUsed: 950,
limit: 1000,
},
}
// Zero threshold → skipped.
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
}
func TestCheckQuotaDimCrossings_NoCrossing_BothBelowThreshold(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
// threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger)
// currentUsed=300 (after), oldUsed=300-50=250 (before). Both < 600, no crossing.
dims := []quotaDim{
{
name: quotaDimDaily,
enabled: true,
threshold: 400,
thresholdType: thresholdTypeFixed,
currentUsed: 300,
limit: 1000,
},
}
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
}
func TestCheckQuotaDimCrossings_NoCrossing_BothAboveThreshold(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
// threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger)
// currentUsed=800 (after), oldUsed=800-50=750 (before). Both >= 600, no crossing.
dims := []quotaDim{
{
name: quotaDimDaily,
enabled: true,
threshold: 400,
thresholdType: thresholdTypeFixed,
currentUsed: 800,
limit: 1000,
},
}
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
}
func TestCheckQuotaDimCrossings_NegativeResolvedThreshold_Skipped(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
// threshold=1200 remaining, limit=1000 → effectiveThreshold = 1000-1200 = -200
// Negative resolved threshold → skipped.
dims := []quotaDim{
{
name: quotaDimDaily,
enabled: true,
threshold: 1200,
thresholdType: thresholdTypeFixed,
currentUsed: 950,
limit: 1000,
},
}
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
}
func TestCheckQuotaDimCrossings_PercentageThreshold_NoCrossing(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
// threshold=30%, limit=1000 → effectiveThreshold = 1000 * (1 - 0.30) = 700
// currentUsed=500, oldUsed=500-50=450. Both < 700, no crossing.
dims := []quotaDim{
{
name: quotaDimWeekly,
enabled: true,
threshold: 30,
thresholdType: thresholdTypePercentage,
currentUsed: 500,
limit: 1000,
},
}
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
}
func TestCheckQuotaDimCrossings_ZeroLimit_Skipped(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
// limit=0 → resolvedThreshold returns 0 → skipped.
dims := []quotaDim{
{
name: quotaDimTotal,
enabled: true,
threshold: 100,
thresholdType: thresholdTypeFixed,
currentUsed: 50,
limit: 0,
},
}
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
}
func TestCheckQuotaDimCrossings_MultipleDims_MixedResults(t *testing.T) {
s, _ := newBalanceNotifyServiceForTest()
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
// dim1: no crossing (both below effective threshold)
// dim2: disabled (skipped)
// dim3: zero threshold (skipped)
dims := []quotaDim{
{
name: quotaDimDaily,
enabled: true,
threshold: 400,
thresholdType: thresholdTypeFixed,
currentUsed: 300, // oldUsed=250, effectiveThreshold=600, both below
limit: 1000,
},
{
name: quotaDimWeekly,
enabled: false,
threshold: 100,
thresholdType: thresholdTypeFixed,
currentUsed: 900,
limit: 1000,
},
{
name: quotaDimTotal,
enabled: true,
threshold: 0,
thresholdType: thresholdTypeFixed,
currentUsed: 500,
limit: 1000,
},
}
// None should trigger. No panic expected.
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
}

View File

@ -0,0 +1,147 @@
//go:build unit
package service
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
// These tests guard against fmt.Sprintf arg-count mismatches in the email
// templates. A mismatch would produce "%!(EXTRA ...)" or "%!v(MISSING)" in
// the output, which these assertions will catch.
// ---------- buildBalanceLowEmailBody ----------
func TestBuildBalanceLowEmailBody_ContainsRequiredFields(t *testing.T) {
s := &BalanceNotifyService{}
body := s.buildBalanceLowEmailBody("Alice", 3.14, 10.0, "MySite", "")
// All substituted values should appear in the output.
require.Contains(t, body, "MySite")
require.Contains(t, body, "Alice")
require.Contains(t, body, "$3.14")
require.Contains(t, body, "$10.00")
// No fmt.Sprintf format error markers.
require.NotContains(t, body, "%!")
require.NotContains(t, body, "MISSING")
require.NotContains(t, body, "EXTRA")
}
func TestBuildBalanceLowEmailBody_WithRechargeURL(t *testing.T) {
s := &BalanceNotifyService{}
body := s.buildBalanceLowEmailBody("Bob", 5.0, 20.0, "Site", "https://example.com/pay")
// The recharge anchor element should appear with the URL.
require.Contains(t, body, `href="https://example.com/pay"`)
require.Contains(t, body, "立即充值")
require.NotContains(t, body, "%!")
}
func TestBuildBalanceLowEmailBody_RechargeURLEscaped(t *testing.T) {
s := &BalanceNotifyService{}
// Try a URL with characters that need HTML escaping.
body := s.buildBalanceLowEmailBody("u", 1.0, 5.0, "Site", `https://example.com/?a=1&b=<script>`)
// `&` and `<` should be escaped in the href.
require.Contains(t, body, "&amp;")
require.Contains(t, body, "&lt;script&gt;")
require.NotContains(t, body, "<script>")
}
func TestBuildBalanceLowEmailBody_NoRechargeURLOmitsButton(t *testing.T) {
s := &BalanceNotifyService{}
body := s.buildBalanceLowEmailBody("u", 1.0, 5.0, "Site", "")
// The anchor element should not be rendered (style class may still appear).
require.NotContains(t, body, `<a href`)
require.NotContains(t, body, "立即充值")
}
// ---------- buildQuotaAlertEmailBody ----------
func TestBuildQuotaAlertEmailBody_AllFieldsPresent(t *testing.T) {
s := &BalanceNotifyService{}
body := s.buildQuotaAlertEmailBody(
42, // accountID
"acc-foo", // accountName
"anthropic", // platform
"日限额 / Daily", // dimLabel
750.50, // used
1000.0, // limit
249.50, // remaining
"$249.50", // thresholdDisplay
"MySite", // siteName
)
require.Contains(t, body, "MySite")
require.Contains(t, body, "#42")
require.Contains(t, body, "acc-foo")
require.Contains(t, body, "anthropic")
require.Contains(t, body, "Daily")
require.Contains(t, body, "$750.50")
require.Contains(t, body, "$1000.00")
require.Contains(t, body, "$249.50")
// No format error markers.
require.NotContains(t, body, "%!")
require.NotContains(t, body, "MISSING")
require.NotContains(t, body, "EXTRA")
}
func TestBuildQuotaAlertEmailBody_UnlimitedDisplay(t *testing.T) {
s := &BalanceNotifyService{}
body := s.buildQuotaAlertEmailBody(
1, "n", "p", "dim",
100.0, 0.0, // limit=0 triggers unlimited branch
0.0, "30%", "Site",
)
require.Contains(t, body, "无限制")
require.Contains(t, body, "Unlimited")
}
func TestBuildQuotaAlertEmailBody_PercentageThresholdDisplay(t *testing.T) {
s := &BalanceNotifyService{}
body := s.buildQuotaAlertEmailBody(
1, "n", "p", "dim",
700.0, 1000.0, 300.0,
"30%", // percentage-formatted threshold
"Site",
)
require.Contains(t, body, "30%")
require.NotContains(t, body, "%!")
}
func TestBuildQuotaAlertEmailBody_RemainingClampedAtZero(t *testing.T) {
// Even though caller is responsible for clamping, this test documents the
// display behavior with remaining=0.
s := &BalanceNotifyService{}
body := s.buildQuotaAlertEmailBody(
1, "n", "p", "dim",
1500.0, 1000.0, 0.0, // used > limit (over-quota)
"$100.00", "Site",
)
require.Contains(t, body, "$0.00")
}
// ---------- sanity checks on the CSS `%%` escape ----------
func TestBuildBalanceLowEmailBody_NoCSSFormatError(t *testing.T) {
s := &BalanceNotifyService{}
body := s.buildBalanceLowEmailBody("u", 1.0, 5.0, "Site", "")
// CSS `linear-gradient(135deg, #f59e0b 0%, #d97706 100%)` should appear with
// literal percent signs (from the %% escape in the template).
require.True(t,
strings.Contains(body, "0%") && strings.Contains(body, "100%"),
"CSS gradient percentages not rendered; got: %s", body)
}
func TestBuildQuotaAlertEmailBody_NoCSSFormatError(t *testing.T) {
s := &BalanceNotifyService{}
body := s.buildQuotaAlertEmailBody(1, "n", "p", "d", 0, 0, 0, "$0.00", "Site")
require.True(t,
strings.Contains(body, "0%") && strings.Contains(body, "100%"),
"CSS gradient percentages not rendered; got: %s", body)
}

View File

@ -0,0 +1,479 @@
package service
import (
"context"
"fmt"
"html"
"log/slog"
"strconv"
"strings"
"time"
)
const (
emailSendTimeout = 30 * time.Second
// Threshold type values
thresholdTypeFixed = "fixed"
thresholdTypePercentage = "percentage"
// Quota dimension labels
quotaDimDaily = "daily"
quotaDimWeekly = "weekly"
quotaDimTotal = "total"
defaultSiteName = "Sub2API"
)
// quotaDimLabels maps dimension names to display labels.
var quotaDimLabels = map[string]string{
quotaDimDaily: "日限额 / Daily",
quotaDimWeekly: "周限额 / Weekly",
quotaDimTotal: "总限额 / Total",
}
// AccountQuotaReader provides read access to account quota data.
type AccountQuotaReader interface {
GetByID(ctx context.Context, id int64) (*Account, error)
}
// BalanceNotifyService handles balance and quota threshold notifications.
type BalanceNotifyService struct {
emailService *EmailService
settingRepo SettingRepository
accountRepo AccountQuotaReader
}
// NewBalanceNotifyService creates a new BalanceNotifyService.
func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountQuotaReader) *BalanceNotifyService {
return &BalanceNotifyService{
emailService: emailService,
settingRepo: settingRepo,
accountRepo: accountRepo,
}
}
// resolveBalanceThreshold returns the effective balance threshold.
// For percentage type, it computes threshold = totalRecharged * percentage / 100.
func resolveBalanceThreshold(threshold float64, thresholdType string, totalRecharged float64) float64 {
if thresholdType == thresholdTypePercentage && totalRecharged > 0 {
return totalRecharged * threshold / 100
}
return threshold
}
// CheckBalanceAfterDeduction checks if balance crossed below threshold after deduction.
// Notification is sent only on first crossing: oldBalance >= threshold && newBalance < threshold.
func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, user *User, oldBalance, cost float64) {
if !s.canNotifyBalance(user) {
return
}
effectiveThreshold, rechargeURL, ok := s.resolveUserEffectiveThreshold(ctx, user)
if !ok {
return
}
newBalance := oldBalance - cost
if !crossedDownward(oldBalance, newBalance, effectiveThreshold) {
return
}
s.dispatchBalanceLowEmail(ctx, user, newBalance, effectiveThreshold, rechargeURL)
}
// canNotifyBalance checks nil guards and user-level toggle.
func (s *BalanceNotifyService) canNotifyBalance(user *User) bool {
if user == nil || s.emailService == nil || s.settingRepo == nil {
return false
}
return user.BalanceNotifyEnabled
}
// resolveUserEffectiveThreshold reads global + user config, returns the effective threshold.
// Returns ok=false when notifications should be skipped.
func (s *BalanceNotifyService) resolveUserEffectiveThreshold(ctx context.Context, user *User) (effectiveThreshold float64, rechargeURL string, ok bool) {
globalEnabled, globalThreshold, rechargeURL := s.getBalanceNotifyConfig(ctx)
if !globalEnabled {
return 0, "", false
}
threshold := globalThreshold
if user.BalanceNotifyThreshold != nil {
threshold = *user.BalanceNotifyThreshold
}
if threshold <= 0 {
return 0, "", false
}
effectiveThreshold = resolveBalanceThreshold(threshold, user.BalanceNotifyThresholdType, user.TotalRecharged)
if effectiveThreshold <= 0 {
return 0, "", false
}
return effectiveThreshold, rechargeURL, true
}
// crossedDownward returns true when oldV was at-or-above threshold but newV dropped below it.
func crossedDownward(oldV, newV, threshold float64) bool {
return oldV >= threshold && newV < threshold
}
// dispatchBalanceLowEmail collects recipients and sends the alert in a goroutine.
func (s *BalanceNotifyService) dispatchBalanceLowEmail(ctx context.Context, user *User, newBalance, threshold float64, rechargeURL string) {
siteName := s.getSiteName(ctx)
recipients := s.collectBalanceNotifyRecipients(user)
slog.Info("CheckBalanceAfterDeduction: sending notification",
"user_id", user.ID, "recipients", recipients, "new_balance", newBalance, "threshold", threshold)
go func() {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in balance notification", "recover", r)
}
}()
s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, threshold, siteName, rechargeURL)
}()
}
// quotaDim describes one quota dimension for notification checking.
type quotaDim struct {
name string
enabled bool
threshold float64
thresholdType string // "fixed" (default) or "percentage"
currentUsed float64
limit float64
}
// resolvedThreshold converts the user-facing "remaining" threshold into a usage-based trigger point.
// The threshold represents how much quota REMAINS when the alert fires:
// - Fixed ($): threshold=400, limit=1000 → fires when usage reaches 600 (remaining drops to 400)
// - Percentage (%): threshold=30, limit=1000 → fires when usage reaches 700 (remaining drops to 30%)
func (d quotaDim) resolvedThreshold() float64 {
if d.limit <= 0 {
return 0
}
if d.thresholdType == thresholdTypePercentage {
return d.limit * (1 - d.threshold/100)
}
return d.limit - d.threshold
}
// buildQuotaDims returns the three quota dimensions for notification checking.
func buildQuotaDims(account *Account) []quotaDim {
return []quotaDim{
{quotaDimDaily, account.GetQuotaNotifyDailyEnabled(), account.GetQuotaNotifyDailyThreshold(), account.GetQuotaNotifyDailyThresholdType(), account.GetQuotaDailyUsed(), account.GetQuotaDailyLimit()},
{quotaDimWeekly, account.GetQuotaNotifyWeeklyEnabled(), account.GetQuotaNotifyWeeklyThreshold(), account.GetQuotaNotifyWeeklyThresholdType(), account.GetQuotaWeeklyUsed(), account.GetQuotaWeeklyLimit()},
{quotaDimTotal, account.GetQuotaNotifyTotalEnabled(), account.GetQuotaNotifyTotalThreshold(), account.GetQuotaNotifyTotalThresholdType(), account.GetQuotaUsed(), account.GetQuotaLimit()},
}
}
// buildQuotaDimsFromState builds quota dimensions using DB transaction state instead of account snapshot.
// Notification settings (enabled, threshold, thresholdType) come from the account; usage values from quotaState.
func buildQuotaDimsFromState(account *Account, state *AccountQuotaState) []quotaDim {
return []quotaDim{
{quotaDimDaily, account.GetQuotaNotifyDailyEnabled(), account.GetQuotaNotifyDailyThreshold(), account.GetQuotaNotifyDailyThresholdType(), state.DailyUsed, state.DailyLimit},
{quotaDimWeekly, account.GetQuotaNotifyWeeklyEnabled(), account.GetQuotaNotifyWeeklyThreshold(), account.GetQuotaNotifyWeeklyThresholdType(), state.WeeklyUsed, state.WeeklyLimit},
{quotaDimTotal, account.GetQuotaNotifyTotalEnabled(), account.GetQuotaNotifyTotalThreshold(), account.GetQuotaNotifyTotalThresholdType(), state.TotalUsed, state.TotalLimit},
}
}
// CheckAccountQuotaAfterIncrement checks if any quota dimension crossed above its notify threshold.
// When quotaState is non-nil (from DB transaction RETURNING), it is used directly for threshold
// checking, avoiding a separate DB read. Otherwise it falls back to fetching fresh account data.
func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Context, account *Account, cost float64, quotaState *AccountQuotaState) {
if account == nil || s.emailService == nil || s.settingRepo == nil || cost <= 0 {
return
}
if !s.isAccountQuotaNotifyEnabled(ctx) {
return
}
adminEmails := s.getAccountQuotaNotifyEmails(ctx)
if len(adminEmails) == 0 {
return
}
siteName := s.getSiteName(ctx)
var dims []quotaDim
if quotaState != nil {
dims = buildQuotaDimsFromState(account, quotaState)
} else {
freshAccount := s.fetchFreshAccount(ctx, account)
dims = buildQuotaDims(freshAccount)
account = freshAccount // use fresh data for alert metadata
}
s.checkQuotaDimCrossings(account, dims, cost, adminEmails, siteName)
}
// fetchFreshAccount loads the latest account from DB; falls back to the snapshot on error.
func (s *BalanceNotifyService) fetchFreshAccount(ctx context.Context, snapshot *Account) *Account {
if s.accountRepo == nil {
return snapshot
}
fresh, err := s.accountRepo.GetByID(ctx, snapshot.ID)
if err != nil {
slog.Warn("failed to fetch fresh account for quota notify, using snapshot",
"account_id", snapshot.ID, "error", err)
return snapshot
}
return fresh
}
// checkQuotaDimCrossings iterates pre-built quota dimensions and sends alerts for threshold crossings.
// Pre-increment value is reconstructed as currentUsed - cost to detect the crossing moment.
func (s *BalanceNotifyService) checkQuotaDimCrossings(account *Account, dims []quotaDim, cost float64, adminEmails []string, siteName string) {
for _, dim := range dims {
if !dim.enabled || dim.threshold <= 0 {
continue
}
effectiveThreshold := dim.resolvedThreshold()
if effectiveThreshold <= 0 {
continue
}
newUsed := dim.currentUsed
oldUsed := dim.currentUsed - cost
if oldUsed < effectiveThreshold && newUsed >= effectiveThreshold {
s.asyncSendQuotaAlert(adminEmails, account.ID, account.Name, account.Platform, dim, newUsed, effectiveThreshold, siteName)
}
}
}
// asyncSendQuotaAlert sends quota alert email in a goroutine with panic recovery.
func (s *BalanceNotifyService) asyncSendQuotaAlert(adminEmails []string, accountID int64, accountName, platform string, dim quotaDim, newUsed, effectiveThreshold float64, siteName string) {
go func() {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in quota notification", "recover", r)
}
}()
s.sendQuotaAlertEmails(adminEmails, accountID, accountName, platform, dim, newUsed, siteName)
}()
}
// getBalanceNotifyConfig reads global balance notification settings.
func (s *BalanceNotifyService) getBalanceNotifyConfig(ctx context.Context) (enabled bool, threshold float64, rechargeURL string) {
keys := []string{SettingKeyBalanceLowNotifyEnabled, SettingKeyBalanceLowNotifyThreshold, SettingKeyBalanceLowNotifyRechargeURL}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
return false, 0, ""
}
enabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
if v := settings[SettingKeyBalanceLowNotifyThreshold]; v != "" {
if f, err := strconv.ParseFloat(v, 64); err == nil {
threshold = f
}
}
rechargeURL = settings[SettingKeyBalanceLowNotifyRechargeURL]
return
}
// isAccountQuotaNotifyEnabled checks the global account quota notification toggle.
func (s *BalanceNotifyService) isAccountQuotaNotifyEnabled(ctx context.Context) bool {
val, err := s.settingRepo.GetValue(ctx, SettingKeyAccountQuotaNotifyEnabled)
if err != nil {
return false
}
return val == "true"
}
// getAccountQuotaNotifyEmails reads admin notification emails from settings,
// filtering out disabled and unverified entries.
func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context) []string {
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAccountQuotaNotifyEmails)
if err != nil || strings.TrimSpace(raw) == "" || raw == "[]" {
return nil
}
entries := ParseNotifyEmails(raw)
if len(entries) == 0 {
return nil
}
return filterVerifiedEmails(entries)
}
// getSiteName reads site name from settings with fallback.
func (s *BalanceNotifyService) getSiteName(ctx context.Context) string {
name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
if err != nil || name == "" {
return defaultSiteName
}
return name
}
// filterVerifiedEmails returns deduplicated, non-disabled, verified emails.
func filterVerifiedEmails(entries []NotifyEmailEntry) []string {
var recipients []string
seen := make(map[string]bool)
for _, entry := range entries {
if entry.Disabled || !entry.Verified {
continue
}
email := strings.TrimSpace(entry.Email)
if email == "" {
continue
}
lower := strings.ToLower(email)
if seen[lower] {
continue
}
seen[lower] = true
recipients = append(recipients, email)
}
return recipients
}
// collectBalanceNotifyRecipients returns verified, non-disabled email recipients.
// Only emails with verified=true and disabled=false are included.
func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []string {
return filterVerifiedEmails(user.BalanceNotifyExtraEmails)
}
// sendEmails sends an email to all recipients with shared timeout and error logging.
func (s *BalanceNotifyService) sendEmails(recipients []string, subject, body string, logAttrs ...any) {
if len(recipients) == 0 {
slog.Warn("sendEmails: no recipients", "subject", subject)
return
}
for _, to := range recipients {
ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout)
if err := s.emailService.SendEmail(ctx, to, subject, body); err != nil {
attrs := append([]any{"to", to, "error", err}, logAttrs...)
slog.Error("failed to send notification", attrs...)
} else {
slog.Info("notification email sent successfully", "to", to, "subject", subject)
}
cancel()
}
}
// sendBalanceLowEmails sends balance low notification to all recipients.
func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userName, userEmail string, balance, threshold float64, siteName, rechargeURL string) {
displayName := userName
if displayName == "" {
displayName = userEmail
}
subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", sanitizeEmailHeader(siteName))
body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName), rechargeURL)
s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance)
}
// sendQuotaAlertEmails sends quota alert notification to admin emails.
func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accountID int64, accountName, platform string, dim quotaDim, used float64, siteName string) {
dimLabel := quotaDimLabels[dim.name]
if dimLabel == "" {
dimLabel = dim.name
}
// Format the remaining-based threshold for display
thresholdDisplay := fmt.Sprintf("$%.2f", dim.threshold)
if dim.thresholdType == thresholdTypePercentage {
thresholdDisplay = fmt.Sprintf("%.0f%%", dim.threshold)
}
remaining := dim.limit - used
if remaining < 0 {
remaining = 0
}
subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", sanitizeEmailHeader(siteName), sanitizeEmailHeader(accountName))
body := s.buildQuotaAlertEmailBody(accountID, html.EscapeString(accountName), html.EscapeString(platform), html.EscapeString(dimLabel), used, dim.limit, remaining, thresholdDisplay, html.EscapeString(siteName))
s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dim.name)
}
// sanitizeEmailHeader removes CR/LF characters to prevent SMTP header injection.
func sanitizeEmailHeader(s string) string {
return strings.NewReplacer("\r", "", "\n", "").Replace(s)
}
// balanceLowEmailTemplate is the HTML template for balance low notifications.
// Format args: siteName, userName, userName, balance, threshold, threshold.
// The recharge button is appended dynamically when rechargeURL is set.
const balanceLowEmailTemplate = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #fff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #f59e0b 0%%, #d97706 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; text-align: center; }
.balance { font-size: 36px; font-weight: bold; color: #dc2626; margin: 20px 0; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
.recharge-btn { display: inline-block; margin-top: 24px; padding: 12px 32px; background: linear-gradient(135deg, #f59e0b 0%%, #d97706 100%%); color: #fff; text-decoration: none; border-radius: 6px; font-size: 16px; font-weight: bold; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header"><h1>%s</h1></div>
<div class="content">
<p style="font-size: 18px; color: #333;">%s您的余额不足</p>
<p style="color: #666;">Dear %s, your balance is running low</p>
<div class="balance">$%.2f</div>
<div class="info">
<p>您的账户余额已低于提醒阈值 <strong>$%.2f</strong></p>
<p>Your account balance has fallen below the alert threshold of <strong>$%.2f</strong>.</p>
<p>请及时充值以免服务中断</p>
<p>Please top up to avoid service interruption.</p>
</div>
%s
</div>
<div class="footer"><p>此邮件由系统自动发送请勿回复</p></div>
</div>
</body>
</html>`
// quotaAlertEmailTemplate is the HTML template for account quota alert notifications.
// Format args: siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay.
const quotaAlertEmailTemplate = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
.container { max-width: 600px; margin: 0 auto; background-color: #fff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
.header { background: linear-gradient(135deg, #ef4444 0%%, #dc2626 100%%); color: white; padding: 30px; text-align: center; }
.header h1 { margin: 0; font-size: 24px; }
.content { padding: 40px 30px; }
.metric { display: flex; justify-content: space-between; padding: 12px 0; border-bottom: 1px solid #eee; }
.metric-label { color: #666; }
.metric-value { font-weight: bold; color: #333; }
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; text-align: center; }
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
</style>
</head>
<body>
<div class="container">
<div class="header"><h1>%s</h1></div>
<div class="content">
<p style="font-size: 18px; color: #333; text-align: center;">账号限额告警 / Account Quota Alert</p>
<div class="metric"><span class="metric-label">账号 ID / Account ID</span><span class="metric-value">#%d</span></div>
<div class="metric"><span class="metric-label">账号 / Account</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">平台 / Platform</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">维度 / Dimension</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">已使用 / Used</span><span class="metric-value">$%.2f</span></div>
<div class="metric"><span class="metric-label">限额 / Limit</span><span class="metric-value">%s</span></div>
<div class="metric"><span class="metric-label">剩余额度 / Remaining</span><span class="metric-value">$%.2f</span></div>
<div class="metric"><span class="metric-label">提醒阈值 / Alert Threshold</span><span class="metric-value">%s</span></div>
<div class="info">
<p>账号剩余额度已低于提醒阈值请及时关注</p>
<p>Account remaining quota has fallen below the alert threshold.</p>
</div>
</div>
<div class="footer"><p>此邮件由系统自动发送请勿回复</p></div>
</div>
</body>
</html>`
// buildBalanceLowEmailBody builds HTML email for balance low notification.
func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance, threshold float64, siteName, rechargeURL string) string {
rechargeBlock := ""
if rechargeURL != "" {
rechargeBlock = fmt.Sprintf(`<a href="%s" class="recharge-btn">立即充值 / Top Up Now</a>`, html.EscapeString(rechargeURL))
}
return fmt.Sprintf(balanceLowEmailTemplate, siteName, userName, userName, balance, threshold, threshold, rechargeBlock)
}
// buildQuotaAlertEmailBody builds HTML email for account quota alert.
func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountID int64, accountName, platform, dimLabel string, used, limit, remaining float64, thresholdDisplay, siteName string) string {
limitStr := fmt.Sprintf("$%.2f", limit)
if limit <= 0 {
limitStr = "无限制 / Unlimited"
}
return fmt.Sprintf(quotaAlertEmailTemplate, siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay)
}

View File

@ -0,0 +1,280 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// ---------- resolveBalanceThreshold ----------
func TestResolveBalanceThreshold_Fixed(t *testing.T) {
// Fixed type always returns the raw threshold regardless of totalRecharged.
require.Equal(t, 10.0, resolveBalanceThreshold(10, thresholdTypeFixed, 1000))
require.Equal(t, 10.0, resolveBalanceThreshold(10, thresholdTypeFixed, 0))
require.Equal(t, 0.0, resolveBalanceThreshold(0, thresholdTypeFixed, 1000))
}
func TestResolveBalanceThreshold_Percentage(t *testing.T) {
// 10% of 1000 = 100
require.Equal(t, 100.0, resolveBalanceThreshold(10, thresholdTypePercentage, 1000))
// 50% of 200 = 100
require.Equal(t, 100.0, resolveBalanceThreshold(50, thresholdTypePercentage, 200))
}
func TestResolveBalanceThreshold_PercentageZeroRecharged(t *testing.T) {
// When totalRecharged is 0, percentage falls through to raw threshold
// (treated as fixed). This is the defensive behavior.
require.Equal(t, 10.0, resolveBalanceThreshold(10, thresholdTypePercentage, 0))
}
func TestResolveBalanceThreshold_EmptyType(t *testing.T) {
// Empty type is treated as fixed (not percentage).
require.Equal(t, 10.0, resolveBalanceThreshold(10, "", 1000))
}
// ---------- quotaDim.resolvedThreshold ----------
func TestResolvedThreshold_FixedNormal(t *testing.T) {
// threshold=400 remaining, limit=1000 → usage trigger at 600
d := quotaDim{threshold: 400, thresholdType: thresholdTypeFixed, limit: 1000}
require.Equal(t, 600.0, d.resolvedThreshold())
}
func TestResolvedThreshold_FixedThresholdExceedsLimit(t *testing.T) {
// threshold=1200, limit=1000 → returns negative, callers must skip
d := quotaDim{threshold: 1200, thresholdType: thresholdTypeFixed, limit: 1000}
require.Equal(t, -200.0, d.resolvedThreshold())
}
func TestResolvedThreshold_FixedThresholdEqualsLimit(t *testing.T) {
// threshold=1000, limit=1000 → returns 0 (alert fires at 0 usage)
d := quotaDim{threshold: 1000, thresholdType: thresholdTypeFixed, limit: 1000}
require.Equal(t, 0.0, d.resolvedThreshold())
}
func TestResolvedThreshold_PercentageNormal(t *testing.T) {
// threshold=30%, limit=1000 → usage trigger at 700 (remaining drops to 30%)
d := quotaDim{threshold: 30, thresholdType: thresholdTypePercentage, limit: 1000}
require.InDelta(t, 700.0, d.resolvedThreshold(), 0.001)
}
func TestResolvedThreshold_PercentageZeroPercent(t *testing.T) {
// threshold=0%, limit=1000 → fires when remaining drops to 0 (usage=1000)
d := quotaDim{threshold: 0, thresholdType: thresholdTypePercentage, limit: 1000}
require.InDelta(t, 1000.0, d.resolvedThreshold(), 0.001)
}
func TestResolvedThreshold_PercentageHundredPercent(t *testing.T) {
// threshold=100%, limit=1000 → fires immediately (remaining drops to 100% i.e. nothing used yet)
d := quotaDim{threshold: 100, thresholdType: thresholdTypePercentage, limit: 1000}
require.InDelta(t, 0.0, d.resolvedThreshold(), 0.001)
}
func TestResolvedThreshold_PercentageOverHundred(t *testing.T) {
// threshold=150%, limit=1000 → returns negative (never triggers; callers skip)
d := quotaDim{threshold: 150, thresholdType: thresholdTypePercentage, limit: 1000}
require.Less(t, d.resolvedThreshold(), 0.0)
}
func TestResolvedThreshold_ZeroLimit(t *testing.T) {
// limit=0 → returns 0 to avoid division and false alerts on unlimited quotas
d := quotaDim{threshold: 100, thresholdType: thresholdTypeFixed, limit: 0}
require.Equal(t, 0.0, d.resolvedThreshold())
}
func TestResolvedThreshold_NegativeLimit(t *testing.T) {
// Negative limit treated as 0
d := quotaDim{threshold: 100, thresholdType: thresholdTypeFixed, limit: -10}
require.Equal(t, 0.0, d.resolvedThreshold())
}
// ---------- sanitizeEmailHeader ----------
func TestSanitizeEmailHeader_CRLF(t *testing.T) {
require.Equal(t, "Subject injected", sanitizeEmailHeader("Subject\r\n injected"))
}
func TestSanitizeEmailHeader_OnlyCR(t *testing.T) {
require.Equal(t, "foobar", sanitizeEmailHeader("foo\rbar"))
}
func TestSanitizeEmailHeader_OnlyLF(t *testing.T) {
require.Equal(t, "foobar", sanitizeEmailHeader("foo\nbar"))
}
func TestSanitizeEmailHeader_Clean(t *testing.T) {
require.Equal(t, "Sub2API", sanitizeEmailHeader("Sub2API"))
}
func TestSanitizeEmailHeader_Empty(t *testing.T) {
require.Equal(t, "", sanitizeEmailHeader(""))
}
func TestSanitizeEmailHeader_MultipleNewlines(t *testing.T) {
require.Equal(t, "abc", sanitizeEmailHeader("a\r\nb\r\nc"))
}
// ---------- buildQuotaDims ----------
func TestBuildQuotaDims_AllDimensionsReturned(t *testing.T) {
// Use an account with quota notify config across all 3 dimensions.
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"quota_notify_daily_enabled": true,
"quota_notify_daily_threshold": 100.0,
"quota_notify_daily_threshold_type": thresholdTypeFixed,
"quota_notify_weekly_enabled": true,
"quota_notify_weekly_threshold": 20.0,
"quota_notify_weekly_threshold_type": thresholdTypePercentage,
"quota_notify_total_enabled": false,
"quota_daily_limit": 500.0,
"quota_weekly_limit": 2000.0,
"quota_limit": 10000.0,
"quota_daily_used": 50.0,
"quota_weekly_used": 300.0,
"quota_used": 1000.0,
},
}
dims := buildQuotaDims(a)
require.Len(t, dims, 3)
// Daily
require.Equal(t, quotaDimDaily, dims[0].name)
require.True(t, dims[0].enabled)
require.Equal(t, 100.0, dims[0].threshold)
require.Equal(t, thresholdTypeFixed, dims[0].thresholdType)
require.Equal(t, 500.0, dims[0].limit)
require.Equal(t, 50.0, dims[0].currentUsed)
// Weekly
require.Equal(t, quotaDimWeekly, dims[1].name)
require.True(t, dims[1].enabled)
require.Equal(t, 20.0, dims[1].threshold)
require.Equal(t, thresholdTypePercentage, dims[1].thresholdType)
require.Equal(t, 2000.0, dims[1].limit)
// Total
require.Equal(t, quotaDimTotal, dims[2].name)
require.False(t, dims[2].enabled)
require.Equal(t, 10000.0, dims[2].limit)
require.Equal(t, 1000.0, dims[2].currentUsed)
}
func TestBuildQuotaDims_EmptyExtra(t *testing.T) {
// Missing fields default to zero/disabled.
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{},
}
dims := buildQuotaDims(a)
require.Len(t, dims, 3)
for _, d := range dims {
require.False(t, d.enabled)
require.Equal(t, 0.0, d.threshold)
require.Equal(t, 0.0, d.limit)
}
}
// ---------- buildQuotaDimsFromState ----------
func TestBuildQuotaDimsFromState_UsesStateValues(t *testing.T) {
// Usage values should come from the state, not the account.
a := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"quota_notify_daily_enabled": true,
"quota_notify_daily_threshold": 100.0,
"quota_daily_used": 999.0, // should be ignored
"quota_daily_limit": 999.0, // should be ignored
},
}
state := &AccountQuotaState{
DailyUsed: 77.0,
DailyLimit: 500.0,
WeeklyUsed: 88.0,
WeeklyLimit: 2000.0,
TotalUsed: 99.0,
TotalLimit: 10000.0,
}
dims := buildQuotaDimsFromState(a, state)
require.Len(t, dims, 3)
// Settings from account (enabled, threshold, thresholdType)
require.True(t, dims[0].enabled)
require.Equal(t, 100.0, dims[0].threshold)
// Usage from state
require.Equal(t, 77.0, dims[0].currentUsed)
require.Equal(t, 500.0, dims[0].limit)
require.Equal(t, 88.0, dims[1].currentUsed)
require.Equal(t, 2000.0, dims[1].limit)
require.Equal(t, 99.0, dims[2].currentUsed)
require.Equal(t, 10000.0, dims[2].limit)
}
// ---------- collectBalanceNotifyRecipients ----------
func TestCollectBalanceNotifyRecipients_Empty(t *testing.T) {
s := &BalanceNotifyService{}
u := &User{BalanceNotifyExtraEmails: nil}
require.Empty(t, s.collectBalanceNotifyRecipients(u))
}
func TestCollectBalanceNotifyRecipients_FiltersDisabledAndUnverified(t *testing.T) {
s := &BalanceNotifyService{}
u := &User{
BalanceNotifyExtraEmails: []NotifyEmailEntry{
{Email: "a@example.com", Verified: true, Disabled: false},
{Email: "b@example.com", Verified: true, Disabled: true}, // disabled
{Email: "c@example.com", Verified: false, Disabled: false}, // unverified
{Email: "d@example.com", Verified: true, Disabled: false},
},
}
got := s.collectBalanceNotifyRecipients(u)
require.Equal(t, []string{"a@example.com", "d@example.com"}, got)
}
func TestCollectBalanceNotifyRecipients_DeduplicatesCaseInsensitive(t *testing.T) {
s := &BalanceNotifyService{}
u := &User{
BalanceNotifyExtraEmails: []NotifyEmailEntry{
{Email: "User@Example.com", Verified: true},
{Email: "user@example.com", Verified: true},
{Email: "USER@EXAMPLE.COM", Verified: true},
},
}
got := s.collectBalanceNotifyRecipients(u)
require.Len(t, got, 1)
// The original casing of the first entry is preserved.
require.Equal(t, "User@Example.com", got[0])
}
func TestCollectBalanceNotifyRecipients_SkipsEmpty(t *testing.T) {
s := &BalanceNotifyService{}
u := &User{
BalanceNotifyExtraEmails: []NotifyEmailEntry{
{Email: " ", Verified: true},
{Email: "", Verified: true},
{Email: "valid@example.com", Verified: true},
},
}
got := s.collectBalanceNotifyRecipients(u)
require.Equal(t, []string{"valid@example.com"}, got)
}
func TestCollectBalanceNotifyRecipients_TrimsWhitespace(t *testing.T) {
s := &BalanceNotifyService{}
u := &User{
BalanceNotifyExtraEmails: []NotifyEmailEntry{
{Email: " trimmed@example.com ", Verified: true},
},
}
got := s.collectBalanceNotifyRecipients(u)
require.Equal(t, []string{"trimmed@example.com"}, got)
}

View File

@ -363,7 +363,6 @@ func TestCalculateImageCost(t *testing.T) {
require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10) require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10)
} }
func TestIsModelSupported(t *testing.T) { func TestIsModelSupported(t *testing.T) {
svc := newTestBillingService() svc := newTestBillingService()
@ -719,3 +718,123 @@ func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing.
require.InDelta(t, 1.5, pricing.LongContextInputMultiplier, 1e-12) require.InDelta(t, 1.5, pricing.LongContextInputMultiplier, 1e-12)
require.InDelta(t, 1.25, pricing.LongContextOutputMultiplier, 1e-12) require.InDelta(t, 1.25, pricing.LongContextOutputMultiplier, 1e-12)
} }
// ---------------------------------------------------------------------------
// GetModelPricingWithChannel
// ---------------------------------------------------------------------------
func TestGetModelPricingWithChannel_NilChannelPricing_ReturnsOriginal(t *testing.T) {
svc := newTestBillingService()
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", nil)
require.NoError(t, err)
require.NotNil(t, pricing)
// Should be identical to GetModelPricing
original, err := svc.GetModelPricing("claude-sonnet-4")
require.NoError(t, err)
require.InDelta(t, original.InputPricePerToken, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, original.OutputPricePerToken, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, original.CacheCreationPricePerToken, pricing.CacheCreationPricePerToken, 1e-12)
require.InDelta(t, original.CacheReadPricePerToken, pricing.CacheReadPricePerToken, 1e-12)
}
func TestGetModelPricingWithChannel_OverrideInputPriceOnly(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
InputPrice: testPtrFloat64(99e-6),
}
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
require.NoError(t, err)
// InputPrice overridden (both normal and priority)
require.InDelta(t, 99e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 99e-6, pricing.InputPricePerTokenPriority, 1e-12)
// OutputPrice unchanged (claude-sonnet-4 fallback = 15e-6)
require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12)
}
func TestGetModelPricingWithChannel_OverrideOutputPriceOnly(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
OutputPrice: testPtrFloat64(88e-6),
}
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
require.NoError(t, err)
// OutputPrice overridden
require.InDelta(t, 88e-6, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, 88e-6, pricing.OutputPricePerTokenPriority, 1e-12)
// InputPrice unchanged (claude-sonnet-4 fallback = 3e-6)
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12)
}
func TestGetModelPricingWithChannel_OverrideAllFields(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
InputPrice: testPtrFloat64(10e-6),
OutputPrice: testPtrFloat64(20e-6),
CacheWritePrice: testPtrFloat64(5e-6),
CacheReadPrice: testPtrFloat64(1e-6),
ImageOutputPrice: testPtrFloat64(50e-6),
}
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
require.NoError(t, err)
require.InDelta(t, 10e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 10e-6, pricing.InputPricePerTokenPriority, 1e-12)
require.InDelta(t, 20e-6, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, 20e-6, pricing.OutputPricePerTokenPriority, 1e-12)
require.InDelta(t, 5e-6, pricing.CacheCreationPricePerToken, 1e-12)
require.InDelta(t, 5e-6, pricing.CacheCreation5mPrice, 1e-12)
require.InDelta(t, 5e-6, pricing.CacheCreation1hPrice, 1e-12)
require.InDelta(t, 1e-6, pricing.CacheReadPricePerToken, 1e-12)
require.InDelta(t, 1e-6, pricing.CacheReadPricePerTokenPriority, 1e-12)
require.InDelta(t, 50e-6, pricing.ImageOutputPricePerToken, 1e-12)
}
func TestGetModelPricingWithChannel_CacheWritePriceAffects5mAnd1h(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
CacheWritePrice: testPtrFloat64(7e-6),
}
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
require.NoError(t, err)
// CacheWritePrice should set all three: CacheCreationPricePerToken, 5m, and 1h
require.InDelta(t, 7e-6, pricing.CacheCreationPricePerToken, 1e-12)
require.InDelta(t, 7e-6, pricing.CacheCreation5mPrice, 1e-12)
require.InDelta(t, 7e-6, pricing.CacheCreation1hPrice, 1e-12)
}
func TestGetModelPricingWithChannel_CacheReadPriceAffectsPriority(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
CacheReadPrice: testPtrFloat64(2e-6),
}
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
require.NoError(t, err)
// CacheReadPrice should set both normal and priority
require.InDelta(t, 2e-6, pricing.CacheReadPricePerToken, 1e-12)
require.InDelta(t, 2e-6, pricing.CacheReadPricePerTokenPriority, 1e-12)
}
func TestGetModelPricingWithChannel_UnknownModelReturnsError(t *testing.T) {
svc := newTestBillingService()
chPricing := &ChannelModelPricing{
InputPrice: testPtrFloat64(1e-6),
}
pricing, err := svc.GetModelPricingWithChannel("totally-unknown-model", chPricing)
require.Error(t, err)
require.Nil(t, pricing)
require.Contains(t, err.Error(), "pricing not found")
}

View File

@ -0,0 +1,258 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// CalculateCostUnified
// ---------------------------------------------------------------------------
func TestCalculateCostUnified_NilResolver_FallsBackToOldPath(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
input := CostInput{
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 1.0,
Resolver: nil, // no resolver
}
cost, err := svc.CalculateCostUnified(input)
require.NoError(t, err)
// Should match the old-path result exactly
expected, err := svc.calculateCostInternal("claude-sonnet-4", tokens, 1.0, "", nil)
require.NoError(t, err)
require.InDelta(t, expected.TotalCost, cost.TotalCost, 1e-10)
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10)
// BillingMode is NOT set by old path through CalculateCostUnified (resolver == nil)
require.Empty(t, cost.BillingMode)
}
func TestCalculateCostUnified_TokenMode(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
input := CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 1.5,
Resolver: resolver,
}
cost, err := bs.CalculateCostUnified(input)
require.NoError(t, err)
require.NotNil(t, cost)
// Verify token billing: Input: 1000*3e-6=0.003, Output: 500*15e-6=0.0075
expectedTotal := 1000*3e-6 + 500*15e-6
require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10)
require.InDelta(t, expectedTotal*1.5, cost.ActualCost, 1e-10)
require.Equal(t, string(BillingModeToken), cost.BillingMode)
}
func TestCalculateCostUnified_PerRequestMode(t *testing.T) {
// Set up a ChannelService with a per-request pricing channel
cs := newTestChannelServiceWithCache(t, &channelCache{
pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{
{groupID: 1, model: "claude-sonnet-4"}: {
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0.05),
},
},
channelByGroupID: map[int64]*Channel{
1: {ID: 1, Status: StatusActive},
},
groupPlatform: map[int64]string{1: ""},
wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{},
mappingByGroupModel: map[channelModelKey]string{},
wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{},
byID: map[int64]*Channel{},
})
bs := newTestBillingService()
resolver := NewModelPricingResolver(cs, bs)
groupID := int64(1)
input := CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
GroupID: &groupID,
Tokens: UsageTokens{InputTokens: 100, OutputTokens: 50},
RequestCount: 3,
RateMultiplier: 2.0,
Resolver: resolver,
}
cost, err := bs.CalculateCostUnified(input)
require.NoError(t, err)
require.NotNil(t, cost)
// 3 requests * $0.05 = $0.15
require.InDelta(t, 0.15, cost.TotalCost, 1e-10)
// ActualCost = 0.15 * 2.0 = 0.30
require.InDelta(t, 0.30, cost.ActualCost, 1e-10)
require.Equal(t, string(BillingModePerRequest), cost.BillingMode)
}
func TestCalculateCostUnified_ImageMode(t *testing.T) {
cs := newTestChannelServiceWithCache(t, &channelCache{
pricingByGroupModel: map[channelModelKey]*ChannelModelPricing{
{groupID: 2, model: "gemini-image"}: {
BillingMode: BillingModeImage,
PerRequestPrice: testPtrFloat64(0.10),
},
},
channelByGroupID: map[int64]*Channel{
2: {ID: 2, Status: StatusActive},
},
groupPlatform: map[int64]string{2: ""},
wildcardByGroupPlatform: map[channelGroupPlatformKey][]*wildcardPricingEntry{},
mappingByGroupModel: map[channelModelKey]string{},
wildcardMappingByGP: map[channelGroupPlatformKey][]*wildcardMappingEntry{},
byID: map[int64]*Channel{},
})
bs := &BillingService{
cfg: &config.Config{},
fallbackPrices: map[string]*ModelPricing{},
}
resolver := NewModelPricingResolver(cs, bs)
groupID := int64(2)
input := CostInput{
Ctx: context.Background(),
Model: "gemini-image",
GroupID: &groupID,
Tokens: UsageTokens{},
RequestCount: 2,
RateMultiplier: 1.0,
Resolver: resolver,
}
cost, err := bs.CalculateCostUnified(input)
require.NoError(t, err)
require.NotNil(t, cost)
// 2 * $0.10 = $0.20
require.InDelta(t, 0.20, cost.TotalCost, 1e-10)
require.InDelta(t, 0.20, cost.ActualCost, 1e-10)
require.Equal(t, string(BillingModeImage), cost.BillingMode)
}
func TestCalculateCostUnified_RateMultiplierZeroDefaultsToOne(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
costZero, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 0, // should default to 1.0
Resolver: resolver,
})
require.NoError(t, err)
costOne, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 1.0,
Resolver: resolver,
})
require.NoError(t, err)
require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
}
func TestCalculateCostUnified_NegativeRateMultiplierDefaultsToOne(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
tokens := UsageTokens{InputTokens: 1000}
costNeg, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: -5.0,
Resolver: resolver,
})
require.NoError(t, err)
costOne, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: tokens,
RateMultiplier: 1.0,
Resolver: resolver,
})
require.NoError(t, err)
require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
}
func TestCalculateCostUnified_BillingModeFieldFilled(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: UsageTokens{InputTokens: 100},
RateMultiplier: 1.0,
Resolver: resolver,
})
require.NoError(t, err)
require.Equal(t, "token", cost.BillingMode)
}
func TestCalculateCostUnified_UsesPreResolvedPricing(t *testing.T) {
bs := newTestBillingService()
resolver := NewModelPricingResolver(nil, bs)
// Pre-resolve with per_request mode to verify it's used instead of re-resolving
preResolved := &ResolvedPricing{
Mode: BillingModePerRequest,
DefaultPerRequestPrice: 0.07,
}
cost, err := bs.CalculateCostUnified(CostInput{
Ctx: context.Background(),
Model: "claude-sonnet-4",
Tokens: UsageTokens{InputTokens: 100},
RequestCount: 2,
RateMultiplier: 1.0,
Resolver: resolver,
Resolved: preResolved,
})
require.NoError(t, err)
require.NotNil(t, cost)
// 2 * $0.07 = $0.14
require.InDelta(t, 0.14, cost.TotalCost, 1e-10)
require.Equal(t, string(BillingModePerRequest), cost.BillingMode)
}
// ---------------------------------------------------------------------------
// helpers
// ---------------------------------------------------------------------------
// newTestChannelServiceWithCache creates a ChannelService with a pre-populated
// cache snapshot, bypassing the repository layer entirely.
func newTestChannelServiceWithCache(t *testing.T, cache *channelCache) *ChannelService {
t.Helper()
cs := &ChannelService{}
cache.loadedAt = time.Now()
cs.cache.Store(cache)
return cs
}

View File

@ -37,8 +37,10 @@ type Channel struct {
Name string Name string
Description string Description string
Status string Status string
BillingModelSource string // "requested", "upstream", or "channel_mapped" BillingModelSource string // "requested", "upstream", or "channel_mapped"
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型) RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
Features string // 渠道特性描述JSON 数组),用于支付页面展示
FeaturesConfig map[string]any // 渠道功能配置(如 web search emulation
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
@ -48,6 +50,25 @@ type Channel struct {
ModelPricing []ChannelModelPricing ModelPricing []ChannelModelPricing
// 渠道级模型映射按平台分组platform → {src→dst} // 渠道级模型映射按平台分组platform → {src→dst}
ModelMapping map[string]map[string]string ModelMapping map[string]map[string]string
// 账号统计定价
ApplyPricingToAccountStats bool // 是否应用渠道模型定价到账号统计
AccountStatsPricingRules []AccountStatsPricingRule // 自定义账号统计定价规则(按 SortOrder 排序,先命中为准)
}
// AccountStatsPricingRule 账号统计定价规则
// 每条规则包含匹配条件(分组/账号)和独立的模型定价。
// 多条规则按 SortOrder 排序,先命中为准。
type AccountStatsPricingRule struct {
ID int64
ChannelID int64
Name string
GroupIDs []int64
AccountIDs []int64
SortOrder int
Pricing []ChannelModelPricing // 规则内的模型定价(复用现有定价结构)
CreatedAt time.Time
UpdatedAt time.Time
} }
// ChannelModelPricing 渠道模型定价条目 // ChannelModelPricing 渠道模型定价条目
@ -176,9 +197,58 @@ func (c *Channel) Clone() *Channel {
cp.ModelMapping[platform] = inner cp.ModelMapping[platform] = inner
} }
} }
if c.FeaturesConfig != nil {
cp.FeaturesConfig = deepCopyFeaturesConfig(c.FeaturesConfig)
}
if c.AccountStatsPricingRules != nil {
cp.AccountStatsPricingRules = make([]AccountStatsPricingRule, len(c.AccountStatsPricingRules))
for i, rule := range c.AccountStatsPricingRules {
cp.AccountStatsPricingRules[i] = rule
if rule.GroupIDs != nil {
cp.AccountStatsPricingRules[i].GroupIDs = make([]int64, len(rule.GroupIDs))
copy(cp.AccountStatsPricingRules[i].GroupIDs, rule.GroupIDs)
}
if rule.AccountIDs != nil {
cp.AccountStatsPricingRules[i].AccountIDs = make([]int64, len(rule.AccountIDs))
copy(cp.AccountStatsPricingRules[i].AccountIDs, rule.AccountIDs)
}
if rule.Pricing != nil {
cp.AccountStatsPricingRules[i].Pricing = make([]ChannelModelPricing, len(rule.Pricing))
for j := range rule.Pricing {
cp.AccountStatsPricingRules[i].Pricing[j] = rule.Pricing[j].Clone()
}
}
}
}
return &cp return &cp
} }
// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
if c == nil || c.FeaturesConfig == nil {
return false
}
wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any)
if !ok {
return false
}
enabled, ok := wse[platform].(bool)
return ok && enabled
}
// deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution.
func deepCopyFeaturesConfig(src map[string]any) map[string]any {
dst := make(map[string]any, len(src))
for k, v := range src {
if inner, ok := v.(map[string]any); ok {
dst[k] = deepCopyFeaturesConfig(inner)
} else {
dst[k] = v
}
}
return dst
}
// ValidateIntervals 校验区间列表的合法性。 // ValidateIntervals 校验区间列表的合法性。
// 规则MinTokens >= 0MaxTokens 若非 nil 则 > 0 且 > MinTokens // 规则MinTokens >= 0MaxTokens 若非 nil 则 > 0 且 > MinTokens
// 所有价格字段 >= 0区间按 MinTokens 排序后无重叠((min, max] 语义); // 所有价格字段 >= 0区间按 MinTokens 排序后无重叠((min, max] 语义);

View File

@ -81,9 +81,9 @@ type wildcardMappingEntry struct {
type channelCache struct { type channelCache struct {
// 热路径查找 // 热路径查找
pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价 pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价
wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序 wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(按配置顺序,先匹配先使用
mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标 mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标
wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(前缀长度降序 wildcardMappingByGP map[channelGroupPlatformKey][]*wildcardMappingEntry // (groupID, platform) → 通配符映射(按配置顺序,先匹配先使用
channelByGroupID map[int64]*Channel // groupID → 渠道 channelByGroupID map[int64]*Channel // groupID → 渠道
groupPlatform map[int64]string // groupID → platform groupPlatform map[int64]string // groupID → platform
@ -315,6 +315,7 @@ func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *
expandMappingToCache(cache, ch, gid, platform) expandMappingToCache(cache, ch, gid, platform)
} }
} }
return cache return cache
} }
@ -415,6 +416,15 @@ func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64)
return ch.Clone(), nil return ch.Clone(), nil
} }
// GetGroupPlatform 获取分组的平台标识(从缓存)
func (s *ChannelService) GetGroupPlatform(ctx context.Context, groupID int64) string {
cache, err := s.loadCache(ctx)
if err != nil {
return ""
}
return cache.groupPlatform[groupID]
}
// channelLookup 热路径公共查找结果 // channelLookup 热路径公共查找结果
type channelLookup struct { type channelLookup struct {
cache *channelCache cache *channelCache
@ -556,15 +566,21 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。 // validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
// Create 和 Update 共用此函数,避免重复。 // Create 和 Update 共用此函数,避免重复。
func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error { func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error {
if err := validatePricingEntries(pricing); err != nil {
return err
}
return validateNoConflictingMappings(mapping)
}
// validatePricingEntries 校验定价条目(冲突检测 + 区间校验 + 计费模式校验),
// 同时用于主渠道定价和 account_stats_pricing_rules 的内部定价。
func validatePricingEntries(pricing []ChannelModelPricing) error {
if err := validateNoConflictingModels(pricing); err != nil { if err := validateNoConflictingModels(pricing); err != nil {
return err return err
} }
if err := validatePricingIntervals(pricing); err != nil { if err := validatePricingIntervals(pricing); err != nil {
return err return err
} }
if err := validateNoConflictingMappings(mapping); err != nil {
return err
}
return validatePricingBillingMode(pricing) return validatePricingBillingMode(pricing)
} }
@ -655,14 +671,18 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
} }
channel := &Channel{ channel := &Channel{
Name: input.Name, Name: input.Name,
Description: input.Description, Description: input.Description,
Status: StatusActive, Status: StatusActive,
BillingModelSource: input.BillingModelSource, BillingModelSource: input.BillingModelSource,
RestrictModels: input.RestrictModels, RestrictModels: input.RestrictModels,
GroupIDs: input.GroupIDs, GroupIDs: input.GroupIDs,
ModelPricing: input.ModelPricing, ModelPricing: input.ModelPricing,
ModelMapping: input.ModelMapping, ModelMapping: input.ModelMapping,
Features: input.Features,
FeaturesConfig: input.FeaturesConfig,
ApplyPricingToAccountStats: input.ApplyPricingToAccountStats,
AccountStatsPricingRules: input.AccountStatsPricingRules,
} }
if channel.BillingModelSource == "" { if channel.BillingModelSource == "" {
channel.BillingModelSource = BillingModelSourceChannelMapped channel.BillingModelSource = BillingModelSourceChannelMapped
@ -671,6 +691,11 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil { if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err return nil, err
} }
for i, rule := range channel.AccountStatsPricingRules {
if err := validatePricingEntries(rule.Pricing); err != nil {
return nil, fmt.Errorf("account stats pricing rule #%d: %w", i+1, err)
}
}
if err := s.repo.Create(ctx, channel); err != nil { if err := s.repo.Create(ctx, channel); err != nil {
return nil, fmt.Errorf("create channel: %w", err) return nil, fmt.Errorf("create channel: %w", err)
@ -699,6 +724,11 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil { if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err return nil, err
} }
for i, rule := range channel.AccountStatsPricingRules {
if err := validatePricingEntries(rule.Pricing); err != nil {
return nil, fmt.Errorf("account stats pricing rule #%d: %w", i+1, err)
}
}
oldGroupIDs := s.getOldGroupIDs(ctx, id) oldGroupIDs := s.getOldGroupIDs(ctx, id)
@ -733,6 +763,9 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel,
if input.RestrictModels != nil { if input.RestrictModels != nil {
channel.RestrictModels = *input.RestrictModels channel.RestrictModels = *input.RestrictModels
} }
if input.Features != nil {
channel.Features = *input.Features
}
if input.GroupIDs != nil { if input.GroupIDs != nil {
if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil { if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil {
return err return err
@ -748,6 +781,15 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel,
if input.BillingModelSource != "" { if input.BillingModelSource != "" {
channel.BillingModelSource = input.BillingModelSource channel.BillingModelSource = input.BillingModelSource
} }
if input.FeaturesConfig != nil {
channel.FeaturesConfig = input.FeaturesConfig
}
if input.ApplyPricingToAccountStats != nil {
channel.ApplyPricingToAccountStats = *input.ApplyPricingToAccountStats
}
if input.AccountStatsPricingRules != nil {
channel.AccountStatsPricingRules = *input.AccountStatsPricingRules
}
return nil return nil
} }
@ -913,23 +955,31 @@ func detectConflicts(entries []modelEntry, platform, errCode, label string) erro
// CreateChannelInput 创建渠道输入 // CreateChannelInput 创建渠道输入
type CreateChannelInput struct { type CreateChannelInput struct {
Name string Name string
Description string Description string
GroupIDs []int64 GroupIDs []int64
ModelPricing []ChannelModelPricing ModelPricing []ChannelModelPricing
ModelMapping map[string]map[string]string // platform → {src→dst} ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string BillingModelSource string
RestrictModels bool RestrictModels bool
Features string
FeaturesConfig map[string]any
ApplyPricingToAccountStats bool
AccountStatsPricingRules []AccountStatsPricingRule
} }
// UpdateChannelInput 更新渠道输入 // UpdateChannelInput 更新渠道输入
type UpdateChannelInput struct { type UpdateChannelInput struct {
Name string Name string
Description *string Description *string
Status string Status string
GroupIDs *[]int64 GroupIDs *[]int64
ModelPricing *[]ChannelModelPricing ModelPricing *[]ChannelModelPricing
ModelMapping map[string]map[string]string // platform → {src→dst} ModelMapping map[string]map[string]string // platform → {src→dst}
BillingModelSource string BillingModelSource string
RestrictModels *bool RestrictModels *bool
Features *string
FeaturesConfig map[string]any
ApplyPricingToAccountStats *bool
AccountStatsPricingRules *[]AccountStatsPricingRule
} }

View File

@ -0,0 +1,62 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestChannel_IsWebSearchEmulationEnabled_Enabled(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
},
}
require.True(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_DifferentPlatform(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"anthropic": true},
},
}
require.False(t, c.IsWebSearchEmulationEnabled("openai"))
}
func TestChannel_IsWebSearchEmulationEnabled_Disabled(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"anthropic": false},
},
}
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_NilFeaturesConfig(t *testing.T) {
c := &Channel{FeaturesConfig: nil}
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_NilChannel(t *testing.T) {
var c *Channel
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_WrongStructure(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: true, // not a map
},
}
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}
func TestChannel_IsWebSearchEmulationEnabled_PlatformValueNotBool(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{"anthropic": "yes"},
},
}
require.False(t, c.IsWebSearchEmulationEnabled("anthropic"))
}

View File

@ -343,8 +343,9 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor
}() }()
} }
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts // GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts.
// Returns a map of accountID -> current concurrency count // Uses a detached context with timeout to prevent HTTP request cancellation from
// causing the entire batch to fail (which would show all concurrency as 0).
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 { if len(accountIDs) == 0 {
return map[int64]int{}, nil return map[int64]int{}, nil
@ -356,5 +357,11 @@ func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, acc
} }
return result, nil return result, nil
} }
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
// Use a detached context so that a cancelled HTTP request doesn't cause
// the Redis pipeline to fail and return all-zero concurrency counts.
redisCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
return s.cache.GetAccountConcurrencyBatch(redisCtx, accountIDs)
} }

View File

@ -249,6 +249,18 @@ const (
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough" SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false // SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false
SettingKeyEnableCCHSigning = "enable_cch_signing" SettingKeyEnableCCHSigning = "enable_cch_signing"
// Balance Low Notification
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值USD
SettingKeyBalanceLowNotifyRechargeURL = "balance_low_notify_recharge_url" // 充值页面 URL
// Account Quota Notification
SettingKeyAccountQuotaNotifyEnabled = "account_quota_notify_enabled" // 全局开关
SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表JSON 数组)
// Web Search Emulation
SettingKeyWebSearchEmulationConfig = "web_search_emulation_config" // JSON 配置
) )
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).

View File

@ -7,8 +7,9 @@ import (
"crypto/tls" "crypto/tls"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"log" "log/slog"
"math/big" "math/big"
"net"
"net/smtp" "net/smtp"
"net/url" "net/url"
"strconv" "strconv"
@ -34,6 +35,11 @@ type EmailCache interface {
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
DeleteVerificationCode(ctx context.Context, email string) error DeleteVerificationCode(ctx context.Context, email string) error
// Notify email verification code methods
GetNotifyVerifyCode(ctx context.Context, email string) (*VerificationCodeData, error)
SetNotifyVerifyCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
DeleteNotifyVerifyCode(ctx context.Context, email string) error
// Password reset token methods // Password reset token methods
GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error)
SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error
@ -43,6 +49,10 @@ type EmailCache interface {
// Returns true if in cooldown period (email was sent recently) // Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool
SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error
// Notify code rate limiting per user
IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error)
GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error)
} }
// VerificationCodeData represents verification code data // VerificationCodeData represents verification code data
@ -50,6 +60,7 @@ type VerificationCodeData struct {
Code string Code string
Attempts int Attempts int
CreatedAt time.Time CreatedAt time.Time
ExpiresAt time.Time // absolute expiry; used to preserve remaining TTL when updating attempts
} }
// PasswordResetTokenData represents password reset token data // PasswordResetTokenData represents password reset token data
@ -146,11 +157,18 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
return s.SendEmailWithConfig(config, to, subject, body) return s.SendEmailWithConfig(config, to, subject, body)
} }
const smtpDialTimeout = 10 * time.Second
const smtpIOTimeout = 20 * time.Second
// SendEmailWithConfig 使用指定配置发送邮件 // SendEmailWithConfig 使用指定配置发送邮件
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error { func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
from := config.From // Sanitize all SMTP header fields to prevent header injection (CR/LF removal).
to = sanitizeEmailHeader(to)
subject = sanitizeEmailHeader(subject)
from := sanitizeEmailHeader(config.From)
if config.FromName != "" { if config.FromName != "" {
from = fmt.Sprintf("%s <%s>", config.FromName, config.From) from = fmt.Sprintf("%s <%s>", sanitizeEmailHeader(config.FromName), sanitizeEmailHeader(config.From))
} }
msg := fmt.Sprintf("From: %s\r\nTo: %s\r\nSubject: %s\r\nMIME-Version: 1.0\r\nContent-Type: text/html; charset=UTF-8\r\n\r\n%s", msg := fmt.Sprintf("From: %s\r\nTo: %s\r\nSubject: %s\r\nMIME-Version: 1.0\r\nContent-Type: text/html; charset=UTF-8\r\n\r\n%s",
@ -163,7 +181,54 @@ func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body
return s.sendMailTLS(addr, auth, config.From, to, []byte(msg), config.Host) return s.sendMailTLS(addr, auth, config.From, to, []byte(msg), config.Host)
} }
return smtp.SendMail(addr, auth, config.From, []string{to}, []byte(msg)) return s.sendMailPlain(addr, auth, config.From, to, []byte(msg), config.Host)
}
// sendMailPlain sends mail without TLS using a dialer with timeout.
func (s *EmailService) sendMailPlain(addr string, auth smtp.Auth, from, to string, msg []byte, host string) error {
dialer := &net.Dialer{Timeout: smtpDialTimeout}
conn, err := dialer.Dial("tcp", addr)
if err != nil {
return fmt.Errorf("smtp dial: %w", err)
}
_ = conn.SetDeadline(time.Now().Add(smtpIOTimeout))
defer func() { _ = conn.Close() }()
client, err := smtp.NewClient(conn, host)
if err != nil {
return fmt.Errorf("new smtp client: %w", err)
}
defer func() { _ = client.Close() }()
// Opportunistic STARTTLS: upgrade to encrypted connection if the server supports it.
// This mirrors the behavior of smtp.SendMail which we replaced for timeout support.
if ok, _ := client.Extension("STARTTLS"); ok {
if err = client.StartTLS(&tls.Config{ServerName: host, MinVersion: tls.VersionTLS12}); err != nil {
return fmt.Errorf("starttls: %w", err)
}
}
if err = client.Auth(auth); err != nil {
return fmt.Errorf("smtp auth: %w", err)
}
if err = client.Mail(from); err != nil {
return fmt.Errorf("smtp mail: %w", err)
}
if err = client.Rcpt(to); err != nil {
return fmt.Errorf("smtp rcpt: %w", err)
}
w, err := client.Data()
if err != nil {
return fmt.Errorf("smtp data: %w", err)
}
if _, err = w.Write(msg); err != nil {
return fmt.Errorf("write msg: %w", err)
}
if err = w.Close(); err != nil {
return fmt.Errorf("close writer: %w", err)
}
_ = client.Quit()
return nil
} }
// sendMailTLS 使用TLS发送邮件 // sendMailTLS 使用TLS发送邮件
@ -174,10 +239,12 @@ func (s *EmailService) sendMailTLS(addr string, auth smtp.Auth, from, to string,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
} }
conn, err := tls.Dial("tcp", addr, tlsConfig) dialer := &net.Dialer{Timeout: smtpDialTimeout}
conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
if err != nil { if err != nil {
return fmt.Errorf("tls dial: %w", err) return fmt.Errorf("tls dial: %w", err)
} }
_ = conn.SetDeadline(time.Now().Add(smtpIOTimeout))
defer func() { _ = conn.Close() }() defer func() { _ = conn.Close() }()
client, err := smtp.NewClient(conn, host) client, err := smtp.NewClient(conn, host)
@ -254,6 +321,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
Code: code, Code: code,
Attempts: 0, Attempts: 0,
CreatedAt: time.Now(), CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(verifyCodeTTL),
} }
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
return fmt.Errorf("save verify code: %w", err) return fmt.Errorf("save verify code: %w", err)
@ -286,8 +354,12 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证码不匹配 (constant-time comparison to prevent timing attacks) // 验证码不匹配 (constant-time comparison to prevent timing attacks)
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 { if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
data.Attempts++ data.Attempts++
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { remaining := time.Until(data.ExpiresAt)
log.Printf("[Email] Failed to update verification attempt count: %v", err) if remaining <= 0 {
return ErrInvalidVerifyCode
}
if err := s.cache.SetVerificationCode(ctx, email, data, remaining); err != nil {
slog.Error("failed to update verification attempt count", "email", email, "error", err)
} }
if data.Attempts >= maxVerifyCodeAttempts { if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts return ErrVerifyCodeMaxAttempts
@ -297,7 +369,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证成功,删除验证码 // 验证成功,删除验证码
if err := s.cache.DeleteVerificationCode(ctx, email); err != nil { if err := s.cache.DeleteVerificationCode(ctx, email); err != nil {
log.Printf("[Email] Failed to delete verification code after success: %v", err) slog.Error("failed to delete verification code after success", "email", email, "error", err)
} }
return nil return nil
} }
@ -447,7 +519,7 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error { func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
// Check email cooldown to prevent email bombing // Check email cooldown to prevent email bombing
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) { if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
log.Printf("[Email] Password reset email skipped (cooldown): %s", email) slog.Info("password reset email skipped due to cooldown", "email", email)
return nil // Silent success to prevent revealing cooldown to attackers return nil // Silent success to prevent revealing cooldown to attackers
} }
@ -458,7 +530,7 @@ func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, e
// Set cooldown marker (Redis TTL handles expiration) // Set cooldown marker (Redis TTL handles expiration)
if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil { if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil {
log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err) slog.Error("failed to set password reset cooldown", "email", email, "error", err)
} }
return nil return nil
@ -488,7 +560,7 @@ func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, tok
// Delete after verification (one-time use) // Delete after verification (one-time use)
if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil { if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil {
log.Printf("[Email] Failed to delete password reset token after consumption: %v", err) slog.Error("failed to delete password reset token after consumption", "email", email, "error", err)
} }
return nil return nil
} }

View File

@ -43,6 +43,7 @@ func newGatewayRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo
nil, nil,
nil, nil,
nil, nil,
nil,
) )
} }

View File

@ -75,6 +75,9 @@ type ParsedRequest struct {
MaxTokens int // max_tokens 值(用于探测请求拦截) MaxTokens int // max_tokens 值(用于探测请求拦截)
SessionContext *SessionContext // 可选请求上下文区分因子nil 时行为不变) SessionContext *SessionContext // 可选请求上下文区分因子nil 时行为不变)
// GroupID 请求所属分组 ID来自 API Key
GroupID *int64
// OnUpstreamAccepted 上游接受请求后立即调用(用于提前释放串行锁) // OnUpstreamAccepted 上游接受请求后立即调用(用于提前释放串行锁)
// 流式请求在收到 2xx 响应头后调用,避免持锁等流完成 // 流式请求在收到 2xx 响应头后调用,避免持锁等流完成
OnUpstreamAccepted func() OnUpstreamAccepted func()

View File

@ -503,7 +503,6 @@ type ForwardResult struct {
// 图片生成计费字段(图片生成模型使用) // 图片生成计费字段(图片生成模型使用)
ImageCount int // 生成的图片数量 ImageCount int // 生成的图片数量
ImageSize string // 图片尺寸 "1K", "2K", "4K" ImageSize string // 图片尺寸 "1K", "2K", "4K"
} }
// UpstreamFailoverError indicates an upstream error that should trigger account failover. // UpstreamFailoverError indicates an upstream error that should trigger account failover.
@ -570,6 +569,7 @@ type GatewayService struct {
resolver *ModelPricingResolver resolver *ModelPricingResolver
debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set debugGatewayBodyFile atomic.Pointer[os.File] // non-nil when SUB2API_DEBUG_GATEWAY_BODY is set
tlsFPProfileService *TLSFingerprintProfileService tlsFPProfileService *TLSFingerprintProfileService
balanceNotifyService *BalanceNotifyService
} }
// NewGatewayService creates a new GatewayService // NewGatewayService creates a new GatewayService
@ -599,6 +599,7 @@ func NewGatewayService(
tlsFPProfileService *TLSFingerprintProfileService, tlsFPProfileService *TLSFingerprintProfileService,
channelService *ChannelService, channelService *ChannelService,
resolver *ModelPricingResolver, resolver *ModelPricingResolver,
balanceNotifyService *BalanceNotifyService,
) *GatewayService { ) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
modelsListTTL := resolveModelsListCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg)
@ -633,6 +634,7 @@ func NewGatewayService(
tlsFPProfileService: tlsFPProfileService, tlsFPProfileService: tlsFPProfileService,
channelService: channelService, channelService: channelService,
resolver: resolver, resolver: resolver,
balanceNotifyService: balanceNotifyService,
} }
svc.userGroupRateResolver = newUserGroupRateResolver( svc.userGroupRateResolver = newUserGroupRateResolver(
userGroupRateRepo, userGroupRateRepo,
@ -1329,6 +1331,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
ctx = s.withWindowCostPrefetch(ctx, accounts) ctx = s.withWindowCostPrefetch(ctx, accounts)
ctx = s.withRPMPrefetch(ctx, accounts) ctx = s.withRPMPrefetch(ctx, accounts)
// 提前构建 accountByID供 Layer 1 和 Layer 1.5 使用)
accountByID := make(map[int64]*Account, len(accounts))
for i := range accounts {
accountByID[accounts[i].ID] = &accounts[i]
}
isExcluded := func(accountID int64) bool { isExcluded := func(accountID int64) bool {
if excludedIDs == nil { if excludedIDs == nil {
return false return false
@ -1337,12 +1344,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return excluded return excluded
} }
// 提前构建 accountByID供 Layer 1 和 Layer 1.5 使用)
accountByID := make(map[int64]*Account, len(accounts))
for i := range accounts {
accountByID[accounts[i].ID] = &accounts[i]
}
// 获取模型路由配置(仅 anthropic 平台) // 获取模型路由配置(仅 anthropic 平台)
var routingAccountIDs []int64 var routingAccountIDs []int64
if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic { if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic {
@ -1598,7 +1599,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
account, ok := accountByID[accountID] account, ok := accountByID[accountID]
if ok { if ok {
// 检查账户是否需要清理粘性会话绑定 // 检查账户是否需要清理粘性会话绑定
// Check if the account needs sticky session cleanup
clearSticky := shouldClearStickySession(account, requestedModel) clearSticky := shouldClearStickySession(account, requestedModel)
if clearSticky { if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
@ -1614,7 +1614,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired { if err == nil && result.Acquired {
// 会话数量限制检查 // 会话数量限制检查
// Session count limit check
if !s.checkAndRegisterSession(ctx, account, sessionHash) { if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2 result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else { } else {
@ -1628,10 +1627,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting { if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额) // 会话数量限制检查(等待计划也需要占用会话配额)
// Session count limit check (wait plan also requires session quota)
if !s.checkAndRegisterSession(ctx, account, sessionHash) { if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2 // 会话限制已满,继续到 Layer 2
// Session limit full, continue to Layer 2
} else { } else {
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
AccountID: accountID, AccountID: accountID,
@ -2740,7 +2737,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky { if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
} }
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
if s.debugModelRoutingEnabled() { if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
} }
@ -3119,7 +3116,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky { if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
} }
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
return account, nil return account, nil
} }
@ -3435,6 +3432,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
_, ok := ResolveBedrockModelID(account, requestedModel) _, ok := ResolveBedrockModelID(account, requestedModel)
return ok return ok
} }
// OpenAI 透传模式:仅替换认证,允许所有模型
if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() {
return true
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射短ID → 长ID // OAuth/SetupToken 账号使用 Anthropic 标准映射短ID → 长ID
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
requestedModel = claude.NormalizeModelID(requestedModel) requestedModel = claude.NormalizeModelID(requestedModel)
@ -3934,6 +3935,11 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, fmt.Errorf("parse request: empty request") return nil, fmt.Errorf("parse request: empty request")
} }
// Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应
if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) {
return s.handleWebSearchEmulation(ctx, c, account, parsed)
}
if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() {
passthroughBody := parsed.Body passthroughBody := parsed.Body
passthroughModel := parsed.Model passthroughModel := parsed.Model
@ -7279,6 +7285,7 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
// RecordUsageInput 记录使用量的输入参数 // RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct { type RecordUsageInput struct {
Result *ForwardResult Result *ForwardResult
ParsedRequest *ParsedRequest
APIKey *APIKey APIKey *APIKey
User *User User *User
Account *Account Account *Account
@ -7333,49 +7340,41 @@ func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool {
return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit()
} }
// postUsageBilling 统一处理使用量记录后的扣费逻辑: // postUsageBilling is the legacy fallback billing path used when the unified
// - 订阅/余额扣费 // billing repo is unavailable (nil). Production uses applyUsageBilling → repo.Apply
// - API Key 配额更新 // for atomic billing. This path only runs in tests or degraded mode.
// - API Key 限速用量更新
// - 账号配额用量更新账号口径TotalCost × 账号计费倍率)
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) { func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
billingCtx, cancel := detachedBillingContext(ctx) billingCtx, cancel := detachedBillingContext(ctx)
defer cancel() defer cancel()
cost := p.Cost cost := p.Cost
// 1. 订阅 / 余额扣费
if p.IsSubscriptionBill { if p.IsSubscriptionBill {
if cost.TotalCost > 0 { if cost.TotalCost > 0 {
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil { if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
} }
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
} }
} else { } else {
if cost.ActualCost > 0 { if cost.ActualCost > 0 {
if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil { if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil {
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err) slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
} }
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
} }
} }
// 2. API Key 配额
if p.shouldDeductAPIKeyQuota() { if p.shouldDeductAPIKeyQuota() {
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
} }
} }
// 3. API Key 限速用量
if p.shouldUpdateRateLimits() { if p.shouldUpdateRateLimits() {
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
} }
} }
// 4. 账号配额用量账号口径TotalCost × 账号计费倍率)
if p.shouldUpdateAccountQuota() { if p.shouldUpdateAccountQuota() {
accountCost := cost.TotalCost * p.AccountRateMultiplier accountCost := cost.TotalCost * p.AccountRateMultiplier
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
@ -7383,7 +7382,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
} }
} }
finalizePostUsageBilling(p, deps) // NOTE: finalizePostUsageBilling is NOT called here to avoid double-queuing
// cache updates. The legacy path does DB writes directly; the finalize path
// does cache queue + notifications. Notifications are dispatched separately
// by the caller after recording the usage log.
} }
func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string { func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string {
@ -7499,11 +7501,11 @@ func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog
} }
} }
finalizePostUsageBilling(p, deps) finalizePostUsageBilling(p, deps, result)
return true, nil return true, nil
} }
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) { func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
if p == nil || p.Cost == nil || deps == nil { if p == nil || p.Cost == nil || deps == nil {
return return
} }
@ -7521,6 +7523,83 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
} }
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
// Notification checks run async — all parameters are already captured,
// no dependency on the request context or upstream connection.
go notifyBalanceLow(p, deps, result)
go notifyAccountQuota(p, deps, result)
}
// notifyBalanceLow sends balance low notification after deduction.
// When result.NewBalance is available (from DB transaction RETURNING), it is used directly
// to reconstruct oldBalance, avoiding stale Redis reads and concurrent-deduction races.
func notifyBalanceLow(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in notifyBalanceLow", "recover", r)
}
}()
if p.IsSubscriptionBill || p.Cost.ActualCost <= 0 || p.User == nil || deps.balanceNotifyService == nil {
slog.Debug("notifyBalanceLow: skipped",
"is_subscription", p.IsSubscriptionBill,
"actual_cost", p.Cost.ActualCost,
"user_nil", p.User == nil,
"service_nil", deps.balanceNotifyService == nil,
)
return
}
oldBalance := resolveOldBalance(p, result)
slog.Debug("notifyBalanceLow: calling CheckBalanceAfterDeduction",
"user_id", p.User.ID,
"old_balance", oldBalance,
"cost", p.Cost.ActualCost,
"notify_enabled", p.User.BalanceNotifyEnabled,
"threshold", p.User.BalanceNotifyThreshold,
"result_has_new_balance", result != nil && result.NewBalance != nil,
)
deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, oldBalance, p.Cost.ActualCost)
}
// resolveOldBalance returns the pre-deduction balance.
// Prefers the DB transaction result (newBalance + cost) over snapshot.
func resolveOldBalance(p *postUsageBillingParams, result *UsageBillingApplyResult) float64 {
if result != nil && result.NewBalance != nil {
return *result.NewBalance + p.Cost.ActualCost
}
// Legacy fallback: snapshot balance from request context
return p.User.Balance
}
// notifyAccountQuota sends account quota threshold notification after increment.
// When result.QuotaState is available (from DB transaction RETURNING), it is passed directly
// to avoid a separate DB read that may see stale or concurrently-modified data.
func notifyAccountQuota(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
defer func() {
if r := recover(); r != nil {
slog.Error("panic in notifyAccountQuota", "recover", r)
}
}()
if p.Cost.TotalCost <= 0 || p.Account == nil || !p.Account.IsAPIKeyOrBedrock() || deps.balanceNotifyService == nil {
slog.Debug("notifyAccountQuota: skipped",
"total_cost", p.Cost.TotalCost,
"account_nil", p.Account == nil,
"is_apikey_or_bedrock", p.Account != nil && p.Account.IsAPIKeyOrBedrock(),
"service_nil", deps.balanceNotifyService == nil,
)
return
}
accountCost := p.Cost.TotalCost * p.AccountRateMultiplier
var quotaState *AccountQuotaState
if result != nil {
quotaState = result.QuotaState
}
slog.Debug("notifyAccountQuota: calling CheckAccountQuotaAfterIncrement",
"account_id", p.Account.ID,
"account_cost", accountCost,
"has_quota_state", quotaState != nil,
)
deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost, quotaState)
} }
func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) { func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) {
@ -7543,20 +7622,22 @@ func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Cont
// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供) // billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供)
type billingDeps struct { type billingDeps struct {
accountRepo AccountRepository accountRepo AccountRepository
userRepo UserRepository userRepo UserRepository
userSubRepo UserSubscriptionRepository userSubRepo UserSubscriptionRepository
billingCacheService *BillingCacheService billingCacheService *BillingCacheService
deferredService *DeferredService deferredService *DeferredService
balanceNotifyService *BalanceNotifyService
} }
func (s *GatewayService) billingDeps() *billingDeps { func (s *GatewayService) billingDeps() *billingDeps {
return &billingDeps{ return &billingDeps{
accountRepo: s.accountRepo, accountRepo: s.accountRepo,
userRepo: s.userRepo, userRepo: s.userRepo,
userSubRepo: s.userSubRepo, userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService, billingCacheService: s.billingCacheService,
deferredService: s.deferredService, deferredService: s.deferredService,
balanceNotifyService: s.balanceNotifyService,
} }
} }
@ -7746,6 +7827,23 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription, usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription,
requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts) requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts)
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
if apiKey.GroupID != nil {
applyAccountStatsCost(ctx, usageLog, s.channelService, s.billingService,
account.ID, *apiKey.GroupID, result.UpstreamModel, result.Model,
// Anthropic's input_tokens excludes cache_read and cache_creation (billed separately);
// OpenAI gateway uses actualInputTokens which also excludes cache_read for the same reason.
UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
ImageOutputTokens: result.Usage.ImageOutputTokens,
},
cost.TotalCost,
)
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway")
logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
@ -8086,6 +8184,19 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex
return ch.BillingModelSource == BillingModelSourceUpstream return ch.BillingModelSource == BillingModelSourceUpstream
} }
// isStickyAccountUpstreamRestricted 检查粘性会话命中的账号是否受 upstream 渠道限制。
// 合并 needsUpstreamChannelRestrictionCheck + isUpstreamModelRestrictedByChannel 两步调用,
// 供 sticky session 条件链使用,避免内联多个函数调用导致行过长。
func (s *GatewayService) isStickyAccountUpstreamRestricted(ctx context.Context, groupID *int64, account *Account, requestedModel string) bool {
if groupID == nil {
return false
}
if !s.needsUpstreamChannelRestrictionCheck(ctx, groupID) {
return false
}
return s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel)
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API // ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应 // 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {

View File

@ -0,0 +1,394 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/tidwall/gjson"
)
// Web search emulation constants
const (
toolTypeWebSearchPrefix = "web_search"
toolTypeGoogleSearch = "google_search"
toolNameWebSearch = "web_search"
toolNameGoogleSearch = "google_search"
toolNameWebSearch2025 = "web_search_20250305"
webSearchDefaultMaxResults = 5
defaultWebSearchModel = "claude-sonnet-4-6"
webSearchMsgIDPrefix = "msg_ws_"
webSearchToolUseIDPrefix = "srvtoolu_ws_"
tokenEstimateDivisor = 4
// featureKeyWebSearchEmulation is the key used in Account.Extra and Channel.FeaturesConfig.
featureKeyWebSearchEmulation = "web_search_emulation"
)
// webSearchManagerPtr stores *websearch.Manager atomically for concurrent safety.
var webSearchManagerPtr atomic.Pointer[websearch.Manager]
// SetWebSearchManager wires the websearch.Manager into the gateway (goroutine-safe).
func SetWebSearchManager(m *websearch.Manager) {
webSearchManagerPtr.Store(m)
}
func getWebSearchManager() *websearch.Manager {
return webSearchManagerPtr.Load()
}
// shouldEmulateWebSearch checks whether a request should be intercepted.
//
// Judgment chain: manager exists → only web_search tool → global enabled → account/channel enabled.
// Account-level mode: "enabled" (force on), "disabled" (force off), "default" (follow channel).
func (s *GatewayService) shouldEmulateWebSearch(ctx context.Context, account *Account, groupID *int64, body []byte) bool {
if getWebSearchManager() == nil {
return false
}
if !isOnlyWebSearchToolInBody(body) {
return false
}
if !s.settingService.IsWebSearchEmulationEnabled(ctx) {
return false
}
mode := account.GetWebSearchEmulationMode()
switch mode {
case WebSearchModeEnabled:
return true
case WebSearchModeDisabled:
return false
default: // "default" → follow channel config
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil || ch == nil {
return false
}
return ch.IsWebSearchEmulationEnabled(account.Platform)
}
}
// isOnlyWebSearchToolInBody checks if the body contains exactly one web_search tool.
func isOnlyWebSearchToolInBody(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if !tools.IsArray() {
return false
}
arr := tools.Array()
if len(arr) != 1 {
return false
}
return isWebSearchToolJSON(arr[0])
}
func isWebSearchToolJSON(tool gjson.Result) bool {
toolType := tool.Get("type").String()
if strings.HasPrefix(toolType, toolTypeWebSearchPrefix) || toolType == toolTypeGoogleSearch {
return true
}
switch tool.Get("name").String() {
case toolNameWebSearch, toolNameGoogleSearch, toolNameWebSearch2025:
return true
}
return false
}
// extractSearchQueryFromBody extracts the last user message text as the search query.
func extractSearchQueryFromBody(body []byte) string {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() {
return ""
}
arr := messages.Array()
if len(arr) == 0 {
return ""
}
lastMsg := arr[len(arr)-1]
if lastMsg.Get("role").String() != "user" {
return ""
}
return extractWebSearchTextFromContent(lastMsg.Get("content"))
}
func extractWebSearchTextFromContent(content gjson.Result) string {
if content.Type == gjson.String {
return content.String()
}
if content.IsArray() {
for _, block := range content.Array() {
if block.Get("type").String() == "text" {
if text := block.Get("text").String(); text != "" {
return text
}
}
}
}
return ""
}
// handleWebSearchEmulation intercepts a web-search-only request,
// calls a third-party search API, and constructs an Anthropic-format response.
func (s *GatewayService) handleWebSearchEmulation(
ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest,
) (*ForwardResult, error) {
startTime := time.Now()
// Release the serial queue lock immediately — we don't need upstream.
if parsed.OnUpstreamAccepted != nil {
parsed.OnUpstreamAccepted()
}
query := extractSearchQueryFromBody(parsed.Body)
if query == "" {
return nil, fmt.Errorf("web search emulation: no query found in messages")
}
slog.Info("web search emulation: executing search",
"account_id", account.ID, "account_name", account.Name, "query", query)
resp, providerName, err := doWebSearch(ctx, account, query)
if err != nil {
// Proxy unavailable → trigger account switch via UpstreamFailoverError
if errors.Is(err, websearch.ErrProxyUnavailable) {
return nil, &UpstreamFailoverError{
StatusCode: http.StatusBadGateway,
ResponseBody: []byte(err.Error()),
}
}
return nil, err
}
slog.Info("web search emulation: search completed",
"provider", providerName, "results_count", len(resp.Results))
model := parsed.Model
if model == "" {
model = defaultWebSearchModel
}
if parsed.Stream {
return writeWebSearchStreamResponse(c, query, resp, model, startTime)
}
return writeWebSearchNonStreamResponse(c, query, resp, model, startTime)
}
func doWebSearch(ctx context.Context, account *Account, query string) (*websearch.SearchResponse, string, error) {
proxyURL := resolveAccountProxyURL(account)
mgr := getWebSearchManager()
if mgr == nil {
return nil, "", fmt.Errorf("web search emulation: manager not initialized")
}
resp, providerName, err := mgr.SearchWithBestProvider(ctx, websearch.SearchRequest{
Query: query, MaxResults: webSearchDefaultMaxResults, ProxyURL: proxyURL,
})
if err != nil {
slog.Error("web search emulation: search failed", "error", err)
return nil, "", fmt.Errorf("web search emulation: %w", err)
}
return resp, providerName, nil
}
func resolveAccountProxyURL(account *Account) string {
if account.ProxyID != nil && account.Proxy != nil {
return account.Proxy.URL()
}
return ""
}
// --- SSE streaming response ---
func writeWebSearchStreamResponse(
c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
) (*ForwardResult, error) {
msgID := webSearchMsgIDPrefix + uuid.New().String()
toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
textSummary := buildTextSummary(query, resp.Results)
setSSEHeaders(c)
w := c.Writer
for _, fn := range []func() error{
func() error { return writeSSEMessageStart(w, msgID, model) },
func() error { return writeSSEServerToolUse(w, toolUseID, query, 0) },
func() error { return writeSSEToolResult(w, toolUseID, resp.Results, 1) },
func() error { return writeSSETextBlock(w, textSummary, 2) },
func() error { return writeSSEMessageEnd(w, len(textSummary)/tokenEstimateDivisor) },
} {
if err := fn(); err != nil {
slog.Warn("web search emulation: SSE write failed, stopping", "error", err)
break
}
}
w.Flush()
return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
}
func setSSEHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
}
func writeSSEMessageStart(w http.ResponseWriter, msgID, model string) error {
evt := map[string]any{
"type": "message_start",
"message": map[string]any{
"id": msgID, "type": "message", "role": "assistant", "model": model,
"content": []any{}, "stop_reason": nil, "stop_sequence": nil,
"usage": map[string]int{"input_tokens": 0, "output_tokens": 0},
},
}
return flushSSEJSON(w, "message_start", evt)
}
func writeSSEServerToolUse(w http.ResponseWriter, toolUseID, query string, index int) error {
start := map[string]any{
"type": "content_block_start", "index": index,
"content_block": map[string]any{
"type": "server_tool_use", "id": toolUseID,
"name": toolNameWebSearch, "input": map[string]string{"query": query},
},
}
if err := flushSSEJSON(w, "content_block_start", start); err != nil {
return err
}
return flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
}
func writeSSEToolResult(w http.ResponseWriter, toolUseID string, results []websearch.SearchResult, index int) error {
start := map[string]any{
"type": "content_block_start", "index": index,
"content_block": map[string]any{
"type": "web_search_tool_result", "tool_use_id": toolUseID,
"content": buildSearchResultBlocks(results),
},
}
if err := flushSSEJSON(w, "content_block_start", start); err != nil {
return err
}
return flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
}
func writeSSETextBlock(w http.ResponseWriter, text string, index int) error {
if err := flushSSEJSON(w, "content_block_start", map[string]any{
"type": "content_block_start", "index": index,
"content_block": map[string]any{"type": "text", "text": ""},
}); err != nil {
return err
}
if err := flushSSEJSON(w, "content_block_delta", map[string]any{
"type": "content_block_delta", "index": index,
"delta": map[string]string{"type": "text_delta", "text": text},
}); err != nil {
return err
}
return flushSSEJSON(w, "content_block_stop", map[string]any{"type": "content_block_stop", "index": index})
}
func writeSSEMessageEnd(w http.ResponseWriter, outputTokens int) error {
if err := flushSSEJSON(w, "message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{"stop_reason": "end_turn", "stop_sequence": nil},
"usage": map[string]int{"output_tokens": outputTokens},
}); err != nil {
return err
}
return flushSSEJSON(w, "message_stop", map[string]string{"type": "message_stop"})
}
// flushSSEJSON marshals data to JSON and writes an SSE event.
func flushSSEJSON(w http.ResponseWriter, event string, data any) error {
b, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("marshal: %w", err)
}
if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", event, b); err != nil {
return fmt.Errorf("write: %w", err)
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
return nil
}
// --- Non-streaming JSON response ---
func writeWebSearchNonStreamResponse(
c *gin.Context, query string, resp *websearch.SearchResponse, model string, startTime time.Time,
) (*ForwardResult, error) {
msgID := webSearchMsgIDPrefix + uuid.New().String()
toolUseID := webSearchToolUseIDPrefix + uuid.New().String()[:16]
textSummary := buildTextSummary(query, resp.Results)
msg := map[string]any{
"id": msgID, "type": "message", "role": "assistant", "model": model,
"content": []any{
map[string]any{
"type": "server_tool_use", "id": toolUseID,
"name": toolNameWebSearch, "input": map[string]string{"query": query},
},
map[string]any{
"type": "web_search_tool_result", "tool_use_id": toolUseID,
"content": buildSearchResultBlocks(resp.Results),
},
map[string]any{"type": "text", "text": textSummary},
},
"stop_reason": "end_turn", "stop_sequence": nil,
"usage": map[string]int{"input_tokens": 0, "output_tokens": len(textSummary) / tokenEstimateDivisor},
}
body, err := json.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("web search emulation: marshal response: %w", err)
}
c.Data(http.StatusOK, "application/json", body)
return &ForwardResult{Model: model, Duration: time.Since(startTime), Usage: ClaudeUsage{}}, nil
}
// --- Helpers ---
func buildSearchResultBlocks(results []websearch.SearchResult) []map[string]string {
blocks := make([]map[string]string, 0, len(results))
for _, r := range results {
block := map[string]string{
"type": "web_search_result",
"url": r.URL,
"title": r.Title,
}
if r.Snippet != "" {
block["page_content"] = r.Snippet
}
if r.PageAge != "" {
block["page_age"] = r.PageAge
}
blocks = append(blocks, block)
}
return blocks
}
func buildTextSummary(query string, results []websearch.SearchResult) string {
if len(results) == 0 {
return "No search results found for: " + query
}
var sb strings.Builder
fmt.Fprintf(&sb, "Here are the search results for \"%s\":\n\n", query)
for i, r := range results {
fmt.Fprintf(&sb, "%d. **%s**\n %s\n %s\n\n", i+1, r.Title, r.URL, r.Snippet)
}
return sb.String()
}

View File

@ -0,0 +1,380 @@
//go:build unit
package service
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
"github.com/stretchr/testify/require"
)
// --- isOnlyWebSearchToolInBody ---
func TestIsOnlyWebSearchToolInBody_WebSearchType(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_WebSearch2025Type(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"web_search_20250305"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_GoogleSearchType(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"google_search"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_NameWebSearch(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_NameWebSearch2025(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"web_search_20250305"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_NameGoogleSearch(t *testing.T) {
require.True(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"name":"google_search"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_MultipleTools(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody(
[]byte(`{"tools":[{"type":"web_search"},{"type":"text_editor"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_NoTools(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"model":"claude-3"}`)))
}
func TestIsOnlyWebSearchToolInBody_EmptyToolsArray(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[]}`)))
}
func TestIsOnlyWebSearchToolInBody_NonWebSearchTool(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":[{"type":"text_editor"}]}`)))
}
func TestIsOnlyWebSearchToolInBody_ToolsNotArray(t *testing.T) {
require.False(t, isOnlyWebSearchToolInBody([]byte(`{"tools":"web_search"}`)))
}
// --- extractSearchQueryFromBody ---
func TestExtractSearchQueryFromBody_StringContent(t *testing.T) {
body := `{"messages":[{"role":"user","content":"what is golang"}]}`
require.Equal(t, "what is golang", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_ArrayContent(t *testing.T) {
body := `{"messages":[{"role":"user","content":[{"type":"text","text":"search this"}]}]}`
require.Equal(t, "search this", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_MultipleMessages(t *testing.T) {
body := `{"messages":[{"role":"user","content":"first"},{"role":"assistant","content":"ok"},{"role":"user","content":"second"}]}`
require.Equal(t, "second", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_LastMessageNotUser(t *testing.T) {
body := `{"messages":[{"role":"user","content":"q"},{"role":"assistant","content":"a"}]}`
require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_EmptyMessages(t *testing.T) {
require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"messages":[]}`)))
}
func TestExtractSearchQueryFromBody_NoMessages(t *testing.T) {
require.Equal(t, "", extractSearchQueryFromBody([]byte(`{"model":"claude-3"}`)))
}
func TestExtractSearchQueryFromBody_ArrayContentSkipsEmptyText(t *testing.T) {
body := `{"messages":[{"role":"user","content":[{"type":"image"},{"type":"text","text":""},{"type":"text","text":"real query"}]}]}`
require.Equal(t, "real query", extractSearchQueryFromBody([]byte(body)))
}
func TestExtractSearchQueryFromBody_ArrayContentNoTextBlock(t *testing.T) {
body := `{"messages":[{"role":"user","content":[{"type":"image","source":{}}]}]}`
require.Equal(t, "", extractSearchQueryFromBody([]byte(body)))
}
// --- buildSearchResultBlocks ---
func TestBuildSearchResultBlocks_WithResults(t *testing.T) {
results := []websearch.SearchResult{
{URL: "https://a.com", Title: "A", Snippet: "snippet a", PageAge: "2 days"},
{URL: "https://b.com", Title: "B", Snippet: "snippet b"},
}
blocks := buildSearchResultBlocks(results)
require.Len(t, blocks, 2)
require.Equal(t, "web_search_result", blocks[0]["type"])
require.Equal(t, "https://a.com", blocks[0]["url"])
require.Equal(t, "snippet a", blocks[0]["page_content"])
require.Equal(t, "2 days", blocks[0]["page_age"])
// Second result has no PageAge
require.Equal(t, "https://b.com", blocks[1]["url"])
_, hasPageAge := blocks[1]["page_age"]
require.False(t, hasPageAge)
}
func TestBuildSearchResultBlocks_Empty(t *testing.T) {
blocks := buildSearchResultBlocks(nil)
require.Empty(t, blocks)
}
func TestBuildSearchResultBlocks_SnippetEmpty(t *testing.T) {
blocks := buildSearchResultBlocks([]websearch.SearchResult{{URL: "https://x.com", Title: "X", Snippet: ""}})
_, hasContent := blocks[0]["page_content"]
require.False(t, hasContent)
}
// --- buildTextSummary ---
func TestBuildTextSummary_WithResults(t *testing.T) {
results := []websearch.SearchResult{
{URL: "https://a.com", Title: "A", Snippet: "desc a"},
}
summary := buildTextSummary("test query", results)
require.Contains(t, summary, "test query")
require.Contains(t, summary, "1. **A**")
require.Contains(t, summary, "https://a.com")
}
func TestBuildTextSummary_NoResults(t *testing.T) {
summary := buildTextSummary("test", nil)
require.Contains(t, summary, "No search results found for: test")
}
// --- shouldEmulateWebSearch ---
// webSearchToolBody is a valid request body with exactly one web_search tool.
var webSearchToolBody = []byte(`{"tools":[{"type":"web_search"}],"messages":[{"role":"user","content":"test"}]}`)
// nonWebSearchToolBody is a request body without web_search tool.
var nonWebSearchToolBody = []byte(`{"tools":[{"type":"text_editor"}],"messages":[{"role":"user","content":"test"}]}`)
// newAnthropicAPIKeyAccount creates a test Account with the given web search emulation mode.
func newAnthropicAPIKeyAccount(mode string) *Account {
return &Account{
ID: 1,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{featureKeyWebSearchEmulation: mode},
}
}
// setGlobalWebSearchConfig stores a config in the global cache used by SettingService.IsWebSearchEmulationEnabled.
func setGlobalWebSearchConfig(cfg *WebSearchEmulationConfig) {
webSearchEmulationCache.Store(&cachedWebSearchEmulationConfig{
config: cfg,
expiresAt: time.Now().Add(10 * time.Minute).UnixNano(),
})
}
// clearGlobalWebSearchConfig resets the global cache to force re-read.
func clearGlobalWebSearchConfig() {
webSearchEmulationCache.Store((*cachedWebSearchEmulationConfig)(nil))
}
// newSettingServiceForWebSearchTest creates a SettingService with a mock repo pre-loaded with config.
func newSettingServiceForWebSearchTest(enabled bool) *SettingService {
repo := newMockSettingRepo()
cfg := &WebSearchEmulationConfig{
Enabled: enabled,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "sk-test"}},
}
data, _ := json.Marshal(cfg)
repo.data[SettingKeyWebSearchEmulationConfig] = string(data)
return NewSettingService(repo, &config.Config{})
}
// newChannelServiceWithCache creates a ChannelService with a pre-built cache containing the channel.
func newChannelServiceWithCache(groupID int64, ch *Channel) *ChannelService {
svc := &ChannelService{}
cache := &channelCache{
channelByGroupID: map[int64]*Channel{groupID: ch},
byID: map[int64]*Channel{ch.ID: ch},
groupPlatform: map[int64]string{},
loadedAt: time.Now(),
}
svc.cache.Store(cache)
return svc
}
func TestShouldEmulateWebSearch_NilManager(t *testing.T) {
SetWebSearchManager(nil)
defer SetWebSearchManager(nil)
settingSvc := newSettingServiceForWebSearchTest(true)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeEnabled)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
func TestShouldEmulateWebSearch_NotOnlyWebSearchTool(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
settingSvc := newSettingServiceForWebSearchTest(true)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeEnabled)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, nonWebSearchToolBody))
}
func TestShouldEmulateWebSearch_GlobalDisabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
// Global config disabled
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: false,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(false)
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeEnabled)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
func TestShouldEmulateWebSearch_AccountDisabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeDisabled)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
func TestShouldEmulateWebSearch_AccountEnabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeEnabled)
require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
func TestShouldEmulateWebSearch_DefaultMode_ChannelEnabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
ch := &Channel{
ID: 10,
Status: StatusActive,
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{PlatformAnthropic: true},
},
}
channelSvc := newChannelServiceWithCache(42, ch)
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeDefault)
groupID := int64(42)
require.True(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}
func TestShouldEmulateWebSearch_DefaultMode_ChannelDisabled(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
ch := &Channel{
ID: 10,
Status: StatusActive,
FeaturesConfig: map[string]any{
featureKeyWebSearchEmulation: map[string]any{PlatformAnthropic: false},
},
}
channelSvc := newChannelServiceWithCache(42, ch)
svc := &GatewayService{settingService: settingSvc, channelService: channelSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeDefault)
groupID := int64(42)
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}
func TestShouldEmulateWebSearch_DefaultMode_NilGroupID(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
svc := &GatewayService{settingService: settingSvc}
account := newAnthropicAPIKeyAccount(WebSearchModeDefault)
// nil groupID + default mode → falls through to channel check → returns false
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, nil, webSearchToolBody))
}
func TestShouldEmulateWebSearch_DefaultMode_NilChannelService(t *testing.T) {
mgr := websearch.NewManager([]websearch.ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
SetWebSearchManager(mgr)
defer SetWebSearchManager(nil)
setGlobalWebSearchConfig(&WebSearchEmulationConfig{
Enabled: true,
Providers: []WebSearchProviderConfig{{Type: "brave", APIKey: "k"}},
})
defer clearGlobalWebSearchConfig()
settingSvc := newSettingServiceForWebSearchTest(true)
svc := &GatewayService{settingService: settingSvc, channelService: nil}
account := newAnthropicAPIKeyAccount(WebSearchModeDefault)
groupID := int64(42)
// nil channelService + default mode → returns false
require.False(t, svc.shouldEmulateWebSearch(context.Background(), account, &groupID, webSearchToolBody))
}

View File

@ -0,0 +1,81 @@
package service
import (
"encoding/json"
"strings"
)
// NotifyEmailEntry represents a notification email with enable/disable and verification state.
// All emails are user-managed; maximum 3 entries per user.
type NotifyEmailEntry struct {
Email string `json:"email"`
Disabled bool `json:"disabled"`
Verified bool `json:"verified"`
}
// parseNotifyEmails parses a JSON string into []NotifyEmailEntry.
// It auto-detects the format:
// - Old format ["email1","email2"] → converted to [{email, disabled:false, verified:true}, ...]
// - New format [{email,disabled,verified}, ...] → parsed directly
//
// Returns nil on empty/invalid input.
func ParseNotifyEmails(raw string) []NotifyEmailEntry {
raw = strings.TrimSpace(raw)
if raw == "" || raw == "[]" {
return nil
}
// Try parsing as new format first (array of objects)
var entries []NotifyEmailEntry
if err := json.Unmarshal([]byte(raw), &entries); err == nil && len(entries) > 0 {
// Verify it's actually the new format by checking the first element
// json.Unmarshal into []NotifyEmailEntry succeeds even for ["string"]
// because it tries to fit "string" into NotifyEmailEntry and gets zero values.
// We need to detect old format explicitly.
if !isOldStringArrayFormat(raw) {
return entries
}
}
// Try parsing as old format (array of strings)
var emails []string
if err := json.Unmarshal([]byte(raw), &emails); err == nil {
result := make([]NotifyEmailEntry, 0, len(emails))
for _, e := range emails {
e = strings.TrimSpace(e)
if e != "" {
result = append(result, NotifyEmailEntry{
Email: e,
Disabled: false,
Verified: false, // Old format emails default to unverified
})
}
}
return result
}
return nil
}
// isOldStringArrayFormat checks if the JSON is a string array like ["email1","email2"].
func isOldStringArrayFormat(raw string) bool {
var arr []json.RawMessage
if err := json.Unmarshal([]byte(raw), &arr); err != nil || len(arr) == 0 {
return false
}
// Check if first element starts with a quote (string) vs { (object)
first := strings.TrimSpace(string(arr[0]))
return len(first) > 0 && first[0] == '"'
}
// marshalNotifyEmails serializes []NotifyEmailEntry to JSON string.
func MarshalNotifyEmails(entries []NotifyEmailEntry) string {
if len(entries) == 0 {
return "[]"
}
data, err := json.Marshal(entries)
if err != nil {
return "[]"
}
return string(data)
}

View File

@ -0,0 +1,156 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// ---------- ParseNotifyEmails ----------
func TestParseNotifyEmails_EmptyString(t *testing.T) {
result := ParseNotifyEmails("")
require.Nil(t, result)
}
func TestParseNotifyEmails_EmptyArray(t *testing.T) {
result := ParseNotifyEmails("[]")
require.Nil(t, result)
}
func TestParseNotifyEmails_Null(t *testing.T) {
// "null" is valid JSON that unmarshals into a nil string slice.
// The old-format branch then returns an empty (non-nil) slice.
result := ParseNotifyEmails("null")
require.Empty(t, result)
}
func TestParseNotifyEmails_WhitespaceOnly(t *testing.T) {
result := ParseNotifyEmails(" ")
require.Nil(t, result)
}
func TestParseNotifyEmails_OldFormat(t *testing.T) {
raw := `["alice@example.com", "bob@example.com"]`
result := ParseNotifyEmails(raw)
require.Len(t, result, 2)
require.Equal(t, "alice@example.com", result[0].Email)
require.False(t, result[0].Verified, "old format emails should default to unverified")
require.False(t, result[0].Disabled)
require.Equal(t, "bob@example.com", result[1].Email)
require.False(t, result[1].Verified)
require.False(t, result[1].Disabled)
}
func TestParseNotifyEmails_OldFormat_SkipsEmptyEntries(t *testing.T) {
raw := `["alice@example.com", "", " ", "bob@example.com"]`
result := ParseNotifyEmails(raw)
require.Len(t, result, 2)
require.Equal(t, "alice@example.com", result[0].Email)
require.Equal(t, "bob@example.com", result[1].Email)
}
func TestParseNotifyEmails_NewFormat(t *testing.T) {
raw := `[{"email":"alice@example.com","verified":true,"disabled":false},{"email":"bob@example.com","verified":false,"disabled":true}]`
result := ParseNotifyEmails(raw)
require.Len(t, result, 2)
require.Equal(t, "alice@example.com", result[0].Email)
require.True(t, result[0].Verified)
require.False(t, result[0].Disabled)
require.Equal(t, "bob@example.com", result[1].Email)
require.False(t, result[1].Verified)
require.True(t, result[1].Disabled)
}
func TestParseNotifyEmails_NewFormat_SingleEntry(t *testing.T) {
raw := `[{"email":"solo@example.com","verified":true,"disabled":false}]`
result := ParseNotifyEmails(raw)
require.Len(t, result, 1)
require.Equal(t, "solo@example.com", result[0].Email)
require.True(t, result[0].Verified)
}
func TestParseNotifyEmails_InvalidJSON(t *testing.T) {
result := ParseNotifyEmails(`{not valid json`)
require.Nil(t, result)
}
func TestParseNotifyEmails_InvalidJSONObject(t *testing.T) {
// A plain JSON object (not array) should return nil.
result := ParseNotifyEmails(`{"email":"a@b.com"}`)
require.Nil(t, result)
}
func TestParseNotifyEmails_WhitespacePadding(t *testing.T) {
raw := ` ["padded@example.com"] `
result := ParseNotifyEmails(raw)
require.Len(t, result, 1)
require.Equal(t, "padded@example.com", result[0].Email)
}
// ---------- MarshalNotifyEmails ----------
func TestMarshalNotifyEmails_EmptySlice(t *testing.T) {
result := MarshalNotifyEmails([]NotifyEmailEntry{})
require.Equal(t, "[]", result)
}
func TestMarshalNotifyEmails_NilSlice(t *testing.T) {
result := MarshalNotifyEmails(nil)
require.Equal(t, "[]", result)
}
func TestMarshalNotifyEmails_SingleEntry(t *testing.T) {
entries := []NotifyEmailEntry{
{Email: "test@example.com", Verified: true, Disabled: false},
}
result := MarshalNotifyEmails(entries)
require.Contains(t, result, `"email":"test@example.com"`)
require.Contains(t, result, `"verified":true`)
require.Contains(t, result, `"disabled":false`)
// Round-trip: parsing the marshalled result should produce the original entries.
parsed := ParseNotifyEmails(result)
require.Len(t, parsed, 1)
require.Equal(t, entries[0], parsed[0])
}
func TestMarshalNotifyEmails_MultipleEntries(t *testing.T) {
entries := []NotifyEmailEntry{
{Email: "a@example.com", Verified: true, Disabled: false},
{Email: "b@example.com", Verified: false, Disabled: true},
}
result := MarshalNotifyEmails(entries)
// Round-trip verification.
parsed := ParseNotifyEmails(result)
require.Len(t, parsed, 2)
require.Equal(t, entries[0], parsed[0])
require.Equal(t, entries[1], parsed[1])
}
func TestMarshalNotifyEmails_RoundTrip_NewFormat(t *testing.T) {
original := []NotifyEmailEntry{
{Email: "x@example.com", Verified: true, Disabled: true},
{Email: "y@example.com", Verified: false, Disabled: false},
}
marshalled := MarshalNotifyEmails(original)
parsed := ParseNotifyEmails(marshalled)
require.Equal(t, original, parsed)
}
// ---------- isOldStringArrayFormat (indirectly via ParseNotifyEmails) ----------
func TestParseNotifyEmails_MixedOldFormatWithWhitespace(t *testing.T) {
// Emails with leading/trailing whitespace in old format should be trimmed.
raw := `[" alice@example.com "]`
result := ParseNotifyEmails(raw)
require.Len(t, result, 1)
require.Equal(t, "alice@example.com", result[0].Email)
}

View File

@ -147,6 +147,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U
nil, nil,
nil, nil,
nil, nil,
nil,
) )
svc.userGroupRateResolver = newUserGroupRateResolver( svc.userGroupRateResolver = newUserGroupRateResolver(
rateRepo, rateRepo,

View File

@ -327,6 +327,7 @@ type OpenAIGatewayService struct {
openaiWSResolver OpenAIWSProtocolResolver openaiWSResolver OpenAIWSProtocolResolver
resolver *ModelPricingResolver resolver *ModelPricingResolver
channelService *ChannelService channelService *ChannelService
balanceNotifyService *BalanceNotifyService
openaiWSPoolOnce sync.Once openaiWSPoolOnce sync.Once
openaiWSStateStoreOnce sync.Once openaiWSStateStoreOnce sync.Once
@ -364,6 +365,7 @@ func NewOpenAIGatewayService(
openAITokenProvider *OpenAITokenProvider, openAITokenProvider *OpenAITokenProvider,
resolver *ModelPricingResolver, resolver *ModelPricingResolver,
channelService *ChannelService, channelService *ChannelService,
balanceNotifyService *BalanceNotifyService,
) *OpenAIGatewayService { ) *OpenAIGatewayService {
svc := &OpenAIGatewayService{ svc := &OpenAIGatewayService{
accountRepo: accountRepo, accountRepo: accountRepo,
@ -393,6 +395,7 @@ func NewOpenAIGatewayService(
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
resolver: resolver, resolver: resolver,
channelService: channelService, channelService: channelService,
balanceNotifyService: balanceNotifyService,
responseHeaderFilter: compileResponseHeaderFilter(cfg), responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
} }
@ -477,11 +480,12 @@ func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle
func (s *OpenAIGatewayService) billingDeps() *billingDeps { func (s *OpenAIGatewayService) billingDeps() *billingDeps {
return &billingDeps{ return &billingDeps{
accountRepo: s.accountRepo, accountRepo: s.accountRepo,
userRepo: s.userRepo, userRepo: s.userRepo,
userSubRepo: s.userSubRepo, userSubRepo: s.userSubRepo,
billingCacheService: s.billingCacheService, billingCacheService: s.billingCacheService,
deferredService: s.deferredService, deferredService: s.deferredService,
balanceNotifyService: s.balanceNotifyService,
} }
} }
@ -4569,6 +4573,14 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.SubscriptionID = &subscription.ID usageLog.SubscriptionID = &subscription.ID
} }
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
if apiKey.GroupID != nil {
applyAccountStatsCost(ctx, usageLog, s.channelService, s.billingService,
account.ID, *apiKey.GroupID, result.UpstreamModel, result.Model,
tokens, cost.TotalCost,
)
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway") writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway")
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())

View File

@ -413,7 +413,12 @@ func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughModeR
select { select {
case serverErr := <-serverErrCh: case serverErr := <-serverErrCh:
require.NoError(t, serverErr) // After normal client close, the server goroutine may receive the close frame
// as an error — this is expected behavior, not a test failure.
if serverErr != nil {
require.Contains(t, serverErr.Error(), "StatusNormalClosure",
"server error should only be a normal close frame, got: %v", serverErr)
}
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatal("等待 passthrough websocket 结束超时") t.Fatal("等待 passthrough websocket 结束超时")
} }

Some files were not shown because too many files have changed in this diff Show More