diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index 6f76ef4f..d7e15377 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -17,6 +17,7 @@ jobs: go-version-file: backend/go.mod check-latest: false cache: true + cache-dependency-path: backend/go.sum - name: Verify Go version run: | go version | grep -q 'go1.26.2' @@ -36,6 +37,7 @@ jobs: go-version-file: backend/go.mod check-latest: false cache: true + cache-dependency-path: backend/go.sum - name: Verify Go version run: | go version | grep -q 'go1.26.2' diff --git a/README.md b/README.md index 3a56d089..74ab9af2 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,16 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for sub2api users: register via this link to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off! + +AIGoCode +Thanks to AIGoCode for sponsoring this project! AIGoCode is an all-in-one platform that integrates Claude Code, Codex, and the latest Gemini models, providing you with stable, efficient, and highly cost-effective AI coding services. The platform offers flexible subscription plans, zero risk of account suspension, direct access with no VPN required, and lightning-fast responses. AIGoCode has prepared a special benefit for sub2api users: if you register via this link, you'll receive an extra 10% bonus credit on your first top-up! + + + +bmoplus +Huge thanks to BmoPlus for sponsoring this project! BmoPlus is a highly reliable AI account provider built strictly for heavy AI users and developers. They offer rock-solid, ready-to-use accounts and official top-up services for ChatGPT Plus / ChatGPT Pro (Full Warranty) / Claude Pro / Super Grok / Gemini Pro. By registering and ordering through BmoPlus - Premium AI Accounts & Top-ups, users can unlock the mind-blowing rate of 10% of the official GPT subscription price (90% OFF) + + ## Ecosystem diff --git a/README_CN.md b/README_CN.md index c0e6492e..c701372c 100644 --- a/README_CN.md +++ b/README_CN.md @@ -85,6 +85,16 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定性中转服务,企业级并发、快速开票、7×24 小时专属技术支持。Claude Code / Codex / Gemini 官方通道低至原价 38% / 2% / 9%,充值更享额外折扣!AICodeMirror 为 sub2api 用户提供专属福利:通过此链接注册,首次充值立享 8 折优惠,企业客户最高可享 75 折! + +AIGoCode +感谢 AIGoCode 赞助了本项目!AIGoCode 是一站式集成 Claude Code、Codex 以及最新 Gemini 模型的综合平台,为您提供稳定、高效、高性价比的 AI 编程服务。平台提供灵活的订阅方案,零封号风险,免 VPN 直连,响应极速。AIGoCode 为 sub2api 用户准备了专属福利:通过此链接注册,首次充值可额外获得 10% 赠送额度! + + + +bmoplus +感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过BmoPlus AI成品号专卖/代充注册下单的用户,可享GPT 官网订阅一折 的震撼价格! + + ## 生态项目 diff --git a/README_JA.md b/README_JA.md index 4605b877..0d4db616 100644 --- a/README_JA.md +++ b/README_JA.md @@ -85,6 +85,16 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを AICodeMirror のご支援に感謝します!AICodeMirror は Claude Code / Codex / Gemini CLI の公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時実行、迅速な請求書発行、24時間年中無休の専属テクニカルサポートを備えています。Claude Code / Codex / Gemini の公式チャネルを定価の 38% / 2% / 9% で利用可能、チャージ時にはさらに追加割引!AICodeMirror は sub2api ユーザー向けに特別特典を提供中:こちらのリンクから登録すると、初回チャージが 20% オフ、法人のお客様は最大 25% オフ! + +AIGoCode +AIGoCode のご支援に感謝します!AIGoCode は Claude Code、Codex、最新の Gemini モデルを統合したオールインワンプラットフォームで、安定的かつ効率的でコストパフォーマンスに優れた AI コーディングサービスを提供します。柔軟なサブスクリプションプラン、アカウント停止リスクゼロ、VPN 不要の直接アクセス、超高速レスポンスが特長です。AIGoCode は sub2api ユーザー向けに特別特典を用意しています:こちらのリンクから登録すると、初回チャージ時に 10% のボーナスクレジットを追加プレゼント! + + + +bmoplus +本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらのBmoPlus AIアカウント専門店/代行チャージ経由でご登録・ご注文いただいたユーザー様は、GPTを 公式サイト価格の約1割(90% OFF) という驚異的な価格でご利用いただけます! + + ## エコシステム diff --git a/assets/partners/logos/aigocode.png b/assets/partners/logos/aigocode.png new file mode 100644 index 00000000..6dd5965a Binary files /dev/null and b/assets/partners/logos/aigocode.png differ diff --git a/assets/partners/logos/bmoplus.jpg b/assets/partners/logos/bmoplus.jpg new file mode 100644 index 00000000..1a9b4d8b Binary files /dev/null and b/assets/partners/logos/bmoplus.jpg differ diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 715965f3..c21e67e6 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.111 +0.1.113 diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 8cb15572..64709b5b 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -36,15 +36,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { // Business layer ProviderSets repository.ProviderSet, service.ProviderSet, + payment.ProviderSet, middleware.ProviderSet, handler.ProviderSet, // Server layer ProviderSet server.ProviderSet, - // Payment providers - payment.ProviderSet, - // Privacy client factory for OpenAI training opt-out providePrivacyClientFactory, diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index c4b441f1..1d39fa1e 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -50,7 +50,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { refreshTokenCache := repository.NewRefreshTokenCache(redisClient) settingRepository := repository.NewSettingRepository(client) groupRepository := repository.NewGroupRepository(client, db) - settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig) + proxyRepository := repository.NewProxyRepository(client, db) + settingService := service.ProvideSettingService(settingRepository, groupRepository, proxyRepository, configConfig) emailCache := repository.NewEmailCache(redisClient) emailService := service.NewEmailService(settingRepository, emailCache) turnstileVerifier := repository.NewTurnstileVerifier() @@ -68,7 +69,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) - userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) + userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) secretEncryptor, err := repository.NewAESEncryptor(configConfig) @@ -78,7 +79,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { totpCache := repository.NewTotpCache(redisClient) totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService) - userHandler := handler.NewUserHandler(userService) + userHandler := handler.NewUserHandler(userService, emailService, emailCache) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) @@ -100,7 +101,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) schedulerCache := repository.ProvideSchedulerCache(redisClient, configConfig) accountRepository := repository.NewAccountRepository(client, db, schedulerCache) - proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) privacyClientFactory := providePrivacyClientFactory() @@ -136,7 +136,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient) tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) - oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache) + oAuthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache) geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI) gatewayCache := repository.NewGatewayCache(redisClient) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) @@ -176,21 +176,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { channelRepository := repository.NewChannelRepository(db) channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator) modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver) + balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) - encryptionKey, err := payment.ProvideEncryptionKey(configConfig) - if err != nil { - return nil, err - } - paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) - registry := payment.ProvideRegistry() - defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) - paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) opsHandler := admin.NewOpsHandler(opsService) updateCache := repository.NewUpdateCache(redisClient) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) @@ -218,6 +210,16 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) channelHandler := admin.NewChannelHandler(channelService, billingService) + registry := payment.ProvideRegistry() + encryptionKey, err := payment.ProvideEncryptionKey(configConfig) + if err != nil { + return nil, err + } + defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) + paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) + paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository) + settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) + paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) @@ -235,8 +237,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) - languageServerService := service.ProvideLanguageServerService(httpUpstream) - engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient, languageServerService) + engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient) httpServer := server.ProvideHTTPServer(configConfig, engine) opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig) opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig) @@ -247,7 +248,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) - paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService) application := &Application{ Server: httpServer, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index e947b2e8..68bdbf55 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -616,6 +616,7 @@ var ( {Name: "sort_order", Type: field.TypeInt, Default: 0}, {Name: "limits", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, {Name: "refund_enabled", Type: field.TypeBool, Default: false}, + {Name: "allow_user_refund", Type: field.TypeBool, Default: false}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, } @@ -1078,6 +1079,11 @@ var ( {Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "totp_enabled", Type: field.TypeBool, Default: false}, {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true}, + {Name: "balance_notify_enabled", Type: field.TypeBool, Default: true}, + {Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"}, + {Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "balance_notify_extra_emails", Type: field.TypeString, Default: "[]", SchemaType: map[string]string{"postgres": "text"}}, + {Name: "total_recharged", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, } // UsersTable holds the schema information for the "users" table. UsersTable = &schema.Table{ diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 6b2fa838..524ccb92 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -15642,25 +15642,26 @@ func (m *PaymentOrderMutation) ResetEdge(name string) error { // PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph. type PaymentProviderInstanceMutation struct { config - op Op - typ string - id *int64 - provider_key *string - name *string - _config *string - supported_types *string - enabled *bool - payment_mode *string - sort_order *int - addsort_order *int - limits *string - refund_enabled *bool - created_at *time.Time - updated_at *time.Time - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*PaymentProviderInstance, error) - predicates []predicate.PaymentProviderInstance + op Op + typ string + id *int64 + provider_key *string + name *string + _config *string + supported_types *string + enabled *bool + payment_mode *string + sort_order *int + addsort_order *int + limits *string + refund_enabled *bool + allow_user_refund *bool + created_at *time.Time + updated_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*PaymentProviderInstance, error) + predicates []predicate.PaymentProviderInstance } var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil) @@ -16105,6 +16106,42 @@ func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() { m.refund_enabled = nil } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) { + m.allow_user_refund = &b +} + +// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation. +func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) { + v := m.allow_user_refund + if v == nil { + return + } + return *v, true +} + +// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity. +// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAllowUserRefund requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err) + } + return oldValue.AllowUserRefund, nil +} + +// ResetAllowUserRefund resets all changes to the "allow_user_refund" field. +func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() { + m.allow_user_refund = nil +} + // SetCreatedAt sets the "created_at" field. func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -16211,7 +16248,7 @@ func (m *PaymentProviderInstanceMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *PaymentProviderInstanceMutation) Fields() []string { - fields := make([]string, 0, 11) + fields := make([]string, 0, 12) if m.provider_key != nil { fields = append(fields, paymentproviderinstance.FieldProviderKey) } @@ -16239,6 +16276,9 @@ func (m *PaymentProviderInstanceMutation) Fields() []string { if m.refund_enabled != nil { fields = append(fields, paymentproviderinstance.FieldRefundEnabled) } + if m.allow_user_refund != nil { + fields = append(fields, paymentproviderinstance.FieldAllowUserRefund) + } if m.created_at != nil { fields = append(fields, paymentproviderinstance.FieldCreatedAt) } @@ -16271,6 +16311,8 @@ func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) { return m.Limits() case paymentproviderinstance.FieldRefundEnabled: return m.RefundEnabled() + case paymentproviderinstance.FieldAllowUserRefund: + return m.AllowUserRefund() case paymentproviderinstance.FieldCreatedAt: return m.CreatedAt() case paymentproviderinstance.FieldUpdatedAt: @@ -16302,6 +16344,8 @@ func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name str return m.OldLimits(ctx) case paymentproviderinstance.FieldRefundEnabled: return m.OldRefundEnabled(ctx) + case paymentproviderinstance.FieldAllowUserRefund: + return m.OldAllowUserRefund(ctx) case paymentproviderinstance.FieldCreatedAt: return m.OldCreatedAt(ctx) case paymentproviderinstance.FieldUpdatedAt: @@ -16378,6 +16422,13 @@ func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value) } m.SetRefundEnabled(v) return nil + case paymentproviderinstance.FieldAllowUserRefund: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAllowUserRefund(v) + return nil case paymentproviderinstance.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -16483,6 +16534,9 @@ func (m *PaymentProviderInstanceMutation) ResetField(name string) error { case paymentproviderinstance.FieldRefundEnabled: m.ResetRefundEnabled() return nil + case paymentproviderinstance.FieldAllowUserRefund: + m.ResetAllowUserRefund() + return nil case paymentproviderinstance.FieldCreatedAt: m.ResetCreatedAt() return nil @@ -28210,6 +28264,13 @@ type UserMutation struct { totp_secret_encrypted *string totp_enabled *bool totp_enabled_at *time.Time + balance_notify_enabled *bool + balance_notify_threshold_type *string + balance_notify_threshold *float64 + addbalance_notify_threshold *float64 + balance_notify_extra_emails *string + total_recharged *float64 + addtotal_recharged *float64 clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -28927,6 +28988,240 @@ func (m *UserMutation) ResetTotpEnabledAt() { delete(m.clearedFields, user.FieldTotpEnabledAt) } +// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. +func (m *UserMutation) SetBalanceNotifyEnabled(b bool) { + m.balance_notify_enabled = &b +} + +// BalanceNotifyEnabled returns the value of the "balance_notify_enabled" field in the mutation. +func (m *UserMutation) BalanceNotifyEnabled() (r bool, exists bool) { + v := m.balance_notify_enabled + if v == nil { + return + } + return *v, true +} + +// OldBalanceNotifyEnabled returns the old "balance_notify_enabled" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldBalanceNotifyEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBalanceNotifyEnabled is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBalanceNotifyEnabled requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBalanceNotifyEnabled: %w", err) + } + return oldValue.BalanceNotifyEnabled, nil +} + +// ResetBalanceNotifyEnabled resets all changes to the "balance_notify_enabled" field. +func (m *UserMutation) ResetBalanceNotifyEnabled() { + m.balance_notify_enabled = nil +} + +// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field. +func (m *UserMutation) SetBalanceNotifyThresholdType(s string) { + m.balance_notify_threshold_type = &s +} + +// BalanceNotifyThresholdType returns the value of the "balance_notify_threshold_type" field in the mutation. +func (m *UserMutation) BalanceNotifyThresholdType() (r string, exists bool) { + v := m.balance_notify_threshold_type + if v == nil { + return + } + return *v, true +} + +// OldBalanceNotifyThresholdType returns the old "balance_notify_threshold_type" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldBalanceNotifyThresholdType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBalanceNotifyThresholdType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBalanceNotifyThresholdType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBalanceNotifyThresholdType: %w", err) + } + return oldValue.BalanceNotifyThresholdType, nil +} + +// ResetBalanceNotifyThresholdType resets all changes to the "balance_notify_threshold_type" field. +func (m *UserMutation) ResetBalanceNotifyThresholdType() { + m.balance_notify_threshold_type = nil +} + +// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field. +func (m *UserMutation) SetBalanceNotifyThreshold(f float64) { + m.balance_notify_threshold = &f + m.addbalance_notify_threshold = nil +} + +// BalanceNotifyThreshold returns the value of the "balance_notify_threshold" field in the mutation. +func (m *UserMutation) BalanceNotifyThreshold() (r float64, exists bool) { + v := m.balance_notify_threshold + if v == nil { + return + } + return *v, true +} + +// OldBalanceNotifyThreshold returns the old "balance_notify_threshold" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldBalanceNotifyThreshold(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBalanceNotifyThreshold is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBalanceNotifyThreshold requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBalanceNotifyThreshold: %w", err) + } + return oldValue.BalanceNotifyThreshold, nil +} + +// AddBalanceNotifyThreshold adds f to the "balance_notify_threshold" field. +func (m *UserMutation) AddBalanceNotifyThreshold(f float64) { + if m.addbalance_notify_threshold != nil { + *m.addbalance_notify_threshold += f + } else { + m.addbalance_notify_threshold = &f + } +} + +// AddedBalanceNotifyThreshold returns the value that was added to the "balance_notify_threshold" field in this mutation. +func (m *UserMutation) AddedBalanceNotifyThreshold() (r float64, exists bool) { + v := m.addbalance_notify_threshold + if v == nil { + return + } + return *v, true +} + +// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field. +func (m *UserMutation) ClearBalanceNotifyThreshold() { + m.balance_notify_threshold = nil + m.addbalance_notify_threshold = nil + m.clearedFields[user.FieldBalanceNotifyThreshold] = struct{}{} +} + +// BalanceNotifyThresholdCleared returns if the "balance_notify_threshold" field was cleared in this mutation. +func (m *UserMutation) BalanceNotifyThresholdCleared() bool { + _, ok := m.clearedFields[user.FieldBalanceNotifyThreshold] + return ok +} + +// ResetBalanceNotifyThreshold resets all changes to the "balance_notify_threshold" field. +func (m *UserMutation) ResetBalanceNotifyThreshold() { + m.balance_notify_threshold = nil + m.addbalance_notify_threshold = nil + delete(m.clearedFields, user.FieldBalanceNotifyThreshold) +} + +// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field. +func (m *UserMutation) SetBalanceNotifyExtraEmails(s string) { + m.balance_notify_extra_emails = &s +} + +// BalanceNotifyExtraEmails returns the value of the "balance_notify_extra_emails" field in the mutation. +func (m *UserMutation) BalanceNotifyExtraEmails() (r string, exists bool) { + v := m.balance_notify_extra_emails + if v == nil { + return + } + return *v, true +} + +// OldBalanceNotifyExtraEmails returns the old "balance_notify_extra_emails" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldBalanceNotifyExtraEmails(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBalanceNotifyExtraEmails is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBalanceNotifyExtraEmails requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBalanceNotifyExtraEmails: %w", err) + } + return oldValue.BalanceNotifyExtraEmails, nil +} + +// ResetBalanceNotifyExtraEmails resets all changes to the "balance_notify_extra_emails" field. +func (m *UserMutation) ResetBalanceNotifyExtraEmails() { + m.balance_notify_extra_emails = nil +} + +// SetTotalRecharged sets the "total_recharged" field. +func (m *UserMutation) SetTotalRecharged(f float64) { + m.total_recharged = &f + m.addtotal_recharged = nil +} + +// TotalRecharged returns the value of the "total_recharged" field in the mutation. +func (m *UserMutation) TotalRecharged() (r float64, exists bool) { + v := m.total_recharged + if v == nil { + return + } + return *v, true +} + +// OldTotalRecharged returns the old "total_recharged" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldTotalRecharged(ctx context.Context) (v float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTotalRecharged is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTotalRecharged requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTotalRecharged: %w", err) + } + return oldValue.TotalRecharged, nil +} + +// AddTotalRecharged adds f to the "total_recharged" field. +func (m *UserMutation) AddTotalRecharged(f float64) { + if m.addtotal_recharged != nil { + *m.addtotal_recharged += f + } else { + m.addtotal_recharged = &f + } +} + +// AddedTotalRecharged returns the value that was added to the "total_recharged" field in this mutation. +func (m *UserMutation) AddedTotalRecharged() (r float64, exists bool) { + v := m.addtotal_recharged + if v == nil { + return + } + return *v, true +} + +// ResetTotalRecharged resets all changes to the "total_recharged" field. +func (m *UserMutation) ResetTotalRecharged() { + m.total_recharged = nil + m.addtotal_recharged = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *UserMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -29501,7 +29796,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 14) + fields := make([]string, 0, 19) if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -29544,6 +29839,21 @@ func (m *UserMutation) Fields() []string { if m.totp_enabled_at != nil { fields = append(fields, user.FieldTotpEnabledAt) } + if m.balance_notify_enabled != nil { + fields = append(fields, user.FieldBalanceNotifyEnabled) + } + if m.balance_notify_threshold_type != nil { + fields = append(fields, user.FieldBalanceNotifyThresholdType) + } + if m.balance_notify_threshold != nil { + fields = append(fields, user.FieldBalanceNotifyThreshold) + } + if m.balance_notify_extra_emails != nil { + fields = append(fields, user.FieldBalanceNotifyExtraEmails) + } + if m.total_recharged != nil { + fields = append(fields, user.FieldTotalRecharged) + } return fields } @@ -29580,6 +29890,16 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.TotpEnabled() case user.FieldTotpEnabledAt: return m.TotpEnabledAt() + case user.FieldBalanceNotifyEnabled: + return m.BalanceNotifyEnabled() + case user.FieldBalanceNotifyThresholdType: + return m.BalanceNotifyThresholdType() + case user.FieldBalanceNotifyThreshold: + return m.BalanceNotifyThreshold() + case user.FieldBalanceNotifyExtraEmails: + return m.BalanceNotifyExtraEmails() + case user.FieldTotalRecharged: + return m.TotalRecharged() } return nil, false } @@ -29617,6 +29937,16 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldTotpEnabled(ctx) case user.FieldTotpEnabledAt: return m.OldTotpEnabledAt(ctx) + case user.FieldBalanceNotifyEnabled: + return m.OldBalanceNotifyEnabled(ctx) + case user.FieldBalanceNotifyThresholdType: + return m.OldBalanceNotifyThresholdType(ctx) + case user.FieldBalanceNotifyThreshold: + return m.OldBalanceNotifyThreshold(ctx) + case user.FieldBalanceNotifyExtraEmails: + return m.OldBalanceNotifyExtraEmails(ctx) + case user.FieldTotalRecharged: + return m.OldTotalRecharged(ctx) } return nil, fmt.Errorf("unknown User field %s", name) } @@ -29724,6 +30054,41 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetTotpEnabledAt(v) return nil + case user.FieldBalanceNotifyEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBalanceNotifyEnabled(v) + return nil + case user.FieldBalanceNotifyThresholdType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBalanceNotifyThresholdType(v) + return nil + case user.FieldBalanceNotifyThreshold: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBalanceNotifyThreshold(v) + return nil + case user.FieldBalanceNotifyExtraEmails: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBalanceNotifyExtraEmails(v) + return nil + case user.FieldTotalRecharged: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotalRecharged(v) + return nil } return fmt.Errorf("unknown User field %s", name) } @@ -29738,6 +30103,12 @@ func (m *UserMutation) AddedFields() []string { if m.addconcurrency != nil { fields = append(fields, user.FieldConcurrency) } + if m.addbalance_notify_threshold != nil { + fields = append(fields, user.FieldBalanceNotifyThreshold) + } + if m.addtotal_recharged != nil { + fields = append(fields, user.FieldTotalRecharged) + } return fields } @@ -29750,6 +30121,10 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) { return m.AddedBalance() case user.FieldConcurrency: return m.AddedConcurrency() + case user.FieldBalanceNotifyThreshold: + return m.AddedBalanceNotifyThreshold() + case user.FieldTotalRecharged: + return m.AddedTotalRecharged() } return nil, false } @@ -29773,6 +30148,20 @@ func (m *UserMutation) AddField(name string, value ent.Value) error { } m.AddConcurrency(v) return nil + case user.FieldBalanceNotifyThreshold: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddBalanceNotifyThreshold(v) + return nil + case user.FieldTotalRecharged: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddTotalRecharged(v) + return nil } return fmt.Errorf("unknown User numeric field %s", name) } @@ -29790,6 +30179,9 @@ func (m *UserMutation) ClearedFields() []string { if m.FieldCleared(user.FieldTotpEnabledAt) { fields = append(fields, user.FieldTotpEnabledAt) } + if m.FieldCleared(user.FieldBalanceNotifyThreshold) { + fields = append(fields, user.FieldBalanceNotifyThreshold) + } return fields } @@ -29813,6 +30205,9 @@ func (m *UserMutation) ClearField(name string) error { case user.FieldTotpEnabledAt: m.ClearTotpEnabledAt() return nil + case user.FieldBalanceNotifyThreshold: + m.ClearBalanceNotifyThreshold() + return nil } return fmt.Errorf("unknown User nullable field %s", name) } @@ -29863,6 +30258,21 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldTotpEnabledAt: m.ResetTotpEnabledAt() return nil + case user.FieldBalanceNotifyEnabled: + m.ResetBalanceNotifyEnabled() + return nil + case user.FieldBalanceNotifyThresholdType: + m.ResetBalanceNotifyThresholdType() + return nil + case user.FieldBalanceNotifyThreshold: + m.ResetBalanceNotifyThreshold() + return nil + case user.FieldBalanceNotifyExtraEmails: + m.ResetBalanceNotifyExtraEmails() + return nil + case user.FieldTotalRecharged: + m.ResetTotalRecharged() + return nil } return fmt.Errorf("unknown User field %s", name) } diff --git a/backend/ent/paymentproviderinstance.go b/backend/ent/paymentproviderinstance.go index 087cb13a..4279b86e 100644 --- a/backend/ent/paymentproviderinstance.go +++ b/backend/ent/paymentproviderinstance.go @@ -35,6 +35,8 @@ type PaymentProviderInstance struct { Limits string `json:"limits,omitempty"` // RefundEnabled holds the value of the "refund_enabled" field. RefundEnabled bool `json:"refund_enabled,omitempty"` + // AllowUserRefund holds the value of the "allow_user_refund" field. + AllowUserRefund bool `json:"allow_user_refund,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // UpdatedAt holds the value of the "updated_at" field. @@ -47,7 +49,7 @@ func (*PaymentProviderInstance) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled: + case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled, paymentproviderinstance.FieldAllowUserRefund: values[i] = new(sql.NullBool) case paymentproviderinstance.FieldID, paymentproviderinstance.FieldSortOrder: values[i] = new(sql.NullInt64) @@ -130,6 +132,12 @@ func (_m *PaymentProviderInstance) assignValues(columns []string, values []any) } else if value.Valid { _m.RefundEnabled = value.Bool } + case paymentproviderinstance.FieldAllowUserRefund: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field allow_user_refund", values[i]) + } else if value.Valid { + _m.AllowUserRefund = value.Bool + } case paymentproviderinstance.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -205,6 +213,9 @@ func (_m *PaymentProviderInstance) String() string { builder.WriteString("refund_enabled=") builder.WriteString(fmt.Sprintf("%v", _m.RefundEnabled)) builder.WriteString(", ") + builder.WriteString("allow_user_refund=") + builder.WriteString(fmt.Sprintf("%v", _m.AllowUserRefund)) + builder.WriteString(", ") builder.WriteString("created_at=") builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") diff --git a/backend/ent/paymentproviderinstance/paymentproviderinstance.go b/backend/ent/paymentproviderinstance/paymentproviderinstance.go index c430fef6..eb1b0c52 100644 --- a/backend/ent/paymentproviderinstance/paymentproviderinstance.go +++ b/backend/ent/paymentproviderinstance/paymentproviderinstance.go @@ -31,6 +31,8 @@ const ( FieldLimits = "limits" // FieldRefundEnabled holds the string denoting the refund_enabled field in the database. FieldRefundEnabled = "refund_enabled" + // FieldAllowUserRefund holds the string denoting the allow_user_refund field in the database. + FieldAllowUserRefund = "allow_user_refund" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // FieldUpdatedAt holds the string denoting the updated_at field in the database. @@ -51,6 +53,7 @@ var Columns = []string{ FieldSortOrder, FieldLimits, FieldRefundEnabled, + FieldAllowUserRefund, FieldCreatedAt, FieldUpdatedAt, } @@ -88,6 +91,8 @@ var ( DefaultLimits string // DefaultRefundEnabled holds the default value on creation for the "refund_enabled" field. DefaultRefundEnabled bool + // DefaultAllowUserRefund holds the default value on creation for the "allow_user_refund" field. + DefaultAllowUserRefund bool // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. @@ -149,6 +154,11 @@ func ByRefundEnabled(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldRefundEnabled, opts...).ToFunc() } +// ByAllowUserRefund orders the results by the allow_user_refund field. +func ByAllowUserRefund(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAllowUserRefund, opts...).ToFunc() +} + // ByCreatedAt orders the results by the created_at field. func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() diff --git a/backend/ent/paymentproviderinstance/where.go b/backend/ent/paymentproviderinstance/where.go index 7b99517f..40e5a1f6 100644 --- a/backend/ent/paymentproviderinstance/where.go +++ b/backend/ent/paymentproviderinstance/where.go @@ -99,6 +99,11 @@ func RefundEnabled(v bool) predicate.PaymentProviderInstance { return predicate.PaymentProviderInstance(sql.FieldEQ(FieldRefundEnabled, v)) } +// AllowUserRefund applies equality check predicate on the "allow_user_refund" field. It's identical to AllowUserRefundEQ. +func AllowUserRefund(v bool) predicate.PaymentProviderInstance { + return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v)) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.PaymentProviderInstance { return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v)) @@ -559,6 +564,16 @@ func RefundEnabledNEQ(v bool) predicate.PaymentProviderInstance { return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldRefundEnabled, v)) } +// AllowUserRefundEQ applies the EQ predicate on the "allow_user_refund" field. +func AllowUserRefundEQ(v bool) predicate.PaymentProviderInstance { + return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v)) +} + +// AllowUserRefundNEQ applies the NEQ predicate on the "allow_user_refund" field. +func AllowUserRefundNEQ(v bool) predicate.PaymentProviderInstance { + return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldAllowUserRefund, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.PaymentProviderInstance { return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v)) diff --git a/backend/ent/paymentproviderinstance_create.go b/backend/ent/paymentproviderinstance_create.go index 20b16ddd..d1b14617 100644 --- a/backend/ent/paymentproviderinstance_create.go +++ b/backend/ent/paymentproviderinstance_create.go @@ -132,6 +132,20 @@ func (_c *PaymentProviderInstanceCreate) SetNillableRefundEnabled(v *bool) *Paym return _c } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (_c *PaymentProviderInstanceCreate) SetAllowUserRefund(v bool) *PaymentProviderInstanceCreate { + _c.mutation.SetAllowUserRefund(v) + return _c +} + +// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil. +func (_c *PaymentProviderInstanceCreate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceCreate { + if v != nil { + _c.SetAllowUserRefund(*v) + } + return _c +} + // SetCreatedAt sets the "created_at" field. func (_c *PaymentProviderInstanceCreate) SetCreatedAt(v time.Time) *PaymentProviderInstanceCreate { _c.mutation.SetCreatedAt(v) @@ -223,6 +237,10 @@ func (_c *PaymentProviderInstanceCreate) defaults() { v := paymentproviderinstance.DefaultRefundEnabled _c.mutation.SetRefundEnabled(v) } + if _, ok := _c.mutation.AllowUserRefund(); !ok { + v := paymentproviderinstance.DefaultAllowUserRefund + _c.mutation.SetAllowUserRefund(v) + } if _, ok := _c.mutation.CreatedAt(); !ok { v := paymentproviderinstance.DefaultCreatedAt() _c.mutation.SetCreatedAt(v) @@ -282,6 +300,9 @@ func (_c *PaymentProviderInstanceCreate) check() error { if _, ok := _c.mutation.RefundEnabled(); !ok { return &ValidationError{Name: "refund_enabled", err: errors.New(`ent: missing required field "PaymentProviderInstance.refund_enabled"`)} } + if _, ok := _c.mutation.AllowUserRefund(); !ok { + return &ValidationError{Name: "allow_user_refund", err: errors.New(`ent: missing required field "PaymentProviderInstance.allow_user_refund"`)} + } if _, ok := _c.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PaymentProviderInstance.created_at"`)} } @@ -351,6 +372,10 @@ func (_c *PaymentProviderInstanceCreate) createSpec() (*PaymentProviderInstance, _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) _node.RefundEnabled = value } + if value, ok := _c.mutation.AllowUserRefund(); ok { + _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value) + _node.AllowUserRefund = value + } if value, ok := _c.mutation.CreatedAt(); ok { _spec.SetField(paymentproviderinstance.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value @@ -525,6 +550,18 @@ func (u *PaymentProviderInstanceUpsert) UpdateRefundEnabled() *PaymentProviderIn return u } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (u *PaymentProviderInstanceUpsert) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsert { + u.Set(paymentproviderinstance.FieldAllowUserRefund, v) + return u +} + +// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create. +func (u *PaymentProviderInstanceUpsert) UpdateAllowUserRefund() *PaymentProviderInstanceUpsert { + u.SetExcluded(paymentproviderinstance.FieldAllowUserRefund) + return u +} + // SetUpdatedAt sets the "updated_at" field. func (u *PaymentProviderInstanceUpsert) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsert { u.Set(paymentproviderinstance.FieldUpdatedAt, v) @@ -715,6 +752,20 @@ func (u *PaymentProviderInstanceUpsertOne) UpdateRefundEnabled() *PaymentProvide }) } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (u *PaymentProviderInstanceUpsertOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertOne { + return u.Update(func(s *PaymentProviderInstanceUpsert) { + s.SetAllowUserRefund(v) + }) +} + +// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create. +func (u *PaymentProviderInstanceUpsertOne) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertOne { + return u.Update(func(s *PaymentProviderInstanceUpsert) { + s.UpdateAllowUserRefund() + }) +} + // SetUpdatedAt sets the "updated_at" field. func (u *PaymentProviderInstanceUpsertOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertOne { return u.Update(func(s *PaymentProviderInstanceUpsert) { @@ -1073,6 +1124,20 @@ func (u *PaymentProviderInstanceUpsertBulk) UpdateRefundEnabled() *PaymentProvid }) } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (u *PaymentProviderInstanceUpsertBulk) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertBulk { + return u.Update(func(s *PaymentProviderInstanceUpsert) { + s.SetAllowUserRefund(v) + }) +} + +// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create. +func (u *PaymentProviderInstanceUpsertBulk) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertBulk { + return u.Update(func(s *PaymentProviderInstanceUpsert) { + s.UpdateAllowUserRefund() + }) +} + // SetUpdatedAt sets the "updated_at" field. func (u *PaymentProviderInstanceUpsertBulk) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertBulk { return u.Update(func(s *PaymentProviderInstanceUpsert) { diff --git a/backend/ent/paymentproviderinstance_update.go b/backend/ent/paymentproviderinstance_update.go index 06dba527..6bb3a82d 100644 --- a/backend/ent/paymentproviderinstance_update.go +++ b/backend/ent/paymentproviderinstance_update.go @@ -161,6 +161,20 @@ func (_u *PaymentProviderInstanceUpdate) SetNillableRefundEnabled(v *bool) *Paym return _u } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (_u *PaymentProviderInstanceUpdate) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdate { + _u.mutation.SetAllowUserRefund(v) + return _u +} + +// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil. +func (_u *PaymentProviderInstanceUpdate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdate { + if v != nil { + _u.SetAllowUserRefund(*v) + } + return _u +} + // SetUpdatedAt sets the "updated_at" field. func (_u *PaymentProviderInstanceUpdate) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdate { _u.mutation.SetUpdatedAt(v) @@ -275,6 +289,9 @@ func (_u *PaymentProviderInstanceUpdate) sqlSave(ctx context.Context) (_node int if value, ok := _u.mutation.RefundEnabled(); ok { _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) } + if value, ok := _u.mutation.AllowUserRefund(); ok { + _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value) + } if value, ok := _u.mutation.UpdatedAt(); ok { _spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value) } @@ -431,6 +448,20 @@ func (_u *PaymentProviderInstanceUpdateOne) SetNillableRefundEnabled(v *bool) *P return _u } +// SetAllowUserRefund sets the "allow_user_refund" field. +func (_u *PaymentProviderInstanceUpdateOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdateOne { + _u.mutation.SetAllowUserRefund(v) + return _u +} + +// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil. +func (_u *PaymentProviderInstanceUpdateOne) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdateOne { + if v != nil { + _u.SetAllowUserRefund(*v) + } + return _u +} + // SetUpdatedAt sets the "updated_at" field. func (_u *PaymentProviderInstanceUpdateOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdateOne { _u.mutation.SetUpdatedAt(v) @@ -575,6 +606,9 @@ func (_u *PaymentProviderInstanceUpdateOne) sqlSave(ctx context.Context) (_node if value, ok := _u.mutation.RefundEnabled(); ok { _spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value) } + if value, ok := _u.mutation.AllowUserRefund(); ok { + _spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value) + } if value, ok := _u.mutation.UpdatedAt(); ok { _spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 821b7d66..fbdd08c7 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -668,12 +668,16 @@ func init() { paymentproviderinstanceDescRefundEnabled := paymentproviderinstanceFields[8].Descriptor() // paymentproviderinstance.DefaultRefundEnabled holds the default value on creation for the refund_enabled field. paymentproviderinstance.DefaultRefundEnabled = paymentproviderinstanceDescRefundEnabled.Default.(bool) + // paymentproviderinstanceDescAllowUserRefund is the schema descriptor for allow_user_refund field. + paymentproviderinstanceDescAllowUserRefund := paymentproviderinstanceFields[9].Descriptor() + // paymentproviderinstance.DefaultAllowUserRefund holds the default value on creation for the allow_user_refund field. + paymentproviderinstance.DefaultAllowUserRefund = paymentproviderinstanceDescAllowUserRefund.Default.(bool) // paymentproviderinstanceDescCreatedAt is the schema descriptor for created_at field. - paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[9].Descriptor() + paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[10].Descriptor() // paymentproviderinstance.DefaultCreatedAt holds the default value on creation for the created_at field. paymentproviderinstance.DefaultCreatedAt = paymentproviderinstanceDescCreatedAt.Default.(func() time.Time) // paymentproviderinstanceDescUpdatedAt is the schema descriptor for updated_at field. - paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[10].Descriptor() + paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[11].Descriptor() // paymentproviderinstance.DefaultUpdatedAt holds the default value on creation for the updated_at field. paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time) // paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. @@ -1293,6 +1297,22 @@ func init() { userDescTotpEnabled := userFields[9].Descriptor() // user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field. user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool) + // userDescBalanceNotifyEnabled is the schema descriptor for balance_notify_enabled field. + userDescBalanceNotifyEnabled := userFields[11].Descriptor() + // user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field. + user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool) + // userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field. + userDescBalanceNotifyThresholdType := userFields[12].Descriptor() + // user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field. + user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string) + // userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field. + userDescBalanceNotifyExtraEmails := userFields[14].Descriptor() + // user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field. + user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string) + // userDescTotalRecharged is the schema descriptor for total_recharged field. + userDescTotalRecharged := userFields[15].Descriptor() + // user.DefaultTotalRecharged holds the default value on creation for the total_recharged field. + user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64) userallowedgroupFields := schema.UserAllowedGroup{}.Fields() _ = userallowedgroupFields // userallowedgroupDescCreatedAt is the schema descriptor for created_at field. diff --git a/backend/ent/schema/payment_provider_instance.go b/backend/ent/schema/payment_provider_instance.go index 08ab7d31..e4c0b72c 100644 --- a/backend/ent/schema/payment_provider_instance.go +++ b/backend/ent/schema/payment_provider_instance.go @@ -53,6 +53,8 @@ func (PaymentProviderInstance) Fields() []ent.Field { Default(""), field.Bool("refund_enabled"). Default(false), + field.Bool("allow_user_refund"). + Default(false), field.Time("created_at"). Immutable(). Default(time.Now). diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index af143d38..ef52e985 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -72,6 +72,22 @@ func (User) Fields() []ent.Field { field.Time("totp_enabled_at"). Optional(). Nillable(), + + // 余额不足通知 + field.Bool("balance_notify_enabled"). + Default(true), + field.String("balance_notify_threshold_type"). + Default("fixed"), // "fixed" | "percentage" + field.Float("balance_notify_threshold"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Optional(). + Nillable(), + field.String("balance_notify_extra_emails"). + SchemaType(map[string]string{dialect.Postgres: "text"}). + Default("[]"), + field.Float("total_recharged"). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). + Default(0), } } diff --git a/backend/ent/user.go b/backend/ent/user.go index a0eef2ba..9fa91f74 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -45,6 +45,16 @@ type User struct { TotpEnabled bool `json:"totp_enabled,omitempty"` // TotpEnabledAt holds the value of the "totp_enabled_at" field. TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"` + // BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field. + BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"` + // BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field. + BalanceNotifyThresholdType string `json:"balance_notify_threshold_type,omitempty"` + // BalanceNotifyThreshold holds the value of the "balance_notify_threshold" field. + BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"` + // BalanceNotifyExtraEmails holds the value of the "balance_notify_extra_emails" field. + BalanceNotifyExtraEmails string `json:"balance_notify_extra_emails,omitempty"` + // TotalRecharged holds the value of the "total_recharged" field. + TotalRecharged float64 `json:"total_recharged,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the UserQuery when eager-loading is set. Edges UserEdges `json:"edges"` @@ -184,13 +194,13 @@ func (*User) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case user.FieldTotpEnabled: + case user.FieldTotpEnabled, user.FieldBalanceNotifyEnabled: values[i] = new(sql.NullBool) - case user.FieldBalance: + case user.FieldBalance, user.FieldBalanceNotifyThreshold, user.FieldTotalRecharged: values[i] = new(sql.NullFloat64) case user.FieldID, user.FieldConcurrency: values[i] = new(sql.NullInt64) - case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted: + case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails: values[i] = new(sql.NullString) case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt: values[i] = new(sql.NullTime) @@ -302,6 +312,37 @@ func (_m *User) assignValues(columns []string, values []any) error { _m.TotpEnabledAt = new(time.Time) *_m.TotpEnabledAt = value.Time } + case user.FieldBalanceNotifyEnabled: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i]) + } else if value.Valid { + _m.BalanceNotifyEnabled = value.Bool + } + case user.FieldBalanceNotifyThresholdType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field balance_notify_threshold_type", values[i]) + } else if value.Valid { + _m.BalanceNotifyThresholdType = value.String + } + case user.FieldBalanceNotifyThreshold: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field balance_notify_threshold", values[i]) + } else if value.Valid { + _m.BalanceNotifyThreshold = new(float64) + *_m.BalanceNotifyThreshold = value.Float64 + } + case user.FieldBalanceNotifyExtraEmails: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field balance_notify_extra_emails", values[i]) + } else if value.Valid { + _m.BalanceNotifyExtraEmails = value.String + } + case user.FieldTotalRecharged: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field total_recharged", values[i]) + } else if value.Valid { + _m.TotalRecharged = value.Float64 + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -440,6 +481,23 @@ func (_m *User) String() string { builder.WriteString("totp_enabled_at=") builder.WriteString(v.Format(time.ANSIC)) } + builder.WriteString(", ") + builder.WriteString("balance_notify_enabled=") + builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled)) + builder.WriteString(", ") + builder.WriteString("balance_notify_threshold_type=") + builder.WriteString(_m.BalanceNotifyThresholdType) + builder.WriteString(", ") + if v := _m.BalanceNotifyThreshold; v != nil { + builder.WriteString("balance_notify_threshold=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("balance_notify_extra_emails=") + builder.WriteString(_m.BalanceNotifyExtraEmails) + builder.WriteString(", ") + builder.WriteString("total_recharged=") + builder.WriteString(fmt.Sprintf("%v", _m.TotalRecharged)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index 338518a8..d88a3a38 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -43,6 +43,16 @@ const ( FieldTotpEnabled = "totp_enabled" // FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database. FieldTotpEnabledAt = "totp_enabled_at" + // FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database. + FieldBalanceNotifyEnabled = "balance_notify_enabled" + // FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database. + FieldBalanceNotifyThresholdType = "balance_notify_threshold_type" + // FieldBalanceNotifyThreshold holds the string denoting the balance_notify_threshold field in the database. + FieldBalanceNotifyThreshold = "balance_notify_threshold" + // FieldBalanceNotifyExtraEmails holds the string denoting the balance_notify_extra_emails field in the database. + FieldBalanceNotifyExtraEmails = "balance_notify_extra_emails" + // FieldTotalRecharged holds the string denoting the total_recharged field in the database. + FieldTotalRecharged = "total_recharged" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -161,6 +171,11 @@ var Columns = []string{ FieldTotpSecretEncrypted, FieldTotpEnabled, FieldTotpEnabledAt, + FieldBalanceNotifyEnabled, + FieldBalanceNotifyThresholdType, + FieldBalanceNotifyThreshold, + FieldBalanceNotifyExtraEmails, + FieldTotalRecharged, } var ( @@ -217,6 +232,14 @@ var ( DefaultNotes string // DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field. DefaultTotpEnabled bool + // DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field. + DefaultBalanceNotifyEnabled bool + // DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field. + DefaultBalanceNotifyThresholdType string + // DefaultBalanceNotifyExtraEmails holds the default value on creation for the "balance_notify_extra_emails" field. + DefaultBalanceNotifyExtraEmails string + // DefaultTotalRecharged holds the default value on creation for the "total_recharged" field. + DefaultTotalRecharged float64 ) // OrderOption defines the ordering options for the User queries. @@ -297,6 +320,31 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc() } +// ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field. +func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc() +} + +// ByBalanceNotifyThresholdType orders the results by the balance_notify_threshold_type field. +func ByBalanceNotifyThresholdType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBalanceNotifyThresholdType, opts...).ToFunc() +} + +// ByBalanceNotifyThreshold orders the results by the balance_notify_threshold field. +func ByBalanceNotifyThreshold(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBalanceNotifyThreshold, opts...).ToFunc() +} + +// ByBalanceNotifyExtraEmails orders the results by the balance_notify_extra_emails field. +func ByBalanceNotifyExtraEmails(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBalanceNotifyExtraEmails, opts...).ToFunc() +} + +// ByTotalRecharged orders the results by the total_recharged field. +func ByTotalRecharged(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotalRecharged, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index b1d1000f..2788aa7a 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -125,6 +125,31 @@ func TotpEnabledAt(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v)) } +// BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ. +func BalanceNotifyEnabled(v bool) predicate.User { + return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v)) +} + +// BalanceNotifyThresholdType applies equality check predicate on the "balance_notify_threshold_type" field. It's identical to BalanceNotifyThresholdTypeEQ. +func BalanceNotifyThresholdType(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThreshold applies equality check predicate on the "balance_notify_threshold" field. It's identical to BalanceNotifyThresholdEQ. +func BalanceNotifyThreshold(v float64) predicate.User { + return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v)) +} + +// BalanceNotifyExtraEmails applies equality check predicate on the "balance_notify_extra_emails" field. It's identical to BalanceNotifyExtraEmailsEQ. +func BalanceNotifyExtraEmails(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldBalanceNotifyExtraEmails, v)) +} + +// TotalRecharged applies equality check predicate on the "total_recharged" field. It's identical to TotalRechargedEQ. +func TotalRecharged(v float64) predicate.User { + return predicate.User(sql.FieldEQ(FieldTotalRecharged, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) @@ -860,6 +885,236 @@ func TotpEnabledAtNotNil() predicate.User { return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt)) } +// BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field. +func BalanceNotifyEnabledEQ(v bool) predicate.User { + return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v)) +} + +// BalanceNotifyEnabledNEQ applies the NEQ predicate on the "balance_notify_enabled" field. +func BalanceNotifyEnabledNEQ(v bool) predicate.User { + return predicate.User(sql.FieldNEQ(FieldBalanceNotifyEnabled, v)) +} + +// BalanceNotifyThresholdTypeEQ applies the EQ predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeEQ(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeNEQ applies the NEQ predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeNEQ(v string) predicate.User { + return predicate.User(sql.FieldNEQ(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeIn applies the In predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeIn(vs ...string) predicate.User { + return predicate.User(sql.FieldIn(FieldBalanceNotifyThresholdType, vs...)) +} + +// BalanceNotifyThresholdTypeNotIn applies the NotIn predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeNotIn(vs ...string) predicate.User { + return predicate.User(sql.FieldNotIn(FieldBalanceNotifyThresholdType, vs...)) +} + +// BalanceNotifyThresholdTypeGT applies the GT predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeGT(v string) predicate.User { + return predicate.User(sql.FieldGT(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeGTE applies the GTE predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeGTE(v string) predicate.User { + return predicate.User(sql.FieldGTE(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeLT applies the LT predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeLT(v string) predicate.User { + return predicate.User(sql.FieldLT(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeLTE applies the LTE predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeLTE(v string) predicate.User { + return predicate.User(sql.FieldLTE(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeContains applies the Contains predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeContains(v string) predicate.User { + return predicate.User(sql.FieldContains(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeHasPrefix applies the HasPrefix predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeHasPrefix(v string) predicate.User { + return predicate.User(sql.FieldHasPrefix(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeHasSuffix applies the HasSuffix predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeHasSuffix(v string) predicate.User { + return predicate.User(sql.FieldHasSuffix(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeEqualFold applies the EqualFold predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeEqualFold(v string) predicate.User { + return predicate.User(sql.FieldEqualFold(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdTypeContainsFold applies the ContainsFold predicate on the "balance_notify_threshold_type" field. +func BalanceNotifyThresholdTypeContainsFold(v string) predicate.User { + return predicate.User(sql.FieldContainsFold(FieldBalanceNotifyThresholdType, v)) +} + +// BalanceNotifyThresholdEQ applies the EQ predicate on the "balance_notify_threshold" field. +func BalanceNotifyThresholdEQ(v float64) predicate.User { + return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v)) +} + +// BalanceNotifyThresholdNEQ applies the NEQ predicate on the "balance_notify_threshold" field. +func BalanceNotifyThresholdNEQ(v float64) predicate.User { + return predicate.User(sql.FieldNEQ(FieldBalanceNotifyThreshold, v)) +} + +// BalanceNotifyThresholdIn applies the In predicate on the "balance_notify_threshold" field. +func BalanceNotifyThresholdIn(vs ...float64) predicate.User { + return predicate.User(sql.FieldIn(FieldBalanceNotifyThreshold, vs...)) +} + +// BalanceNotifyThresholdNotIn applies the NotIn predicate on the "balance_notify_threshold" field. +func BalanceNotifyThresholdNotIn(vs ...float64) predicate.User { + return predicate.User(sql.FieldNotIn(FieldBalanceNotifyThreshold, vs...)) +} + +// BalanceNotifyThresholdGT applies the GT predicate on the "balance_notify_threshold" field. +func BalanceNotifyThresholdGT(v float64) predicate.User { + return predicate.User(sql.FieldGT(FieldBalanceNotifyThreshold, v)) +} + +// BalanceNotifyThresholdGTE applies the GTE predicate on the "balance_notify_threshold" field. +func BalanceNotifyThresholdGTE(v float64) predicate.User { + return predicate.User(sql.FieldGTE(FieldBalanceNotifyThreshold, v)) +} + +// BalanceNotifyThresholdLT applies the LT predicate on the "balance_notify_threshold" field. +func BalanceNotifyThresholdLT(v float64) predicate.User { + return predicate.User(sql.FieldLT(FieldBalanceNotifyThreshold, v)) +} + +// BalanceNotifyThresholdLTE applies the LTE predicate on the "balance_notify_threshold" field. +func BalanceNotifyThresholdLTE(v float64) predicate.User { + return predicate.User(sql.FieldLTE(FieldBalanceNotifyThreshold, v)) +} + +// BalanceNotifyThresholdIsNil applies the IsNil predicate on the "balance_notify_threshold" field. +func BalanceNotifyThresholdIsNil() predicate.User { + return predicate.User(sql.FieldIsNull(FieldBalanceNotifyThreshold)) +} + +// BalanceNotifyThresholdNotNil applies the NotNil predicate on the "balance_notify_threshold" field. +func BalanceNotifyThresholdNotNil() predicate.User { + return predicate.User(sql.FieldNotNull(FieldBalanceNotifyThreshold)) +} + +// BalanceNotifyExtraEmailsEQ applies the EQ predicate on the "balance_notify_extra_emails" field. +func BalanceNotifyExtraEmailsEQ(v string) predicate.User { + return predicate.User(sql.FieldEQ(FieldBalanceNotifyExtraEmails, v)) +} + +// BalanceNotifyExtraEmailsNEQ applies the NEQ predicate on the "balance_notify_extra_emails" field. +func BalanceNotifyExtraEmailsNEQ(v string) predicate.User { + return predicate.User(sql.FieldNEQ(FieldBalanceNotifyExtraEmails, v)) +} + +// BalanceNotifyExtraEmailsIn applies the In predicate on the "balance_notify_extra_emails" field. +func BalanceNotifyExtraEmailsIn(vs ...string) predicate.User { + return predicate.User(sql.FieldIn(FieldBalanceNotifyExtraEmails, vs...)) +} + +// BalanceNotifyExtraEmailsNotIn applies the NotIn predicate on the "balance_notify_extra_emails" field. +func BalanceNotifyExtraEmailsNotIn(vs ...string) predicate.User { + return predicate.User(sql.FieldNotIn(FieldBalanceNotifyExtraEmails, vs...)) +} + +// BalanceNotifyExtraEmailsGT applies the GT predicate on the "balance_notify_extra_emails" field. +func BalanceNotifyExtraEmailsGT(v string) predicate.User { + return predicate.User(sql.FieldGT(FieldBalanceNotifyExtraEmails, v)) +} + +// BalanceNotifyExtraEmailsGTE applies the GTE predicate on the "balance_notify_extra_emails" field. +func BalanceNotifyExtraEmailsGTE(v string) predicate.User { + return predicate.User(sql.FieldGTE(FieldBalanceNotifyExtraEmails, v)) +} + +// BalanceNotifyExtraEmailsLT applies the LT predicate on the "balance_notify_extra_emails" field. +func BalanceNotifyExtraEmailsLT(v string) predicate.User { + return predicate.User(sql.FieldLT(FieldBalanceNotifyExtraEmails, v)) +} + +// BalanceNotifyExtraEmailsLTE applies the LTE predicate on the "balance_notify_extra_emails" field. +func BalanceNotifyExtraEmailsLTE(v string) predicate.User { + return predicate.User(sql.FieldLTE(FieldBalanceNotifyExtraEmails, v)) +} + +// BalanceNotifyExtraEmailsContains applies the Contains predicate on the "balance_notify_extra_emails" field. +func BalanceNotifyExtraEmailsContains(v string) predicate.User { + return predicate.User(sql.FieldContains(FieldBalanceNotifyExtraEmails, v)) +} + +// BalanceNotifyExtraEmailsHasPrefix applies the HasPrefix predicate on the "balance_notify_extra_emails" field. +func BalanceNotifyExtraEmailsHasPrefix(v string) predicate.User { + return predicate.User(sql.FieldHasPrefix(FieldBalanceNotifyExtraEmails, v)) +} + +// BalanceNotifyExtraEmailsHasSuffix applies the HasSuffix predicate on the "balance_notify_extra_emails" field. +func BalanceNotifyExtraEmailsHasSuffix(v string) predicate.User { + return predicate.User(sql.FieldHasSuffix(FieldBalanceNotifyExtraEmails, v)) +} + +// BalanceNotifyExtraEmailsEqualFold applies the EqualFold predicate on the "balance_notify_extra_emails" field. +func BalanceNotifyExtraEmailsEqualFold(v string) predicate.User { + return predicate.User(sql.FieldEqualFold(FieldBalanceNotifyExtraEmails, v)) +} + +// BalanceNotifyExtraEmailsContainsFold applies the ContainsFold predicate on the "balance_notify_extra_emails" field. +func BalanceNotifyExtraEmailsContainsFold(v string) predicate.User { + return predicate.User(sql.FieldContainsFold(FieldBalanceNotifyExtraEmails, v)) +} + +// TotalRechargedEQ applies the EQ predicate on the "total_recharged" field. +func TotalRechargedEQ(v float64) predicate.User { + return predicate.User(sql.FieldEQ(FieldTotalRecharged, v)) +} + +// TotalRechargedNEQ applies the NEQ predicate on the "total_recharged" field. +func TotalRechargedNEQ(v float64) predicate.User { + return predicate.User(sql.FieldNEQ(FieldTotalRecharged, v)) +} + +// TotalRechargedIn applies the In predicate on the "total_recharged" field. +func TotalRechargedIn(vs ...float64) predicate.User { + return predicate.User(sql.FieldIn(FieldTotalRecharged, vs...)) +} + +// TotalRechargedNotIn applies the NotIn predicate on the "total_recharged" field. +func TotalRechargedNotIn(vs ...float64) predicate.User { + return predicate.User(sql.FieldNotIn(FieldTotalRecharged, vs...)) +} + +// TotalRechargedGT applies the GT predicate on the "total_recharged" field. +func TotalRechargedGT(v float64) predicate.User { + return predicate.User(sql.FieldGT(FieldTotalRecharged, v)) +} + +// TotalRechargedGTE applies the GTE predicate on the "total_recharged" field. +func TotalRechargedGTE(v float64) predicate.User { + return predicate.User(sql.FieldGTE(FieldTotalRecharged, v)) +} + +// TotalRechargedLT applies the LT predicate on the "total_recharged" field. +func TotalRechargedLT(v float64) predicate.User { + return predicate.User(sql.FieldLT(FieldTotalRecharged, v)) +} + +// TotalRechargedLTE applies the LTE predicate on the "total_recharged" field. +func TotalRechargedLTE(v float64) predicate.User { + return predicate.User(sql.FieldLTE(FieldTotalRecharged, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index 7f1c5df1..fbc64f9c 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -211,6 +211,76 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate { return _c } +// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. +func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate { + _c.mutation.SetBalanceNotifyEnabled(v) + return _c +} + +// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil. +func (_c *UserCreate) SetNillableBalanceNotifyEnabled(v *bool) *UserCreate { + if v != nil { + _c.SetBalanceNotifyEnabled(*v) + } + return _c +} + +// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field. +func (_c *UserCreate) SetBalanceNotifyThresholdType(v string) *UserCreate { + _c.mutation.SetBalanceNotifyThresholdType(v) + return _c +} + +// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil. +func (_c *UserCreate) SetNillableBalanceNotifyThresholdType(v *string) *UserCreate { + if v != nil { + _c.SetBalanceNotifyThresholdType(*v) + } + return _c +} + +// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field. +func (_c *UserCreate) SetBalanceNotifyThreshold(v float64) *UserCreate { + _c.mutation.SetBalanceNotifyThreshold(v) + return _c +} + +// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil. +func (_c *UserCreate) SetNillableBalanceNotifyThreshold(v *float64) *UserCreate { + if v != nil { + _c.SetBalanceNotifyThreshold(*v) + } + return _c +} + +// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field. +func (_c *UserCreate) SetBalanceNotifyExtraEmails(v string) *UserCreate { + _c.mutation.SetBalanceNotifyExtraEmails(v) + return _c +} + +// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil. +func (_c *UserCreate) SetNillableBalanceNotifyExtraEmails(v *string) *UserCreate { + if v != nil { + _c.SetBalanceNotifyExtraEmails(*v) + } + return _c +} + +// SetTotalRecharged sets the "total_recharged" field. +func (_c *UserCreate) SetTotalRecharged(v float64) *UserCreate { + _c.mutation.SetTotalRecharged(v) + return _c +} + +// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil. +func (_c *UserCreate) SetNillableTotalRecharged(v *float64) *UserCreate { + if v != nil { + _c.SetTotalRecharged(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -440,6 +510,22 @@ func (_c *UserCreate) defaults() error { v := user.DefaultTotpEnabled _c.mutation.SetTotpEnabled(v) } + if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok { + v := user.DefaultBalanceNotifyEnabled + _c.mutation.SetBalanceNotifyEnabled(v) + } + if _, ok := _c.mutation.BalanceNotifyThresholdType(); !ok { + v := user.DefaultBalanceNotifyThresholdType + _c.mutation.SetBalanceNotifyThresholdType(v) + } + if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok { + v := user.DefaultBalanceNotifyExtraEmails + _c.mutation.SetBalanceNotifyExtraEmails(v) + } + if _, ok := _c.mutation.TotalRecharged(); !ok { + v := user.DefaultTotalRecharged + _c.mutation.SetTotalRecharged(v) + } return nil } @@ -503,6 +589,18 @@ func (_c *UserCreate) check() error { if _, ok := _c.mutation.TotpEnabled(); !ok { return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)} } + if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok { + return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)} + } + if _, ok := _c.mutation.BalanceNotifyThresholdType(); !ok { + return &ValidationError{Name: "balance_notify_threshold_type", err: errors.New(`ent: missing required field "User.balance_notify_threshold_type"`)} + } + if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok { + return &ValidationError{Name: "balance_notify_extra_emails", err: errors.New(`ent: missing required field "User.balance_notify_extra_emails"`)} + } + if _, ok := _c.mutation.TotalRecharged(); !ok { + return &ValidationError{Name: "total_recharged", err: errors.New(`ent: missing required field "User.total_recharged"`)} + } return nil } @@ -586,6 +684,26 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value) _node.TotpEnabledAt = &value } + if value, ok := _c.mutation.BalanceNotifyEnabled(); ok { + _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) + _node.BalanceNotifyEnabled = value + } + if value, ok := _c.mutation.BalanceNotifyThresholdType(); ok { + _spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value) + _node.BalanceNotifyThresholdType = value + } + if value, ok := _c.mutation.BalanceNotifyThreshold(); ok { + _spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value) + _node.BalanceNotifyThreshold = &value + } + if value, ok := _c.mutation.BalanceNotifyExtraEmails(); ok { + _spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value) + _node.BalanceNotifyExtraEmails = value + } + if value, ok := _c.mutation.TotalRecharged(); ok { + _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value) + _node.TotalRecharged = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -988,6 +1106,84 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert { return u } +// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. +func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert { + u.Set(user.FieldBalanceNotifyEnabled, v) + return u +} + +// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create. +func (u *UserUpsert) UpdateBalanceNotifyEnabled() *UserUpsert { + u.SetExcluded(user.FieldBalanceNotifyEnabled) + return u +} + +// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field. +func (u *UserUpsert) SetBalanceNotifyThresholdType(v string) *UserUpsert { + u.Set(user.FieldBalanceNotifyThresholdType, v) + return u +} + +// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create. +func (u *UserUpsert) UpdateBalanceNotifyThresholdType() *UserUpsert { + u.SetExcluded(user.FieldBalanceNotifyThresholdType) + return u +} + +// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field. +func (u *UserUpsert) SetBalanceNotifyThreshold(v float64) *UserUpsert { + u.Set(user.FieldBalanceNotifyThreshold, v) + return u +} + +// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create. +func (u *UserUpsert) UpdateBalanceNotifyThreshold() *UserUpsert { + u.SetExcluded(user.FieldBalanceNotifyThreshold) + return u +} + +// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field. +func (u *UserUpsert) AddBalanceNotifyThreshold(v float64) *UserUpsert { + u.Add(user.FieldBalanceNotifyThreshold, v) + return u +} + +// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field. +func (u *UserUpsert) ClearBalanceNotifyThreshold() *UserUpsert { + u.SetNull(user.FieldBalanceNotifyThreshold) + return u +} + +// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field. +func (u *UserUpsert) SetBalanceNotifyExtraEmails(v string) *UserUpsert { + u.Set(user.FieldBalanceNotifyExtraEmails, v) + return u +} + +// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create. +func (u *UserUpsert) UpdateBalanceNotifyExtraEmails() *UserUpsert { + u.SetExcluded(user.FieldBalanceNotifyExtraEmails) + return u +} + +// SetTotalRecharged sets the "total_recharged" field. +func (u *UserUpsert) SetTotalRecharged(v float64) *UserUpsert { + u.Set(user.FieldTotalRecharged, v) + return u +} + +// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create. +func (u *UserUpsert) UpdateTotalRecharged() *UserUpsert { + u.SetExcluded(user.FieldTotalRecharged) + return u +} + +// AddTotalRecharged adds v to the "total_recharged" field. +func (u *UserUpsert) AddTotalRecharged(v float64) *UserUpsert { + u.Add(user.FieldTotalRecharged, v) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1250,6 +1446,97 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne { }) } +// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. +func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetBalanceNotifyEnabled(v) + }) +} + +// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateBalanceNotifyEnabled() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateBalanceNotifyEnabled() + }) +} + +// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field. +func (u *UserUpsertOne) SetBalanceNotifyThresholdType(v string) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetBalanceNotifyThresholdType(v) + }) +} + +// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateBalanceNotifyThresholdType() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateBalanceNotifyThresholdType() + }) +} + +// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field. +func (u *UserUpsertOne) SetBalanceNotifyThreshold(v float64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetBalanceNotifyThreshold(v) + }) +} + +// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field. +func (u *UserUpsertOne) AddBalanceNotifyThreshold(v float64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.AddBalanceNotifyThreshold(v) + }) +} + +// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateBalanceNotifyThreshold() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateBalanceNotifyThreshold() + }) +} + +// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field. +func (u *UserUpsertOne) ClearBalanceNotifyThreshold() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.ClearBalanceNotifyThreshold() + }) +} + +// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field. +func (u *UserUpsertOne) SetBalanceNotifyExtraEmails(v string) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetBalanceNotifyExtraEmails(v) + }) +} + +// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateBalanceNotifyExtraEmails() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateBalanceNotifyExtraEmails() + }) +} + +// SetTotalRecharged sets the "total_recharged" field. +func (u *UserUpsertOne) SetTotalRecharged(v float64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetTotalRecharged(v) + }) +} + +// AddTotalRecharged adds v to the "total_recharged" field. +func (u *UserUpsertOne) AddTotalRecharged(v float64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.AddTotalRecharged(v) + }) +} + +// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateTotalRecharged() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateTotalRecharged() + }) +} + // Exec executes the query. func (u *UserUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1678,6 +1965,97 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk { }) } +// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. +func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetBalanceNotifyEnabled(v) + }) +} + +// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateBalanceNotifyEnabled() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateBalanceNotifyEnabled() + }) +} + +// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field. +func (u *UserUpsertBulk) SetBalanceNotifyThresholdType(v string) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetBalanceNotifyThresholdType(v) + }) +} + +// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateBalanceNotifyThresholdType() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateBalanceNotifyThresholdType() + }) +} + +// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field. +func (u *UserUpsertBulk) SetBalanceNotifyThreshold(v float64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetBalanceNotifyThreshold(v) + }) +} + +// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field. +func (u *UserUpsertBulk) AddBalanceNotifyThreshold(v float64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.AddBalanceNotifyThreshold(v) + }) +} + +// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateBalanceNotifyThreshold() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateBalanceNotifyThreshold() + }) +} + +// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field. +func (u *UserUpsertBulk) ClearBalanceNotifyThreshold() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.ClearBalanceNotifyThreshold() + }) +} + +// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field. +func (u *UserUpsertBulk) SetBalanceNotifyExtraEmails(v string) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetBalanceNotifyExtraEmails(v) + }) +} + +// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateBalanceNotifyExtraEmails() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateBalanceNotifyExtraEmails() + }) +} + +// SetTotalRecharged sets the "total_recharged" field. +func (u *UserUpsertBulk) SetTotalRecharged(v float64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetTotalRecharged(v) + }) +} + +// AddTotalRecharged adds v to the "total_recharged" field. +func (u *UserUpsertBulk) AddTotalRecharged(v float64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.AddTotalRecharged(v) + }) +} + +// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateTotalRecharged() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateTotalRecharged() + }) +} + // Exec executes the query. func (u *UserUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index 8107c980..6b355247 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -243,6 +243,96 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate { return _u } +// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. +func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate { + _u.mutation.SetBalanceNotifyEnabled(v) + return _u +} + +// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil. +func (_u *UserUpdate) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdate { + if v != nil { + _u.SetBalanceNotifyEnabled(*v) + } + return _u +} + +// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field. +func (_u *UserUpdate) SetBalanceNotifyThresholdType(v string) *UserUpdate { + _u.mutation.SetBalanceNotifyThresholdType(v) + return _u +} + +// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil. +func (_u *UserUpdate) SetNillableBalanceNotifyThresholdType(v *string) *UserUpdate { + if v != nil { + _u.SetBalanceNotifyThresholdType(*v) + } + return _u +} + +// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field. +func (_u *UserUpdate) SetBalanceNotifyThreshold(v float64) *UserUpdate { + _u.mutation.ResetBalanceNotifyThreshold() + _u.mutation.SetBalanceNotifyThreshold(v) + return _u +} + +// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil. +func (_u *UserUpdate) SetNillableBalanceNotifyThreshold(v *float64) *UserUpdate { + if v != nil { + _u.SetBalanceNotifyThreshold(*v) + } + return _u +} + +// AddBalanceNotifyThreshold adds value to the "balance_notify_threshold" field. +func (_u *UserUpdate) AddBalanceNotifyThreshold(v float64) *UserUpdate { + _u.mutation.AddBalanceNotifyThreshold(v) + return _u +} + +// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field. +func (_u *UserUpdate) ClearBalanceNotifyThreshold() *UserUpdate { + _u.mutation.ClearBalanceNotifyThreshold() + return _u +} + +// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field. +func (_u *UserUpdate) SetBalanceNotifyExtraEmails(v string) *UserUpdate { + _u.mutation.SetBalanceNotifyExtraEmails(v) + return _u +} + +// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil. +func (_u *UserUpdate) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpdate { + if v != nil { + _u.SetBalanceNotifyExtraEmails(*v) + } + return _u +} + +// SetTotalRecharged sets the "total_recharged" field. +func (_u *UserUpdate) SetTotalRecharged(v float64) *UserUpdate { + _u.mutation.ResetTotalRecharged() + _u.mutation.SetTotalRecharged(v) + return _u +} + +// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil. +func (_u *UserUpdate) SetNillableTotalRecharged(v *float64) *UserUpdate { + if v != nil { + _u.SetTotalRecharged(*v) + } + return _u +} + +// AddTotalRecharged adds value to the "total_recharged" field. +func (_u *UserUpdate) AddTotalRecharged(v float64) *UserUpdate { + _u.mutation.AddTotalRecharged(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -746,6 +836,30 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.BalanceNotifyEnabled(); ok { + _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) + } + if value, ok := _u.mutation.BalanceNotifyThresholdType(); ok { + _spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value) + } + if value, ok := _u.mutation.BalanceNotifyThreshold(); ok { + _spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedBalanceNotifyThreshold(); ok { + _spec.AddField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value) + } + if _u.mutation.BalanceNotifyThresholdCleared() { + _spec.ClearField(user.FieldBalanceNotifyThreshold, field.TypeFloat64) + } + if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok { + _spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value) + } + if value, ok := _u.mutation.TotalRecharged(); ok { + _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedTotalRecharged(); ok { + _spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1434,6 +1548,96 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne { return _u } +// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field. +func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne { + _u.mutation.SetBalanceNotifyEnabled(v) + return _u +} + +// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdateOne { + if v != nil { + _u.SetBalanceNotifyEnabled(*v) + } + return _u +} + +// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field. +func (_u *UserUpdateOne) SetBalanceNotifyThresholdType(v string) *UserUpdateOne { + _u.mutation.SetBalanceNotifyThresholdType(v) + return _u +} + +// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableBalanceNotifyThresholdType(v *string) *UserUpdateOne { + if v != nil { + _u.SetBalanceNotifyThresholdType(*v) + } + return _u +} + +// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field. +func (_u *UserUpdateOne) SetBalanceNotifyThreshold(v float64) *UserUpdateOne { + _u.mutation.ResetBalanceNotifyThreshold() + _u.mutation.SetBalanceNotifyThreshold(v) + return _u +} + +// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableBalanceNotifyThreshold(v *float64) *UserUpdateOne { + if v != nil { + _u.SetBalanceNotifyThreshold(*v) + } + return _u +} + +// AddBalanceNotifyThreshold adds value to the "balance_notify_threshold" field. +func (_u *UserUpdateOne) AddBalanceNotifyThreshold(v float64) *UserUpdateOne { + _u.mutation.AddBalanceNotifyThreshold(v) + return _u +} + +// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field. +func (_u *UserUpdateOne) ClearBalanceNotifyThreshold() *UserUpdateOne { + _u.mutation.ClearBalanceNotifyThreshold() + return _u +} + +// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field. +func (_u *UserUpdateOne) SetBalanceNotifyExtraEmails(v string) *UserUpdateOne { + _u.mutation.SetBalanceNotifyExtraEmails(v) + return _u +} + +// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpdateOne { + if v != nil { + _u.SetBalanceNotifyExtraEmails(*v) + } + return _u +} + +// SetTotalRecharged sets the "total_recharged" field. +func (_u *UserUpdateOne) SetTotalRecharged(v float64) *UserUpdateOne { + _u.mutation.ResetTotalRecharged() + _u.mutation.SetTotalRecharged(v) + return _u +} + +// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableTotalRecharged(v *float64) *UserUpdateOne { + if v != nil { + _u.SetTotalRecharged(*v) + } + return _u +} + +// AddTotalRecharged adds value to the "total_recharged" field. +func (_u *UserUpdateOne) AddTotalRecharged(v float64) *UserUpdateOne { + _u.mutation.AddTotalRecharged(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -1967,6 +2171,30 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.BalanceNotifyEnabled(); ok { + _spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value) + } + if value, ok := _u.mutation.BalanceNotifyThresholdType(); ok { + _spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value) + } + if value, ok := _u.mutation.BalanceNotifyThreshold(); ok { + _spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedBalanceNotifyThreshold(); ok { + _spec.AddField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value) + } + if _u.mutation.BalanceNotifyThresholdCleared() { + _spec.ClearField(user.FieldBalanceNotifyThreshold, field.TypeFloat64) + } + if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok { + _spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value) + } + if value, ok := _u.mutation.TotalRecharged(); ok { + _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedTotalRecharged(); ok { + _spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/go.sum b/backend/go.sum index 5e453b6f..691a7fd6 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -185,6 +185,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -220,6 +222,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -253,6 +257,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -282,6 +288,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= @@ -314,6 +322,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 618d567c..87e3ff5a 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -28,7 +28,7 @@ const ( // DefaultCSPPolicy is the default Content-Security-Policy with nonce support // __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware -const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" +const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com https://*.stripe.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com https://*.stripe.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" // UMQ(用户消息队列)模式常量 const ( diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index fe181a2f..cf58316c 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -233,12 +233,13 @@ func TestLoadForcedCodexInstructionsTemplate(t *testing.T) { configPath := filepath.Join(tempDir, "config.yaml") require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644)) - require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+templatePath+"\"\n"), 0o644)) + yamlSafePath := filepath.ToSlash(templatePath) + require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+yamlSafePath+"\"\n"), 0o644)) t.Setenv("DATA_DIR", tempDir) cfg, err := Load() require.NoError(t, err) - require.Equal(t, templatePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile) + require.Equal(t, yamlSafePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile) require.Equal(t, "server-prefix\n\n{{ .ExistingInstructions }}", cfg.Gateway.ForcedCodexInstructionsTemplate) } diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index d8ad78ce..a3a7000f 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -1412,6 +1412,12 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { c.JSON(409, gin.H{ "error": "mixed_channel_warning", "message": mixedErr.Error(), + "details": gin.H{ + "group_id": mixedErr.GroupID, + "group_name": mixedErr.GroupName, + "current_platform": mixedErr.CurrentPlatform, + "other_platform": mixedErr.OtherPlatform, + }, }) return } diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index c92b35bb..9151d018 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -1,6 +1,7 @@ package admin import ( + "fmt" "strconv" "strings" @@ -26,24 +27,32 @@ func NewChannelHandler(channelService *service.ChannelService, billingService *s // --- Request / Response types --- type createChannelRequest struct { - Name string `json:"name" binding:"required,max=100"` - Description string `json:"description"` - GroupIDs []int64 `json:"group_ids"` - ModelPricing []channelModelPricingRequest `json:"model_pricing"` - ModelMapping map[string]map[string]string `json:"model_mapping"` - BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` - RestrictModels bool `json:"restrict_models"` + Name string `json:"name" binding:"required,max=100"` + Description string `json:"description"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingRequest `json:"model_pricing"` + ModelMapping map[string]map[string]string `json:"model_mapping"` + BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` + RestrictModels bool `json:"restrict_models"` + Features string `json:"features"` + FeaturesConfig map[string]any `json:"features_config"` + ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"` + AccountStatsPricingRules []accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"` } type updateChannelRequest struct { - Name string `json:"name" binding:"omitempty,max=100"` - Description *string `json:"description"` - Status string `json:"status" binding:"omitempty,oneof=active disabled"` - GroupIDs *[]int64 `json:"group_ids"` - ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` - ModelMapping map[string]map[string]string `json:"model_mapping"` - BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` - RestrictModels *bool `json:"restrict_models"` + Name string `json:"name" binding:"omitempty,max=100"` + Description *string `json:"description"` + Status string `json:"status" binding:"omitempty,oneof=active disabled"` + GroupIDs *[]int64 `json:"group_ids"` + ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` + ModelMapping map[string]map[string]string `json:"model_mapping"` + BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` + RestrictModels *bool `json:"restrict_models"` + Features *string `json:"features"` + FeaturesConfig map[string]any `json:"features_config"` + ApplyPricingToAccountStats *bool `json:"apply_pricing_to_account_stats"` + AccountStatsPricingRules *[]accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"` } type channelModelPricingRequest struct { @@ -71,18 +80,29 @@ type pricingIntervalRequest struct { SortOrder int `json:"sort_order"` } +type accountStatsPricingRuleRequest struct { + Name string `json:"name"` + GroupIDs []int64 `json:"group_ids"` + AccountIDs []int64 `json:"account_ids"` + Pricing []channelModelPricingRequest `json:"pricing"` +} + type channelResponse struct { - ID int64 `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - Status string `json:"status"` - BillingModelSource string `json:"billing_model_source"` - RestrictModels bool `json:"restrict_models"` - GroupIDs []int64 `json:"group_ids"` - ModelPricing []channelModelPricingResponse `json:"model_pricing"` - ModelMapping map[string]map[string]string `json:"model_mapping"` - CreatedAt string `json:"created_at"` - UpdatedAt string `json:"updated_at"` + ID int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Status string `json:"status"` + BillingModelSource string `json:"billing_model_source"` + RestrictModels bool `json:"restrict_models"` + Features string `json:"features"` + FeaturesConfig map[string]any `json:"features_config"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingResponse `json:"model_pricing"` + ModelMapping map[string]map[string]string `json:"model_mapping"` + ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"` + AccountStatsPricingRules []accountStatsPricingRuleResponse `json:"account_stats_pricing_rules"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` } type channelModelPricingResponse struct { @@ -112,6 +132,14 @@ type pricingIntervalResponse struct { SortOrder int `json:"sort_order"` } +type accountStatsPricingRuleResponse struct { + ID int64 `json:"id"` + Name string `json:"name"` + GroupIDs []int64 `json:"group_ids"` + AccountIDs []int64 `json:"account_ids"` + Pricing []channelModelPricingResponse `json:"pricing"` +} + func channelToResponse(ch *service.Channel) *channelResponse { if ch == nil { return nil @@ -122,6 +150,8 @@ func channelToResponse(ch *service.Channel) *channelResponse { Description: ch.Description, Status: ch.Status, RestrictModels: ch.RestrictModels, + Features: ch.Features, + FeaturesConfig: ch.FeaturesConfig, GroupIDs: ch.GroupIDs, ModelMapping: ch.ModelMapping, CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), @@ -142,6 +172,29 @@ func channelToResponse(ch *service.Channel) *channelResponse { for _, p := range ch.ModelPricing { resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p)) } + + resp.ApplyPricingToAccountStats = ch.ApplyPricingToAccountStats + resp.AccountStatsPricingRules = make([]accountStatsPricingRuleResponse, 0, len(ch.AccountStatsPricingRules)) + for _, rule := range ch.AccountStatsPricingRules { + ruleResp := accountStatsPricingRuleResponse{ + ID: rule.ID, + Name: rule.Name, + GroupIDs: rule.GroupIDs, + AccountIDs: rule.AccountIDs, + Pricing: make([]channelModelPricingResponse, 0, len(rule.Pricing)), + } + if ruleResp.GroupIDs == nil { + ruleResp.GroupIDs = []int64{} + } + if ruleResp.AccountIDs == nil { + ruleResp.AccountIDs = []int64{} + } + for i := range rule.Pricing { + ruleResp.Pricing = append(ruleResp.Pricing, pricingToResponse(&rule.Pricing[i])) + } + resp.AccountStatsPricingRules = append(resp.AccountStatsPricingRules, ruleResp) + } + return resp } @@ -200,9 +253,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe billingMode = service.BillingModeToken } platform := r.Platform - if platform == "" { - platform = service.PlatformAnthropic - } intervals := make([]service.PricingInterval, 0, len(r.Intervals)) for _, iv := range r.Intervals { intervals = append(intervals, service.PricingInterval{ @@ -233,6 +283,15 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe return result } +func accountStatsPricingRuleRequestToService(r accountStatsPricingRuleRequest) service.AccountStatsPricingRule { + return service.AccountStatsPricingRule{ + Name: r.Name, + GroupIDs: r.GroupIDs, + AccountIDs: r.AccountIDs, + Pricing: pricingRequestToService(r.Pricing), + } +} + // --- Handlers --- // List handles listing channels with pagination @@ -291,15 +350,42 @@ func (h *ChannelHandler) Create(c *gin.Context) { } pricing := pricingRequestToService(req.ModelPricing) + // Main model_pricing requires a platform; default to anthropic for backward compatibility. + for i := range pricing { + if pricing[i].Platform == "" { + pricing[i].Platform = service.PlatformAnthropic + } + } + + var statsRules []service.AccountStatsPricingRule + for i, r := range req.AccountStatsPricingRules { + if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 { + response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE", + fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1))) + return + } + if len(r.Pricing) == 0 { + response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING", + fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1))) + return + } + rule := accountStatsPricingRuleRequestToService(r) + rule.SortOrder = i + statsRules = append(statsRules, rule) + } channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ - Name: req.Name, - Description: req.Description, - GroupIDs: req.GroupIDs, - ModelPricing: pricing, - ModelMapping: req.ModelMapping, - BillingModelSource: req.BillingModelSource, - RestrictModels: req.RestrictModels, + Name: req.Name, + Description: req.Description, + GroupIDs: req.GroupIDs, + ModelPricing: pricing, + ModelMapping: req.ModelMapping, + BillingModelSource: req.BillingModelSource, + RestrictModels: req.RestrictModels, + Features: req.Features, + FeaturesConfig: req.FeaturesConfig, + ApplyPricingToAccountStats: req.ApplyPricingToAccountStats, + AccountStatsPricingRules: statsRules, }) if err != nil { response.ErrorFrom(c, err) @@ -325,18 +411,45 @@ func (h *ChannelHandler) Update(c *gin.Context) { } input := &service.UpdateChannelInput{ - Name: req.Name, - Description: req.Description, - Status: req.Status, - GroupIDs: req.GroupIDs, - ModelMapping: req.ModelMapping, - BillingModelSource: req.BillingModelSource, - RestrictModels: req.RestrictModels, + Name: req.Name, + Description: req.Description, + Status: req.Status, + GroupIDs: req.GroupIDs, + ModelMapping: req.ModelMapping, + BillingModelSource: req.BillingModelSource, + RestrictModels: req.RestrictModels, + Features: req.Features, + FeaturesConfig: req.FeaturesConfig, + ApplyPricingToAccountStats: req.ApplyPricingToAccountStats, } if req.ModelPricing != nil { pricing := pricingRequestToService(*req.ModelPricing) + for i := range pricing { + if pricing[i].Platform == "" { + pricing[i].Platform = service.PlatformAnthropic + } + } input.ModelPricing = &pricing } + if req.AccountStatsPricingRules != nil { + statsRules := make([]service.AccountStatsPricingRule, 0, len(*req.AccountStatsPricingRules)) + for i, r := range *req.AccountStatsPricingRules { + if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 { + response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE", + fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1))) + return + } + if len(r.Pricing) == 0 { + response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING", + fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1))) + return + } + rule := accountStatsPricingRuleRequestToService(r) + rule.SortOrder = i + statsRules = append(statsRules, rule) + } + input.AccountStatsPricingRules = &statsRules + } channel, err := h.channelService.Update(c.Request.Context(), id, input) if err != nil { diff --git a/backend/internal/handler/admin/channel_handler_test.go b/backend/internal/handler/admin/channel_handler_test.go index 2f4b4440..f218cce4 100644 --- a/backend/internal/handler/admin/channel_handler_test.go +++ b/backend/internal/handler/admin/channel_handler_test.go @@ -273,13 +273,13 @@ func TestPricingRequestToService_Defaults(t *testing.T) { wantValue: string(service.BillingModeToken), }, { - name: "empty platform defaults to anthropic", + name: "empty platform stays empty", req: channelModelPricingRequest{ Models: []string{"m1"}, Platform: "", }, wantField: "Platform", - wantValue: "anthropic", + wantValue: "", }, } diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index ba751131..bec0f126 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -5,11 +5,10 @@ import ( "encoding/hex" "encoding/json" "fmt" - "log" + "log/slog" "net/http" "regexp" "strings" - "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -175,6 +174,12 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { EnableFingerprintUnification: settings.EnableFingerprintUnification, EnableMetadataPassthrough: settings.EnableMetadataPassthrough, EnableCCHSigning: settings.EnableCCHSigning, + WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled, + BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled, + BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, + BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL, + AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled, + AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails), PaymentEnabled: paymentCfg.Enabled, PaymentMinAmount: paymentCfg.MinAmount, PaymentMaxAmount: paymentCfg.MaxAmount, @@ -183,6 +188,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders, PaymentEnabledTypes: paymentCfg.EnabledTypes, PaymentBalanceDisabled: paymentCfg.BalanceDisabled, + PaymentBalanceRechargeMultiplier: paymentCfg.BalanceRechargeMultiplier, + PaymentRechargeFeeRate: paymentCfg.RechargeFeeRate, PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy, PaymentProductNamePrefix: paymentCfg.ProductNamePrefix, PaymentProductNameSuffix: paymentCfg.ProductNameSuffix, @@ -304,20 +311,29 @@ type UpdateSettingsRequest struct { EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"` EnableCCHSigning *bool `json:"enable_cch_signing"` + // Balance low notification + BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"` + BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"` + BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"` + AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"` + AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"` + // Payment configuration (integrated into settings, full replace) - PaymentEnabled *bool `json:"payment_enabled"` - PaymentMinAmount *float64 `json:"payment_min_amount"` - PaymentMaxAmount *float64 `json:"payment_max_amount"` - PaymentDailyLimit *float64 `json:"payment_daily_limit"` - PaymentOrderTimeoutMin *int `json:"payment_order_timeout_minutes"` - PaymentMaxPendingOrders *int `json:"payment_max_pending_orders"` - PaymentEnabledTypes []string `json:"payment_enabled_types"` - PaymentBalanceDisabled *bool `json:"payment_balance_disabled"` - PaymentLoadBalanceStrat *string `json:"payment_load_balance_strategy"` - PaymentProductNamePrefix *string `json:"payment_product_name_prefix"` - PaymentProductNameSuffix *string `json:"payment_product_name_suffix"` - PaymentHelpImageURL *string `json:"payment_help_image_url"` - PaymentHelpText *string `json:"payment_help_text"` + PaymentEnabled *bool `json:"payment_enabled"` + PaymentMinAmount *float64 `json:"payment_min_amount"` + PaymentMaxAmount *float64 `json:"payment_max_amount"` + PaymentDailyLimit *float64 `json:"payment_daily_limit"` + PaymentOrderTimeoutMin *int `json:"payment_order_timeout_minutes"` + PaymentMaxPendingOrders *int `json:"payment_max_pending_orders"` + PaymentEnabledTypes []string `json:"payment_enabled_types"` + PaymentBalanceDisabled *bool `json:"payment_balance_disabled"` + PaymentBalanceRechargeMultiplier *float64 `json:"payment_balance_recharge_multiplier"` + PaymentRechargeFeeRate *float64 `json:"payment_recharge_fee_rate"` + PaymentLoadBalanceStrat *string `json:"payment_load_balance_strategy"` + PaymentProductNamePrefix *string `json:"payment_product_name_prefix"` + PaymentProductNameSuffix *string `json:"payment_product_name_suffix"` + PaymentHelpImageURL *string `json:"payment_help_image_url"` + PaymentHelpText *string `json:"payment_help_text"` // Cancel rate limit PaymentCancelRateLimitEnabled *bool `json:"payment_cancel_rate_limit_enabled"` @@ -881,6 +897,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.EnableCCHSigning }(), + BalanceLowNotifyEnabled: func() bool { + if req.BalanceLowNotifyEnabled != nil { + return *req.BalanceLowNotifyEnabled + } + return previousSettings.BalanceLowNotifyEnabled + }(), + BalanceLowNotifyThreshold: func() float64 { + if req.BalanceLowNotifyThreshold != nil { + return *req.BalanceLowNotifyThreshold + } + return previousSettings.BalanceLowNotifyThreshold + }(), + BalanceLowNotifyRechargeURL: func() string { + if req.BalanceLowNotifyRechargeURL != nil { + return *req.BalanceLowNotifyRechargeURL + } + return previousSettings.BalanceLowNotifyRechargeURL + }(), + AccountQuotaNotifyEnabled: func() bool { + if req.AccountQuotaNotifyEnabled != nil { + return *req.AccountQuotaNotifyEnabled + } + return previousSettings.AccountQuotaNotifyEnabled + }(), + AccountQuotaNotifyEmails: func() []service.NotifyEmailEntry { + if req.AccountQuotaNotifyEmails != nil { + return dto.NotifyEmailEntriesToService(*req.AccountQuotaNotifyEmails) + } + return previousSettings.AccountQuotaNotifyEmails + }(), } if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { @@ -892,24 +938,26 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { // Skip if no payment fields were provided (prevents accidental wipe). if h.paymentConfigService != nil && hasPaymentFields(req) { paymentReq := service.UpdatePaymentConfigRequest{ - Enabled: req.PaymentEnabled, - MinAmount: req.PaymentMinAmount, - MaxAmount: req.PaymentMaxAmount, - DailyLimit: req.PaymentDailyLimit, - OrderTimeoutMin: req.PaymentOrderTimeoutMin, - MaxPendingOrders: req.PaymentMaxPendingOrders, - EnabledTypes: req.PaymentEnabledTypes, - BalanceDisabled: req.PaymentBalanceDisabled, - LoadBalanceStrategy: req.PaymentLoadBalanceStrat, - ProductNamePrefix: req.PaymentProductNamePrefix, - ProductNameSuffix: req.PaymentProductNameSuffix, - HelpImageURL: req.PaymentHelpImageURL, - HelpText: req.PaymentHelpText, - CancelRateLimitEnabled: req.PaymentCancelRateLimitEnabled, - CancelRateLimitMax: req.PaymentCancelRateLimitMax, - CancelRateLimitWindow: req.PaymentCancelRateLimitWindow, - CancelRateLimitUnit: req.PaymentCancelRateLimitUnit, - CancelRateLimitMode: req.PaymentCancelRateLimitMode, + Enabled: req.PaymentEnabled, + MinAmount: req.PaymentMinAmount, + MaxAmount: req.PaymentMaxAmount, + DailyLimit: req.PaymentDailyLimit, + OrderTimeoutMin: req.PaymentOrderTimeoutMin, + MaxPendingOrders: req.PaymentMaxPendingOrders, + EnabledTypes: req.PaymentEnabledTypes, + BalanceDisabled: req.PaymentBalanceDisabled, + BalanceRechargeMultiplier: req.PaymentBalanceRechargeMultiplier, + RechargeFeeRate: req.PaymentRechargeFeeRate, + LoadBalanceStrategy: req.PaymentLoadBalanceStrat, + ProductNamePrefix: req.PaymentProductNamePrefix, + ProductNameSuffix: req.PaymentProductNameSuffix, + HelpImageURL: req.PaymentHelpImageURL, + HelpText: req.PaymentHelpText, + CancelRateLimitEnabled: req.PaymentCancelRateLimitEnabled, + CancelRateLimitMax: req.PaymentCancelRateLimitMax, + CancelRateLimitWindow: req.PaymentCancelRateLimitWindow, + CancelRateLimitUnit: req.PaymentCancelRateLimitUnit, + CancelRateLimitMode: req.PaymentCancelRateLimitMode, } if err := h.paymentConfigService.UpdatePaymentConfig(c.Request.Context(), paymentReq); err != nil { response.ErrorFrom(c, err) @@ -1027,6 +1075,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification, EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough, EnableCCHSigning: updatedSettings.EnableCCHSigning, + BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled, + BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold, + BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL, + AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled, + AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails), PaymentEnabled: updatedPaymentCfg.Enabled, PaymentMinAmount: updatedPaymentCfg.MinAmount, PaymentMaxAmount: updatedPaymentCfg.MaxAmount, @@ -1035,6 +1088,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders, PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes, PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled, + PaymentBalanceRechargeMultiplier: updatedPaymentCfg.BalanceRechargeMultiplier, + PaymentRechargeFeeRate: updatedPaymentCfg.RechargeFeeRate, PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy, PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix, PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix, @@ -1054,6 +1109,7 @@ func hasPaymentFields(req UpdateSettingsRequest) bool { req.PaymentMaxAmount != nil || req.PaymentDailyLimit != nil || req.PaymentOrderTimeoutMin != nil || req.PaymentMaxPendingOrders != nil || req.PaymentEnabledTypes != nil || req.PaymentBalanceDisabled != nil || + req.PaymentBalanceRechargeMultiplier != nil || req.PaymentRechargeFeeRate != nil || req.PaymentLoadBalanceStrat != nil || req.PaymentProductNamePrefix != nil || req.PaymentProductNameSuffix != nil || req.PaymentHelpImageURL != nil || req.PaymentHelpText != nil || req.PaymentCancelRateLimitEnabled != nil || @@ -1073,11 +1129,11 @@ func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.Sys subject, _ := middleware.GetAuthSubjectFromContext(c) role, _ := middleware.GetUserRoleFromContext(c) - log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v", - time.Now().UTC().Format(time.RFC3339), - subject.UserID, - role, - changed, + slog.Info("settings updated", + "audit", true, + "user_id", subject.UserID, + "role", role, + "changed", changed, ) } @@ -1092,6 +1148,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) { changed = append(changed, "registration_email_suffix_whitelist") } + if before.PromoCodeEnabled != after.PromoCodeEnabled { + changed = append(changed, "promo_code_enabled") + } + if before.InvitationCodeEnabled != after.InvitationCodeEnabled { + changed = append(changed, "invitation_code_enabled") + } if before.PasswordResetEnabled != after.PasswordResetEnabled { changed = append(changed, "password_reset_enabled") } @@ -1302,6 +1364,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.CustomMenuItems != after.CustomMenuItems { changed = append(changed, "custom_menu_items") } + if before.CustomEndpoints != after.CustomEndpoints { + changed = append(changed, "custom_endpoints") + } if before.EnableFingerprintUnification != after.EnableFingerprintUnification { changed = append(changed, "enable_fingerprint_unification") } @@ -1311,6 +1376,22 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.EnableCCHSigning != after.EnableCCHSigning { changed = append(changed, "enable_cch_signing") } + // Balance & quota notification + if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled { + changed = append(changed, "balance_low_notify_enabled") + } + if before.BalanceLowNotifyThreshold != after.BalanceLowNotifyThreshold { + changed = append(changed, "balance_low_notify_threshold") + } + if before.BalanceLowNotifyRechargeURL != after.BalanceLowNotifyRechargeURL { + changed = append(changed, "balance_low_notify_recharge_url") + } + if before.AccountQuotaNotifyEnabled != after.AccountQuotaNotifyEnabled { + changed = append(changed, "account_quota_notify_enabled") + } + if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) { + changed = append(changed, "account_quota_notify_emails") + } return changed } @@ -1367,6 +1448,18 @@ func equalIntSlice(a, b []int) bool { return true } +func equalNotifyEmailEntries(a, b []service.NotifyEmailEntry) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].Email != b[i].Email || a[i].Verified != b[i].Verified || a[i].Disabled != b[i].Disabled { + return false + } + } + return true +} + // TestSMTPRequest 测试SMTP连接请求 type TestSMTPRequest struct { SMTPHost string `json:"smtp_host"` @@ -1847,3 +1940,80 @@ func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) { ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes, }) } + +// GetWebSearchEmulationConfig 获取 Web Search 模拟配置 +// GET /api/v1/admin/settings/web-search-emulation +func (h *SettingHandler) GetWebSearchEmulationConfig(c *gin.Context) { + cfg, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), cfg)) +} + +// UpdateWebSearchEmulationConfig 更新 Web Search 模拟配置 +// PUT /api/v1/admin/settings/web-search-emulation +func (h *SettingHandler) UpdateWebSearchEmulationConfig(c *gin.Context) { + var cfg service.WebSearchEmulationConfig + if err := c.ShouldBindJSON(&cfg); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if err := h.settingService.SaveWebSearchEmulationConfig(c.Request.Context(), &cfg); err != nil { + response.ErrorFrom(c, err) + return + } + + // Re-read (with sanitized api keys) to return current state + updated, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), updated)) +} + +// ResetWebSearchUsage 重置指定 provider 的配额用量 +// POST /api/v1/admin/settings/web-search-emulation/reset-usage +func (h *SettingHandler) ResetWebSearchUsage(c *gin.Context) { + var req struct { + ProviderType string `json:"provider_type"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if req.ProviderType == "" { + response.BadRequest(c, "provider_type is required") + return + } + if err := service.ResetWebSearchUsage(c.Request.Context(), req.ProviderType); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, nil) +} + +// TestWebSearchEmulation 测试 Web Search 搜索 +// POST /api/v1/admin/settings/web-search-emulation/test +func (h *SettingHandler) TestWebSearchEmulation(c *gin.Context) { + var req struct { + Query string `json:"query"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if strings.TrimSpace(req.Query) == "" { + req.Query = "搜索今年世界大事件" + } + + result, err := service.TestWebSearch(c.Request.Context(), req.Query) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 478600eb..d2ccb8d6 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -13,16 +13,21 @@ func UserFromServiceShallow(u *service.User) *User { return nil } return &User{ - ID: u.ID, - Email: u.Email, - Username: u.Username, - Role: u.Role, - Balance: u.Balance, - Concurrency: u.Concurrency, - Status: u.Status, - AllowedGroups: u.AllowedGroups, - CreatedAt: u.CreatedAt, - UpdatedAt: u.UpdatedAt, + ID: u.ID, + Email: u.Email, + Username: u.Username, + Role: u.Role, + Balance: u.Balance, + Concurrency: u.Concurrency, + Status: u.Status, + AllowedGroups: u.AllowedGroups, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, + BalanceNotifyEnabled: u.BalanceNotifyEnabled, + BalanceNotifyThresholdType: u.BalanceNotifyThresholdType, + BalanceNotifyThreshold: u.BalanceNotifyThreshold, + BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails), + TotalRecharged: u.TotalRecharged, } } @@ -322,6 +327,26 @@ func AccountFromServiceShallow(a *service.Account) *Account { out.QuotaWeeklyResetAt = &v } } + + // 配额通知配置 + if enabled := a.GetQuotaNotifyDailyEnabled(); enabled { + out.QuotaNotifyDailyEnabled = &enabled + } + if threshold := a.GetQuotaNotifyDailyThreshold(); threshold > 0 { + out.QuotaNotifyDailyThreshold = &threshold + } + if enabled := a.GetQuotaNotifyWeeklyEnabled(); enabled { + out.QuotaNotifyWeeklyEnabled = &enabled + } + if threshold := a.GetQuotaNotifyWeeklyThreshold(); threshold > 0 { + out.QuotaNotifyWeeklyThreshold = &threshold + } + if enabled := a.GetQuotaNotifyTotalEnabled(); enabled { + out.QuotaNotifyTotalEnabled = &enabled + } + if threshold := a.GetQuotaNotifyTotalThreshold(); threshold > 0 { + out.QuotaNotifyTotalThreshold = &threshold + } } return out @@ -603,6 +628,7 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog { ModelMappingChain: l.ModelMappingChain, BillingTier: l.BillingTier, AccountRateMultiplier: l.AccountRateMultiplier, + AccountStatsCost: l.AccountStatsCost, IPAddress: l.IPAddress, Account: AccountSummaryFromService(l.Account), } diff --git a/backend/internal/handler/dto/notify_email_entry.go b/backend/internal/handler/dto/notify_email_entry.go new file mode 100644 index 00000000..78641005 --- /dev/null +++ b/backend/internal/handler/dto/notify_email_entry.go @@ -0,0 +1,43 @@ +package dto + +import "github.com/Wei-Shaw/sub2api/internal/service" + +// NotifyEmailEntry represents a notification email with enable/disable and verification state. +// All emails are user-managed; maximum 3 entries per user. +type NotifyEmailEntry struct { + Email string `json:"email"` + Disabled bool `json:"disabled"` + Verified bool `json:"verified"` +} + +// NotifyEmailEntriesFromService converts service entries to DTO entries. +func NotifyEmailEntriesFromService(entries []service.NotifyEmailEntry) []NotifyEmailEntry { + if entries == nil { + return nil + } + result := make([]NotifyEmailEntry, len(entries)) + for i, e := range entries { + result[i] = NotifyEmailEntry{ + Email: e.Email, + Disabled: e.Disabled, + Verified: e.Verified, + } + } + return result +} + +// NotifyEmailEntriesToService converts DTO entries to service entries. +func NotifyEmailEntriesToService(entries []NotifyEmailEntry) []service.NotifyEmailEntry { + if entries == nil { + return nil + } + result := make([]service.NotifyEmailEntry, len(entries)) + for i, e := range entries { + result[i] = service.NotifyEmailEntry{ + Email: e.Email, + Disabled: e.Disabled, + Verified: e.Verified, + } + } + return result +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index cbbe9216..3659e79b 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -124,20 +124,25 @@ type SystemSettings struct { EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"` EnableCCHSigning bool `json:"enable_cch_signing"` + // Web Search Emulation + WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"` + // Payment configuration - PaymentEnabled bool `json:"payment_enabled"` - PaymentMinAmount float64 `json:"payment_min_amount"` - PaymentMaxAmount float64 `json:"payment_max_amount"` - PaymentDailyLimit float64 `json:"payment_daily_limit"` - PaymentOrderTimeoutMin int `json:"payment_order_timeout_minutes"` - PaymentMaxPendingOrders int `json:"payment_max_pending_orders"` - PaymentEnabledTypes []string `json:"payment_enabled_types"` - PaymentBalanceDisabled bool `json:"payment_balance_disabled"` - PaymentLoadBalanceStrat string `json:"payment_load_balance_strategy"` - PaymentProductNamePrefix string `json:"payment_product_name_prefix"` - PaymentProductNameSuffix string `json:"payment_product_name_suffix"` - PaymentHelpImageURL string `json:"payment_help_image_url"` - PaymentHelpText string `json:"payment_help_text"` + PaymentEnabled bool `json:"payment_enabled"` + PaymentMinAmount float64 `json:"payment_min_amount"` + PaymentMaxAmount float64 `json:"payment_max_amount"` + PaymentDailyLimit float64 `json:"payment_daily_limit"` + PaymentOrderTimeoutMin int `json:"payment_order_timeout_minutes"` + PaymentMaxPendingOrders int `json:"payment_max_pending_orders"` + PaymentEnabledTypes []string `json:"payment_enabled_types"` + PaymentBalanceDisabled bool `json:"payment_balance_disabled"` + PaymentBalanceRechargeMultiplier float64 `json:"payment_balance_recharge_multiplier"` + PaymentRechargeFeeRate float64 `json:"payment_recharge_fee_rate"` + PaymentLoadBalanceStrat string `json:"payment_load_balance_strategy"` + PaymentProductNamePrefix string `json:"payment_product_name_prefix"` + PaymentProductNameSuffix string `json:"payment_product_name_suffix"` + PaymentHelpImageURL string `json:"payment_help_image_url"` + PaymentHelpText string `json:"payment_help_text"` // Cancel rate limit PaymentCancelRateLimitEnabled bool `json:"payment_cancel_rate_limit_enabled"` @@ -145,6 +150,13 @@ type SystemSettings struct { PaymentCancelRateLimitWindow int `json:"payment_cancel_rate_limit_window"` PaymentCancelRateLimitUnit string `json:"payment_cancel_rate_limit_unit"` PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"` + + // Balance low notification + BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` + BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` + BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"` + AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` + AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"` } type DefaultSubscriptionSetting struct { @@ -183,6 +195,10 @@ type PublicSettings struct { BackendModeEnabled bool `json:"backend_mode_enabled"` PaymentEnabled bool `json:"payment_enabled"` Version string `json:"version"` + BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` + AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` + BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` + BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"` } // OverloadCooldownSettings 529过载冷却配置 DTO diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index e026ca65..8c1e166f 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -18,6 +18,13 @@ type User struct { CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` + // 余额不足通知 + BalanceNotifyEnabled bool `json:"balance_notify_enabled"` + BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"` + BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` + BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"` + TotalRecharged float64 `json:"total_recharged"` + APIKeys []APIKey `json:"api_keys,omitempty"` Subscriptions []UserSubscription `json:"subscriptions,omitempty"` } @@ -218,6 +225,14 @@ type Account struct { QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"` QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"` + // 配额通知配置 + QuotaNotifyDailyEnabled *bool `json:"quota_notify_daily_enabled,omitempty"` + QuotaNotifyDailyThreshold *float64 `json:"quota_notify_daily_threshold,omitempty"` + QuotaNotifyWeeklyEnabled *bool `json:"quota_notify_weekly_enabled,omitempty"` + QuotaNotifyWeeklyThreshold *float64 `json:"quota_notify_weekly_threshold,omitempty"` + QuotaNotifyTotalEnabled *bool `json:"quota_notify_total_enabled,omitempty"` + QuotaNotifyTotalThreshold *float64 `json:"quota_notify_total_threshold,omitempty"` + Proxy *Proxy `json:"proxy,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"` @@ -412,6 +427,8 @@ type AdminUsageLog struct { // AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理) AccountRateMultiplier *float64 `json:"account_rate_multiplier"` + // AccountStatsCost 自定义定价规则计算的账号统计费用(nil 表示使用默认公式) + AccountStatsCost *float64 `json:"account_stats_cost,omitempty"` // IPAddress 用户请求 IP(仅管理员可见) IPAddress *string `json:"ip_address,omitempty"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 59619d50..f5eff8c9 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -248,6 +248,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } + // 设置请求所属分组 ID(用于渠道级功能判断,如 WebSearch 模拟) + parsedReq.GroupID = apiKey.GroupID + // 计算粘性会话hash parsedReq.SessionContext = &service.SessionContext{ ClientIP: ip.GetClientIP(c), @@ -470,6 +473,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, + ParsedRequest: parsedReq, APIKey: apiKey, User: apiKey.User, Account: account, @@ -518,7 +522,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { for { // 选择支持该模型的账号 - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, int64(0)) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID) if err != nil { if len(fs.FailedAccountIDs) == 0 { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) @@ -672,6 +676,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 转发请求 - 根据账号平台分流 + c.Set("parsed_request", parsedReq) var result *service.ForwardResult requestCtx := c.Request.Context() if fs.SwitchCount > 0 { @@ -810,6 +815,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, + ParsedRequest: parsedReq, APIKey: currentAPIKey, User: currentAPIKey.User, Account: account, diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index acea3780..1fdc46ba 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -168,6 +168,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // tlsFPProfileService nil, // channelService nil, // resolver + nil, // balanceNotifyService ) // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index 0425fc49..1ddb8ae2 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -126,26 +126,30 @@ func (h *PaymentHandler) GetCheckoutInfo(c *gin.Context) { } response.Success(c, checkoutInfoResponse{ - Methods: limitsResp.Methods, - GlobalMin: limitsResp.GlobalMin, - GlobalMax: limitsResp.GlobalMax, - Plans: planList, - BalanceDisabled: cfg.BalanceDisabled, - HelpText: cfg.HelpText, - HelpImageURL: cfg.HelpImageURL, - StripePublishableKey: cfg.StripePublishableKey, + Methods: limitsResp.Methods, + GlobalMin: limitsResp.GlobalMin, + GlobalMax: limitsResp.GlobalMax, + Plans: planList, + BalanceDisabled: cfg.BalanceDisabled, + BalanceRechargeMultiplier: cfg.BalanceRechargeMultiplier, + RechargeFeeRate: cfg.RechargeFeeRate, + HelpText: cfg.HelpText, + HelpImageURL: cfg.HelpImageURL, + StripePublishableKey: cfg.StripePublishableKey, }) } type checkoutInfoResponse struct { - Methods map[string]service.MethodLimits `json:"methods"` - GlobalMin float64 `json:"global_min"` - GlobalMax float64 `json:"global_max"` - Plans []checkoutPlan `json:"plans"` - BalanceDisabled bool `json:"balance_disabled"` - HelpText string `json:"help_text"` - HelpImageURL string `json:"help_image_url"` - StripePublishableKey string `json:"stripe_publishable_key"` + Methods map[string]service.MethodLimits `json:"methods"` + GlobalMin float64 `json:"global_min"` + GlobalMax float64 `json:"global_max"` + Plans []checkoutPlan `json:"plans"` + BalanceDisabled bool `json:"balance_disabled"` + BalanceRechargeMultiplier float64 `json:"balance_recharge_multiplier"` + RechargeFeeRate float64 `json:"recharge_fee_rate"` + HelpText string `json:"help_text"` + HelpImageURL string `json:"help_image_url"` + StripePublishableKey string `json:"stripe_publishable_key"` } type checkoutPlan struct { @@ -335,6 +339,16 @@ func (h *PaymentHandler) RequestRefund(c *gin.Context) { response.Success(c, gin.H{"message": "refund requested"}) } +// GetRefundEligibleProviders returns provider instance IDs that allow user refund. +func (h *PaymentHandler) GetRefundEligibleProviders(c *gin.Context) { + ids, err := h.configService.GetUserRefundEligibleInstanceIDs(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"provider_instance_ids": ids}) +} + // VerifyOrderRequest is the request body for verifying a payment order. type VerifyOrderRequest struct { OutTradeNo string `json:"out_trade_no" binding:"required"` @@ -371,6 +385,7 @@ type PublicOrderResult struct { Amount float64 `json:"amount"` PayAmount float64 `json:"pay_amount"` PaymentType string `json:"payment_type"` + OrderType string `json:"order_type"` Status string `json:"status"` } @@ -394,6 +409,7 @@ func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) { Amount: order.Amount, PayAmount: order.PayAmount, PaymentType: order.PaymentType, + OrderType: order.OrderType, Status: order.Status, }) } diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 54a92a8c..1717b7a1 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -61,5 +61,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { BackendModeEnabled: settings.BackendModeEnabled, PaymentEnabled: settings.PaymentEnabled, Version: h.version, + BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled, + AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled, + BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, + BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL, }) } diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 35862f1c..2535ea5e 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -11,13 +11,17 @@ import ( // UserHandler handles user-related requests type UserHandler struct { - userService *service.UserService + userService *service.UserService + emailService *service.EmailService + emailCache service.EmailCache } // NewUserHandler creates a new UserHandler -func NewUserHandler(userService *service.UserService) *UserHandler { +func NewUserHandler(userService *service.UserService, emailService *service.EmailService, emailCache service.EmailCache) *UserHandler { return &UserHandler{ - userService: userService, + userService: userService, + emailService: emailService, + emailCache: emailCache, } } @@ -29,7 +33,9 @@ type ChangePasswordRequest struct { // UpdateProfileRequest represents the update profile request payload type UpdateProfileRequest struct { - Username *string `json:"username"` + Username *string `json:"username"` + BalanceNotifyEnabled *bool `json:"balance_notify_enabled"` + BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"` } // GetProfile handles getting user profile @@ -94,7 +100,9 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { } svcReq := service.UpdateProfileRequest{ - Username: req.Username, + Username: req.Username, + BalanceNotifyEnabled: req.BalanceNotifyEnabled, + BalanceNotifyThreshold: req.BalanceNotifyThreshold, } updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq) if err != nil { @@ -104,3 +112,141 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { response.Success(c, dto.UserFromService(updatedUser)) } + +// SendNotifyEmailCodeRequest represents the request to send notify email verification code +type SendNotifyEmailCodeRequest struct { + Email string `json:"email" binding:"required,email"` +} + +// SendNotifyEmailCode sends verification code to extra notification email +// POST /api/v1/user/notify-email/send-code +func (h *UserHandler) SendNotifyEmailCode(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + var req SendNotifyEmailCodeRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "Verification code sent successfully"}) +} + +// VerifyNotifyEmailRequest represents the request to verify and add notify email +type VerifyNotifyEmailRequest struct { + Email string `json:"email" binding:"required,email"` + Code string `json:"code" binding:"required,len=6"` +} + +// VerifyNotifyEmail verifies code and adds email to notification list +// POST /api/v1/user/notify-email/verify +func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + var req VerifyNotifyEmailRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + err := h.userService.VerifyAndAddNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Code, h.emailCache) + if err != nil { + response.ErrorFrom(c, err) + return + } + + // Return updated user + updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.UserFromService(updatedUser)) +} + +// RemoveNotifyEmailRequest represents the request to remove a notify email +type RemoveNotifyEmailRequest struct { + Email string `json:"email" binding:"required,email"` +} + +// RemoveNotifyEmail removes email from notification list +// DELETE /api/v1/user/notify-email +func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + var req RemoveNotifyEmailRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + err := h.userService.RemoveNotifyEmail(c.Request.Context(), subject.UserID, req.Email) + if err != nil { + response.ErrorFrom(c, err) + return + } + + // Return updated user + updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.UserFromService(updatedUser)) +} + +// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state +type ToggleNotifyEmailRequest struct { + Email string `json:"email" binding:"required,email"` + Disabled bool `json:"disabled"` +} + +// ToggleNotifyEmail toggles the disabled state of a notification email +// PUT /api/v1/user/notify-email/toggle +func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + var req ToggleNotifyEmailRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + err := h.userService.ToggleNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Disabled) + if err != nil { + response.ErrorFrom(c, err) + return + } + + updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, dto.UserFromService(updatedUser)) +} diff --git a/backend/internal/payment/load_balancer.go b/backend/internal/payment/load_balancer.go index afe607e0..f0353173 100644 --- a/backend/internal/payment/load_balancer.go +++ b/backend/internal/payment/load_balancer.go @@ -94,17 +94,21 @@ func (lb *DefaultLoadBalancer) SelectInstance( return lb.buildSelection(selected.inst) } -// queryEnabledInstances returns enabled instances for providerKey that support paymentType. +// queryEnabledInstances returns enabled instances that support paymentType. +// When providerKey is non-empty, only instances with that provider key are considered. +// When providerKey is empty, instances across all providers are considered, +// enabling cross-provider load balancing (e.g. EasyPay + Alipay direct for "alipay"). func (lb *DefaultLoadBalancer) queryEnabledInstances( ctx context.Context, providerKey string, paymentType PaymentType, ) ([]*dbent.PaymentProviderInstance, error) { - instances, err := lb.db.PaymentProviderInstance.Query(). - Where( - paymentproviderinstance.ProviderKey(providerKey), - paymentproviderinstance.Enabled(true), - ). + query := lb.db.PaymentProviderInstance.Query(). + Where(paymentproviderinstance.Enabled(true)) + if providerKey != "" { + query = query.Where(paymentproviderinstance.ProviderKey(providerKey)) + } + instances, err := query. Order(dbent.Asc(paymentproviderinstance.FieldSortOrder)). All(ctx) if err != nil { @@ -113,12 +117,18 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances( var matched []*dbent.PaymentProviderInstance for _, inst := range instances { - if paymentType == providerKey || InstanceSupportsType(inst.SupportedTypes, paymentType) { + // Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay), + // not "stripe" itself. The checkout page aggregates all sub-types under "stripe". + if paymentType == TypeStripe { + if inst.ProviderKey == TypeStripe { + matched = append(matched, inst) + } + } else if InstanceSupportsType(inst.SupportedTypes, paymentType) { matched = append(matched, inst) } } if len(matched) == 0 { - return nil, fmt.Errorf("no enabled instance for provider %s type %s", providerKey, paymentType) + return nil, fmt.Errorf("no enabled instance for payment type %s", paymentType) } return matched, nil } @@ -258,6 +268,7 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns return &InstanceSelection{ InstanceID: fmt.Sprintf("%d", selected.ID), + ProviderKey: selected.ProviderKey, Config: config, SupportedTypes: selected.SupportedTypes, PaymentMode: selected.PaymentMode, diff --git a/backend/internal/payment/load_balancer_test.go b/backend/internal/payment/load_balancer_test.go index 568b56a3..04b3c25b 100644 --- a/backend/internal/payment/load_balancer_test.go +++ b/backend/internal/payment/load_balancer_test.go @@ -242,7 +242,7 @@ func TestFilterByLimits(t *testing.T) { wantIDs: nil, }, { - name: "empty candidates returns empty", + name: "empty candidates returns empty", candidates: nil, paymentType: "alipay", orderAmount: 10, diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go index 3eca0b2c..af8a90c6 100644 --- a/backend/internal/payment/provider/alipay.go +++ b/backend/internal/payment/provider/alipay.go @@ -76,7 +76,7 @@ func (a *Alipay) getClient() (*alipay.Client, error) { func (a *Alipay) Name() string { return "Alipay" } func (a *Alipay) ProviderKey() string { return payment.TypeAlipay } func (a *Alipay) SupportedTypes() []payment.PaymentType { - return []payment.PaymentType{payment.TypeAlipayDirect} + return []payment.PaymentType{payment.TypeAlipay} } // CreatePayment creates an Alipay payment page URL. diff --git a/backend/internal/payment/provider/alipay_test.go b/backend/internal/payment/provider/alipay_test.go index 1b9d66ba..7b0ce0d8 100644 --- a/backend/internal/payment/provider/alipay_test.go +++ b/backend/internal/payment/provider/alipay_test.go @@ -98,9 +98,9 @@ func TestNewAlipay(t *testing.T) { errSubstr: "privateKey", }, { - name: "nil config map returns error for appId", - config: map[string]string{}, - wantErr: true, + name: "nil config map returns error for appId", + config: map[string]string{}, + wantErr: true, errSubstr: "appId", }, } diff --git a/backend/internal/payment/provider/wxpay.go b/backend/internal/payment/provider/wxpay.go index 14e51cd2..0b41c4fb 100644 --- a/backend/internal/payment/provider/wxpay.go +++ b/backend/internal/payment/provider/wxpay.go @@ -72,7 +72,7 @@ func NewWxpay(instanceID string, config map[string]string) (*Wxpay, error) { func (w *Wxpay) Name() string { return "Wxpay" } func (w *Wxpay) ProviderKey() string { return payment.TypeWxpay } func (w *Wxpay) SupportedTypes() []payment.PaymentType { - return []payment.PaymentType{payment.TypeWxpayDirect} + return []payment.PaymentType{payment.TypeWxpay} } func formatPEM(key, keyType string) string { diff --git a/backend/internal/payment/types.go b/backend/internal/payment/types.go index c413d8f3..5d613a4a 100644 --- a/backend/internal/payment/types.go +++ b/backend/internal/payment/types.go @@ -148,6 +148,7 @@ type RefundResponse struct { // InstanceSelection holds the selected provider instance and its decrypted config. type InstanceSelection struct { InstanceID string + ProviderKey string // Provider key of the selected instance (e.g. "alipay", "easypay") Config map[string]string SupportedTypes string // Comma-separated list of supported payment types from the instance PaymentMode string // Payment display mode: "qrcode", "redirect", "popup" diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go index 9c03bc24..8dec839c 100644 --- a/backend/internal/pkg/antigravity/stream_transformer.go +++ b/backend/internal/pkg/antigravity/stream_transformer.go @@ -18,6 +18,9 @@ const ( BlockTypeFunction ) +// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events. +type UsageMapHook func(usageMap map[string]any) + // StreamingProcessor 流式响应处理器 type StreamingProcessor struct { blockType BlockType @@ -30,6 +33,7 @@ type StreamingProcessor struct { originalModel string webSearchQueries []string groundingChunks []GeminiGroundingChunk + usageMapHook UsageMapHook // 累计 usage inputTokens int @@ -46,6 +50,28 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor { } } +// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted. +func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) { + p.usageMapHook = fn +} + +func usageToMap(u ClaudeUsage) map[string]any { + m := map[string]any{ + "input_tokens": u.InputTokens, + "output_tokens": u.OutputTokens, + } + if u.CacheCreationInputTokens > 0 { + m["cache_creation_input_tokens"] = u.CacheCreationInputTokens + } + if u.CacheReadInputTokens > 0 { + m["cache_read_input_tokens"] = u.CacheReadInputTokens + } + if u.ImageOutputTokens > 0 { + m["image_output_tokens"] = u.ImageOutputTokens + } + return m +} + // ProcessLine 处理 SSE 行,返回 Claude SSE 事件 func (p *StreamingProcessor) ProcessLine(line string) []byte { line = strings.TrimSpace(line) @@ -172,6 +198,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte responseID = "msg_" + generateRandomID() } + var usageValue any = usage + if p.usageMapHook != nil { + usageMap := usageToMap(usage) + p.usageMapHook(usageMap) + usageValue = usageMap + } + message := map[string]any{ "id": responseID, "type": "message", @@ -180,7 +213,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte "model": p.originalModel, "stop_reason": nil, "stop_sequence": nil, - "usage": usage, + "usage": usageValue, } event := map[string]any{ @@ -496,13 +529,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte { ImageOutputTokens: p.imageOutputTokens, } + var usageValue any = usage + if p.usageMapHook != nil { + usageMap := usageToMap(usage) + p.usageMapHook(usageMap) + usageValue = usageMap + } + deltaEvent := map[string]any{ "type": "message_delta", "delta": map[string]any{ "stop_reason": stopReason, "stop_sequence": nil, }, - "usage": usage, + "usage": usageValue, } _, _ = result.Write(p.formatSSE("message_delta", deltaEvent)) diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go index dc157a6d..c2725406 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go +++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go @@ -27,13 +27,14 @@ func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest, } out := &ResponsesRequest{ - Model: req.Model, - Input: inputJSON, - Temperature: req.Temperature, - TopP: req.TopP, - Stream: true, // upstream always streams - Include: []string{"reasoning.encrypted_content"}, - ServiceTier: req.ServiceTier, + Model: req.Model, + Instructions: req.Instructions, + Input: inputJSON, + Temperature: req.Temperature, + TopP: req.TopP, + Stream: true, // upstream always streams + Include: []string{"reasoning.encrypted_content"}, + ServiceTier: req.ServiceTier, } storeFalse := false diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go index b383f867..e0d1a53e 100644 --- a/backend/internal/pkg/apicompat/types.go +++ b/backend/internal/pkg/apicompat/types.go @@ -152,6 +152,7 @@ type AnthropicDelta struct { // ResponsesRequest is the request body for POST /v1/responses. type ResponsesRequest struct { Model string `json:"model"` + Instructions string `json:"instructions,omitempty"` Input json.RawMessage `json:"input"` // string or []ResponsesInputItem MaxOutputTokens *int `json:"max_output_tokens,omitempty"` Temperature *float64 `json:"temperature,omitempty"` @@ -337,6 +338,7 @@ type ResponsesStreamEvent struct { type ChatCompletionsRequest struct { Model string `json:"model"` Messages []ChatMessage `json:"messages"` + Instructions string `json:"instructions,omitempty"` // OpenAI Responses API compat MaxTokens *int `json:"max_tokens,omitempty"` MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` Temperature *float64 `json:"temperature,omitempty"` diff --git a/backend/internal/pkg/logger/logger_test.go b/backend/internal/pkg/logger/logger_test.go index 74aae061..06a277a4 100644 --- a/backend/internal/pkg/logger/logger_test.go +++ b/backend/internal/pkg/logger/logger_test.go @@ -10,7 +10,13 @@ import ( ) func TestInit_DualOutput(t *testing.T) { - tmpDir := t.TempDir() + // Use os.MkdirTemp instead of t.TempDir to avoid cleanup failures + // when lumberjack holds file handles on Windows. + tmpDir, err := os.MkdirTemp("", "logger-test-*") + if err != nil { + t.Fatalf("create temp dir: %v", err) + } + t.Cleanup(func() { _ = os.RemoveAll(tmpDir) }) logPath := filepath.Join(tmpDir, "logs", "sub2api.log") origStdout := os.Stdout @@ -57,7 +63,9 @@ func TestInit_DualOutput(t *testing.T) { L().Info("dual-output-info") L().Warn("dual-output-warn") - Sync() + + // Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers). + // The log data is already in the pipe buffer; closing writers is sufficient. _ = stdoutW.Close() _ = stderrW.Close() @@ -166,7 +174,9 @@ func TestInit_CallerShouldPointToCallsite(t *testing.T) { } L().Info("caller-check") - Sync() + // Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers). + os.Stdout = origStdout + os.Stderr = origStderr _ = stdoutW.Close() logBytes, _ := io.ReadAll(stdoutR) diff --git a/backend/internal/pkg/logger/stdlog_bridge_test.go b/backend/internal/pkg/logger/stdlog_bridge_test.go index 4482a2ec..30d25b33 100644 --- a/backend/internal/pkg/logger/stdlog_bridge_test.go +++ b/backend/internal/pkg/logger/stdlog_bridge_test.go @@ -77,7 +77,7 @@ func TestStdLogBridgeRoutesLevels(t *testing.T) { log.Printf("service started") log.Printf("Warning: queue full") log.Printf("Forward request failed: timeout") - Sync() + // Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers). _ = stdoutW.Close() _ = stderrW.Close() @@ -139,7 +139,7 @@ func TestLegacyPrintfRoutesLevels(t *testing.T) { LegacyPrintf("service.test", "request started") LegacyPrintf("service.test", "Warning: queue full") LegacyPrintf("service.test", "forward failed: timeout") - Sync() + // Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers). _ = stdoutW.Close() _ = stderrW.Close() diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 5d1f7911..fe5f98d6 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -56,8 +56,9 @@ type DashboardStats struct { TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"` TotalCacheReadTokens int64 `json:"total_cache_read_tokens"` TotalTokens int64 `json:"total_tokens"` - TotalCost float64 `json:"total_cost"` // 累计标准计费 - TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除 + TotalCost float64 `json:"total_cost"` // 累计标准计费 + TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除 + TotalAccountCost float64 `json:"total_account_cost"` // 累计账号成本 // 今日 Token 使用统计 TodayRequests int64 `json:"today_requests"` @@ -66,8 +67,9 @@ type DashboardStats struct { TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"` TodayCacheReadTokens int64 `json:"today_cache_read_tokens"` TodayTokens int64 `json:"today_tokens"` - TodayCost float64 `json:"today_cost"` // 今日标准计费 - TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除 + TodayCost float64 `json:"today_cost"` // 今日标准计费 + TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除 + TodayAccountCost float64 `json:"today_account_cost"` // 今日账号成本 // 系统运行统计 AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间 @@ -99,8 +101,9 @@ type ModelStat struct { CacheCreationTokens int64 `json:"cache_creation_tokens"` CacheReadTokens int64 `json:"cache_read_tokens"` TotalTokens int64 `json:"total_tokens"` - Cost float64 `json:"cost"` // 标准计费 - ActualCost float64 `json:"actual_cost"` // 实际扣除 + Cost float64 `json:"cost"` // 标准计费 + ActualCost float64 `json:"actual_cost"` // 实际扣除 + AccountCost float64 `json:"account_cost"` // 账号成本 } // EndpointStat represents usage statistics for a single request endpoint. @@ -125,8 +128,9 @@ type GroupStat struct { GroupName string `json:"group_name"` Requests int64 `json:"requests"` TotalTokens int64 `json:"total_tokens"` - Cost float64 `json:"cost"` // 标准计费 - ActualCost float64 `json:"actual_cost"` // 实际扣除 + Cost float64 `json:"cost"` // 标准计费 + ActualCost float64 `json:"actual_cost"` // 实际扣除 + AccountCost float64 `json:"account_cost"` // 账号成本 } // UserUsageTrendPoint represents user usage trend data point @@ -164,8 +168,9 @@ type UserBreakdownItem struct { Email string `json:"email"` Requests int64 `json:"requests"` TotalTokens int64 `json:"total_tokens"` - Cost float64 `json:"cost"` // 标准计费 - ActualCost float64 `json:"actual_cost"` // 实际扣除 + Cost float64 `json:"cost"` // 标准计费 + ActualCost float64 `json:"actual_cost"` // 实际扣除 + AccountCost float64 `json:"account_cost"` // 账号成本 } // UserBreakdownDimension specifies the dimension to filter for user breakdown. diff --git a/backend/internal/pkg/websearch/brave.go b/backend/internal/pkg/websearch/brave.go new file mode 100644 index 00000000..707e7029 --- /dev/null +++ b/backend/internal/pkg/websearch/brave.go @@ -0,0 +1,106 @@ +package websearch + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" +) + +const ( + braveSearchEndpoint = "https://api.search.brave.com/res/v1/web/search" + braveMaxCount = 20 + braveProviderName = "brave" +) + +// braveSearchURL is pre-parsed at init time; url.Parse cannot fail on a constant literal. +var braveSearchURL, _ = url.Parse(braveSearchEndpoint) //nolint:errcheck + +// BraveProvider implements web search via the Brave Search API. +type BraveProvider struct { + apiKey string + httpClient *http.Client +} + +// NewBraveProvider creates a Brave Search provider. +// The caller is responsible for configuring the http.Client with proxy/timeouts. +func NewBraveProvider(apiKey string, httpClient *http.Client) *BraveProvider { + if httpClient == nil { + httpClient = http.DefaultClient + } + return &BraveProvider{apiKey: apiKey, httpClient: httpClient} +} + +func (b *BraveProvider) Name() string { return braveProviderName } + +func (b *BraveProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) { + count := req.MaxResults + if count <= 0 { + count = defaultMaxResults + } + if count > braveMaxCount { + count = braveMaxCount + } + + u := *braveSearchURL // copy the pre-parsed URL + q := u.Query() + q.Set("q", req.Query) + q.Set("count", strconv.Itoa(count)) + u.RawQuery = q.Encode() + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return nil, fmt.Errorf("brave: build request: %w", err) + } + httpReq.Header.Set("X-Subscription-Token", b.apiKey) + httpReq.Header.Set("Accept", "application/json") + + resp, err := b.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("brave: request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize)) + if err != nil { + return nil, fmt.Errorf("brave: read body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("brave: status %d: %s", resp.StatusCode, truncateBody(body)) + } + + var raw braveResponse + if err := json.Unmarshal(body, &raw); err != nil { + return nil, fmt.Errorf("brave: decode response: %w", err) + } + + results := make([]SearchResult, 0, len(raw.Web.Results)) + for _, r := range raw.Web.Results { + results = append(results, SearchResult{ + URL: r.URL, + Title: r.Title, + Snippet: r.Description, + PageAge: r.Age, + }) + } + + return &SearchResponse{Results: results, Query: req.Query}, nil +} + +// braveResponse is the minimal structure of the Brave Search API response. +type braveResponse struct { + Web struct { + Results []braveResult `json:"results"` + } `json:"web"` +} + +type braveResult struct { + URL string `json:"url"` + Title string `json:"title"` + Description string `json:"description"` + Age string `json:"age"` +} diff --git a/backend/internal/pkg/websearch/brave_test.go b/backend/internal/pkg/websearch/brave_test.go new file mode 100644 index 00000000..4dc5b219 --- /dev/null +++ b/backend/internal/pkg/websearch/brave_test.go @@ -0,0 +1,119 @@ +package websearch + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBraveProvider_Name(t *testing.T) { + p := NewBraveProvider("key", nil) + require.Equal(t, "brave", p.Name()) +} + +func TestBraveProvider_Search_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "test-key", r.Header.Get("X-Subscription-Token")) + require.Equal(t, "application/json", r.Header.Get("Accept")) + require.Equal(t, "golang", r.URL.Query().Get("q")) + require.Equal(t, "3", r.URL.Query().Get("count")) + + resp := braveResponse{} + resp.Web.Results = []braveResult{ + {URL: "https://go.dev", Title: "Go", Description: "Go lang", Age: "1 day"}, + {URL: "https://pkg.go.dev", Title: "Pkg", Description: "Packages"}, + {URL: "https://tour.go.dev", Title: "Tour", Description: "A Tour of Go", Age: "3 days"}, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + p := NewBraveProvider("test-key", srv.Client()) + // Override the endpoint for testing + origURL := *braveSearchURL + u, _ := http.NewRequest("GET", srv.URL, nil) + *braveSearchURL = *u.URL + defer func() { *braveSearchURL = origURL }() + + resp, err := p.Search(context.Background(), SearchRequest{Query: "golang", MaxResults: 3}) + require.NoError(t, err) + require.Len(t, resp.Results, 3) + require.Equal(t, "https://go.dev", resp.Results[0].URL) + require.Equal(t, "Go lang", resp.Results[0].Snippet) + require.Equal(t, "1 day", resp.Results[0].PageAge) +} + +func TestBraveProvider_Search_DefaultMaxResults(t *testing.T) { + var receivedCount string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedCount = r.URL.Query().Get("count") + resp := braveResponse{} + _ = json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + p := NewBraveProvider("key", srv.Client()) + origURL := *braveSearchURL + u, _ := http.NewRequest("GET", srv.URL, nil) + *braveSearchURL = *u.URL + defer func() { *braveSearchURL = origURL }() + + _, _ = p.Search(context.Background(), SearchRequest{Query: "test", MaxResults: 0}) + require.Equal(t, "5", receivedCount) +} + +func TestBraveProvider_Search_HTTPError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(429) + _, _ = w.Write([]byte("rate limited")) + })) + defer srv.Close() + + p := NewBraveProvider("key", srv.Client()) + origURL := *braveSearchURL + u, _ := http.NewRequest("GET", srv.URL, nil) + *braveSearchURL = *u.URL + defer func() { *braveSearchURL = origURL }() + + _, err := p.Search(context.Background(), SearchRequest{Query: "test"}) + require.ErrorContains(t, err, "brave: status 429") +} + +func TestBraveProvider_Search_InvalidJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("not json")) + })) + defer srv.Close() + + p := NewBraveProvider("key", srv.Client()) + origURL := *braveSearchURL + u, _ := http.NewRequest("GET", srv.URL, nil) + *braveSearchURL = *u.URL + defer func() { *braveSearchURL = origURL }() + + _, err := p.Search(context.Background(), SearchRequest{Query: "test"}) + require.ErrorContains(t, err, "brave: decode response") +} + +func TestBraveProvider_Search_EmptyResults(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + resp := braveResponse{} + _ = json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + p := NewBraveProvider("key", srv.Client()) + origURL := *braveSearchURL + u, _ := http.NewRequest("GET", srv.URL, nil) + *braveSearchURL = *u.URL + defer func() { *braveSearchURL = origURL }() + + resp, err := p.Search(context.Background(), SearchRequest{Query: "test"}) + require.NoError(t, err) + require.Empty(t, resp.Results) +} diff --git a/backend/internal/pkg/websearch/helpers.go b/backend/internal/pkg/websearch/helpers.go new file mode 100644 index 00000000..0d08b749 --- /dev/null +++ b/backend/internal/pkg/websearch/helpers.go @@ -0,0 +1,14 @@ +package websearch + +const ( + maxResponseSize = 1 << 20 // 1 MB + errorBodyTruncLen = 200 +) + +// truncateBody returns a truncated string of body for error messages. +func truncateBody(body []byte) string { + if len(body) <= errorBodyTruncLen { + return string(body) + } + return string(body[:errorBodyTruncLen]) + "...(truncated)" +} diff --git a/backend/internal/pkg/websearch/helpers_test.go b/backend/internal/pkg/websearch/helpers_test.go new file mode 100644 index 00000000..e3164329 --- /dev/null +++ b/backend/internal/pkg/websearch/helpers_test.go @@ -0,0 +1,25 @@ +package websearch + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTruncateBody_Short(t *testing.T) { + body := []byte("short body") + require.Equal(t, "short body", truncateBody(body)) +} + +func TestTruncateBody_Long(t *testing.T) { + body := []byte(strings.Repeat("x", 500)) + result := truncateBody(body) + require.Len(t, result, errorBodyTruncLen+len("...(truncated)")) + require.True(t, strings.HasSuffix(result, "...(truncated)")) +} + +func TestTruncateBody_ExactBoundary(t *testing.T) { + body := []byte(strings.Repeat("x", errorBodyTruncLen)) + require.Equal(t, string(body), truncateBody(body)) +} diff --git a/backend/internal/pkg/websearch/manager.go b/backend/internal/pkg/websearch/manager.go new file mode 100644 index 00000000..307aa1e9 --- /dev/null +++ b/backend/internal/pkg/websearch/manager.go @@ -0,0 +1,528 @@ +package websearch + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "log/slog" + "math/rand" + "net" + "net/http" + "net/url" + "sort" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" + "github.com/redis/go-redis/v9" +) + +// ProviderConfig holds the configuration for a single search provider. +type ProviderConfig struct { + Type string `json:"type"` // ProviderTypeBrave | ProviderTypeTavily + APIKey string `json:"api_key"` // secret + QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited + SubscribedAt *int64 `json:"subscribed_at,omitempty"` // subscription start (unix seconds); quota resets monthly from this date + ProxyURL string `json:"-"` // resolved proxy URL (not persisted) + ProxyID int64 `json:"-"` // resolved proxy ID for unavailability tracking + ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds) +} + +// Manager selects providers by quota-weighted load balancing and tracks quota via Redis. +type Manager struct { + configs []ProviderConfig + redis *redis.Client + + clientMu sync.Mutex + clientCache map[string]*http.Client +} + +// Timeout constants for proxy and search operations. +const ( + proxyDialTimeout = 3 * time.Second // proxy TCP connection timeout + proxyTLSTimeout = 3 * time.Second // TLS handshake timeout + searchDataTimeout = 60 * time.Second // response data transfer timeout + searchRequestTimeout = searchDataTimeout + proxyDialTimeout + + quotaKeyPrefix = "websearch:quota:" + proxyUnavailableKey = "websearch:proxy_unavailable:%d" + proxyUnavailableTTL = 5 * time.Minute + quotaTTLBuffer = 24 * time.Hour + defaultQuotaTTL = 31*24*time.Hour + quotaTTLBuffer // fallback when no subscription date + maxCachedClients = 100 +) + +// ErrProxyUnavailable indicates the search failed due to a proxy connectivity issue. +// Callers may use this to trigger account switching instead of direct fallback. +var ErrProxyUnavailable = errors.New("websearch: proxy unavailable") + +// quotaIncrScript atomically increments the counter and sets TTL on first creation. +var quotaIncrScript = redis.NewScript(` +local val = redis.call('INCR', KEYS[1]) +if val == 1 then + redis.call('EXPIRE', KEYS[1], ARGV[1]) +else + local ttl = redis.call('TTL', KEYS[1]) + if ttl == -1 then + redis.call('EXPIRE', KEYS[1], ARGV[1]) + end +end +return val +`) + +// NewManager creates a Manager with the given provider configs and Redis client. +// Provider order is preserved as-is; selectByQuotaWeight handles load balancing. +func NewManager(configs []ProviderConfig, redisClient *redis.Client) *Manager { + copied := make([]ProviderConfig, len(configs)) + copy(copied, configs) + return &Manager{ + configs: copied, + redis: redisClient, + clientCache: make(map[string]*http.Client), + } +} + +// SearchWithBestProvider selects a provider using quota-weighted load balancing, +// reserves quota, executes the search, and rolls back quota on failure. +// If the search fails due to a proxy error, the proxy is marked unavailable for 5 minutes. +func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) { + if strings.TrimSpace(req.Query) == "" { + return nil, "", fmt.Errorf("websearch: empty search query") + } + + candidates := m.filterAvailableProviders(ctx, req.ProxyURL) + if len(candidates) == 0 { + return nil, "", fmt.Errorf("websearch: no available provider (all exhausted, expired, or proxy unavailable)") + } + + selected := m.selectByQuotaWeight(ctx, candidates) + + for _, cfg := range selected { + allowed, incremented := m.tryReserveQuota(ctx, cfg) + if !allowed { + continue + } + resp, err := m.executeSearch(ctx, cfg, req) + if err != nil { + if incremented { + m.rollbackQuota(ctx, cfg) + } + if isProxyError(err) { + m.markProxyUnavailable(ctx, cfg, req.ProxyURL) + if req.ProxyURL != "" { + // Account-level proxy is shared by all providers — no point + // trying others with the same broken proxy; signal account switch. + slog.Warn("websearch: account proxy error, aborting failover", + "provider", cfg.Type, "error", err) + return nil, "", fmt.Errorf("%w: %s", ErrProxyUnavailable, err.Error()) + } + // Provider-specific proxy failed — try the next provider which + // may use a different (or no) proxy. + slog.Warn("websearch: provider proxy error, trying next provider", + "provider", cfg.Type, "error", err) + continue + } + slog.Warn("websearch: provider search failed", + "provider", cfg.Type, "error", err) + continue + } + return resp, cfg.Type, nil + } + return nil, "", fmt.Errorf("websearch: no available provider (all exhausted or failed)") +} + +// filterAvailableProviders returns providers that have API keys, are not expired, +// and whose proxies are not marked unavailable. +func (m *Manager) filterAvailableProviders(ctx context.Context, accountProxyURL string) []ProviderConfig { + var out []ProviderConfig + for _, cfg := range m.configs { + if !m.isProviderAvailable(cfg) { + continue + } + proxyID := resolveProxyID(cfg, accountProxyURL) + if proxyID > 0 && !m.isProxyAvailable(ctx, proxyID) { + slog.Debug("websearch: proxy marked unavailable, skipping", + "provider", cfg.Type, "proxy_id", proxyID) + continue + } + out = append(out, cfg) + } + return out +} + +// weighted is a provider candidate with computed quota weight. +type weighted struct { + cfg ProviderConfig + weight int64 +} + +// selectByQuotaWeight orders candidates by remaining quota weight. +// Providers with quota_limit=0 (no limit set) get weight 0 and are placed last. +// Among providers with quota, higher remaining quota = higher priority. +func (m *Manager) selectByQuotaWeight(ctx context.Context, candidates []ProviderConfig) []ProviderConfig { + items := m.computeWeights(ctx, candidates) + withQuota, withoutQuota := partitionByQuota(items) + sortByStableRandomWeight(withQuota) + return mergeWeightedResults(withQuota, withoutQuota, len(candidates)) +} + +func (m *Manager) computeWeights(ctx context.Context, candidates []ProviderConfig) []weighted { + items := make([]weighted, 0, len(candidates)) + for _, cfg := range candidates { + w := int64(0) + if cfg.QuotaLimit > 0 { + used, _ := m.GetUsage(ctx, cfg.Type) + if remaining := cfg.QuotaLimit - used; remaining > 0 { + w = remaining + } + } + items = append(items, weighted{cfg: cfg, weight: w}) + } + return items +} + +func partitionByQuota(items []weighted) (withQuota, withoutQuota []weighted) { + for _, item := range items { + if item.weight > 0 { + withQuota = append(withQuota, item) + } else { + withoutQuota = append(withoutQuota, item) + } + } + return +} + +// sortByStableRandomWeight assigns a fixed random factor to each item before sorting, +// ensuring deterministic sort behavior (transitivity) within a single call. +func sortByStableRandomWeight(items []weighted) { + if len(items) <= 1 { + return + } + type entry struct { + item weighted + factor float64 + } + entries := make([]entry, len(items)) + for i, item := range items { + entries[i] = entry{item: item, factor: float64(item.weight) * (0.5 + rand.Float64())} + } + sort.Slice(entries, func(i, j int) bool { + return entries[i].factor > entries[j].factor + }) + for i, e := range entries { + items[i] = e.item + } +} + +func mergeWeightedResults(withQuota, withoutQuota []weighted, capacity int) []ProviderConfig { + result := make([]ProviderConfig, 0, capacity) + for _, item := range withQuota { + result = append(result, item.cfg) + } + for _, item := range withoutQuota { + result = append(result, item.cfg) + } + return result +} + +func (m *Manager) isProviderAvailable(cfg ProviderConfig) bool { + if cfg.APIKey == "" { + return false + } + if cfg.ExpiresAt != nil && time.Now().Unix() > *cfg.ExpiresAt { + slog.Info("websearch: provider expired, skipping", + "provider", cfg.Type, "expires_at", *cfg.ExpiresAt) + return false + } + return true +} + +// --- Proxy availability tracking --- + +// markProxyUnavailable marks the effective proxy as unavailable for proxyUnavailableTTL. +func (m *Manager) markProxyUnavailable(ctx context.Context, cfg ProviderConfig, accountProxyURL string) { + proxyID := resolveProxyID(cfg, accountProxyURL) + if proxyID <= 0 || m.redis == nil { + return + } + key := fmt.Sprintf(proxyUnavailableKey, proxyID) + if err := m.redis.Set(ctx, key, "1", proxyUnavailableTTL).Err(); err != nil { + slog.Warn("websearch: failed to mark proxy unavailable", + "proxy_id", proxyID, "error", err) + } +} + +// isProxyAvailable checks whether a proxy is currently marked as unavailable. +func (m *Manager) isProxyAvailable(ctx context.Context, proxyID int64) bool { + if m.redis == nil || proxyID <= 0 { + return true + } + key := fmt.Sprintf(proxyUnavailableKey, proxyID) + val, err := m.redis.Get(ctx, key).Result() + if err != nil { + return true // Redis error → assume available + } + return val == "" +} + +// resolveProxyID determines the effective proxy ID for a provider+account combination. +func resolveProxyID(cfg ProviderConfig, accountProxyURL string) int64 { + if accountProxyURL != "" { + return 0 // account proxy has no ID in provider config + } + return cfg.ProxyID +} + +// isProxyError checks whether the error is likely caused by proxy or network connectivity +// (as opposed to an API-level error from the search provider). +func isProxyError(err error) bool { + if err == nil { + return false + } + // Network-level errors (timeout, connection refused, DNS failure) + var netErr net.Error + if errors.As(err, &netErr) { + return true + } + var opErr *net.OpError + if errors.As(err, &opErr) { + return true + } + // TLS handshake failures (often caused by proxy intercepting/blocking) + var tlsErr *tls.RecordHeaderError + if errors.As(err, &tlsErr) { + return true + } + // String-based detection for wrapped errors + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "proxy") || + strings.Contains(msg, "socks") || + strings.Contains(msg, "connection refused") || + strings.Contains(msg, "no such host") || + strings.Contains(msg, "i/o timeout") || + strings.Contains(msg, "tls handshake") || + strings.Contains(msg, "certificate") +} + +// --- Quota management --- + +func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool, bool) { + if cfg.QuotaLimit <= 0 { + return true, false + } + if m.redis == nil { + slog.Warn("websearch: Redis unavailable, quota check skipped", "provider", cfg.Type) + return true, false + } + key := quotaRedisKey(cfg.Type) + ttlSec := int(quotaTTLFromSubscription(cfg.SubscribedAt).Seconds()) + newVal, err := quotaIncrScript.Run(ctx, m.redis, []string{key}, ttlSec).Int64() + if err != nil { + slog.Warn("websearch: quota Lua INCR failed, allowing request", + "provider", cfg.Type, "error", err) + return true, false + } + if newVal > cfg.QuotaLimit { + if decrErr := m.redis.Decr(ctx, key).Err(); decrErr != nil { + slog.Warn("websearch: quota over-limit DECR failed", + "provider", cfg.Type, "error", decrErr) + } + slog.Info("websearch: provider quota exhausted", + "provider", cfg.Type, "used", newVal, "limit", cfg.QuotaLimit) + return false, false + } + return true, true +} + +func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) { + if cfg.QuotaLimit <= 0 || m.redis == nil { + return + } + key := quotaRedisKey(cfg.Type) + if err := m.redis.Decr(ctx, key).Err(); err != nil { + slog.Warn("websearch: quota rollback DECR failed", + "provider", cfg.Type, "error", err) + } +} + +// --- Search execution --- + +// TestSearch executes a search using the first available provider without reserving quota. +// Intended for admin test functionality only. +func (m *Manager) TestSearch(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) { + if strings.TrimSpace(req.Query) == "" { + return nil, "", fmt.Errorf("websearch: empty search query") + } + for _, cfg := range m.configs { + if !m.isProviderAvailable(cfg) { + continue + } + resp, err := m.executeSearch(ctx, cfg, req) + if err != nil { + continue + } + return resp, cfg.Type, nil + } + return nil, "", fmt.Errorf("websearch: no available provider") +} + +func (m *Manager) executeSearch(ctx context.Context, cfg ProviderConfig, req SearchRequest) (*SearchResponse, error) { + proxyURL := cfg.ProxyURL + if req.ProxyURL != "" { + proxyURL = req.ProxyURL + } + client, err := m.getOrCreateHTTPClient(proxyURL) + if err != nil { + return nil, fmt.Errorf("websearch: %w", err) + } + provider := m.buildProvider(cfg, client) + return provider.Search(ctx, req) +} + +// --- HTTP client cache --- + +func (m *Manager) getOrCreateHTTPClient(proxyURL string) (*http.Client, error) { + m.clientMu.Lock() + defer m.clientMu.Unlock() + + if c, ok := m.clientCache[proxyURL]; ok { + return c, nil + } + if len(m.clientCache) >= maxCachedClients { + m.clientCache = make(map[string]*http.Client) + } + c, err := newHTTPClient(proxyURL) + if err != nil { + return nil, err + } + m.clientCache[proxyURL] = c + return c, nil +} + +// newHTTPClient creates an HTTP client with proper timeout settings. +// Uses proxyutil.ConfigureTransportProxy for unified proxy protocol support +// (HTTP/HTTPS/SOCKS5/SOCKS5H). +// Returns error if proxyURL is invalid — never falls back to direct connection. +func newHTTPClient(proxyURL string) (*http.Client, error) { + transport := &http.Transport{ + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + DialContext: (&net.Dialer{Timeout: proxyDialTimeout}).DialContext, + TLSHandshakeTimeout: proxyTLSTimeout, + ResponseHeaderTimeout: searchDataTimeout, + } + if proxyURL != "" { + parsed, err := url.Parse(proxyURL) + if err != nil { + return nil, fmt.Errorf("invalid proxy URL %q: %w", proxyURL, err) + } + if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { + return nil, fmt.Errorf("configure proxy: %w", err) + } + } + return &http.Client{Transport: transport, Timeout: searchRequestTimeout}, nil +} + +// GetUsage returns the current usage count for the given provider. +func (m *Manager) GetUsage(ctx context.Context, providerType string) (int64, error) { + if m.redis == nil { + return 0, nil + } + key := quotaRedisKey(providerType) + val, err := m.redis.Get(ctx, key).Int64() + if err == redis.Nil { + return 0, nil + } + return val, err +} + +// GetAllUsage returns usage for every configured provider. +func (m *Manager) GetAllUsage(ctx context.Context) map[string]int64 { + result := make(map[string]int64, len(m.configs)) + for _, cfg := range m.configs { + used, _ := m.GetUsage(ctx, cfg.Type) + result[cfg.Type] = used + } + return result +} + +// ResetUsage deletes the Redis quota key for the given provider, resetting usage to 0. +func (m *Manager) ResetUsage(ctx context.Context, providerType string) error { + if m.redis == nil { + return nil + } + key := quotaRedisKey(providerType) + return m.redis.Del(ctx, key).Err() +} + +// --- Provider factory --- + +func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provider { + switch cfg.Type { + case braveProviderName: + return NewBraveProvider(cfg.APIKey, client) + case tavilyProviderName: + return NewTavilyProvider(cfg.APIKey, client) + default: + slog.Warn("websearch: unknown provider type, falling back to brave", + "type", cfg.Type) + return NewBraveProvider(cfg.APIKey, client) + } +} + +// --- Redis key helpers --- + +func quotaRedisKey(providerType string) string { + return quotaKeyPrefix + providerType +} + +// quotaTTLFromSubscription calculates the TTL for the quota counter based on +// the provider's subscription start date. Quota resets monthly from that date. +// When the Redis key expires naturally, the next INCR creates a fresh counter (lazy refresh). +func quotaTTLFromSubscription(subscribedAt *int64) time.Duration { + if subscribedAt == nil || *subscribedAt == 0 { + return defaultQuotaTTL + } + next := nextMonthlyReset(time.Unix(*subscribedAt, 0).UTC()) + ttl := time.Until(next) + quotaTTLBuffer + if ttl <= quotaTTLBuffer { + // Already past the reset — next cycle + ttl = defaultQuotaTTL + } + return ttl +} + +// nextMonthlyReset returns the next monthly reset time based on the subscription start date. +// E.g., subscribed on Jan 15 → resets on Feb 15, Mar 15, etc. +// Handles day-of-month overflow: Jan 31 → Feb 28 (not Mar 3). +func nextMonthlyReset(subscribedAt time.Time) time.Time { + now := time.Now().UTC() + if subscribedAt.IsZero() { + return now.AddDate(0, 1, 0) + } + months := (now.Year()-subscribedAt.Year())*12 + int(now.Month()-subscribedAt.Month()) + if months < 0 { + months = 0 + } + candidate := addMonthsClamped(subscribedAt, months) + if candidate.After(now) { + return candidate + } + return addMonthsClamped(subscribedAt, months+1) +} + +// addMonthsClamped adds N months to a date, clamping the day to the last day of the target month. +// E.g., Jan 31 + 1 month = Feb 28 (not Mar 3). +func addMonthsClamped(t time.Time, months int) time.Time { + y, m, d := t.Date() + targetMonth := time.Month(int(m) + months) + targetYear := y + int(targetMonth-1)/12 + targetMonth = (targetMonth-1)%12 + 1 + // Last day of the target month + lastDay := time.Date(targetYear, targetMonth+1, 0, 0, 0, 0, 0, time.UTC).Day() + if d > lastDay { + d = lastDay + } + return time.Date(targetYear, targetMonth, d, 0, 0, 0, 0, time.UTC) +} diff --git a/backend/internal/pkg/websearch/manager_test.go b/backend/internal/pkg/websearch/manager_test.go new file mode 100644 index 00000000..a4413417 --- /dev/null +++ b/backend/internal/pkg/websearch/manager_test.go @@ -0,0 +1,323 @@ +package websearch + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNewManager_PreservesOrder(t *testing.T) { + configs := []ProviderConfig{ + {Type: "brave", APIKey: "k3"}, + {Type: "tavily", APIKey: "k1"}, + } + m := NewManager(configs, nil) + require.Equal(t, "brave", m.configs[0].Type) + require.Equal(t, "tavily", m.configs[1].Type) +} + +func TestManager_SearchWithBestProvider_EmptyQuery(t *testing.T) { + m := NewManager([]ProviderConfig{{Type: "brave", APIKey: "k"}}, nil) + _, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: ""}) + require.ErrorContains(t, err, "empty search query") + + _, _, err = m.SearchWithBestProvider(context.Background(), SearchRequest{Query: " "}) + require.ErrorContains(t, err, "empty search query") +} + +func TestManager_SearchWithBestProvider_SkipEmptyAPIKey(t *testing.T) { + m := NewManager([]ProviderConfig{{Type: "brave", APIKey: ""}}, nil) + _, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"}) + require.ErrorContains(t, err, "no available provider") +} + +func TestManager_SearchWithBestProvider_SkipExpired(t *testing.T) { + past := time.Now().Add(-1 * time.Hour).Unix() + m := NewManager([]ProviderConfig{ + {Type: "brave", APIKey: "k", ExpiresAt: &past}, + }, nil) + _, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"}) + require.ErrorContains(t, err, "no available provider") +} + +func TestManager_SearchWithBestProvider_UsesFirstAvailable(t *testing.T) { + srvBrave := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + resp := braveResponse{} + resp.Web.Results = []braveResult{{URL: "https://brave.com", Title: "Brave", Description: "from brave"}} + _ = json.NewEncoder(w).Encode(resp) + })) + defer srvBrave.Close() + + origURL := *braveSearchURL + u, _ := http.NewRequest("GET", srvBrave.URL, nil) + *braveSearchURL = *u.URL + defer func() { *braveSearchURL = origURL }() + + m := NewManager([]ProviderConfig{ + {Type: "brave", APIKey: "k1"}, + {Type: "tavily", APIKey: "k2"}, + }, nil) + m.clientCache[srvBrave.URL] = srvBrave.Client() + m.clientCache[""] = srvBrave.Client() + + resp, providerName, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"}) + require.NoError(t, err) + require.Equal(t, "brave", providerName) + require.Len(t, resp.Results, 1) + require.Equal(t, "from brave", resp.Results[0].Snippet) +} + +func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + resp := braveResponse{} + resp.Web.Results = []braveResult{{URL: "https://test.com", Title: "Test", Description: "result"}} + _ = json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + origURL := *braveSearchURL + u, _ := http.NewRequest("GET", srv.URL, nil) + *braveSearchURL = *u.URL + defer func() { *braveSearchURL = origURL }() + + m := NewManager([]ProviderConfig{ + {Type: "brave", APIKey: "k", QuotaLimit: 100}, + }, nil) + m.clientCache[""] = srv.Client() + + resp, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"}) + require.NoError(t, err) + require.Len(t, resp.Results, 1) +} + +func TestManager_GetUsage_NilRedis(t *testing.T) { + m := NewManager(nil, nil) + used, err := m.GetUsage(context.Background(), "brave") + require.NoError(t, err) + require.Equal(t, int64(0), used) +} + +func TestManager_GetAllUsage_NilRedis(t *testing.T) { + m := NewManager([]ProviderConfig{ + {Type: "brave"}, + }, nil) + usage := m.GetAllUsage(context.Background()) + require.Equal(t, int64(0), usage["brave"]) +} + +// --- Quota TTL from subscription --- + +func TestQuotaTTLFromSubscription_NilSubscription(t *testing.T) { + ttl := quotaTTLFromSubscription(nil) + require.Equal(t, defaultQuotaTTL, ttl) +} + +func TestQuotaTTLFromSubscription_ZeroSubscription(t *testing.T) { + zero := int64(0) + ttl := quotaTTLFromSubscription(&zero) + require.Equal(t, defaultQuotaTTL, ttl) +} + +func TestQuotaTTLFromSubscription_ValidSubscription(t *testing.T) { + // Subscribed 10 days ago — next reset in ~20 days + sub := time.Now().Add(-10 * 24 * time.Hour).Unix() + ttl := quotaTTLFromSubscription(&sub) + require.Greater(t, ttl, 15*24*time.Hour) // at least 15 days + require.Less(t, ttl, 25*24*time.Hour+quotaTTLBuffer) +} + +func TestNextMonthlyReset_SubscribedRecentPast(t *testing.T) { + // Subscribed on the 10th of this month (always valid day) + now := time.Now().UTC() + sub := time.Date(now.Year(), now.Month(), 10, 0, 0, 0, 0, time.UTC) + next := nextMonthlyReset(sub) + require.True(t, next.After(now) || next.Equal(now), "next reset should be in the future or now") + require.True(t, next.Before(now.AddDate(0, 1, 1))) +} + +func TestNextMonthlyReset_SubscribedLongAgo(t *testing.T) { + // Subscribed 6 months ago on the 1st + sub := time.Now().UTC().AddDate(0, -6, 0) + sub = time.Date(sub.Year(), sub.Month(), 1, 0, 0, 0, 0, time.UTC) + next := nextMonthlyReset(sub) + require.True(t, next.After(time.Now().UTC())) + // Should be within the next 31 days + require.True(t, next.Before(time.Now().UTC().AddDate(0, 1, 1))) +} + +func TestNextMonthlyReset_FutureSubscription(t *testing.T) { + sub := time.Now().UTC().AddDate(0, 0, 5) + next := nextMonthlyReset(sub) + require.True(t, next.After(time.Now().UTC())) +} + +func TestAddMonthsClamped_Jan31ToFeb(t *testing.T) { + sub := time.Date(2026, 1, 31, 0, 0, 0, 0, time.UTC) + next := addMonthsClamped(sub, 1) + require.Equal(t, time.Month(2), next.Month()) + require.Equal(t, 28, next.Day()) // Feb 28 (2026 is not a leap year) +} + +func TestAddMonthsClamped_Jan31ToFebLeapYear(t *testing.T) { + sub := time.Date(2028, 1, 31, 0, 0, 0, 0, time.UTC) + next := addMonthsClamped(sub, 1) + require.Equal(t, time.Month(2), next.Month()) + require.Equal(t, 29, next.Day()) // Feb 29 (2028 is a leap year) +} + +func TestAddMonthsClamped_Mar31ToApr(t *testing.T) { + sub := time.Date(2026, 3, 31, 0, 0, 0, 0, time.UTC) + next := addMonthsClamped(sub, 1) + require.Equal(t, time.Month(4), next.Month()) + require.Equal(t, 30, next.Day()) // Apr has 30 days +} + +func TestAddMonthsClamped_NormalDay(t *testing.T) { + sub := time.Date(2026, 1, 15, 0, 0, 0, 0, time.UTC) + next := addMonthsClamped(sub, 1) + require.Equal(t, time.Month(2), next.Month()) + require.Equal(t, 15, next.Day()) // no clamping needed +} + +// --- Redis key --- + +func TestQuotaRedisKey_Format(t *testing.T) { + key := quotaRedisKey("brave") + require.Equal(t, "websearch:quota:brave", key) +} + +// --- isProviderAvailable --- + +func TestIsProviderAvailable_EmptyAPIKey(t *testing.T) { + m := NewManager(nil, nil) + require.False(t, m.isProviderAvailable(ProviderConfig{APIKey: ""})) +} + +func TestIsProviderAvailable_Expired(t *testing.T) { + m := NewManager(nil, nil) + past := time.Now().Add(-1 * time.Hour).Unix() + require.False(t, m.isProviderAvailable(ProviderConfig{APIKey: "k", ExpiresAt: &past})) +} + +func TestIsProviderAvailable_Valid(t *testing.T) { + m := NewManager(nil, nil) + future := time.Now().Add(1 * time.Hour).Unix() + require.True(t, m.isProviderAvailable(ProviderConfig{APIKey: "k", ExpiresAt: &future})) + require.True(t, m.isProviderAvailable(ProviderConfig{APIKey: "k"})) // no expiry +} + +// --- resolveProxyID --- + +func TestResolveProxyID_AccountProxyOverrides(t *testing.T) { + cfg := ProviderConfig{ProxyID: 42} + require.Equal(t, int64(0), resolveProxyID(cfg, "http://account-proxy:8080")) + require.Equal(t, int64(42), resolveProxyID(cfg, "")) +} + +// --- isProxyError --- + +func TestIsProxyError_Nil(t *testing.T) { + require.False(t, isProxyError(nil)) +} + +func TestIsProxyError_ConnectionRefused(t *testing.T) { + require.True(t, isProxyError(fmt.Errorf("dial tcp: connection refused"))) +} + +func TestIsProxyError_Timeout(t *testing.T) { + require.True(t, isProxyError(fmt.Errorf("i/o timeout while connecting to proxy"))) +} + +func TestIsProxyError_SOCKS(t *testing.T) { + require.True(t, isProxyError(fmt.Errorf("socks connect failed"))) +} + +func TestIsProxyError_TLSHandshake(t *testing.T) { + require.True(t, isProxyError(fmt.Errorf("tls handshake timeout"))) +} + +func TestIsProxyError_APIError_NotProxy(t *testing.T) { + require.False(t, isProxyError(fmt.Errorf("API rate limit exceeded"))) +} + +// --- isProxyAvailable (nil Redis) --- + +func TestIsProxyAvailable_NilRedis(t *testing.T) { + m := NewManager(nil, nil) + require.True(t, m.isProxyAvailable(context.Background(), 42)) +} + +func TestIsProxyAvailable_ZeroID(t *testing.T) { + m := NewManager(nil, nil) + require.True(t, m.isProxyAvailable(context.Background(), 0)) +} + +// --- selectByQuotaWeight --- + +func TestSelectByQuotaWeight_NoQuotaLast(t *testing.T) { + m := NewManager(nil, nil) + candidates := []ProviderConfig{ + {Type: "brave", APIKey: "k1", QuotaLimit: 0}, + {Type: "tavily", APIKey: "k2", QuotaLimit: 100}, + } + result := m.selectByQuotaWeight(context.Background(), candidates) + require.Len(t, result, 2) + require.Equal(t, "tavily", result[0].Type) + require.Equal(t, "brave", result[1].Type) +} + +func TestSelectByQuotaWeight_AllNoQuota(t *testing.T) { + m := NewManager(nil, nil) + candidates := []ProviderConfig{ + {Type: "brave", APIKey: "k1", QuotaLimit: 0}, + {Type: "tavily", APIKey: "k2", QuotaLimit: 0}, + } + result := m.selectByQuotaWeight(context.Background(), candidates) + require.Len(t, result, 2) +} + +func TestSelectByQuotaWeight_Empty(t *testing.T) { + m := NewManager(nil, nil) + result := m.selectByQuotaWeight(context.Background(), nil) + require.Empty(t, result) +} + +// --- newHTTPClient --- + +func TestNewHTTPClient_NoProxy(t *testing.T) { + c, err := newHTTPClient("") + require.NoError(t, err) + require.NotNil(t, c) +} + +func TestNewHTTPClient_InvalidProxy(t *testing.T) { + _, err := newHTTPClient("://bad-url") + require.Error(t, err) + require.Contains(t, err.Error(), "invalid proxy URL") +} + +func TestNewHTTPClient_ValidHTTPProxy(t *testing.T) { + c, err := newHTTPClient("http://proxy.example.com:8080") + require.NoError(t, err) + require.NotNil(t, c) +} + +func TestNewHTTPClient_ValidSOCKS5Proxy(t *testing.T) { + c, err := newHTTPClient("socks5://proxy.example.com:1080") + require.NoError(t, err) + require.NotNil(t, c) +} + +// --- ResetUsage --- + +func TestManager_ResetUsage_NilRedis(t *testing.T) { + m := NewManager(nil, nil) + err := m.ResetUsage(context.Background(), "brave") + require.NoError(t, err) +} diff --git a/backend/internal/pkg/websearch/provider.go b/backend/internal/pkg/websearch/provider.go new file mode 100644 index 00000000..3424c056 --- /dev/null +++ b/backend/internal/pkg/websearch/provider.go @@ -0,0 +1,11 @@ +package websearch + +import "context" + +// Provider is the interface every search backend must implement. +type Provider interface { + // Name returns the provider identifier ("brave" or "tavily"). + Name() string + // Search executes a web search and returns results. + Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) +} diff --git a/backend/internal/pkg/websearch/tavily.go b/backend/internal/pkg/websearch/tavily.go new file mode 100644 index 00000000..ac4928a6 --- /dev/null +++ b/backend/internal/pkg/websearch/tavily.go @@ -0,0 +1,107 @@ +package websearch + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" +) + +const ( + tavilySearchEndpoint = "https://api.tavily.com/search" + tavilyProviderName = "tavily" + tavilySearchDepthBasic = "basic" +) + +// TavilyProvider implements web search via the Tavily Search API. +type TavilyProvider struct { + apiKey string + httpClient *http.Client +} + +// NewTavilyProvider creates a Tavily Search provider. +// The caller is responsible for configuring the http.Client with proxy/timeouts. +func NewTavilyProvider(apiKey string, httpClient *http.Client) *TavilyProvider { + if httpClient == nil { + httpClient = http.DefaultClient + } + return &TavilyProvider{apiKey: apiKey, httpClient: httpClient} +} + +func (t *TavilyProvider) Name() string { return tavilyProviderName } + +func (t *TavilyProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) { + maxResults := req.MaxResults + if maxResults <= 0 { + maxResults = defaultMaxResults + } + + payload := tavilyRequest{ + APIKey: t.apiKey, + Query: req.Query, + MaxResults: maxResults, + SearchDepth: tavilySearchDepthBasic, + } + + bodyBytes, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("tavily: encode request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tavilySearchEndpoint, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("tavily: build request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := t.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("tavily: request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize)) + if err != nil { + return nil, fmt.Errorf("tavily: read body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("tavily: status %d: %s", resp.StatusCode, truncateBody(body)) + } + + var raw tavilyResponse + if err := json.Unmarshal(body, &raw); err != nil { + return nil, fmt.Errorf("tavily: decode response: %w", err) + } + + results := make([]SearchResult, 0, len(raw.Results)) + for _, r := range raw.Results { + results = append(results, SearchResult{ + URL: r.URL, + Title: r.Title, + Snippet: r.Content, + }) + } + + return &SearchResponse{Results: results, Query: req.Query}, nil +} + +type tavilyRequest struct { + APIKey string `json:"api_key"` + Query string `json:"query"` + MaxResults int `json:"max_results"` + SearchDepth string `json:"search_depth"` +} + +type tavilyResponse struct { + Results []tavilyResult `json:"results"` +} + +type tavilyResult struct { + URL string `json:"url"` + Title string `json:"title"` + Content string `json:"content"` + Score float64 `json:"score"` +} diff --git a/backend/internal/pkg/websearch/tavily_test.go b/backend/internal/pkg/websearch/tavily_test.go new file mode 100644 index 00000000..e1b6819a --- /dev/null +++ b/backend/internal/pkg/websearch/tavily_test.go @@ -0,0 +1,63 @@ +package websearch + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTavilyProvider_Name(t *testing.T) { + p := NewTavilyProvider("key", nil) + require.Equal(t, "tavily", p.Name()) +} + +func TestTavilyProvider_Search_RequestConstruction(t *testing.T) { + // Verify tavilyRequest struct fields map correctly + req := tavilyRequest{ + APIKey: "test-key", + Query: "golang", + MaxResults: 3, + SearchDepth: tavilySearchDepthBasic, + } + data, err := json.Marshal(req) + require.NoError(t, err) + + var parsed map[string]any + require.NoError(t, json.Unmarshal(data, &parsed)) + require.Equal(t, "test-key", parsed["api_key"]) + require.Equal(t, "golang", parsed["query"]) + require.Equal(t, float64(3), parsed["max_results"]) + require.Equal(t, "basic", parsed["search_depth"]) +} + +func TestTavilyProvider_Search_ResponseParsing(t *testing.T) { + rawResp := `{"results":[{"url":"https://go.dev","title":"Go","content":"Go programming language","score":0.95}]}` + var resp tavilyResponse + require.NoError(t, json.Unmarshal([]byte(rawResp), &resp)) + require.Len(t, resp.Results, 1) + require.Equal(t, "https://go.dev", resp.Results[0].URL) + require.Equal(t, "Go programming language", resp.Results[0].Content) + require.InDelta(t, 0.95, resp.Results[0].Score, 0.001) + + // Verify mapping to SearchResult + results := make([]SearchResult, 0, len(resp.Results)) + for _, r := range resp.Results { + results = append(results, SearchResult{ + URL: r.URL, Title: r.Title, Snippet: r.Content, + }) + } + require.Equal(t, "Go programming language", results[0].Snippet) + require.Equal(t, "", results[0].PageAge) +} + +func TestTavilyProvider_Search_EmptyResults(t *testing.T) { + var resp tavilyResponse + require.NoError(t, json.Unmarshal([]byte(`{"results":[]}`), &resp)) + require.Empty(t, resp.Results) +} + +func TestTavilyProvider_Search_InvalidJSON(t *testing.T) { + var resp tavilyResponse + require.Error(t, json.Unmarshal([]byte("not json"), &resp)) +} diff --git a/backend/internal/pkg/websearch/types.go b/backend/internal/pkg/websearch/types.go new file mode 100644 index 00000000..bb489690 --- /dev/null +++ b/backend/internal/pkg/websearch/types.go @@ -0,0 +1,30 @@ +package websearch + +// SearchResult represents a single web search result. +type SearchResult struct { + URL string `json:"url"` + Title string `json:"title"` + Snippet string `json:"snippet"` + PageAge string `json:"page_age,omitempty"` +} + +// SearchRequest describes a web search to perform. +type SearchRequest struct { + Query string + MaxResults int // defaults to defaultMaxResults if <= 0 + ProxyURL string // optional HTTP proxy URL +} + +// SearchResponse holds the results of a web search. +type SearchResponse struct { + Results []SearchResult + Query string // the query that was actually executed +} + +const defaultMaxResults = 5 + +// Provider type identifiers. +const ( + ProviderTypeBrave = "brave" + ProviderTypeTavily = "tavily" +) diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 7fd98855..38ea9bde 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -138,10 +138,17 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se WithUser(func(q *dbent.UserQuery) { q.Select( user.FieldID, + user.FieldEmail, + user.FieldUsername, user.FieldStatus, user.FieldRole, user.FieldBalance, user.FieldConcurrency, + user.FieldBalanceNotifyEnabled, + user.FieldBalanceNotifyThresholdType, + user.FieldBalanceNotifyThreshold, + user.FieldBalanceNotifyExtraEmails, + user.FieldTotalRecharged, ) }). WithGroup(func(q *dbent.GroupQuery) { @@ -639,22 +646,31 @@ func userEntityToService(u *dbent.User) *service.User { if u == nil { return nil } - return &service.User{ - ID: u.ID, - Email: u.Email, - Username: u.Username, - Notes: u.Notes, - PasswordHash: u.PasswordHash, - Role: u.Role, - Balance: u.Balance, - Concurrency: u.Concurrency, - Status: u.Status, - TotpSecretEncrypted: u.TotpSecretEncrypted, - TotpEnabled: u.TotpEnabled, - TotpEnabledAt: u.TotpEnabledAt, - CreatedAt: u.CreatedAt, - UpdatedAt: u.UpdatedAt, + out := &service.User{ + ID: u.ID, + Email: u.Email, + Username: u.Username, + Notes: u.Notes, + PasswordHash: u.PasswordHash, + Role: u.Role, + Balance: u.Balance, + Concurrency: u.Concurrency, + Status: u.Status, + TotpSecretEncrypted: u.TotpSecretEncrypted, + TotpEnabled: u.TotpEnabled, + TotpEnabledAt: u.TotpEnabledAt, + BalanceNotifyEnabled: u.BalanceNotifyEnabled, + BalanceNotifyThresholdType: u.BalanceNotifyThresholdType, + BalanceNotifyThreshold: u.BalanceNotifyThreshold, + TotalRecharged: u.TotalRecharged, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, } + // Parse extra emails JSON (supports both old []string and new []NotifyEmailEntry format) + if u.BalanceNotifyExtraEmails != "" && u.BalanceNotifyExtraEmails != "[]" { + out.BalanceNotifyExtraEmails = service.ParseNotifyEmails(u.BalanceNotifyExtraEmails) + } + return out } func groupEntityToService(g *dbent.Group) *service.Group { diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go index 49c2d8d9..2cb90aab 100644 --- a/backend/internal/repository/channel_repo.go +++ b/backend/internal/repository/channel_repo.go @@ -41,10 +41,14 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel if err != nil { return err } + featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig) + if err != nil { + return err + } err = tx.QueryRowContext(ctx, - `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models) VALUES ($1, $2, $3, $4, $5, $6) + `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, created_at, updated_at`, - channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats, ).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt) if err != nil { if isUniqueViolation(err) { @@ -67,17 +71,24 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel } } + // 设置账号统计定价规则 + if len(channel.AccountStatsPricingRules) > 0 { + if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil { + return err + } + } + return nil }) } func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) { ch := &service.Channel{} - var modelMappingJSON []byte + var modelMappingJSON, featuresConfigJSON []byte err := r.db.QueryRowContext(ctx, - `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at FROM channels WHERE id = $1`, id, - ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt) + ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt) if err == sql.ErrNoRows { return nil, service.ErrChannelNotFound } @@ -85,6 +96,7 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha return nil, fmt.Errorf("get channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON) groupIDs, err := r.GetGroupIDs(ctx, id) if err != nil { @@ -98,6 +110,12 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha } ch.ModelPricing = pricing + statsPricingRules, err := r.loadAccountStatsPricingRules(ctx, id) + if err != nil { + return nil, err + } + ch.AccountStatsPricingRules = statsPricingRules + return ch, nil } @@ -107,10 +125,14 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel if err != nil { return err } + featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig) + if err != nil { + return err + } result, err := tx.ExecContext(ctx, - `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, updated_at = NOW() - WHERE id = $7`, - channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.ID, + `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, apply_pricing_to_account_stats = $9, updated_at = NOW() + WHERE id = $10`, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats, channel.ID, ) if err != nil { if isUniqueViolation(err) { @@ -137,6 +159,13 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel } } + // 更新账号统计定价规则 + if channel.AccountStatsPricingRules != nil { + if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil { + return err + } + } + return nil }) } @@ -187,7 +216,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati // 查询 channel 列表 dataQuery := fmt.Sprintf( - `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at + `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.apply_pricing_to_account_stats, c.created_at, c.updated_at FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`, whereClause, channelListOrderBy(params), argIdx, argIdx+1, ) @@ -203,11 +232,12 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati var channelIDs []int64 for rows.Next() { var ch service.Channel - var modelMappingJSON []byte - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + var modelMappingJSON, featuresConfigJSON []byte + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil { return nil, nil, fmt.Errorf("scan channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON) channels = append(channels, ch) channelIDs = append(channelIDs, ch.ID) } @@ -225,9 +255,14 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati if err != nil { return nil, nil, err } + statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs) + if err != nil { + return nil, nil, err + } for i := range channels { channels[i].GroupIDs = groupMap[channels[i].ID] channels[i].ModelPricing = pricingMap[channels[i].ID] + channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID] } } @@ -273,7 +308,7 @@ func channelListOrderBy(params pagination.PaginationParams) string { func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) { rows, err := r.db.QueryContext(ctx, - `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`, + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`, ) if err != nil { return nil, fmt.Errorf("query all channels: %w", err) @@ -284,11 +319,12 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err var channelIDs []int64 for rows.Next() { var ch service.Channel - var modelMappingJSON []byte - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + var modelMappingJSON, featuresConfigJSON []byte + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil { return nil, fmt.Errorf("scan channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON) channels = append(channels, ch) channelIDs = append(channelIDs, ch.ID) } @@ -312,9 +348,16 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err return nil, err } + // 批量加载账号统计定价规则 + statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs) + if err != nil { + return nil, err + } + for i := range channels { channels[i].GroupIDs = groupMap[channels[i].ID] channels[i].ModelPricing = pricingMap[channels[i].ID] + channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID] } return channels, nil @@ -456,6 +499,28 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string { return m } +func marshalFeaturesConfig(m map[string]any) ([]byte, error) { + if len(m) == 0 { + return []byte("{}"), nil + } + data, err := json.Marshal(m) + if err != nil { + return nil, fmt.Errorf("marshal features_config: %w", err) + } + return data, nil +} + +func unmarshalFeaturesConfig(data []byte) map[string]any { + if len(data) == 0 { + return nil + } + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + return nil + } + return m +} + // GetGroupPlatforms 批量查询分组 ID 对应的平台 func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) { if len(groupIDs) == 0 { diff --git a/backend/internal/repository/channel_repo_account_stats_pricing.go b/backend/internal/repository/channel_repo_account_stats_pricing.go new file mode 100644 index 00000000..9e00fed8 --- /dev/null +++ b/backend/internal/repository/channel_repo_account_stats_pricing.go @@ -0,0 +1,244 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" +) + +// --- 账号统计定价规则 --- + +// batchLoadAccountStatsPricingRules 批量加载多个渠道的账号统计定价规则(含模型定价) +func (r *channelRepository) batchLoadAccountStatsPricingRules(ctx context.Context, channelIDs []int64) (map[int64][]service.AccountStatsPricingRule, error) { + // 1. 查询规则 + rows, err := r.db.QueryContext(ctx, + `SELECT id, channel_id, name, group_ids, account_ids, sort_order, created_at, updated_at + FROM channel_account_stats_pricing_rules WHERE channel_id = ANY($1) ORDER BY channel_id, sort_order, id`, + pq.Array(channelIDs), + ) + if err != nil { + return nil, fmt.Errorf("batch load account stats pricing rules: %w", err) + } + defer func() { _ = rows.Close() }() + + var allRules []service.AccountStatsPricingRule + var ruleIDs []int64 + for rows.Next() { + var rule service.AccountStatsPricingRule + if err := rows.Scan( + &rule.ID, &rule.ChannelID, &rule.Name, + pq.Array(&rule.GroupIDs), pq.Array(&rule.AccountIDs), + &rule.SortOrder, &rule.CreatedAt, &rule.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan account stats pricing rule: %w", err) + } + ruleIDs = append(ruleIDs, rule.ID) + allRules = append(allRules, rule) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate account stats pricing rules: %w", err) + } + + // 2. 批量加载规则的模型定价 + pricingMap, err := r.batchLoadAccountStatsModelPricing(ctx, ruleIDs) + if err != nil { + return nil, err + } + + // 3. 按 channelID 分组并关联定价 + result := make(map[int64][]service.AccountStatsPricingRule, len(channelIDs)) + for i := range allRules { + allRules[i].Pricing = pricingMap[allRules[i].ID] + result[allRules[i].ChannelID] = append(result[allRules[i].ChannelID], allRules[i]) + } + + return result, nil +} + +// batchLoadAccountStatsModelPricing 批量加载规则的模型定价 +func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Context, ruleIDs []int64) (map[int64][]service.ChannelModelPricing, error) { + if len(ruleIDs) == 0 { + return make(map[int64][]service.ChannelModelPricing), nil + } + + rows, err := r.db.QueryContext(ctx, + `SELECT id, rule_id, platform, models, billing_mode, input_price, output_price, + cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at + FROM channel_account_stats_model_pricing WHERE rule_id = ANY($1) ORDER BY rule_id, id`, + pq.Array(ruleIDs), + ) + if err != nil { + return nil, fmt.Errorf("batch load account stats model pricing: %w", err) + } + defer func() { _ = rows.Close() }() + + pricingMap := make(map[int64][]service.ChannelModelPricing, len(ruleIDs)) + for rows.Next() { + var p service.ChannelModelPricing + var ruleID int64 + var modelsJSON []byte + if err := rows.Scan( + &p.ID, &ruleID, &p.Platform, &modelsJSON, &p.BillingMode, + &p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice, + &p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan account stats model pricing: %w", err) + } + if err := json.Unmarshal(modelsJSON, &p.Models); err != nil { + p.Models = []string{} + } + pricingMap[ruleID] = append(pricingMap[ruleID], p) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate account stats model pricing: %w", err) + } + + // Load intervals for all pricing entries. + var allPricingIDs []int64 + for _, pricings := range pricingMap { + for _, p := range pricings { + allPricingIDs = append(allPricingIDs, p.ID) + } + } + if len(allPricingIDs) > 0 { + intervalsMap, err := r.batchLoadAccountStatsIntervals(ctx, allPricingIDs) + if err != nil { + return nil, err + } + for ruleID, pricings := range pricingMap { + for i := range pricings { + pricings[i].Intervals = intervalsMap[pricings[i].ID] + } + pricingMap[ruleID] = pricings + } + } + + return pricingMap, nil +} + +// loadAccountStatsPricingRules 加载单个渠道的账号统计定价规则(供 GetByID 使用) +func (r *channelRepository) loadAccountStatsPricingRules(ctx context.Context, channelID int64) ([]service.AccountStatsPricingRule, error) { + result, err := r.batchLoadAccountStatsPricingRules(ctx, []int64{channelID}) + if err != nil { + return nil, err + } + return result[channelID], nil +} + +// replaceAccountStatsPricingRulesTx 在事务中替换渠道的账号统计定价规则(删除旧的 + 插入新的) +func replaceAccountStatsPricingRulesTx(ctx context.Context, tx *sql.Tx, channelID int64, rules []service.AccountStatsPricingRule) error { + // CASCADE 会自动删除关联的 model_pricing + if _, err := tx.ExecContext(ctx, + `DELETE FROM channel_account_stats_pricing_rules WHERE channel_id = $1`, channelID, + ); err != nil { + return fmt.Errorf("delete old account stats pricing rules: %w", err) + } + + for i := range rules { + rules[i].ChannelID = channelID + if err := createAccountStatsPricingRuleTx(ctx, tx, &rules[i]); err != nil { + return fmt.Errorf("insert account stats pricing rule: %w", err) + } + } + return nil +} + +// createAccountStatsPricingRuleTx 在事务中创建单条账号统计定价规则及其模型定价 +func createAccountStatsPricingRuleTx(ctx context.Context, tx *sql.Tx, rule *service.AccountStatsPricingRule) error { + err := tx.QueryRowContext(ctx, + `INSERT INTO channel_account_stats_pricing_rules (channel_id, name, group_ids, account_ids, sort_order) + VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at`, + rule.ChannelID, rule.Name, pq.Array(rule.GroupIDs), pq.Array(rule.AccountIDs), rule.SortOrder, + ).Scan(&rule.ID, &rule.CreatedAt, &rule.UpdatedAt) + if err != nil { + return fmt.Errorf("insert account stats pricing rule: %w", err) + } + + for j := range rule.Pricing { + if err := createAccountStatsModelPricingTx(ctx, tx, rule.ID, &rule.Pricing[j]); err != nil { + return err + } + } + return nil +} + +// createAccountStatsModelPricingTx 在事务中创建单条账号统计模型定价 +func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID int64, pricing *service.ChannelModelPricing) error { + modelsJSON, err := json.Marshal(pricing.Models) + if err != nil { + return fmt.Errorf("marshal models: %w", err) + } + billingMode := pricing.BillingMode + if billingMode == "" { + billingMode = service.BillingModeToken + } + platform := pricing.Platform + err = tx.QueryRowContext(ctx, + `INSERT INTO channel_account_stats_model_pricing (rule_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`, + ruleID, platform, modelsJSON, billingMode, + pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice, + pricing.ImageOutputPrice, pricing.PerRequestPrice, + ).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt) + if err != nil { + return fmt.Errorf("insert account stats model pricing: %w", err) + } + // Persist intervals (mirrors channel_pricing_intervals logic). + for i := range pricing.Intervals { + iv := &pricing.Intervals[i] + iv.PricingID = pricing.ID + if err := createAccountStatsIntervalTx(ctx, tx, iv); err != nil { + return err + } + } + return nil +} + +// createAccountStatsIntervalTx inserts a single interval for an account stats pricing entry. +func createAccountStatsIntervalTx(ctx context.Context, tx *sql.Tx, iv *service.PricingInterval) error { + return tx.QueryRowContext(ctx, + `INSERT INTO channel_account_stats_pricing_intervals + (pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`, + iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel, + iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice, + iv.PerRequestPrice, iv.SortOrder, + ).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt) +} + +// batchLoadAccountStatsIntervals loads intervals for account stats pricing entries. +func (r *channelRepository) batchLoadAccountStatsIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) { + if len(pricingIDs) == 0 { + return nil, nil + } + rows, err := r.db.QueryContext(ctx, + `SELECT id, pricing_id, min_tokens, max_tokens, tier_label, + input_price, output_price, cache_write_price, cache_read_price, + per_request_price, sort_order, created_at, updated_at + FROM channel_account_stats_pricing_intervals + WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`, + pq.Array(pricingIDs), + ) + if err != nil { + return nil, fmt.Errorf("batch load account stats pricing intervals: %w", err) + } + defer func() { _ = rows.Close() }() + + result := make(map[int64][]service.PricingInterval) + for rows.Next() { + var iv service.PricingInterval + if err := rows.Scan( + &iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel, + &iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice, + &iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan account stats pricing interval: %w", err) + } + result[iv.PricingID] = append(result[iv.PricingID], iv) + } + return result, rows.Err() +} diff --git a/backend/internal/repository/dashboard_aggregation_repo.go b/backend/internal/repository/dashboard_aggregation_repo.go index e82a73a3..5e09e75d 100644 --- a/backend/internal/repository/dashboard_aggregation_repo.go +++ b/backend/internal/repository/dashboard_aggregation_repo.go @@ -331,6 +331,7 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens, COALESCE(SUM(total_cost), 0) AS total_cost, COALESCE(SUM(actual_cost), 0) AS actual_cost, + COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) AS account_cost, COALESCE(SUM(COALESCE(duration_ms, 0)), 0) AS total_duration_ms FROM usage_logs WHERE created_at >= $1 AND created_at < $2 @@ -351,6 +352,7 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont cache_read_tokens, total_cost, actual_cost, + account_cost, total_duration_ms, active_users, computed_at @@ -364,6 +366,7 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont hourly.cache_read_tokens, hourly.total_cost, hourly.actual_cost, + hourly.account_cost, hourly.total_duration_ms, COALESCE(user_counts.active_users, 0) AS active_users, NOW() @@ -378,6 +381,7 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont cache_read_tokens = EXCLUDED.cache_read_tokens, total_cost = EXCLUDED.total_cost, actual_cost = EXCLUDED.actual_cost, + account_cost = EXCLUDED.account_cost, total_duration_ms = EXCLUDED.total_duration_ms, active_users = EXCLUDED.active_users, computed_at = EXCLUDED.computed_at @@ -399,6 +403,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens, COALESCE(SUM(total_cost), 0) AS total_cost, COALESCE(SUM(actual_cost), 0) AS actual_cost, + COALESCE(SUM(account_cost), 0) AS account_cost, COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms FROM usage_dashboard_hourly WHERE bucket_start >= $1 AND bucket_start < $2 @@ -419,6 +424,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte cache_read_tokens, total_cost, actual_cost, + account_cost, total_duration_ms, active_users, computed_at @@ -432,6 +438,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte daily.cache_read_tokens, daily.total_cost, daily.actual_cost, + daily.account_cost, daily.total_duration_ms, COALESCE(user_counts.active_users, 0) AS active_users, NOW() @@ -446,6 +453,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte cache_read_tokens = EXCLUDED.cache_read_tokens, total_cost = EXCLUDED.total_cost, actual_cost = EXCLUDED.actual_cost, + account_cost = EXCLUDED.account_cost, total_duration_ms = EXCLUDED.total_duration_ms, active_users = EXCLUDED.active_users, computed_at = EXCLUDED.computed_at diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go index 8f2b8eca..96a23a8e 100644 --- a/backend/internal/repository/email_cache.go +++ b/backend/internal/repository/email_cache.go @@ -3,6 +3,8 @@ package repository import ( "context" "encoding/json" + "fmt" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/service" @@ -11,23 +13,33 @@ import ( const ( verifyCodeKeyPrefix = "verify_code:" + notifyVerifyKeyPrefix = "notify_verify:" passwordResetKeyPrefix = "password_reset:" passwordResetSentAtKeyPrefix = "password_reset_sent:" + notifyCodeUserRateKeyPrefix = "notify_code_user_rate:" ) // verifyCodeKey generates the Redis key for email verification code. +// Email is lowercased for case-insensitive consistency. func verifyCodeKey(email string) string { - return verifyCodeKeyPrefix + email + return verifyCodeKeyPrefix + strings.ToLower(email) +} + +// notifyVerifyKey generates the Redis key for notify email verification code. +// Email is lowercased to prevent case-sensitive key mismatch (the business layer +// uses strings.EqualFold for comparison). +func notifyVerifyKey(email string) string { + return notifyVerifyKeyPrefix + strings.ToLower(email) } // passwordResetKey generates the Redis key for password reset token. func passwordResetKey(email string) string { - return passwordResetKeyPrefix + email + return passwordResetKeyPrefix + strings.ToLower(email) } // passwordResetSentAtKey generates the Redis key for password reset email sent timestamp. func passwordResetSentAtKey(email string) string { - return passwordResetSentAtKeyPrefix + email + return passwordResetSentAtKeyPrefix + strings.ToLower(email) } type emailCache struct { @@ -106,3 +118,60 @@ func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email st key := passwordResetSentAtKey(email) return c.rdb.Set(ctx, key, "1", ttl).Err() } + +// Notify email verification code methods + +func (c *emailCache) GetNotifyVerifyCode(ctx context.Context, email string) (*service.VerificationCodeData, error) { + key := notifyVerifyKey(email) + val, err := c.rdb.Get(ctx, key).Result() + if err != nil { + return nil, err + } + var data service.VerificationCodeData + if err := json.Unmarshal([]byte(val), &data); err != nil { + return nil, err + } + return &data, nil +} + +func (c *emailCache) SetNotifyVerifyCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error { + key := notifyVerifyKey(email) + val, err := json.Marshal(data) + if err != nil { + return err + } + return c.rdb.Set(ctx, key, val, ttl).Err() +} + +func (c *emailCache) DeleteNotifyVerifyCode(ctx context.Context, email string) error { + key := notifyVerifyKey(email) + return c.rdb.Del(ctx, key).Err() +} + +// User-level rate limiting for notify email verification codes + +func notifyCodeUserRateKey(userID int64) string { + return notifyCodeUserRateKeyPrefix + fmt.Sprintf("%d", userID) +} + +func (c *emailCache) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) { + key := notifyCodeUserRateKey(userID) + count, err := c.rdb.Incr(ctx, key).Result() + if err != nil { + return 0, err + } + // Always set TTL (idempotent) to avoid orphan keys if process crashes between INCR and EXPIRE. + if err := c.rdb.Expire(ctx, key, window).Err(); err != nil { + return count, fmt.Errorf("expire notify code rate key: %w", err) + } + return count, nil +} + +func (c *emailCache) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) { + key := notifyCodeUserRateKey(userID) + count, err := c.rdb.Get(ctx, key).Int64() + if err != nil { + return 0, err + } + return count, nil +} diff --git a/backend/internal/repository/scheduler_cache.go b/backend/internal/repository/scheduler_cache.go index e9be8c7a..add0e501 100644 --- a/backend/internal/repository/scheduler_cache.go +++ b/backend/internal/repository/scheduler_cache.go @@ -426,6 +426,13 @@ func filterSchedulerExtra(extra map[string]any) map[string]any { "window_cost_sticky_reserve", "max_sessions", "session_idle_timeout_minutes", + "openai_oauth_responses_websockets_v2_enabled", + "openai_oauth_responses_websockets_v2_mode", + "openai_apikey_responses_websockets_v2_enabled", + "openai_apikey_responses_websockets_v2_mode", + "responses_websockets_v2_enabled", + "openai_ws_enabled", + "openai_ws_force_http", } filtered := make(map[string]any) for _, key := range keys { diff --git a/backend/internal/repository/scheduler_cache_unit_test.go b/backend/internal/repository/scheduler_cache_unit_test.go new file mode 100644 index 00000000..bcfd0e7a --- /dev/null +++ b/backend/internal/repository/scheduler_cache_unit_test.go @@ -0,0 +1,33 @@ +//go:build unit + +package repository + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) { + account := service.Account{ + ID: 42, + Platform: service.PlatformOpenAI, + Type: service.AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + "openai_oauth_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough, + "openai_ws_force_http": true, + "mixed_scheduling": true, + "unused_large_field": "drop-me", + }, + } + + got := buildSchedulerMetadataAccount(account) + + require.Equal(t, true, got.Extra["openai_oauth_responses_websockets_v2_enabled"]) + require.Equal(t, service.OpenAIWSIngressModePassthrough, got.Extra["openai_oauth_responses_websockets_v2_mode"]) + require.Equal(t, true, got.Extra["openai_ws_force_http"]) + require.Equal(t, true, got.Extra["mixed_scheduling"]) + require.Nil(t, got.Extra["unused_large_field"]) +} diff --git a/backend/internal/repository/usage_billing_repo.go b/backend/internal/repository/usage_billing_repo.go index b4c76da5..2b6edad3 100644 --- a/backend/internal/repository/usage_billing_repo.go +++ b/backend/internal/repository/usage_billing_repo.go @@ -113,9 +113,11 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t } if cmd.BalanceCost > 0 { - if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil { + newBalance, err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost) + if err != nil { return err } + result.NewBalance = &newBalance } if cmd.APIKeyQuotaCost > 0 { @@ -133,9 +135,11 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t } if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) { - if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil { + quotaState, err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost) + if err != nil { return err } + result.QuotaState = quotaState } return nil @@ -169,24 +173,22 @@ func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscrip return service.ErrSubscriptionNotFound } -func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error { - res, err := tx.ExecContext(ctx, ` +func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) (float64, error) { + var newBalance float64 + err := tx.QueryRowContext(ctx, ` UPDATE users SET balance = balance - $1, updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL - `, amount, userID) + RETURNING balance + `, amount, userID).Scan(&newBalance) + if errors.Is(err, sql.ErrNoRows) { + return 0, service.ErrUserNotFound + } if err != nil { - return err + return 0, err } - affected, err := res.RowsAffected() - if err != nil { - return err - } - if affected > 0 { - return nil - } - return service.ErrUserNotFound + return newBalance, nil } func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) { @@ -240,7 +242,7 @@ func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKe return nil } -func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error { +func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) (*service.AccountQuotaState, error) { rows, err := tx.QueryContext(ctx, `UPDATE accounts SET extra = ( COALESCE(extra, '{}'::jsonb) @@ -248,61 +250,71 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI || CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN jsonb_build_object( 'quota_daily_used', - CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) - + '24 hours'::interval <= NOW() + CASE WHEN `+dailyExpiredExpr+` THEN $1 ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END, 'quota_daily_start', - CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz) - + '24 hours'::interval <= NOW() + CASE WHEN `+dailyExpiredExpr+` THEN `+nowUTC+` ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END ) + || CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL + THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`) + ELSE '{}'::jsonb END ELSE '{}'::jsonb END || CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN jsonb_build_object( 'quota_weekly_used', - CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) - + '168 hours'::interval <= NOW() + CASE WHEN `+weeklyExpiredExpr+` THEN $1 ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END, 'quota_weekly_start', - CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz) - + '168 hours'::interval <= NOW() + CASE WHEN `+weeklyExpiredExpr+` THEN `+nowUTC+` ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END ) + || CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL + THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`) + ELSE '{}'::jsonb END ELSE '{}'::jsonb END ), updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL RETURNING COALESCE((extra->>'quota_used')::numeric, 0), - COALESCE((extra->>'quota_limit')::numeric, 0)`, + COALESCE((extra->>'quota_limit')::numeric, 0), + COALESCE((extra->>'quota_daily_used')::numeric, 0), + COALESCE((extra->>'quota_daily_limit')::numeric, 0), + COALESCE((extra->>'quota_weekly_used')::numeric, 0), + COALESCE((extra->>'quota_weekly_limit')::numeric, 0)`, amount, accountID) if err != nil { - return err + return nil, err } defer func() { _ = rows.Close() }() - var newUsed, limit float64 + var state service.AccountQuotaState if rows.Next() { - if err := rows.Scan(&newUsed, &limit); err != nil { - return err + if err := rows.Scan( + &state.TotalUsed, &state.TotalLimit, + &state.DailyUsed, &state.DailyLimit, + &state.WeeklyUsed, &state.WeeklyLimit, + ); err != nil { + return nil, err } } else { if err := rows.Err(); err != nil { - return err + return nil, err } - return service.ErrAccountNotFound + return nil, service.ErrAccountNotFound } if err := rows.Err(); err != nil { - return err + return nil, err } - if limit > 0 && newUsed >= limit && (newUsed-amount) < limit { + if state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit { if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil { logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err) - return err + return nil, err } } - return nil + return &state, nil } diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 3ba2191e..f2fb87da 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -28,7 +28,7 @@ import ( gocache "github.com/patrickmn/go-cache" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at" // usageLogInsertArgTypes must stay in the same order as: // 1. prepareUsageLogInsert().args @@ -82,6 +82,7 @@ var usageLogInsertArgTypes = [...]string{ "text", // model_mapping_chain "text", // billing_tier "text", // billing_mode + "numeric", // account_stats_cost "timestamptz", // created_at } @@ -360,6 +361,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, model_mapping_chain, billing_tier, billing_mode, + account_stats_cost, created_at ) VALUES ( $1, $2, $3, $4, $5, $6, $7, @@ -367,7 +369,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, - $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45 + $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -797,6 +799,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage model_mapping_chain, billing_tier, billing_mode, + account_stats_cost, created_at ) AS (VALUES `) @@ -873,6 +876,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage model_mapping_chain, billing_tier, billing_mode, + account_stats_cost, created_at ) SELECT @@ -920,6 +924,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage model_mapping_chain, billing_tier, billing_mode, + account_stats_cost, created_at FROM input ON CONFLICT (request_id, api_key_id) DO NOTHING @@ -1007,10 +1012,11 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( model_mapping_chain, billing_tier, billing_mode, + account_stats_cost, created_at ) AS (VALUES `) - args := make([]any, 0, len(preparedList)*45) + args := make([]any, 0, len(preparedList)*46) argPos := 1 for idx, prepared := range preparedList { if idx > 0 { @@ -1080,6 +1086,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( model_mapping_chain, billing_tier, billing_mode, + account_stats_cost, created_at ) SELECT @@ -1127,6 +1134,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( model_mapping_chain, billing_tier, billing_mode, + account_stats_cost, created_at FROM input ON CONFLICT (request_id, api_key_id) DO NOTHING @@ -1182,6 +1190,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared model_mapping_chain, billing_tier, billing_mode, + account_stats_cost, created_at ) VALUES ( $1, $2, $3, $4, $5, $6, $7, @@ -1189,7 +1198,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, - $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45 + $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46 ) ON CONFLICT (request_id, api_key_id) DO NOTHING `, prepared.args...) @@ -1285,6 +1294,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { modelMappingChain, billingTier, billingMode, + log.AccountStatsCost, // account_stats_cost createdAt, }, } @@ -1518,6 +1528,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens, COALESCE(SUM(total_cost), 0) as total_cost, COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(SUM(account_cost), 0) as total_account_cost, COALESCE(SUM(total_duration_ms), 0) as total_duration_ms FROM usage_dashboard_daily ` @@ -1534,6 +1545,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte &stats.TotalCacheReadTokens, &stats.TotalCost, &stats.TotalActualCost, + &stats.TotalAccountCost, &totalDurationMs, ); err != nil { return err @@ -1552,6 +1564,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte cache_read_tokens as today_cache_read_tokens, total_cost as today_cost, actual_cost as today_actual_cost, + account_cost as today_account_cost, active_users as active_users FROM usage_dashboard_daily WHERE bucket_date = $1::date @@ -1568,6 +1581,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte &stats.TodayCacheReadTokens, &stats.TodayCost, &stats.TodayActualCost, + &stats.TodayAccountCost, &stats.ActiveUsers, ); err != nil { if err != sql.ErrNoRows { @@ -1603,6 +1617,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co cache_read_tokens, total_cost, actual_cost, + COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1) AS account_cost, COALESCE(duration_ms, 0) AS duration_ms FROM usage_logs WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz) @@ -1616,6 +1631,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_read_tokens, COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cost, COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_actual_cost, + COALESCE(SUM(account_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_account_cost, COALESCE(SUM(duration_ms) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_duration_ms, COUNT(*) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz) AS today_requests, COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_input_tokens, @@ -1623,7 +1639,8 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_creation_tokens, COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_read_tokens, COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cost, - COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost, + COALESCE(SUM(account_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_account_cost FROM scoped ` var totalDurationMs int64 @@ -1639,6 +1656,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co &stats.TotalCacheReadTokens, &stats.TotalCost, &stats.TotalActualCost, + &stats.TotalAccountCost, &totalDurationMs, &stats.TodayRequests, &stats.TodayInputTokens, @@ -1647,6 +1665,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co &stats.TodayCacheReadTokens, &stats.TodayCost, &stats.TodayActualCost, + &stats.TodayAccountCost, ); err != nil { return err } @@ -1959,7 +1978,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID SELECT COUNT(*) as requests, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, - COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost, + COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost, COALESCE(SUM(total_cost), 0) as standard_cost, COALESCE(SUM(actual_cost), 0) as user_cost FROM usage_logs @@ -1989,7 +2008,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI SELECT COUNT(*) as requests, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, - COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost, + COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost, COALESCE(SUM(total_cost), 0) as standard_cost, COALESCE(SUM(actual_cost), 0) as user_cost FROM usage_logs @@ -2026,7 +2045,7 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc account_id, COUNT(*) as requests, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, - COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost, + COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost, COALESCE(SUM(total_cost), 0) as standard_cost, COALESCE(SUM(actual_cost), 0) as user_cost FROM usage_logs @@ -2585,7 +2604,8 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64 COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(total_cost), 0) as cost, - COALESCE(SUM(actual_cost), 0) as actual_cost + COALESCE(SUM(actual_cost), 0) as actual_cost, + COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as account_cost FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 GROUP BY model @@ -2990,8 +3010,9 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" // 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。 if accountID > 0 && userID == 0 && apiKeyID == 0 { - actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" + actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" } + accountCostExpr := "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as account_cost" modelExpr := resolveModelDimensionExpression(source) query := fmt.Sprintf(` @@ -3004,10 +3025,11 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, COALESCE(SUM(total_cost), 0) as cost, + %s, %s FROM usage_logs WHERE created_at >= $1 AND created_at < $2 - `, modelExpr, actualCostExpr) + `, modelExpr, actualCostExpr, accountCostExpr) args := []any{startTime, endTime} if userID > 0 { @@ -3062,7 +3084,8 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start COUNT(*) as requests, COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens, COALESCE(SUM(ul.total_cost), 0) as cost, - COALESCE(SUM(ul.actual_cost), 0) as actual_cost + COALESCE(SUM(ul.actual_cost), 0) as actual_cost, + COALESCE(SUM(COALESCE(ul.account_stats_cost, ul.total_cost) * COALESCE(ul.account_rate_multiplier, 1)), 0) as account_cost FROM usage_logs ul LEFT JOIN groups g ON g.id = ul.group_id WHERE ul.created_at >= $1 AND ul.created_at < $2 @@ -3113,6 +3136,7 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start &row.TotalTokens, &row.Cost, &row.ActualCost, + &row.AccountCost, ); err != nil { return nil, err } @@ -3133,7 +3157,8 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim COUNT(*) as requests, COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens, COALESCE(SUM(ul.total_cost), 0) as cost, - COALESCE(SUM(ul.actual_cost), 0) as actual_cost + COALESCE(SUM(ul.actual_cost), 0) as actual_cost, + COALESCE(SUM(COALESCE(ul.account_stats_cost, ul.total_cost) * COALESCE(ul.account_rate_multiplier, 1)), 0) as account_cost FROM usage_logs ul LEFT JOIN users u ON u.id = ul.user_id WHERE ul.created_at >= $1 AND ul.created_at < $2 @@ -3204,6 +3229,7 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim &row.TotalTokens, &row.Cost, &row.ActualCost, + &row.AccountCost, ); err != nil { return nil, err } @@ -3358,7 +3384,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, COALESCE(SUM(total_cost), 0) as total_cost, COALESCE(SUM(actual_cost), 0) as total_actual_cost, - COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost, + COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost, COALESCE(AVG(duration_ms), 0) as avg_duration_ms FROM usage_logs %s @@ -3382,9 +3408,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us ); err != nil { return nil, err } - if filters.AccountID > 0 { - stats.TotalAccountCost = &totalAccountCost - } + stats.TotalAccountCost = &totalAccountCost stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens start := time.Unix(0, 0).UTC() @@ -3433,7 +3457,7 @@ type EndpointStat = usagestats.EndpointStat func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) { actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" if accountID > 0 && userID == 0 && apiKeyID == 0 { - actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" + actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" } query := fmt.Sprintf(` @@ -3500,7 +3524,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) { actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" if accountID > 0 && userID == 0 && apiKeyID == 0 { - actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" + actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" } query := fmt.Sprintf(` @@ -3591,7 +3615,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID COUNT(*) as requests, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(total_cost), 0) as cost, - COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost, + COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost, COALESCE(SUM(actual_cost), 0) as user_cost FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 @@ -4069,6 +4093,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e modelMappingChain sql.NullString billingTier sql.NullString billingMode sql.NullString + accountStatsCost sql.NullFloat64 createdAt time.Time ) @@ -4118,6 +4143,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &modelMappingChain, &billingTier, &billingMode, + &accountStatsCost, &createdAt, ); err != nil { return nil, err @@ -4214,6 +4240,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if billingMode.Valid { log.BillingMode = &billingMode.String } + if accountStatsCost.Valid { + log.AccountStatsCost = &accountStatsCost.Float64 + } return log, nil } @@ -4257,6 +4286,7 @@ func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) { &row.TotalTokens, &row.Cost, &row.ActualCost, + &row.AccountCost, ); err != nil { return nil, err } diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 0383f3bc..ed3050d8 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -753,8 +753,11 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { s.Require().Equal(baseStats.TotalTokens+int64(51), stats.TotalTokens, "TotalTokens mismatch") s.Require().Equal(baseStats.TotalCost+2.3, stats.TotalCost, "TotalCost mismatch") s.Require().Equal(baseStats.TotalActualCost+2.0, stats.TotalActualCost, "TotalActualCost mismatch") + // account_cost falls back to total_cost when account_stats_cost is NULL + s.Require().Equal(baseStats.TotalAccountCost+2.3, stats.TotalAccountCost, "TotalAccountCost mismatch") s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1") s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0") + s.Require().GreaterOrEqual(stats.TodayAccountCost, 0.0, "expected TodayAccountCost >= 0") wantRpm, wantTpm, err := s.repo.getPerformanceStats(s.ctx, 0) s.Require().NoError(err, "getPerformanceStats") @@ -833,6 +836,8 @@ func (s *UsageLogRepoSuite) TestDashboardStatsWithRange_Fallback() { s.Require().Equal(int64(45), stats.TotalTokens) s.Require().Equal(1.5, stats.TotalCost) s.Require().Equal(1.4, stats.TotalActualCost) + // account_cost = COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1) = total_cost + s.Require().Equal(1.5, stats.TotalAccountCost) s.Require().InEpsilon(150.0, stats.AverageDurationMs, 0.0001) } diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index b9cb6a13..a5ff4bc1 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -85,6 +85,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { sqlmock.AnyArg(), // model_mapping_chain sqlmock.AnyArg(), // billing_tier sqlmock.AnyArg(), // billing_mode + sqlmock.AnyArg(), // account_stats_cost createdAt, ). WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt)) @@ -163,6 +164,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { sqlmock.AnyArg(), // model_mapping_chain sqlmock.AnyArg(), // billing_tier sqlmock.AnyArg(), // billing_mode + sqlmock.AnyArg(), // account_stats_cost createdAt, ). WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt)) @@ -299,7 +301,7 @@ func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testin mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)"). WithArgs(start, end, requestType). - WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"})) + WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost", "account_cost"})) stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil) require.NoError(t, err) @@ -344,6 +346,93 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T) require.NoError(t, err) require.Equal(t, int64(1), stats.TotalRequests) require.Equal(t, int64(9), stats.TotalTokens) + require.NotNil(t, stats.TotalAccountCost, "TotalAccountCost should always be returned") + require.Equal(t, 1.2, *stats.TotalAccountCost) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryGetModelStatsAccountCostColumn(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + + mock.ExpectQuery("FROM usage_logs"). + WithArgs(start, end). + WillReturnRows(sqlmock.NewRows([]string{ + "model", "requests", "input_tokens", "output_tokens", + "cache_creation_tokens", "cache_read_tokens", "total_tokens", + "cost", "actual_cost", "account_cost", + }). + AddRow("claude-opus-4-6", int64(10), int64(100), int64(200), int64(5), int64(3), int64(308), 2.5, 2.0, 1.8). + AddRow("claude-sonnet-4-6", int64(5), int64(50), int64(100), int64(0), int64(0), int64(150), 1.0, 0.8, 0.7)) + + results, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, nil, nil, nil) + require.NoError(t, err) + require.Len(t, results, 2) + require.Equal(t, "claude-opus-4-6", results[0].Model) + require.Equal(t, 2.5, results[0].Cost) + require.Equal(t, 2.0, results[0].ActualCost) + require.Equal(t, 1.8, results[0].AccountCost) + require.Equal(t, "claude-sonnet-4-6", results[1].Model) + require.Equal(t, 0.7, results[1].AccountCost) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryGetGroupStatsAccountCostColumn(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + + mock.ExpectQuery("FROM usage_logs"). + WithArgs(start, end). + WillReturnRows(sqlmock.NewRows([]string{ + "group_id", "group_name", "requests", "total_tokens", + "cost", "actual_cost", "account_cost", + }). + AddRow(int64(1), "azure-cc", int64(100), int64(5000), 10.0, 8.5, 7.2). + AddRow(int64(2), "max", int64(50), int64(2000), 5.0, 4.0, 3.5)) + + results, err := repo.GetGroupStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, nil, nil, nil) + require.NoError(t, err) + require.Len(t, results, 2) + require.Equal(t, int64(1), results[0].GroupID) + require.Equal(t, "azure-cc", results[0].GroupName) + require.Equal(t, 10.0, results[0].Cost) + require.Equal(t, 8.5, results[0].ActualCost) + require.Equal(t, 7.2, results[0].AccountCost) + require.Equal(t, int64(2), results[1].GroupID) + require.Equal(t, 3.5, results[1].AccountCost) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryGetStatsWithFiltersAlwaysReturnsAccountCost(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + // No AccountID filter set - TotalAccountCost should still be returned + filters := usagestats.UsageLogFilters{} + + mock.ExpectQuery("FROM usage_logs"). + WillReturnRows(sqlmock.NewRows([]string{ + "total_requests", "total_input_tokens", "total_output_tokens", + "total_cache_tokens", "total_cost", "total_actual_cost", + "total_account_cost", "avg_duration_ms", + }).AddRow(int64(50), int64(1000), int64(2000), int64(100), 15.0, 12.5, 11.0, 100.0)) + mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(inbound_endpoint\\)"). + WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"})) + mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(upstream_endpoint\\)"). + WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"})) + mock.ExpectQuery("SELECT CONCAT\\("). + WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"})) + + stats, err := repo.GetStatsWithFilters(context.Background(), filters) + require.NoError(t, err) + require.NotNil(t, stats.TotalAccountCost, "TotalAccountCost must always be returned, even without AccountID filter") + require.Equal(t, 11.0, *stats.TotalAccountCost) require.NoError(t, mock.ExpectationsWereMet()) } @@ -483,10 +572,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, sql.NullString{}, false, - sql.NullInt64{}, // channel_id - sql.NullString{}, // model_mapping_chain - sql.NullString{}, // billing_tier - sql.NullString{}, // billing_mode + sql.NullInt64{}, // channel_id + sql.NullString{}, // model_mapping_chain + sql.NullString{}, // billing_tier + sql.NullString{}, // billing_mode + sql.NullFloat64{}, // account_stats_cost now, }}) require.NoError(t, err) @@ -530,10 +620,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, sql.NullString{}, false, - sql.NullInt64{}, // channel_id - sql.NullString{}, // model_mapping_chain - sql.NullString{}, // billing_tier - sql.NullString{}, // billing_mode + sql.NullInt64{}, // channel_id + sql.NullString{}, // model_mapping_chain + sql.NullString{}, // billing_tier + sql.NullString{}, // billing_mode + sql.NullFloat64{}, // account_stats_cost now, }}) require.NoError(t, err) @@ -577,10 +668,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, sql.NullString{}, false, - sql.NullInt64{}, // channel_id - sql.NullString{}, // model_mapping_chain - sql.NullString{}, // billing_tier - sql.NullString{}, // billing_mode + sql.NullInt64{}, // channel_id + sql.NullString{}, // model_mapping_chain + sql.NullString{}, // billing_tier + sql.NullString{}, // billing_mode + sql.NullFloat64{}, // account_stats_cost now, }}) require.NoError(t, err) diff --git a/backend/internal/repository/user_group_rate_repo.go b/backend/internal/repository/user_group_rate_repo.go index e2471ae5..eca5313f 100644 --- a/backend/internal/repository/user_group_rate_repo.go +++ b/backend/internal/repository/user_group_rate_repo.go @@ -100,7 +100,7 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6 query := ` SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier FROM user_group_rate_multipliers ugr - JOIN users u ON u.id = ugr.user_id + JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL WHERE ugr.group_id = $1 ORDER BY ugr.user_id ` diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index d5a13607..913e1c40 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -137,7 +137,7 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error txClient = r.client } - updated, err := txClient.User.UpdateOneID(userIn.ID). + updateOp := txClient.User.UpdateOneID(userIn.ID). SetEmail(userIn.Email). SetUsername(userIn.Username). SetNotes(userIn.Notes). @@ -146,7 +146,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). - Save(ctx) + SetBalanceNotifyEnabled(userIn.BalanceNotifyEnabled). + SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType). + SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold). + SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)). + SetTotalRecharged(userIn.TotalRecharged) + if userIn.BalanceNotifyThreshold == nil { + updateOp = updateOp.ClearBalanceNotifyThreshold() + } + updated, err := updateOp.Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists) } @@ -382,7 +390,12 @@ func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[ func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error { client := clientFromContext(ctx, r.client) - n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx) + update := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount) + // Track cumulative recharge amount for percentage-based notifications + if amount > 0 { + update = update.AddTotalRecharged(amount) + } + n, err := update.Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrUserNotFound, nil) } @@ -549,6 +562,11 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) { dst.UpdatedAt = src.UpdatedAt } +// marshalExtraEmails serializes notify email entries to JSON for storage. +func marshalExtraEmails(entries []service.NotifyEmailEntry) string { + return service.MarshalNotifyEmails(entries) +} + // UpdateTotpSecret 更新用户的 TOTP 加密密钥 func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { client := clientFromContext(ctx, r.client) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 1a4892fa..b686b986 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -58,6 +58,11 @@ func TestAPIContracts(t *testing.T) { "allowed_groups": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z", + "balance_notify_enabled": false, + "balance_notify_threshold_type": "", + "balance_notify_threshold": null, + "balance_notify_extra_emails": null, + "total_recharged": 0, "run_mode": "standard" } }`, @@ -204,11 +209,10 @@ func TestAPIContracts(t *testing.T) { "image_price_1k": null, "image_price_2k": null, "image_price_4k": null, - "claude_code_only": false, + "claude_code_only": false, "allow_messages_dispatch": false, "fallback_group_id": null, "fallback_group_id_on_invalid_request": null, - "allow_messages_dispatch": false, "require_oauth_only": false, "require_privacy_set": false, "created_at": "2025-01-02T03:04:05Z", @@ -587,26 +591,34 @@ func TestAPIContracts(t *testing.T) { "enable_cch_signing": false, "enable_fingerprint_unification": true, "enable_metadata_passthrough": false, + "web_search_emulation_enabled": false, + "custom_menu_items": [], + "custom_endpoints": [], "payment_enabled": false, "payment_min_amount": 0, "payment_max_amount": 0, "payment_daily_limit": 0, "payment_order_timeout_minutes": 0, "payment_max_pending_orders": 0, - "payment_enabled_types": null, "payment_balance_disabled": false, + "payment_balance_recharge_multiplier": 0, + "payment_recharge_fee_rate": 0, "payment_load_balance_strategy": "", "payment_product_name_prefix": "", "payment_product_name_suffix": "", "payment_help_image_url": "", "payment_help_text": "", + "payment_enabled_types": null, "payment_cancel_rate_limit_enabled": false, "payment_cancel_rate_limit_max": 0, "payment_cancel_rate_limit_window": 0, "payment_cancel_rate_limit_unit": "", "payment_cancel_rate_limit_window_mode": "", - "custom_menu_items": [], - "custom_endpoints": [] + "balance_low_notify_enabled": false, + "account_quota_notify_enabled": false, + "balance_low_notify_threshold": 0, + "balance_low_notify_recharge_url": "", + "account_quota_notify_emails": [] } }`, }, @@ -699,7 +711,7 @@ func newContractDeps(t *testing.T) *contractDeps { RunMode: config.RunModeStandard, } - userService := service.NewUserService(userRepo, nil, nil) + userService := service.NewUserService(userRepo, nil, nil, nil) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index ca141c05..023e40bb 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -2,12 +2,15 @@ package server import ( + "context" "log" + "log/slog" "net/http" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" + "github.com/Wei-Shaw/sub2api/internal/pkg/websearch" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -36,7 +39,6 @@ func ProvideRouter( opsService *service.OpsService, settingService *service.SettingService, redisClient *redis.Client, - langServerService *service.LanguageServerService, ) *gin.Engine { if cfg.Server.Mode == "release" { gin.SetMode(gin.ReleaseMode) @@ -57,7 +59,43 @@ func ProvideRouter( } } - return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient, langServerService) + // Wire up websearch Manager builder so it initializes on startup and rebuilds on config save. + settingService.SetWebSearchManagerBuilder(context.Background(), func(cfg *service.WebSearchEmulationConfig, proxyURLs map[int64]string) { + if cfg == nil || !cfg.Enabled || len(cfg.Providers) == 0 { + service.SetWebSearchManager(nil) + return + } + configs := make([]websearch.ProviderConfig, 0, len(cfg.Providers)) + for _, p := range cfg.Providers { + if p.APIKey == "" { + continue + } + pc := websearch.ProviderConfig{ + Type: p.Type, + APIKey: p.APIKey, + QuotaLimit: derefInt64(p.QuotaLimit), + ExpiresAt: p.ExpiresAt, + } + if p.SubscribedAt != nil { + pc.SubscribedAt = p.SubscribedAt + } + if p.ProxyID != nil { + pc.ProxyID = *p.ProxyID + if u, ok := proxyURLs[*p.ProxyID]; ok { + pc.ProxyURL = u + } else { + // Proxy configured but not found — skip this provider to prevent direct connection. + slog.Warn("websearch: proxy not found for provider, skipping", + "provider", p.Type, "proxy_id", *p.ProxyID) + continue + } + } + configs = append(configs, pc) + } + service.SetWebSearchManager(websearch.NewManager(configs, redisClient)) + }) + + return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) } // ProvideHTTPServer 提供 HTTP 服务器 @@ -103,3 +141,10 @@ func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server { // 不设置 ReadTimeout,因为大请求体可能需要较长时间读取 } } + +func derefInt64(p *int64) int64 { + if p == nil { + return 0 + } + return *p +} diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index aafe4a58..ed2578c8 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -39,7 +39,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { return &clone, nil }, } - userService := service.NewUserService(userRepo, nil, nil) + userService := service.NewUserService(userRepo, nil, nil, nil) router := gin.New() router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil))) diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go index ad9c1b5b..c483a51e 100644 --- a/backend/internal/server/middleware/jwt_auth_test.go +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -41,7 +41,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer userRepo := &stubJWTUserRepo{users: users} authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) - userSvc := service.NewUserService(userRepo, nil, nil) + userSvc := service.NewUserService(userRepo, nil, nil, nil) mw := NewJWTAuthMiddleware(authSvc, userSvc) r := gin.New() diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index 73210bfc..7021ab2e 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -18,6 +18,8 @@ const ( NonceTemplate = "__CSP_NONCE__" // CloudflareInsightsDomain is the domain for Cloudflare Web Analytics CloudflareInsightsDomain = "https://static.cloudflareinsights.com" + // StripeDomain is the domain for Stripe.js SDK + StripeDomain = "https://*.stripe.com" ) // GenerateNonce generates a cryptographically secure random nonce. @@ -97,8 +99,9 @@ func isAPIRoutePath(c *gin.Context) bool { strings.HasPrefix(path, "/responses") } -// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain. -// This allows the application to work correctly even if the config file has an older CSP policy. +// enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights, +// and Stripe.js domains. This allows the application to work correctly even if the +// config file has an older CSP policy. func enhanceCSPPolicy(policy string) string { // Add nonce placeholder to script-src if not present if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") { @@ -110,6 +113,12 @@ func enhanceCSPPolicy(policy string) string { policy = addToDirective(policy, "script-src", CloudflareInsightsDomain) } + // Add Stripe.js domain to script-src and frame-src if not present + if !strings.Contains(policy, "stripe.com") { + policy = addToDirective(policy, "script-src", StripeDomain) + policy = addToDirective(policy, "frame-src", StripeDomain) + } + return policy } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index b921da95..9af0fd8e 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -407,6 +407,11 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { // Beta 策略配置 adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings) adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings) + // Web Search 模拟配置 + adminSettings.GET("/web-search-emulation", h.Admin.Setting.GetWebSearchEmulationConfig) + adminSettings.PUT("/web-search-emulation", h.Admin.Setting.UpdateWebSearchEmulationConfig) + adminSettings.POST("/web-search-emulation/test", h.Admin.Setting.TestWebSearchEmulation) + adminSettings.POST("/web-search-emulation/reset-usage", h.Admin.Setting.ResetWebSearchUsage) } } diff --git a/backend/internal/server/routes/payment.go b/backend/internal/server/routes/payment.go index 6bf04679..23bd58ad 100644 --- a/backend/internal/server/routes/payment.go +++ b/backend/internal/server/routes/payment.go @@ -39,6 +39,7 @@ func RegisterPaymentRoutes( orders.GET("/:id", paymentHandler.GetOrder) orders.POST("/:id/cancel", paymentHandler.CancelOrder) orders.POST("/:id/refund-request", paymentHandler.RequestRefund) + orders.GET("/refund-eligible-providers", paymentHandler.GetRefundEligibleProviders) } } diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index c3b82742..d004f8b4 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -26,6 +26,15 @@ func RegisterUserRoutes( user.PUT("/password", h.User.ChangePassword) user.PUT("", h.User.UpdateProfile) + // 通知邮箱管理 + notifyEmail := user.Group("/notify-email") + { + notifyEmail.POST("/send-code", h.User.SendNotifyEmailCode) + notifyEmail.POST("/verify", h.User.VerifyNotifyEmail) + notifyEmail.PUT("/toggle", h.User.ToggleNotifyEmail) + notifyEmail.DELETE("", h.User.RemoveNotifyEmail) + } + // TOTP 双因素认证 totp := user.Group("/totp") { diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 512195e3..52db3073 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "hash/fnv" + "log/slog" "reflect" "sort" "strconv" @@ -969,7 +970,7 @@ func (a *Account) IsOveragesEnabled() bool { return false } -// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。 +// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用"自动透传(仅替换认证)"。 // // 新字段:accounts.extra.openai_passthrough。 // 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。 @@ -1133,7 +1134,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri return resolvedDefault } -// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。 +// IsOpenAIWSForceHTTPEnabled 返回账号级"强制 HTTP"开关。 // 字段:accounts.extra.openai_ws_force_http。 func (a *Account) IsOpenAIWSForceHTTPEnabled() bool { if a == nil || !a.IsOpenAI() || a.Extra == nil { @@ -1158,7 +1159,7 @@ func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool { return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled() } -// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用“自动透传(仅替换认证)”。 +// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用"自动透传(仅替换认证)"。 // 字段:accounts.extra.anthropic_passthrough。 // 字段缺失或类型不正确时,按 false(关闭)处理。 func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool { @@ -1169,7 +1170,42 @@ func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool { return ok && enabled } -// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。 +// WebSearch 模拟三态常量 +const ( + WebSearchModeDefault = "default" // 跟随渠道配置 + WebSearchModeEnabled = "enabled" // 强制开启 + WebSearchModeDisabled = "disabled" // 强制关闭 +) + +// GetWebSearchEmulationMode 返回账号的 WebSearch 模拟模式。 +// 三态:default(跟随渠道)/ enabled(强制开启)/ disabled(强制关闭)。 +// 兼容旧 bool 值:true→enabled, false→default(并记录 debug 日志)。 +func (a *Account) GetWebSearchEmulationMode() string { + if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil { + return WebSearchModeDefault + } + raw := a.Extra[featureKeyWebSearchEmulation] + // Tolerant: legacy bool values (pre-migration or stale writes) + if b, ok := raw.(bool); ok { + slog.Debug("legacy bool web_search_emulation value", "account_id", a.ID, "value", b) + if b { + return WebSearchModeEnabled + } + return WebSearchModeDefault + } + mode, ok := raw.(string) + if !ok { + return WebSearchModeDefault + } + switch mode { + case WebSearchModeEnabled, WebSearchModeDisabled: + return mode + default: + return WebSearchModeDefault + } +} + +// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用"仅允许 Codex 官方客户端"。 // 字段:accounts.extra.codex_cli_only。 // 字段缺失或类型不正确时,按 false(关闭)处理。 func (a *Account) IsCodexCLIOnlyEnabled() bool { @@ -1395,6 +1431,19 @@ func (a *Account) getExtraTime(key string) time.Time { return time.Time{} } +// getExtraBool 从 Extra 中读取指定 key 的 bool 值 +func (a *Account) getExtraBool(key string) bool { + if a.Extra == nil { + return false + } + if v, ok := a.Extra[key]; ok { + if b, ok := v.(bool); ok { + return b + } + } + return false +} + // getExtraString 从 Extra 中读取指定 key 的字符串值 func (a *Account) getExtraString(key string) string { if a.Extra == nil { @@ -1408,6 +1457,14 @@ func (a *Account) getExtraString(key string) string { return "" } +// getExtraStringDefault 从 Extra 中读取指定 key 的字符串值,不存在时返回 defaultVal +func (a *Account) getExtraStringDefault(key, defaultVal string) string { + if v := a.getExtraString(key); v != "" { + return v + } + return defaultVal +} + // getExtraInt 从 Extra 中读取指定 key 的 int 值 func (a *Account) getExtraInt(key string) int { if a.Extra == nil { @@ -1464,6 +1521,62 @@ func (a *Account) GetQuotaResetTimezone() string { return "UTC" } +// --- Quota Notification Getters --- + +// QuotaNotifyConfig returns the notify configuration for a given quota dimension. +// dim must be one of quotaDimDaily, quotaDimWeekly, quotaDimTotal. +func (a *Account) QuotaNotifyConfig(dim string) (enabled bool, threshold float64, thresholdType string) { + enabled = a.getExtraBool("quota_notify_" + dim + "_enabled") + threshold = a.getExtraFloat64("quota_notify_" + dim + "_threshold") + thresholdType = a.getExtraStringDefault("quota_notify_"+dim+"_threshold_type", thresholdTypeFixed) + return +} + +func (a *Account) GetQuotaNotifyDailyEnabled() bool { + e, _, _ := a.QuotaNotifyConfig(quotaDimDaily) + return e +} + +func (a *Account) GetQuotaNotifyDailyThreshold() float64 { + _, t, _ := a.QuotaNotifyConfig(quotaDimDaily) + return t +} + +func (a *Account) GetQuotaNotifyDailyThresholdType() string { + _, _, tt := a.QuotaNotifyConfig(quotaDimDaily) + return tt +} + +func (a *Account) GetQuotaNotifyWeeklyEnabled() bool { + e, _, _ := a.QuotaNotifyConfig(quotaDimWeekly) + return e +} + +func (a *Account) GetQuotaNotifyWeeklyThreshold() float64 { + _, t, _ := a.QuotaNotifyConfig(quotaDimWeekly) + return t +} + +func (a *Account) GetQuotaNotifyWeeklyThresholdType() string { + _, _, tt := a.QuotaNotifyConfig(quotaDimWeekly) + return tt +} + +func (a *Account) GetQuotaNotifyTotalEnabled() bool { + e, _, _ := a.QuotaNotifyConfig(quotaDimTotal) + return e +} + +func (a *Account) GetQuotaNotifyTotalThreshold() float64 { + _, t, _ := a.QuotaNotifyConfig(quotaDimTotal) + return t +} + +func (a *Account) GetQuotaNotifyTotalThresholdType() string { + _, _, tt := a.QuotaNotifyConfig(quotaDimTotal) + return tt +} + // nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点 func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time { t := after.In(tz) diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go new file mode 100644 index 00000000..90ff450f --- /dev/null +++ b/backend/internal/service/account_stats_pricing.go @@ -0,0 +1,236 @@ +package service + +import ( + "context" + "strings" +) + +// resolveAccountStatsCost 计算账号统计定价费用。 +// 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。 +// +// 优先级(先命中为准): +// 1. 自定义规则(始终尝试,不依赖 ApplyPricingToAccountStats 开关) +// 2. ApplyPricingToAccountStats 启用时,直接使用本次请求的客户计费(倍率前的 totalCost) +// 3. 模型定价文件(LiteLLM)中上游模型的默认价格 +// 4. nil → 走默认公式(total_cost × account_rate_multiplier) +// +// upstreamModel 是最终发往上游的模型 ID。 +// totalCost 是本次请求的客户计费(倍率前),用于优先级 2。 +func resolveAccountStatsCost( + ctx context.Context, + channelService *ChannelService, + billingService *BillingService, + accountID int64, + groupID int64, + upstreamModel string, + tokens UsageTokens, + requestCount int, + totalCost float64, +) *float64 { + if channelService == nil || upstreamModel == "" { + return nil + } + channel, err := channelService.GetChannelForGroup(ctx, groupID) + if err != nil || channel == nil { + return nil + } + + platform := channelService.GetGroupPlatform(ctx, groupID) + + // 优先级 1:自定义规则(始终尝试) + if cost := tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount); cost != nil { + return cost + } + + // 优先级 2:渠道开启"应用模型定价到账号统计"时,直接使用客户计费(倍率前) + if channel.ApplyPricingToAccountStats { + cost := totalCost + if cost <= 0 { + return nil + } + return &cost + } + + // 优先级 3:模型定价文件(LiteLLM)默认价格 + if billingService != nil { + return tryModelFilePricing(billingService, upstreamModel, tokens) + } + + return nil +} + +// tryModelFilePricing 使用模型定价文件(LiteLLM/fallback)中的标准价格计算费用。 +func tryModelFilePricing(billingService *BillingService, model string, tokens UsageTokens) *float64 { + pricing, err := billingService.GetModelPricing(model) + if err != nil || pricing == nil { + return nil + } + cost := float64(tokens.InputTokens)*pricing.InputPricePerToken + + float64(tokens.OutputTokens)*pricing.OutputPricePerToken + + float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken + + float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken + + float64(tokens.ImageOutputTokens)*pricing.ImageOutputPricePerToken + if cost <= 0 { + return nil + } + return &cost +} + +// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。 +func tryCustomRules( + channel *Channel, accountID, groupID int64, + platform, model string, tokens UsageTokens, requestCount int, +) *float64 { + modelLower := strings.ToLower(model) + for _, rule := range channel.AccountStatsPricingRules { + if !matchAccountStatsRule(&rule, accountID, groupID) { + continue + } + pricing := findPricingForModel(rule.Pricing, platform, modelLower) + if pricing == nil { + continue // 规则匹配但模型不在规则定价中,继续下一条 + } + return calculateStatsCost(pricing, tokens, requestCount) + } + return nil +} + +// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。 +// 匹配条件:accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。 +// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。 +func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int64) bool { + if len(rule.AccountIDs) == 0 && len(rule.GroupIDs) == 0 { + return false + } + for _, id := range rule.AccountIDs { + if id == accountID { + return true + } + } + for _, id := range rule.GroupIDs { + if id == groupID { + return true + } + } + return false +} + +// findPricingForModel 在定价列表中查找匹配的模型定价。 +// 先精确匹配,再通配符匹配(按配置顺序,先匹配先使用)。 +func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing { + // 精确匹配优先 + for i := range pricingList { + p := &pricingList[i] + if !isPlatformMatch(platform, p.Platform) { + continue + } + for _, m := range p.Models { + if strings.ToLower(m) == modelLower { + return p + } + } + } + // 通配符匹配:按配置顺序,先匹配先使用 + for i := range pricingList { + p := &pricingList[i] + if !isPlatformMatch(platform, p.Platform) { + continue + } + for _, m := range p.Models { + ml := strings.ToLower(m) + if !strings.HasSuffix(ml, "*") { + continue + } + prefix := strings.TrimSuffix(ml, "*") + if strings.HasPrefix(modelLower, prefix) { + return p + } + } + } + return nil +} + +// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。 +func isPlatformMatch(queryPlatform, pricingPlatform string) bool { + if queryPlatform == "" || pricingPlatform == "" { + return true + } + return queryPlatform == pricingPlatform +} + +// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。 +func calculateStatsCost(pricing *ChannelModelPricing, tokens UsageTokens, requestCount int) *float64 { + if pricing == nil { + return nil + } + switch pricing.BillingMode { + case BillingModePerRequest, BillingModeImage: + return calculatePerRequestStatsCost(pricing, requestCount) + default: + return calculateTokenStatsCost(pricing, tokens) + } +} + +// calculatePerRequestStatsCost 按次/图片计费。 +func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int) *float64 { + if pricing.PerRequestPrice == nil || *pricing.PerRequestPrice <= 0 { + return nil + } + cost := *pricing.PerRequestPrice * float64(requestCount) + return &cost +} + +// calculateTokenStatsCost Token 计费。 +// If the pricing has intervals, find the matching interval by total token count +// and use its prices instead of the flat pricing fields. +func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 { + p := pricing + if len(pricing.Intervals) > 0 { + totalTokens := tokens.InputTokens + tokens.OutputTokens + tokens.CacheCreationTokens + tokens.CacheReadTokens + if iv := FindMatchingInterval(pricing.Intervals, totalTokens); iv != nil { + p = &ChannelModelPricing{ + InputPrice: iv.InputPrice, + OutputPrice: iv.OutputPrice, + CacheWritePrice: iv.CacheWritePrice, + CacheReadPrice: iv.CacheReadPrice, + PerRequestPrice: iv.PerRequestPrice, + } + } + } + deref := func(ptr *float64) float64 { + if ptr == nil { + return 0 + } + return *ptr + } + cost := float64(tokens.InputTokens)*deref(p.InputPrice) + + float64(tokens.OutputTokens)*deref(p.OutputPrice) + + float64(tokens.CacheCreationTokens)*deref(p.CacheWritePrice) + + float64(tokens.CacheReadTokens)*deref(p.CacheReadPrice) + + float64(tokens.ImageOutputTokens)*deref(p.ImageOutputPrice) + if cost <= 0 { + return nil + } + return &cost +} + +// applyAccountStatsCost resolves the account stats cost for a usage log entry. +// It resolves the upstream model (falling back to the requested model) and calls +// the 4-level priority chain via resolveAccountStatsCost. +func applyAccountStatsCost( + ctx context.Context, + usageLog *UsageLog, + cs *ChannelService, bs *BillingService, + accountID int64, groupID int64, + upstreamModel, requestedModel string, + tokens UsageTokens, + totalCost float64, +) { + model := upstreamModel + if model == "" { + model = requestedModel + } + usageLog.AccountStatsCost = resolveAccountStatsCost( + ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost, + ) +} diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go new file mode 100644 index 00000000..36e5eb74 --- /dev/null +++ b/backend/internal/service/account_stats_pricing_test.go @@ -0,0 +1,771 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// matchAccountStatsRule +// --------------------------------------------------------------------------- + +func TestMatchAccountStatsRule_BothEmpty_NoMatch(t *testing.T) { + rule := &AccountStatsPricingRule{} + require.False(t, matchAccountStatsRule(rule, 1, 10)) +} + +func TestMatchAccountStatsRule_AccountIDMatch(t *testing.T) { + rule := &AccountStatsPricingRule{AccountIDs: []int64{1, 2, 3}} + require.True(t, matchAccountStatsRule(rule, 2, 999)) +} + +func TestMatchAccountStatsRule_GroupIDMatch(t *testing.T) { + rule := &AccountStatsPricingRule{GroupIDs: []int64{10, 20}} + require.True(t, matchAccountStatsRule(rule, 999, 20)) +} + +func TestMatchAccountStatsRule_BothConfigured_AccountMatch(t *testing.T) { + rule := &AccountStatsPricingRule{ + AccountIDs: []int64{1, 2}, + GroupIDs: []int64{10, 20}, + } + require.True(t, matchAccountStatsRule(rule, 2, 999)) +} + +func TestMatchAccountStatsRule_BothConfigured_GroupMatch(t *testing.T) { + rule := &AccountStatsPricingRule{ + AccountIDs: []int64{1, 2}, + GroupIDs: []int64{10, 20}, + } + require.True(t, matchAccountStatsRule(rule, 999, 10)) +} + +func TestMatchAccountStatsRule_BothConfigured_NeitherMatch(t *testing.T) { + rule := &AccountStatsPricingRule{ + AccountIDs: []int64{1, 2}, + GroupIDs: []int64{10, 20}, + } + require.False(t, matchAccountStatsRule(rule, 999, 999)) +} + +// --------------------------------------------------------------------------- +// findPricingForModel +// --------------------------------------------------------------------------- + +func TestFindPricingForModel(t *testing.T) { + exactPricing := ChannelModelPricing{ + ID: 1, + Models: []string{"claude-opus-4"}, + } + wildcardPricing := ChannelModelPricing{ + ID: 2, + Models: []string{"claude-*"}, + } + platformPricing := ChannelModelPricing{ + ID: 3, + Platform: "openai", + Models: []string{"gpt-4o"}, + } + emptyPlatformPricing := ChannelModelPricing{ + ID: 4, + Models: []string{"gemini-2.5-pro"}, + } + + tests := []struct { + name string + list []ChannelModelPricing + platform string + model string + wantID int64 + wantNil bool + }{ + { + name: "exact match", + list: []ChannelModelPricing{exactPricing}, + platform: "anthropic", + model: "claude-opus-4", + wantID: 1, + }, + { + name: "exact match case insensitive", + list: []ChannelModelPricing{{ID: 5, Models: []string{"Claude-Opus-4"}}}, + platform: "", + model: "claude-opus-4", + wantID: 5, + }, + { + name: "wildcard match", + list: []ChannelModelPricing{wildcardPricing}, + platform: "anthropic", + model: "claude-opus-4", + wantID: 2, + }, + { + name: "exact match takes priority over wildcard", + list: []ChannelModelPricing{wildcardPricing, exactPricing}, + platform: "anthropic", + model: "claude-opus-4", + wantID: 1, + }, + { + name: "platform mismatch skipped", + list: []ChannelModelPricing{platformPricing}, + platform: "anthropic", + model: "gpt-4o", + wantNil: true, + }, + { + name: "empty platform in pricing matches any", + list: []ChannelModelPricing{emptyPlatformPricing}, + platform: "gemini", + model: "gemini-2.5-pro", + wantID: 4, + }, + { + name: "empty platform in query matches any pricing platform", + list: []ChannelModelPricing{platformPricing}, + platform: "", + model: "gpt-4o", + wantID: 3, + }, + { + name: "no match at all", + list: []ChannelModelPricing{exactPricing, wildcardPricing}, + platform: "anthropic", + model: "gpt-4o", + wantNil: true, + }, + { + name: "empty list returns nil", + list: nil, + model: "claude-opus-4", + wantNil: true, + }, + { + name: "wildcard matches by config order (first match wins)", + list: []ChannelModelPricing{ + {ID: 10, Models: []string{"claude-*"}}, + {ID: 11, Models: []string{"claude-opus-*"}}, + }, + platform: "", + model: "claude-opus-4", + wantID: 10, // config order: "claude-*" is first and matches, so it wins + }, + { + name: "shorter wildcard used when longer does not match", + list: []ChannelModelPricing{ + {ID: 10, Models: []string{"claude-*"}}, + {ID: 11, Models: []string{"claude-opus-*"}}, + }, + platform: "", + model: "claude-sonnet-4", + wantID: 10, // only "claude-*" matches + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := findPricingForModel(tt.list, tt.platform, tt.model) + if tt.wantNil { + require.Nil(t, result) + return + } + require.NotNil(t, result) + require.Equal(t, tt.wantID, result.ID) + }) + } +} + +// --------------------------------------------------------------------------- +// calculateStatsCost +// --------------------------------------------------------------------------- + +func TestCalculateStatsCost_NilPricing(t *testing.T) { + result := calculateStatsCost(nil, UsageTokens{}, 1) + require.Nil(t, result) +} + +func TestCalculateStatsCost_TokenBilling(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(0.001), + OutputPrice: testPtrFloat64(0.002), + } + tokens := UsageTokens{ + InputTokens: 100, + OutputTokens: 50, + } + result := calculateStatsCost(pricing, tokens, 1) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2 + require.InDelta(t, 0.2, *result, 1e-12) +} + +func TestCalculateStatsCost_TokenBilling_WithCache(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(0.001), + OutputPrice: testPtrFloat64(0.002), + CacheWritePrice: testPtrFloat64(0.003), + CacheReadPrice: testPtrFloat64(0.0005), + } + tokens := UsageTokens{ + InputTokens: 100, + OutputTokens: 50, + CacheCreationTokens: 200, + CacheReadTokens: 300, + } + result := calculateStatsCost(pricing, tokens, 1) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005 + // = 0.1 + 0.1 + 0.6 + 0.15 = 0.95 + require.InDelta(t, 0.95, *result, 1e-12) +} + +func TestCalculateStatsCost_TokenBilling_WithImageOutput(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(0.001), + OutputPrice: testPtrFloat64(0.002), + ImageOutputPrice: testPtrFloat64(0.01), + } + tokens := UsageTokens{ + InputTokens: 100, + OutputTokens: 50, + ImageOutputTokens: 10, + } + result := calculateStatsCost(pricing, tokens, 1) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3 + require.InDelta(t, 0.3, *result, 1e-12) +} + +func TestCalculateStatsCost_TokenBilling_PartialPricesNil(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(0.001), + // OutputPrice, CacheWritePrice, etc. are all nil → treated as 0 + } + tokens := UsageTokens{ + InputTokens: 100, + OutputTokens: 50, + CacheCreationTokens: 200, + } + result := calculateStatsCost(pricing, tokens, 1) + require.NotNil(t, result) + // Only input contributes: 100*0.001 = 0.1 + require.InDelta(t, 0.1, *result, 1e-12) +} + +func TestCalculateStatsCost_TokenBilling_AllTokensZero(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(0.001), + OutputPrice: testPtrFloat64(0.002), + } + tokens := UsageTokens{} // all zeros + result := calculateStatsCost(pricing, tokens, 1) + // totalCost == 0 → returns nil (does not override, falls back to default formula) + require.Nil(t, result) +} + +func TestCalculateStatsCost_PerRequestBilling(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModePerRequest, + PerRequestPrice: testPtrFloat64(0.05), + } + tokens := UsageTokens{InputTokens: 999, OutputTokens: 999} + result := calculateStatsCost(pricing, tokens, 3) + require.NotNil(t, result) + // 0.05 * 3 = 0.15 + require.InDelta(t, 0.15, *result, 1e-12) +} + +func TestCalculateStatsCost_PerRequestBilling_PriceNil(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModePerRequest, + // PerRequestPrice is nil + } + result := calculateStatsCost(pricing, UsageTokens{}, 1) + require.Nil(t, result) +} + +func TestCalculateStatsCost_PerRequestBilling_PriceZero(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModePerRequest, + PerRequestPrice: testPtrFloat64(0), + } + result := calculateStatsCost(pricing, UsageTokens{}, 1) + // price == 0 → condition *pricing.PerRequestPrice > 0 is false → returns nil + require.Nil(t, result) +} + +func TestCalculateStatsCost_ImageBilling(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModeImage, + PerRequestPrice: testPtrFloat64(0.10), + } + result := calculateStatsCost(pricing, UsageTokens{}, 2) + require.NotNil(t, result) + // 0.10 * 2 = 0.20 + require.InDelta(t, 0.20, *result, 1e-12) +} + +func TestCalculateStatsCost_ImageBilling_PriceNil(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModeImage, + // PerRequestPrice is nil + } + result := calculateStatsCost(pricing, UsageTokens{}, 1) + require.Nil(t, result) +} + +func TestCalculateStatsCost_DefaultBillingMode_FallsToToken(t *testing.T) { + // BillingMode is empty string (default) → falls into token billing + pricing := &ChannelModelPricing{ + InputPrice: testPtrFloat64(0.001), + OutputPrice: testPtrFloat64(0.002), + } + tokens := UsageTokens{ + InputTokens: 100, + OutputTokens: 50, + } + result := calculateStatsCost(pricing, tokens, 1) + require.NotNil(t, result) + require.InDelta(t, 0.2, *result, 1e-12) +} + +// --------------------------------------------------------------------------- +// tryCustomRules — 多规则顺序测试 +// --------------------------------------------------------------------------- + +func TestTryCustomRules_FirstMatchWins(t *testing.T) { + channel := &Channel{ + AccountStatsPricingRules: []AccountStatsPricingRule{ + { + GroupIDs: []int64{1}, + Pricing: []ChannelModelPricing{ + {ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01), OutputPrice: testPtrFloat64(0.02)}, + }, + }, + { + GroupIDs: []int64{1}, + Pricing: []ChannelModelPricing{ + {ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99), OutputPrice: testPtrFloat64(0.99)}, + }, + }, + }, + } + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1) + require.NotNil(t, result) + // 应使用第一条规则的价格:100*0.01 + 50*0.02 = 2.0 + require.InDelta(t, 2.0, *result, 1e-12) +} + +func TestTryCustomRules_SkipsNonMatchingRules(t *testing.T) { + channel := &Channel{ + AccountStatsPricingRules: []AccountStatsPricingRule{ + { + AccountIDs: []int64{888}, // 不匹配 + Pricing: []ChannelModelPricing{ + {ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99)}, + }, + }, + { + GroupIDs: []int64{1}, // 匹配 + Pricing: []ChannelModelPricing{ + {ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)}, + }, + }, + }, + } + tokens := UsageTokens{InputTokens: 100} + result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1) + require.NotNil(t, result) + // 跳过规则1(账号不匹配),使用规则2:100*0.05 = 5.0 + require.InDelta(t, 5.0, *result, 1e-12) +} + +func TestTryCustomRules_NoMatch_ReturnsNil(t *testing.T) { + channel := &Channel{ + AccountStatsPricingRules: []AccountStatsPricingRule{ + { + AccountIDs: []int64{888}, + Pricing: []ChannelModelPricing{ + {ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01)}, + }, + }, + }, + } + tokens := UsageTokens{InputTokens: 100} + result := tryCustomRules(channel, 999, 2, "", "claude-opus-4", tokens, 1) + require.Nil(t, result) // 账号和分组都不匹配 +} + +func TestTryCustomRules_RuleMatchesButModelNot_ContinuesToNext(t *testing.T) { + channel := &Channel{ + AccountStatsPricingRules: []AccountStatsPricingRule{ + { + GroupIDs: []int64{1}, + Pricing: []ChannelModelPricing{ + {ID: 100, Models: []string{"gpt-4o"}, InputPrice: testPtrFloat64(0.01)}, // 模型不匹配 + }, + }, + { + GroupIDs: []int64{1}, + Pricing: []ChannelModelPricing{ + {ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)}, // 模型匹配 + }, + }, + }, + } + tokens := UsageTokens{InputTokens: 100} + result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1) + require.NotNil(t, result) + require.InDelta(t, 5.0, *result, 1e-12) // 使用规则2 +} + +// --------------------------------------------------------------------------- +// tryModelFilePricing +// --------------------------------------------------------------------------- + +// newTestBillingServiceWithPrices creates a BillingService with pre-populated +// fallback prices for testing. No config or pricing service is needed. +// The key must match what getFallbackPricing resolves to for a given model name. +// E.g., model "claude-sonnet-4" resolves to key "claude-sonnet-4". +func newTestBillingServiceWithPrices(prices map[string]*ModelPricing) *BillingService { + return &BillingService{ + fallbackPrices: prices, + } +} + +func TestTryModelFilePricing_Success(t *testing.T) { + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 0.001, + OutputPricePerToken: 0.002, + }, + }) + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + result := tryModelFilePricing(bs, "claude-sonnet-4", tokens) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2 + require.InDelta(t, 0.2, *result, 1e-12) +} + +func TestTryModelFilePricing_PricingNotFound(t *testing.T) { + // "nonexistent-model" does not match any fallback pattern + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{}) + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + result := tryModelFilePricing(bs, "nonexistent-model", tokens) + require.Nil(t, result) +} + +func TestTryModelFilePricing_NilFallback(t *testing.T) { + // getFallbackPricing returns nil when key maps to nil + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": nil, + }) + tokens := UsageTokens{InputTokens: 100} + result := tryModelFilePricing(bs, "claude-sonnet-4", tokens) + require.Nil(t, result) +} + +func TestTryModelFilePricing_ZeroCost(t *testing.T) { + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 0.001, + OutputPricePerToken: 0.002, + }, + }) + tokens := UsageTokens{} // all zero tokens → cost = 0 → nil + result := tryModelFilePricing(bs, "claude-sonnet-4", tokens) + require.Nil(t, result) +} + +func TestTryModelFilePricing_WithImageOutput(t *testing.T) { + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 0.001, + OutputPricePerToken: 0.002, + ImageOutputPricePerToken: 0.01, + }, + }) + tokens := UsageTokens{ + InputTokens: 100, + OutputTokens: 50, + ImageOutputTokens: 10, + } + result := tryModelFilePricing(bs, "claude-sonnet-4", tokens) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3 + require.InDelta(t, 0.3, *result, 1e-12) +} + +func TestTryModelFilePricing_WithCacheTokens(t *testing.T) { + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 0.001, + OutputPricePerToken: 0.002, + CacheCreationPricePerToken: 0.003, + CacheReadPricePerToken: 0.0005, + }, + }) + tokens := UsageTokens{ + InputTokens: 100, + OutputTokens: 50, + CacheCreationTokens: 200, + CacheReadTokens: 300, + } + result := tryModelFilePricing(bs, "claude-sonnet-4", tokens) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005 + // = 0.1 + 0.1 + 0.6 + 0.15 = 0.95 + require.InDelta(t, 0.95, *result, 1e-12) +} + +// --------------------------------------------------------------------------- +// resolveAccountStatsCost — integration tests covering the 4-level priority chain +// --------------------------------------------------------------------------- + +func TestResolveAccountStatsCost_NilChannelService(t *testing.T) { + result := resolveAccountStatsCost( + context.Background(), + nil, // channelService is nil + newTestBillingServiceWithPrices(map[string]*ModelPricing{}), + 1, 1, "claude-sonnet-4", + UsageTokens{InputTokens: 100}, 1, 0.5, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_EmptyUpstreamModel(t *testing.T) { + cs := newTestChannelServiceForStats(t, &Channel{ + ID: 1, + Status: StatusActive, + }, 1, "") + + result := resolveAccountStatsCost( + context.Background(), + cs, + newTestBillingServiceWithPrices(map[string]*ModelPricing{}), + 1, 1, "", // empty upstream model + UsageTokens{InputTokens: 100}, 1, 0.5, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_GetChannelForGroupReturnsNil(t *testing.T) { + // Group 99 is NOT in the cache, so GetChannelForGroup returns nil + cs := newTestChannelServiceForStats(t, &Channel{ + ID: 1, + Status: StatusActive, + }, 1, "") + + result := resolveAccountStatsCost( + context.Background(), + cs, + newTestBillingServiceWithPrices(map[string]*ModelPricing{}), + 1, 99, "claude-sonnet-4", // groupID 99 has no channel + UsageTokens{InputTokens: 100}, 1, 0.5, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_HitsCustomRule(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + AccountStatsPricingRules: []AccountStatsPricingRule{ + { + GroupIDs: []int64{10}, + Pricing: []ChannelModelPricing{ + { + ID: 100, + Models: []string{"claude-sonnet-4"}, + InputPrice: testPtrFloat64(0.01), + OutputPrice: testPtrFloat64(0.02), + }, + }, + }, + }, + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, // billingService not needed when custom rule hits + 1, 10, "claude-sonnet-4", + tokens, 1, 999.0, // totalCost ignored because custom rule hits + ) + require.NotNil(t, result) + // 100*0.01 + 50*0.02 = 1.0 + 1.0 = 2.0 + require.InDelta(t, 2.0, *result, 1e-12) +} + +func TestResolveAccountStatsCost_ApplyPricingToAccountStats_UsesTotalCost(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: true, + // No custom rules + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, + 1, 10, "claude-sonnet-4", + tokens, 1, 0.75, // totalCost = 0.75 + ) + require.NotNil(t, result) + require.InDelta(t, 0.75, *result, 1e-12) +} + +func TestResolveAccountStatsCost_ApplyPricingToAccountStats_ZeroTotalCost_ReturnsNil(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: true, + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, + 1, 10, "claude-sonnet-4", + UsageTokens{}, 1, 0.0, // totalCost = 0 + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_FallsBackToLiteLLM(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: false, // not enabled + // No custom rules + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 0.001, + OutputPricePerToken: 0.002, + }, + }) + + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + + result := resolveAccountStatsCost( + context.Background(), + cs, bs, + 1, 10, "claude-sonnet-4", + tokens, 1, 999.0, // totalCost ignored + ) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2 + require.InDelta(t, 0.2, *result, 1e-12) +} + +func TestResolveAccountStatsCost_AllMiss_ReturnsNil(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: false, + // No custom rules + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + // BillingService with no pricing for the model + bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{}) + + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + + result := resolveAccountStatsCost( + context.Background(), + cs, bs, + 1, 10, "totally-unknown-model", + tokens, 1, 0.0, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_NilBillingService_SkipsLiteLLM(t *testing.T) { + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: false, + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, // billingService is nil + 1, 10, "claude-sonnet-4", + UsageTokens{InputTokens: 100}, 1, 0.0, + ) + require.Nil(t, result) +} + +func TestResolveAccountStatsCost_CustomRulePriorityOverApplyPricing(t *testing.T) { + // Both custom rule and ApplyPricingToAccountStats are configured; + // custom rule should take precedence. + channel := &Channel{ + ID: 1, + Status: StatusActive, + ApplyPricingToAccountStats: true, + AccountStatsPricingRules: []AccountStatsPricingRule{ + { + GroupIDs: []int64{10}, + Pricing: []ChannelModelPricing{ + { + ID: 100, + Models: []string{"claude-sonnet-4"}, + InputPrice: testPtrFloat64(0.05), + }, + }, + }, + }, + } + cs := newTestChannelServiceForStats(t, channel, 10, "anthropic") + + tokens := UsageTokens{InputTokens: 100} + + result := resolveAccountStatsCost( + context.Background(), + cs, nil, + 1, 10, "claude-sonnet-4", + tokens, 1, 99.0, // totalCost = 99.0 (would be used if ApplyPricing wins) + ) + require.NotNil(t, result) + // Custom rule: 100*0.05 = 5.0 (NOT 99.0 from totalCost) + require.InDelta(t, 5.0, *result, 1e-12) +} + +// --------------------------------------------------------------------------- +// helpers for resolveAccountStatsCost tests +// --------------------------------------------------------------------------- + +// newTestChannelServiceForStats creates a ChannelService with a single channel +// mapped to the given groupID, suitable for resolveAccountStatsCost tests. +func newTestChannelServiceForStats(t *testing.T, channel *Channel, groupID int64, platform string) *ChannelService { + t.Helper() + cache := newEmptyChannelCache() + cache.channelByGroupID[groupID] = channel + cache.groupPlatform[groupID] = platform + cs := &ChannelService{} + cache.loadedAt = time.Now() + cs.cache.Store(cache) + return cs +} diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 55865945..a5559b7d 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -515,22 +515,10 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account _ = s.accountRepo.UpdateExtra(ctx, account.ID, updates) mergeAccountExtra(account, updates) } - if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { - if resetAt := codexRateLimitResetAtFromSnapshot(snapshot, time.Now()); resetAt != nil { - _ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt) - account.RateLimitResetAt = resetAt - } - } } if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - if isOAuth && s.accountRepo != nil { - if resetAt := (&RateLimitService{}).calculateOpenAI429ResetTime(resp.Header); resetAt != nil { - _ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt) - account.RateLimitResetAt = resetAt - } - } // 401 Unauthorized: 标记账号为永久错误 if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil { errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body)) diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go index 5125db5b..82606979 100644 --- a/backend/internal/service/account_test_service_openai_test.go +++ b/backend/internal/service/account_test_service_openai_test.go @@ -111,7 +111,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing. require.Contains(t, recorder.Body.String(), "test_complete") } -func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) { +func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing.T) { gin.SetMode(gin.TestMode) ctx, _ := newTestContext() @@ -138,10 +138,7 @@ func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) require.Error(t, err) require.NotEmpty(t, repo.updatedExtra) require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"]) - require.Equal(t, int64(88), repo.rateLimitedID) - require.NotNil(t, repo.rateLimitedAt) - require.NotNil(t, account.RateLimitResetAt) - if account.RateLimitResetAt != nil && repo.rateLimitedAt != nil { - require.WithinDuration(t, *repo.rateLimitedAt, *account.RateLimitResetAt, time.Second) - } + require.Zero(t, repo.rateLimitedID) + require.Nil(t, repo.rateLimitedAt) + require.Nil(t, account.RateLimitResetAt) } diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 0e5741d8..8d5bcec8 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -499,7 +499,6 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou if account == nil { return usage, nil } - syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, now) if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil { usage.FiveHour = progress @@ -509,11 +508,8 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou } if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) { - if updates, resetAt, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && (len(updates) > 0 || resetAt != nil) { + if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 { mergeAccountExtra(account, updates) - if resetAt != nil { - account.RateLimitResetAt = resetAt - } if usage.UpdatedAt == nil { usage.UpdatedAt = &now } @@ -594,26 +590,26 @@ func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, no return true } -func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, *time.Time, error) { +func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) { if account == nil || !account.IsOAuth() { - return nil, nil, nil + return nil, nil } accessToken := account.GetOpenAIAccessToken() if accessToken == "" { - return nil, nil, fmt.Errorf("no access token available") + return nil, fmt.Errorf("no access token available") } modelID := openaipkg.DefaultTestModel payload := createOpenAITestPayload(modelID, true) payloadBytes, err := json.Marshal(payload) if err != nil { - return nil, nil, fmt.Errorf("marshal openai probe payload: %w", err) + return nil, fmt.Errorf("marshal openai probe payload: %w", err) } reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second) defer cancel() req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes)) if err != nil { - return nil, nil, fmt.Errorf("create openai probe request: %w", err) + return nil, fmt.Errorf("create openai probe request: %w", err) } req.Host = "chatgpt.com" req.Header.Set("Content-Type", "application/json") @@ -642,67 +638,51 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco ResponseHeaderTimeout: 10 * time.Second, }) if err != nil { - return nil, nil, fmt.Errorf("build openai probe client: %w", err) + return nil, fmt.Errorf("build openai probe client: %w", err) } resp, err := client.Do(req) if err != nil { - return nil, nil, fmt.Errorf("openai codex probe request failed: %w", err) + return nil, fmt.Errorf("openai codex probe request failed: %w", err) } defer func() { _ = resp.Body.Close() }() - updates, resetAt, err := extractOpenAICodexProbeSnapshot(resp) + updates, err := extractOpenAICodexProbeUpdates(resp) if err != nil { - return nil, nil, err + return nil, err } - if len(updates) > 0 || resetAt != nil { - s.persistOpenAICodexProbeSnapshot(account.ID, updates, resetAt) - return updates, resetAt, nil + if len(updates) > 0 { + s.persistOpenAICodexProbeSnapshot(account.ID, updates) + return updates, nil } - return nil, nil, nil + return nil, nil } -func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any, resetAt *time.Time) { +func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any) { if s == nil || s.accountRepo == nil || accountID <= 0 { return } - if len(updates) == 0 && resetAt == nil { + if len(updates) == 0 { return } go func() { updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second) defer updateCancel() - if len(updates) > 0 { - _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) - } - if resetAt != nil { - _ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt) - } + _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) }() } -func extractOpenAICodexProbeSnapshot(resp *http.Response) (map[string]any, *time.Time, error) { +func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) { if resp == nil { - return nil, nil, nil + return nil, nil } if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { - baseTime := time.Now() - updates := buildCodexUsageExtraUpdates(snapshot, baseTime) - resetAt := codexRateLimitResetAtFromSnapshot(snapshot, baseTime) - if len(updates) > 0 { - return updates, resetAt, nil - } - return nil, resetAt, nil + return buildCodexUsageExtraUpdates(snapshot, time.Now()), nil } if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode) + return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode) } - return nil, nil, nil -} - -func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) { - updates, _, err := extractOpenAICodexProbeSnapshot(resp) - return updates, err + return nil, nil } func mergeAccountExtra(account *Account, updates map[string]any) { diff --git a/backend/internal/service/account_usage_service_test.go b/backend/internal/service/account_usage_service_test.go index fe255225..28b49838 100644 --- a/backend/internal/service/account_usage_service_test.go +++ b/backend/internal/service/account_usage_service_test.go @@ -92,30 +92,7 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T) } } -func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) { - t.Parallel() - - headers := make(http.Header) - headers.Set("x-codex-primary-used-percent", "100") - headers.Set("x-codex-primary-reset-after-seconds", "604800") - headers.Set("x-codex-primary-window-minutes", "10080") - headers.Set("x-codex-secondary-used-percent", "100") - headers.Set("x-codex-secondary-reset-after-seconds", "18000") - headers.Set("x-codex-secondary-window-minutes", "300") - - updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers}) - if err != nil { - t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err) - } - if len(updates) == 0 { - t.Fatal("expected codex probe updates from 429 headers") - } - if resetAt == nil { - t.Fatal("expected resetAt from exhausted codex headers") - } -} - -func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) { +func TestAccountUsageService_PersistOpenAICodexProbeSnapshotOnlyUpdatesExtra(t *testing.T) { t.Parallel() repo := &accountUsageCodexProbeRepo{ @@ -123,12 +100,10 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes rateLimitCh: make(chan time.Time, 1), } svc := &AccountUsageService{accountRepo: repo} - resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second) - svc.persistOpenAICodexProbeSnapshot(321, map[string]any{ "codex_7d_used_percent": 100.0, - "codex_7d_reset_at": resetAt.Format(time.RFC3339), - }, &resetAt) + "codex_7d_reset_at": time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second).Format(time.RFC3339), + }) select { case updates := <-repo.updateExtraCh: @@ -136,16 +111,49 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes t.Fatalf("codex_7d_used_percent = %v, want 100", got) } case <-time.After(2 * time.Second): - t.Fatal("waiting for codex probe extra persistence timed out") + t.Fatal("等待 codex 探测快照写入 extra 超时") } select { case got := <-repo.rateLimitCh: - if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) { - t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt) - } - case <-time.After(2 * time.Second): - t.Fatal("waiting for codex probe rate limit persistence timed out") + t.Fatalf("不应将探测快照写入运行时限流状态: %v", got) + case <-time.After(200 * time.Millisecond): + } +} + +func TestAccountUsageService_GetOpenAIUsage_DoesNotPromoteCodexExtraToRateLimit(t *testing.T) { + t.Parallel() + + resetAt := time.Now().Add(6 * 24 * time.Hour).UTC().Truncate(time.Second) + repo := &accountUsageCodexProbeRepo{ + rateLimitCh: make(chan time.Time, 1), + } + svc := &AccountUsageService{accountRepo: repo} + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "codex_5h_used_percent": 1.0, + "codex_5h_reset_at": time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second).Format(time.RFC3339), + "codex_7d_used_percent": 100.0, + "codex_7d_reset_at": resetAt.Format(time.RFC3339), + }, + } + + usage, err := svc.getOpenAIUsage(context.Background(), account) + if err != nil { + t.Fatalf("getOpenAIUsage() error = %v", err) + } + if usage.SevenDay == nil || usage.SevenDay.Utilization != 100.0 { + t.Fatalf("预期 7 天用量仍然可见,实际为 %#v", usage.SevenDay) + } + if account.RateLimitResetAt != nil { + t.Fatalf("不应让已耗尽的 codex extra 改写运行时限流状态: %v", account.RateLimitResetAt) + } + select { + case got := <-repo.rateLimitCh: + t.Fatalf("不应将已耗尽的 codex extra 持久化为运行时限流状态: %v", got) + case <-time.After(200 * time.Millisecond): } } diff --git a/backend/internal/service/account_websearch_test.go b/backend/internal/service/account_websearch_test.go new file mode 100644 index 00000000..6ed69d4c --- /dev/null +++ b/backend/internal/service/account_websearch_test.go @@ -0,0 +1,105 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetWebSearchEmulationMode_Enabled(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"}, + } + require.Equal(t, WebSearchModeEnabled, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_Disabled(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: "disabled"}, + } + require.Equal(t, WebSearchModeDisabled, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_Default(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: "default"}, + } + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_UnknownString(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: "unknown"}, + } + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_OldBoolTrue(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: true}, + } + // bool true → tolerant fallback → enabled (not default) + require.Equal(t, WebSearchModeEnabled, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_OldBoolFalse(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: false}, + } + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_NilAccount(t *testing.T) { + var a *Account + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_NilExtra(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: nil, + } + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_MissingField(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{}, + } + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_NonAnthropicPlatform(t *testing.T) { + a := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"}, + } + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) +} + +func TestGetWebSearchEmulationMode_NonAPIKeyType(t *testing.T) { + a := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"}, + } + require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode()) +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 97b42c24..7c26a47c 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -1470,10 +1470,6 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, if err != nil { return nil, 0, err } - now := time.Now() - for i := range accounts { - syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, &accounts[i], now) - } return accounts, result.Total, nil } diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index f9fd6742..419ddbc3 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -65,14 +65,14 @@ func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (boo func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { panic("unexpected") } -func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { - panic("unexpected") -} func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") } func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") } +func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { + panic("unexpected") +} // apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests. type apiKeyRepoStubForGroupUpdate struct { @@ -131,9 +131,6 @@ func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, str func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { panic("unexpected") } -func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) { - panic("unexpected") -} func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) { panic("unexpected") } @@ -158,6 +155,9 @@ func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, in func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) { panic("unexpected") } +func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) { + panic("unexpected") +} // groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests. type groupRepoStubForGroupUpdate struct { diff --git a/backend/internal/service/admin_service_clear_error_test.go b/backend/internal/service/admin_service_clear_error_test.go index f039612c..141466dc 100644 --- a/backend/internal/service/admin_service_clear_error_test.go +++ b/backend/internal/service/admin_service_clear_error_test.go @@ -12,12 +12,12 @@ import ( type accountRepoStubForClearAccountError struct { mockAccountRepoForGemini - account *Account - clearErrorCalls int - clearRateLimitCalls int - clearAntigravityCalls int + account *Account + clearErrorCalls int + clearRateLimitCalls int + clearAntigravityCalls int clearModelRateLimitCalls int - clearTempUnschedCalls int + clearTempUnschedCalls int } func (r *accountRepoStubForClearAccountError) GetByID(ctx context.Context, id int64) (*Account, error) { @@ -60,13 +60,13 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes resetAt := time.Now().Add(5 * time.Minute) repo := &accountRepoStubForClearAccountError{ account: &Account{ - ID: 31, - Platform: PlatformOpenAI, - Type: AccountTypeOAuth, - Status: StatusError, - ErrorMessage: "refresh failed", - RateLimitResetAt: &resetAt, - TempUnschedulableUntil: &until, + ID: 31, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusError, + ErrorMessage: "refresh failed", + RateLimitResetAt: &resetAt, + TempUnschedulableUntil: &until, TempUnschedulableReason: "missing refresh token", }, } diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index c2e96df1..b1660ea7 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -34,6 +34,15 @@ type APIKeyAuthUserSnapshot struct { Role string `json:"role"` Balance float64 `json:"balance"` Concurrency int `json:"concurrency"` + + // Balance notification fields (required for CheckBalanceAfterDeduction) + Email string `json:"email"` + Username string `json:"username"` + BalanceNotifyEnabled bool `json:"balance_notify_enabled"` + BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"` + BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"` + BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"` + TotalRecharged float64 `json:"total_recharged"` } // APIKeyAuthGroupSnapshot 分组快照 diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 8069ed4f..2bd9a091 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "errors" "fmt" + "log/slog" "math/rand/v2" "time" @@ -13,7 +14,7 @@ import ( "github.com/dgraph-io/ristretto" ) -const apiKeyAuthSnapshotVersion = 3 +const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold type apiKeyAuthCacheConfig struct { l1Size int @@ -99,7 +100,7 @@ func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context s.authCacheL1.Del(cacheKey) }); err != nil { // Log but don't fail - L1 cache will still work, just without cross-instance invalidation - println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error()) + slog.Warn("failed to start auth cache invalidation subscriber", "error", err) } } @@ -219,11 +220,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { RateLimit1d: apiKey.RateLimit1d, RateLimit7d: apiKey.RateLimit7d, User: APIKeyAuthUserSnapshot{ - ID: apiKey.User.ID, - Status: apiKey.User.Status, - Role: apiKey.User.Role, - Balance: apiKey.User.Balance, - Concurrency: apiKey.User.Concurrency, + ID: apiKey.User.ID, + Status: apiKey.User.Status, + Role: apiKey.User.Role, + Balance: apiKey.User.Balance, + Concurrency: apiKey.User.Concurrency, + Email: apiKey.User.Email, + Username: apiKey.User.Username, + BalanceNotifyEnabled: apiKey.User.BalanceNotifyEnabled, + BalanceNotifyThresholdType: apiKey.User.BalanceNotifyThresholdType, + BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold, + BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails, + TotalRecharged: apiKey.User.TotalRecharged, }, } if apiKey.Group != nil { @@ -274,11 +282,18 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho RateLimit1d: snapshot.RateLimit1d, RateLimit7d: snapshot.RateLimit7d, User: &User{ - ID: snapshot.User.ID, - Status: snapshot.User.Status, - Role: snapshot.User.Role, - Balance: snapshot.User.Balance, - Concurrency: snapshot.User.Concurrency, + ID: snapshot.User.ID, + Status: snapshot.User.Status, + Role: snapshot.User.Role, + Balance: snapshot.User.Balance, + Concurrency: snapshot.User.Concurrency, + Email: snapshot.User.Email, + Username: snapshot.User.Username, + BalanceNotifyEnabled: snapshot.User.BalanceNotifyEnabled, + BalanceNotifyThresholdType: snapshot.User.BalanceNotifyThresholdType, + BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold, + BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails, + TotalRecharged: snapshot.User.TotalRecharged, }, } if snapshot.Group != nil { diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index 7b50e90d..103bafe7 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -87,6 +87,18 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin return nil } +func (s *emailCacheStub) GetNotifyVerifyCode(ctx context.Context, email string) (*VerificationCodeData, error) { + return nil, nil +} + +func (s *emailCacheStub) SetNotifyVerifyCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error { + return nil +} + +func (s *emailCacheStub) DeleteNotifyVerifyCode(ctx context.Context, email string) error { + return nil +} + func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) { return nil, nil } @@ -107,6 +119,14 @@ func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, emai return nil } +func (s *emailCacheStub) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) { + return 0, nil +} + +func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) { + return 0, nil +} + func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService { cfg := &config.Config{ JWT: config.JWTConfig{ diff --git a/backend/internal/service/balance_notify_check_test.go b/backend/internal/service/balance_notify_check_test.go new file mode 100644 index 00000000..7bb4cf9e --- /dev/null +++ b/backend/internal/service/balance_notify_check_test.go @@ -0,0 +1,404 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +// newBalanceNotifyServiceForTest constructs a BalanceNotifyService with an +// in-memory settings repo and a non-nil emailService so that the guard-clause +// nil-checks pass. The emailService is intentionally minimal — tests must +// avoid crossing scenarios that would actually dispatch emails. +func newBalanceNotifyServiceForTest() (*BalanceNotifyService, *mockSettingRepo) { + repo := newMockSettingRepo() + // EmailService is a concrete type; construct with the same repo so that + // any accidental fallback reads still succeed. Tests should not trigger a + // crossing that reaches SendEmail. + email := NewEmailService(repo, nil) + return NewBalanceNotifyService(email, repo, nil), repo +} + +// ---------- guard clauses ---------- + +func TestCheckBalanceAfterDeduction_NilUser(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + // Should not panic. + s.CheckBalanceAfterDeduction(context.Background(), nil, 100, 50) +} + +func TestCheckBalanceAfterDeduction_UserNotifyDisabled(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "10" + u := &User{ID: 1, BalanceNotifyEnabled: false} + // Even with a crossing, disabled flag short-circuits. + s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15) +} + +func TestCheckBalanceAfterDeduction_GlobalDisabled(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "false" + u := &User{ID: 1, BalanceNotifyEnabled: true} + s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15) +} + +func TestCheckBalanceAfterDeduction_ThresholdZero(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "0" + u := &User{ID: 1, BalanceNotifyEnabled: true} + s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15) +} + +func TestCheckBalanceAfterDeduction_UserThresholdOverride(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "100" // global default + customThreshold := 5.0 + u := &User{ + ID: 1, + BalanceNotifyEnabled: true, + BalanceNotifyThreshold: &customThreshold, + } + // User's 5.0 threshold takes precedence over global 100. 20 -> 15 does not + // cross 5, so nothing fires (verified by absence of panic). + s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15) +} + +func TestCheckBalanceAfterDeduction_NoCrossingNotFired(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "10" + u := &User{ID: 1, BalanceNotifyEnabled: true} + + // 100 -> 95, both remain above threshold=10, no crossing. + s.CheckBalanceAfterDeduction(context.Background(), u, 100, 5) + // 5 -> 3, both already below threshold, no crossing (only fires on first + // cross from above-to-below). + s.CheckBalanceAfterDeduction(context.Background(), u, 5, 2) +} + +// ---------- nil-service guards on CheckAccountQuotaAfterIncrement ---------- + +func TestCheckAccountQuotaAfterIncrement_NilAccount(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + // Should not panic. + s.CheckAccountQuotaAfterIncrement(context.Background(), nil, 10, nil) +} + +func TestCheckAccountQuotaAfterIncrement_ZeroCost(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + a := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey} + s.CheckAccountQuotaAfterIncrement(context.Background(), a, 0, nil) +} + +func TestCheckAccountQuotaAfterIncrement_NegativeCost(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + a := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey} + s.CheckAccountQuotaAfterIncrement(context.Background(), a, -5, nil) +} + +func TestCheckAccountQuotaAfterIncrement_GlobalDisabled(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyAccountQuotaNotifyEnabled] = "false" + a := &Account{ + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "quota_notify_daily_enabled": true, + "quota_notify_daily_threshold": 100.0, + "quota_daily_limit": 1000.0, + "quota_daily_used": 950.0, + }, + } + // Global disabled → no processing even if a dim would cross. + s.CheckAccountQuotaAfterIncrement(context.Background(), a, 100, nil) +} + +// ---------- sanity: internal helpers still work ---------- + +func TestGetBalanceNotifyConfig_AllFields(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "12.5" + repo.data[SettingKeyBalanceLowNotifyRechargeURL] = "https://example.com/pay" + + enabled, threshold, url := s.getBalanceNotifyConfig(context.Background()) + require.True(t, enabled) + require.Equal(t, 12.5, threshold) + require.Equal(t, "https://example.com/pay", url) +} + +func TestGetBalanceNotifyConfig_Disabled(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "false" + + enabled, _, _ := s.getBalanceNotifyConfig(context.Background()) + require.False(t, enabled) +} + +func TestGetBalanceNotifyConfig_InvalidThreshold(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeyBalanceLowNotifyEnabled] = "true" + repo.data[SettingKeyBalanceLowNotifyThreshold] = "not-a-number" + + enabled, threshold, _ := s.getBalanceNotifyConfig(context.Background()) + require.True(t, enabled) + require.Equal(t, 0.0, threshold) +} + +func TestIsAccountQuotaNotifyEnabled(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + + // Missing key → false + require.False(t, s.isAccountQuotaNotifyEnabled(context.Background())) + + // Explicit "false" + repo.data[SettingKeyAccountQuotaNotifyEnabled] = "false" + require.False(t, s.isAccountQuotaNotifyEnabled(context.Background())) + + // Explicit "true" + repo.data[SettingKeyAccountQuotaNotifyEnabled] = "true" + require.True(t, s.isAccountQuotaNotifyEnabled(context.Background())) +} + +func TestGetSiteName_FallsBackToDefault(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + name := s.getSiteName(context.Background()) + require.Equal(t, defaultSiteName, name) +} + +func TestGetSiteName_Configured(t *testing.T) { + s, repo := newBalanceNotifyServiceForTest() + repo.data[SettingKeySiteName] = "My Site" + require.Equal(t, "My Site", s.getSiteName(context.Background())) +} + +// ---------- crossedDownward ---------- + +func TestCrossedDownward_CrossesBelow(t *testing.T) { + // oldBalance > threshold, newBalance < threshold → true + require.True(t, crossedDownward(100, 5, 10)) +} + +func TestCrossedDownward_ExactlyAtThreshold(t *testing.T) { + // oldBalance > threshold, newBalance == threshold → false (not below) + require.False(t, crossedDownward(100, 10, 10)) +} + +func TestCrossedDownward_OldExactlyAtThreshold_NewBelow(t *testing.T) { + // oldBalance == threshold, newBalance < threshold → true + // (at-or-above → below counts as a crossing) + require.True(t, crossedDownward(10, 5, 10)) +} + +func TestCrossedDownward_AlreadyBelow(t *testing.T) { + // oldBalance < threshold → false (already below, no new crossing) + require.False(t, crossedDownward(5, 3, 10)) +} + +func TestCrossedDownward_BothAbove(t *testing.T) { + // oldBalance > threshold, newBalance > threshold → false (no crossing) + require.False(t, crossedDownward(100, 50, 10)) +} + +func TestCrossedDownward_ZeroThreshold(t *testing.T) { + // threshold == 0 → oldV >= 0 is always true, but newV < 0 only for negatives + // Typical case: positive balances should not fire when threshold is 0. + require.False(t, crossedDownward(10, 5, 0)) + require.False(t, crossedDownward(0, 0, 0)) +} + +func TestCrossedDownward_ZeroThreshold_NegativeNew(t *testing.T) { + // Edge case: newBalance goes negative with threshold=0. + require.True(t, crossedDownward(5, -1, 0)) +} + +func TestCrossedDownward_NegativeValues(t *testing.T) { + // Both already negative, threshold is positive → no crossing (already below). + require.False(t, crossedDownward(-5, -10, 10)) +} + +func TestCrossedDownward_LargeDecrement(t *testing.T) { + // A single large deduction crosses the threshold. + require.True(t, crossedDownward(1000, 0.5, 100)) +} + +func TestCrossedDownward_SmallDecrement_NoCrossing(t *testing.T) { + // A tiny deduction stays above threshold. + require.False(t, crossedDownward(100, 99.99, 10)) +} + +// ---------- checkQuotaDimCrossings ---------- + +func TestCheckQuotaDimCrossings_NoDimensions(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // Empty dims → no crossing, no panic. + s.checkQuotaDimCrossings(account, nil, 10, []string{"admin@example.com"}, "TestSite") + s.checkQuotaDimCrossings(account, []quotaDim{}, 10, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_DisabledDimension(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: false, // disabled + threshold: 100, + thresholdType: thresholdTypeFixed, + currentUsed: 950, + limit: 1000, + }, + } + // Disabled dimension should be skipped even if crossing would occur. + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_ZeroThresholdSkipped(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 0, // zero threshold + thresholdType: thresholdTypeFixed, + currentUsed: 950, + limit: 1000, + }, + } + // Zero threshold → skipped. + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_NoCrossing_BothBelowThreshold(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger) + // currentUsed=300 (after), oldUsed=300-50=250 (before). Both < 600, no crossing. + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 400, + thresholdType: thresholdTypeFixed, + currentUsed: 300, + limit: 1000, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_NoCrossing_BothAboveThreshold(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger) + // currentUsed=800 (after), oldUsed=800-50=750 (before). Both >= 600, no crossing. + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 400, + thresholdType: thresholdTypeFixed, + currentUsed: 800, + limit: 1000, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_NegativeResolvedThreshold_Skipped(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // threshold=1200 remaining, limit=1000 → effectiveThreshold = 1000-1200 = -200 + // Negative resolved threshold → skipped. + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 1200, + thresholdType: thresholdTypeFixed, + currentUsed: 950, + limit: 1000, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_PercentageThreshold_NoCrossing(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // threshold=30%, limit=1000 → effectiveThreshold = 1000 * (1 - 0.30) = 700 + // currentUsed=500, oldUsed=500-50=450. Both < 700, no crossing. + dims := []quotaDim{ + { + name: quotaDimWeekly, + enabled: true, + threshold: 30, + thresholdType: thresholdTypePercentage, + currentUsed: 500, + limit: 1000, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_ZeroLimit_Skipped(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // limit=0 → resolvedThreshold returns 0 → skipped. + dims := []quotaDim{ + { + name: quotaDimTotal, + enabled: true, + threshold: 100, + thresholdType: thresholdTypeFixed, + currentUsed: 50, + limit: 0, + }, + } + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} + +func TestCheckQuotaDimCrossings_MultipleDims_MixedResults(t *testing.T) { + s, _ := newBalanceNotifyServiceForTest() + account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic} + // dim1: no crossing (both below effective threshold) + // dim2: disabled (skipped) + // dim3: zero threshold (skipped) + dims := []quotaDim{ + { + name: quotaDimDaily, + enabled: true, + threshold: 400, + thresholdType: thresholdTypeFixed, + currentUsed: 300, // oldUsed=250, effectiveThreshold=600, both below + limit: 1000, + }, + { + name: quotaDimWeekly, + enabled: false, + threshold: 100, + thresholdType: thresholdTypeFixed, + currentUsed: 900, + limit: 1000, + }, + { + name: quotaDimTotal, + enabled: true, + threshold: 0, + thresholdType: thresholdTypeFixed, + currentUsed: 500, + limit: 1000, + }, + } + // None should trigger. No panic expected. + s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite") +} diff --git a/backend/internal/service/balance_notify_email_body_test.go b/backend/internal/service/balance_notify_email_body_test.go new file mode 100644 index 00000000..aee5a5bc --- /dev/null +++ b/backend/internal/service/balance_notify_email_body_test.go @@ -0,0 +1,147 @@ +//go:build unit + +package service + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// These tests guard against fmt.Sprintf arg-count mismatches in the email +// templates. A mismatch would produce "%!(EXTRA ...)" or "%!v(MISSING)" in +// the output, which these assertions will catch. + +// ---------- buildBalanceLowEmailBody ---------- + +func TestBuildBalanceLowEmailBody_ContainsRequiredFields(t *testing.T) { + s := &BalanceNotifyService{} + body := s.buildBalanceLowEmailBody("Alice", 3.14, 10.0, "MySite", "") + + // All substituted values should appear in the output. + require.Contains(t, body, "MySite") + require.Contains(t, body, "Alice") + require.Contains(t, body, "$3.14") + require.Contains(t, body, "$10.00") + + // No fmt.Sprintf format error markers. + require.NotContains(t, body, "%!") + require.NotContains(t, body, "MISSING") + require.NotContains(t, body, "EXTRA") +} + +func TestBuildBalanceLowEmailBody_WithRechargeURL(t *testing.T) { + s := &BalanceNotifyService{} + body := s.buildBalanceLowEmailBody("Bob", 5.0, 20.0, "Site", "https://example.com/pay") + + // The recharge anchor element should appear with the URL. + require.Contains(t, body, `href="https://example.com/pay"`) + require.Contains(t, body, "立即充值") + require.NotContains(t, body, "%!") +} + +func TestBuildBalanceLowEmailBody_RechargeURLEscaped(t *testing.T) { + s := &BalanceNotifyService{} + // Try a URL with characters that need HTML escaping. + body := s.buildBalanceLowEmailBody("u", 1.0, 5.0, "Site", `https://example.com/?a=1&b= diff --git a/frontend/src/components/account/AccountTestModal.vue b/frontend/src/components/account/AccountTestModal.vue index 87b7ff0a..b0ce7e70 100644 --- a/frontend/src/components/account/AccountTestModal.vue +++ b/frontend/src/components/account/AccountTestModal.vue @@ -165,7 +165,6 @@ @@ -249,7 +248,7 @@ const availableModels = ref([]) const selectedModelId = ref('') const testPrompt = ref('') const loadingModels = ref(false) -let eventSource: EventSource | null = null +let abortController: AbortController | null = null const generatedImages = ref([]) const prioritizedGeminiModels = ['gemini-3.1-flash-image', 'gemini-2.5-flash-image', 'gemini-2.5-flash', 'gemini-2.5-pro', 'gemini-3-flash-preview', 'gemini-3-pro-preview', 'gemini-2.0-flash'] const supportsGeminiImageTest = computed(() => { @@ -279,7 +278,7 @@ watch( resetState() await loadAvailableModels() } else { - closeEventSource() + abortStream() } } ) @@ -329,18 +328,14 @@ const resetState = () => { } const handleClose = () => { - // 防止在连接测试进行中关闭对话框 - if (status.value === 'connecting') { - return - } - closeEventSource() + abortStream() emit('close') } -const closeEventSource = () => { - if (eventSource) { - eventSource.close() - eventSource = null +const abortStream = () => { + if (abortController) { + abortController.abort() + abortController = null } } @@ -365,7 +360,9 @@ const startTest = async () => { addLine(t('admin.accounts.testAccountTypeLabel', { type: props.account.type }), 'text-gray-400') addLine('', 'text-gray-300') - closeEventSource() + abortStream() + + abortController = new AbortController() try { // Create EventSource for SSE @@ -381,7 +378,8 @@ const startTest = async () => { body: JSON.stringify({ model_id: selectedModelId.value, prompt: supportsGeminiImageTest.value ? testPrompt.value.trim() : '' - }) + }), + signal: abortController.signal }) if (!response.ok) { @@ -418,10 +416,15 @@ const startTest = async () => { } } } - } catch (error: any) { + } catch (error: unknown) { + if (error instanceof DOMException && error.name === 'AbortError') { + status.value = 'idle' + return + } status.value = 'error' - errorMessage.value = error.message || 'Unknown error' - addLine(`Error: ${errorMessage.value}`, 'text-red-400') + const msg = error instanceof Error ? error.message : 'Unknown error' + errorMessage.value = msg + addLine(`Error: ${msg}`, 'text-red-400') } } diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue index 37e18c35..1c023fb3 100644 --- a/frontend/src/components/account/AccountUsageCell.vue +++ b/frontend/src/components/account/AccountUsageCell.vue @@ -1,5 +1,5 @@ diff --git a/frontend/src/composables/useQuotaNotifyState.ts b/frontend/src/composables/useQuotaNotifyState.ts new file mode 100644 index 00000000..25fad0e8 --- /dev/null +++ b/frontend/src/composables/useQuotaNotifyState.ts @@ -0,0 +1,69 @@ +import { reactive, ref } from 'vue' +import { adminAPI } from '@/api/admin' +import { QUOTA_THRESHOLD_TYPE_FIXED, type QuotaThresholdType } from '@/constants/account' + +export const QUOTA_NOTIFY_DIMS = ['daily', 'weekly', 'total'] as const +export type QuotaNotifyDim = (typeof QUOTA_NOTIFY_DIMS)[number] + +interface DimState { + enabled: boolean | null + threshold: number | null + thresholdType: QuotaThresholdType | null +} + +export function useQuotaNotifyState() { + const globalEnabled = ref(false) + const state = reactive>({ + daily: { enabled: null, threshold: null, thresholdType: null }, + weekly: { enabled: null, threshold: null, thresholdType: null }, + total: { enabled: null, threshold: null, thresholdType: null }, + }) + + function loadGlobalState() { + adminAPI.settings + .getSettings() + .then((settings) => { + globalEnabled.value = settings.account_quota_notify_enabled === true + }) + .catch(() => { + globalEnabled.value = false + }) + } + + function loadFromExtra(extra: Record | null | undefined) { + for (const d of QUOTA_NOTIFY_DIMS) { + state[d].enabled = (extra?.[`quota_notify_${d}_enabled`] as boolean) ?? null + state[d].threshold = (extra?.[`quota_notify_${d}_threshold`] as number) ?? null + state[d].thresholdType = (extra?.[`quota_notify_${d}_threshold_type`] as QuotaThresholdType) ?? null + } + } + + function writeToExtra(extra: Record, mode: 'create' | 'update') { + for (const d of QUOTA_NOTIFY_DIMS) { + const s = state[d] + if (s.enabled) { + extra[`quota_notify_${d}_enabled`] = true + if (s.threshold != null) { + extra[`quota_notify_${d}_threshold`] = s.threshold + } else if (mode === 'update') { + delete extra[`quota_notify_${d}_threshold`] + } + extra[`quota_notify_${d}_threshold_type`] = s.thresholdType || QUOTA_THRESHOLD_TYPE_FIXED + } else if (mode === 'update') { + delete extra[`quota_notify_${d}_enabled`] + delete extra[`quota_notify_${d}_threshold`] + delete extra[`quota_notify_${d}_threshold_type`] + } + } + } + + function reset() { + for (const d of QUOTA_NOTIFY_DIMS) { + state[d].enabled = null + state[d].threshold = null + state[d].thresholdType = null + } + } + + return { globalEnabled, state, loadGlobalState, loadFromExtra, writeToExtra, reset } +} diff --git a/frontend/src/composables/useTableSelection.ts b/frontend/src/composables/useTableSelection.ts index a65144a9..f0e096ff 100644 --- a/frontend/src/composables/useTableSelection.ts +++ b/frontend/src/composables/useTableSelection.ts @@ -76,6 +76,12 @@ export function useTableSelection({ rows, getId }: UseTableSelectionOptions) => void) => { + const draft = new Set(selectedSet.value) + updater(draft) + replaceSelectedSet(draft) + } + const selectVisible = () => { toggleVisible(true) } @@ -93,6 +99,7 @@ export function useTableSelection({ rows, getId }: UseTableSelectionOptions 0 if set', + quotaLimitMustBePositive: 'Quota limit must be greater than 0', + subscribedAt: 'Subscribed At', + subscribedAtHint: 'Quota resets monthly from this date; leave empty to disable auto-reset', + quotaUsage: 'Usage', + resetUsage: 'Reset', + resetUsageConfirm: 'Reset usage counter for this provider?', + resetUsageSuccess: 'Usage counter reset', + proxy: 'Proxy', + removeProvider: 'Remove', + noProviders: 'No search providers configured', + test: 'Test', + testDefaultQuery: 'Major world events this year', + testing: 'Searching...', + testResultTitle: 'Search Results', + testResultProvider: 'Provider', + testNoResults: 'No results found', + }, site: { title: 'Site Settings', description: 'Customize site branding', @@ -4467,6 +4568,12 @@ export default { minAmount: 'Minimum Amount', maxAmount: 'Maximum Amount', dailyLimit: 'Daily Limit', + balanceRechargeMultiplier: 'Balance Recharge Multiplier', + balanceRechargeMultiplierHint: 'How many USD balance the user receives for each 1 CNY paid', + balanceRechargePreview: 'Preview: 1 CNY = {usd} USD', + rechargeFeeRate: 'Recharge Fee Rate', + rechargeFeeRateHint: 'Percentage of service fee charged on top of recharge amount, 0 means no fee', + rechargeFeePreview: 'Preview: Recharge 100, fee {fee}', orderTimeout: 'Order Timeout', orderTimeoutHint: 'In minutes, minimum 1', maxPendingOrders: 'Max Pending Orders', @@ -4550,6 +4657,27 @@ export default { supportedTypes: 'Supported Payment Types', supportedTypesHint: 'Comma-separated, e.g. alipay,wxpay', refundEnabled: 'Allow Refund', + allowUserRefund: 'Allow User Refund', + }, + balanceNotify: { + title: 'Balance Low Notification', + description: 'Send email notification when user balance falls below threshold', + enabled: 'Enable Balance Low Notification', + threshold: 'Default Threshold', + thresholdHint: 'Used when user has not set a custom value', + thresholdPlaceholder: 'Enter amount', + rechargeUrl: 'Recharge Page URL', + rechargeUrlPlaceholder: 'https://example.com/payment', + rechargeUrlHint: 'A top-up button will appear in the email when set', + }, + quotaNotify: { + title: 'Account Quota Notification', + description: 'Notify admins when account quota usage reaches alert threshold', + enabled: 'Enable Account Quota Notification', + emails: 'Notification Emails', + emailsHint: 'Leave empty to disable notifications', + addEmail: 'Add Email', + emailPlaceholder: 'Enter email address', }, smtp: { title: 'SMTP Settings', @@ -5204,6 +5332,8 @@ export default { payment: { title: 'Recharge / Subscription', amountLabel: 'Amount', + paymentAmount: 'Payment Amount', + creditedBalance: 'Credited Balance', quickAmounts: 'Quick Amounts', customAmount: 'Custom Amount', enterAmount: 'Enter amount', @@ -5259,6 +5389,10 @@ export default { orderNo: 'Order No.', amount: 'Amount', payAmount: 'Paid', + creditedAmount: 'Credited Amount', + fee: 'Fee', + baseAmount: 'Base Amount', + includedInPayAmount: 'included in paid amount', status: 'Status', paymentMethod: 'Payment Method', createdAt: 'Created', @@ -5288,6 +5422,7 @@ export default { amountTooLow: 'Minimum amount is {min}', amountTooHigh: 'Maximum amount is {max}', amountNoMethod: 'No payment method available for this amount', + rechargeRatePreview: 'Current rate: 1 CNY = {usd} USD', refundReason: 'Refund Reason', refundReasonPlaceholder: 'Please describe your refund reason', stripeLoadFailed: 'Failed to load payment component. Please refresh and try again.', @@ -5370,6 +5505,7 @@ export default { refundSuccess: 'Refund successful', refundInfo: 'Refund Info', refundEnabled: 'Refund Enabled', + allowUserRefund: 'Allow User Refund', alreadyRefunded: 'Already Refunded', deductBalance: 'Deduct Balance', deductBalanceHint: 'Subtract recharged amount from user balance', @@ -5439,6 +5575,9 @@ export default { tabPlanConfig: 'Plan Configuration', tabUserSubs: 'User Subscriptions', selectGroup: 'Select a group', + groupRequired: 'Please select a subscription group', + priceRequired: 'Price must be greater than 0', + validityDaysRequired: 'Validity days must be greater than 0', groupMissing: 'Missing', groupInfo: 'Group Info', platform: 'Platform', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 2038970a..6f57ab3e 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -247,6 +247,8 @@ export default { loading: '加载中...', justNow: '刚刚', save: '保存', + saved: '保存成功', + deleted: '删除成功', cancel: '取消', delete: '删除', edit: '编辑', @@ -304,6 +306,7 @@ export default { saving: '保存中...', selectedCount: '(已选 {count} 个)', refresh: '刷新', + view: '查看', settings: '设置', chooseFile: '选择文件', notAvailable: '不可用', @@ -733,6 +736,7 @@ export default { totalCost: '总消费', standardCost: '标准', actualCost: '实际', + accountCost: '成本', userBilled: '用户扣费', accountBilled: '账号计费', accountMultiplier: '账号倍率', @@ -778,6 +782,8 @@ export default { inputTokenPrice: '输入单价', outputTokenPrice: '输出单价', perMillionTokens: '/ 1M Token', + unitPrice: '单次价格', + imageUnitPrice: '单张价格', cacheRead: '读取', cacheWrite: '写入', serviceTier: '服务档位', @@ -906,6 +912,38 @@ export default { sendCode: '发送验证码', codeSent: '验证码已发送到您的邮箱', sendCodeFailed: '发送验证码失败' + }, + balanceNotify: { + title: '余额不足提醒', + description: '当账户余额低于阈值时发送邮件提醒', + enabled: '启用余额不足提醒', + threshold: '自定义提醒阈值', + thresholdHint: '留空使用系统默认值', + thresholdPlaceholder: '输入金额', + systemDefault: '系统默认值', + extraEmails: '通知邮箱', + extraEmailsHint: '必须添加并验证邮箱后,余额不足时才能收到提醒邮件', + primaryEmail: '主邮箱', + noExtraEmails: '暂无额外通知邮箱', + enterEmail: '输入邮箱地址', + addEmail: '添加邮箱', + emailPlaceholder: '输入邮箱地址', + sendCode: '发送验证码', + resend: '重发', + codeSent: '验证码已发送', + codeSentTo: '验证码已发送到 {email}', + enterCode: '输入验证码', + codePlaceholder: '6位验证码', + verify: '验证', + emailAdded: '邮箱已添加', + emailRemoved: '邮箱已移除', + verifySuccess: '邮箱添加成功', + removeEmail: '移除', + removeSuccess: '邮箱已移除', + emailDuplicate: '该邮箱已存在', + maxEmailsReached: '已达到通知邮箱数量上限', + unverified: '未验证', + verified: '已验证', } }, @@ -989,6 +1027,7 @@ export default { totalCost: '总消费', actual: '实际', standard: '标准', + accountCost: '成本', todayTokens: '今日 Token', totalTokens: '总 Token', input: '输入', @@ -1915,12 +1954,28 @@ export default { defaultPerRequestPrice: '默认单次价格(未命中层级时使用)', defaultImagePrice: '默认图片价格(未命中层级时使用)', platformConfig: '平台配置', + webSearchEmulation: 'Web Search 模拟', + webSearchEmulationHint: '⚠️ 开启后该渠道下所有 Anthropic 分组的账号将自动拦截 web_search 请求,请谨慎操作', + webSearchEmulationGlobalDisabled: '请先在系统设置 → 网关 → Web Search 模拟中启用全局开关', basicSettings: '基础设置', addPlatform: '添加平台', noPlatforms: '点击"添加平台"开始配置渠道', mappingCount: '条映射', pricingEntry: '定价配置', - noModels: '未添加模型' + noModels: '未添加模型', + applyPricingToAccountStats: '应用模型定价到账号统计', + applyPricingToAccountStatsDesc: '启用后,未被自定义规则匹配的请求将使用模型定价文件中的标准价格计算账号统计费用', + accountStatsPricingRules: '自定义账号统计定价规则', + addRule: '添加规则', + noRulesConfigured: '未配置自定义规则,将使用上方的模型定价。', + ruleName: '规则名称(可选)', + ruleGroups: '分组', + ruleAccounts: '账号', + searchAccountPlaceholder: '搜索账号...', + ruleAccountsHint: '留空表示匹配所有账号', + ruleModelPricing: '模型定价', + noGroupsInChannel: '上方平台标签页中未选择分组', + unnamed: '未命名' } }, @@ -2212,6 +2267,12 @@ export default { }, quotaLimitAmount: '总限额', quotaLimitAmountHint: '累计消费上限,不会自动重置。', + quotaNotify: { + alert: '提醒阈值', + enabled: '启用告警', + threshold: '告警金额', + thresholdPlaceholder: '输入百分比', + }, testConnection: '测试连接', reAuthorize: '重新授权', refreshToken: '刷新令牌', @@ -2472,7 +2533,13 @@ export default { anthropic: { apiKeyPassthrough: '自动透传(仅替换认证)', apiKeyPassthroughDesc: - '仅对 Anthropic API Key 生效。开启后,messages/count_tokens 请求将透传上游并仅替换认证,保留计费/并发/审计及必要安全过滤;关闭即可回滚到现有兼容链路。' + '仅对 Anthropic API Key 生效。开启后,messages/count_tokens 请求将透传上游并仅替换认证,保留计费/并发/审计及必要安全过滤;关闭即可回滚到现有兼容链路。', + webSearchEmulation: 'Web Search 模拟', + webSearchEmulationDesc: + '为该 API Key 账号启用 web search 模拟。客户端发送纯 web_search 请求时,由网关调用第三方搜索 API 并构造响应返回。默认跟随渠道配置。', + webSearchDefault: '默认', + webSearchEnabled: '开启', + webSearchDisabled: '关闭', }, modelRestriction: '模型限制(可选)', modelWhitelist: '模型白名单', @@ -4520,6 +4587,40 @@ export default { cchSigning: 'CCH 签名', cchSigningHint: '对转发请求的 billing header 进行 CCH 哈希签名。关闭时保留原始占位符。', }, + webSearchEmulation: { + title: 'Web Search 模拟', + description: '为不原生支持搜索的 Anthropic API Key 账号注入 web search 能力', + enabled: '启用 Web Search 模拟', + enabledHint: '全局开关。关闭后所有渠道和账号的 web search 模拟均不生效。', + providers: '搜索服务商', + addProvider: '添加服务商', + providerType: '服务商类型', + apiKey: 'API Key', + apiKeyPlaceholder: '输入 API Key', + apiKeyConfigured: '已配置', + showApiKey: '显示', + hideApiKey: '隐藏', + copyApiKey: '复制', + copied: '已复制', + quotaLimit: '配额上限', + quotaLimitHint: '留空表示无限制;填写时必须大于 0', + quotaLimitMustBePositive: '配额上限必须大于 0', + subscribedAt: '订阅时间', + subscribedAtHint: '配额从此日期起每月自动重置;留空则不自动重置', + quotaUsage: '用量', + resetUsage: '重置', + resetUsageConfirm: '确定要重置此服务商的用量计数吗?', + resetUsageSuccess: '用量已重置', + proxy: '代理', + removeProvider: '删除', + noProviders: '未配置搜索服务商', + test: '测试', + testDefaultQuery: '搜索今年世界大事件', + testing: '搜索中...', + testResultTitle: '搜索结果', + testResultProvider: '服务商', + testNoResults: '无搜索结果', + }, site: { title: '站点设置', description: '自定义站点品牌', @@ -4627,10 +4728,16 @@ export default { enabledHint: '启用或禁用支付系统', enabledPaymentTypes: '启用的服务商', enabledPaymentTypesHint: '禁用服务商将同时禁用对应的实例。', - findProvider: '正在寻找合适的 EasyPay 服务商?', + findProvider: '正在寻找合适的易支付服务商?', minAmount: '最低金额', maxAmount: '最高金额', dailyLimit: '每日限额', + balanceRechargeMultiplier: '余额充值倍率', + balanceRechargeMultiplierHint: '用户每支付 1 CNY 可获得多少 USD 余额', + balanceRechargePreview: '预览:1 CNY = {usd} USD', + rechargeFeeRate: '充值手续费率', + rechargeFeeRateHint: '用户充值时额外收取的手续费百分比,0 表示不收取手续费', + rechargeFeePreview: '预览:充值 100 元,手续费 {fee} 元', orderTimeout: '订单超时时间', orderTimeoutHint: '单位:分钟,至少 1 分钟', maxPendingOrders: '最大待支付订单数', @@ -4714,6 +4821,27 @@ export default { supportedTypes: '支持的支付方式', supportedTypesHint: '逗号分隔,如 alipay,wxpay', refundEnabled: '允许退款', + allowUserRefund: '允许用户退款', + }, + balanceNotify: { + title: '余额不足提醒', + description: '当用户余额低于阈值时发送邮件提醒', + enabled: '启用余额不足提醒', + threshold: '默认提醒阈值', + thresholdHint: '用户未自定义时使用此值', + thresholdPlaceholder: '输入金额', + rechargeUrl: '充值页面 URL', + rechargeUrlPlaceholder: 'https://example.com/payment', + rechargeUrlHint: '设置后邮件中将包含充值链接按钮', + }, + quotaNotify: { + title: '账号限额通知', + description: '当账号配额用量达到告警阈值时通知管理员', + enabled: '启用账号限额通知', + emails: '通知邮箱', + emailsHint: '留空则不发送通知', + addEmail: '添加邮箱', + emailPlaceholder: '输入邮箱地址', }, smtp: { title: 'SMTP 设置', @@ -5392,6 +5520,8 @@ export default { payment: { title: '充值/订阅', amountLabel: '充值金额', + paymentAmount: '支付金额', + creditedBalance: '到账余额', quickAmounts: '快捷金额', customAmount: '自定义金额', enterAmount: '输入金额', @@ -5447,6 +5577,10 @@ export default { orderNo: '订单编号', amount: '金额', payAmount: '实付', + creditedAmount: '到账金额', + fee: '手续费', + baseAmount: '充值金额', + includedInPayAmount: '已含在实付金额中', status: '状态', paymentMethod: '支付方式', createdAt: '创建时间', @@ -5476,6 +5610,7 @@ export default { amountTooLow: '最低金额为 {min}', amountTooHigh: '最高金额为 {max}', amountNoMethod: '该金额没有可用的支付方式', + rechargeRatePreview: '当前倍率:1 CNY = {usd} USD', refundReason: '退款原因', refundReasonPlaceholder: '请描述您的退款原因', stripeLoadFailed: '支付组件加载失败,请刷新页面重试', @@ -5627,6 +5762,9 @@ export default { tabPlanConfig: '套餐配置', tabUserSubs: '用户订阅', selectGroup: '请选择分组', + groupRequired: '请选择订阅分组', + priceRequired: '价格必须大于 0', + validityDaysRequired: '有效期天数必须大于 0', groupMissing: '缺失', groupInfo: '分组信息', platform: '平台', diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts index 09e73621..1995383d 100644 --- a/frontend/src/stores/app.ts +++ b/frontend/src/stores/app.ts @@ -339,7 +339,10 @@ export const useAppStore = defineStore('app', () => { oidc_oauth_enabled: false, oidc_oauth_provider_name: 'OIDC', backend_mode_enabled: false, - version: siteVersion.value + version: siteVersion.value, + balance_low_notify_enabled: false, + account_quota_notify_enabled: false, + balance_low_notify_threshold: 0, } } diff --git a/frontend/src/stores/payment.ts b/frontend/src/stores/payment.ts index ce2d7b3a..fc21c1f9 100644 --- a/frontend/src/stores/payment.ts +++ b/frontend/src/stores/payment.ts @@ -66,7 +66,7 @@ export const usePaymentStore = defineStore('payment', () => { return response.data } - /** Poll order status by ID */ + /** Poll order status by ID (read-only, no upstream check) */ async function pollOrderStatus(orderId: number): Promise { try { const response = await paymentAPI.getOrder(orderId) diff --git a/frontend/src/style.css b/frontend/src/style.css index 59c6d182..acff4abc 100644 --- a/frontend/src/style.css +++ b/frontend/src/style.css @@ -529,7 +529,6 @@ .sidebar-header { @apply h-16 px-6; @apply flex items-center gap-3; - @apply overflow-hidden; @apply border-b border-gray-100 dark:border-dark-800; transition: padding 0.2s ease, diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 8b7e44c1..89fd777f 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -22,6 +22,16 @@ export interface FetchOptions { signal?: AbortSignal } +// ==================== Notification Types ==================== + +/** Notification email entry with enable/disable and verification state. + * email="" is a placeholder for the primary email (user's registration email or admin email). */ +export interface NotifyEmailEntry { + email: string + disabled: boolean + verified: boolean +} + // ==================== User & Auth Types ==================== export interface User { @@ -33,6 +43,9 @@ export interface User { concurrency: number // Allowed concurrent requests status: 'active' | 'disabled' // Account status allowed_groups: number[] | null // Allowed group IDs (null = all non-exclusive groups) + balance_notify_enabled: boolean + balance_notify_threshold: number | null + balance_notify_extra_emails: NotifyEmailEntry[] subscriptions?: UserSubscription[] // User's active subscriptions created_at: string updated_at: string @@ -114,6 +127,9 @@ export interface PublicSettings { oidc_oauth_provider_name: string backend_mode_enabled: boolean version: string + balance_low_notify_enabled: boolean + account_quota_notify_enabled: boolean + balance_low_notify_threshold: number } export interface AuthResponse { @@ -413,8 +429,6 @@ export interface AdminGroup extends Group { // MCP XML 协议注入(仅 antigravity 平台使用) mcp_xml_inject: boolean - // Claude usage 模拟开关(仅 anthropic 平台使用) - simulate_claude_max_enabled: boolean // 支持的模型系列(仅 antigravity 平台使用) supported_model_scopes?: string[] @@ -507,7 +521,6 @@ export interface CreateGroupRequest { fallback_group_id?: number | null fallback_group_id_on_invalid_request?: number | null mcp_xml_inject?: boolean - simulate_claude_max_enabled?: boolean supported_model_scopes?: string[] require_oauth_only?: boolean require_privacy_set?: boolean @@ -533,7 +546,6 @@ export interface UpdateGroupRequest { fallback_group_id?: number | null fallback_group_id_on_invalid_request?: number | null mcp_xml_inject?: boolean - simulate_claude_max_enabled?: boolean supported_model_scopes?: string[] require_oauth_only?: boolean require_privacy_set?: boolean @@ -675,6 +687,7 @@ export interface Account { // Extra fields including Codex usage and model-level rate limits (Antigravity smart retry) extra?: (CodexUsageSnapshot & { model_rate_limits?: Record + antigravity_credits_overages?: Record } & Record) proxy_id: number | null concurrency: number @@ -736,12 +749,6 @@ export interface Account { custom_base_url_enabled?: boolean | null custom_base_url?: string | null - // 客户端亲和调度(仅 Anthropic/Antigravity 平台有效) - // 启用后新会话会优先调度到客户端之前使用过的账号 - client_affinity_enabled?: boolean | null - affinity_client_count?: number | null - affinity_clients?: string[] | null - // API Key 账号配额限制 quota_limit?: number | null quota_used?: number | null @@ -1050,6 +1057,12 @@ export interface AdminUsageLog extends UsageLog { // 账号计费倍率(仅管理员可见) account_rate_multiplier?: number | null + // 自定义定价规则计算的账号统计费用(nil 时使用 total_cost * multiplier) + account_stats_cost?: number | null + + // 渠道 ID 和计费等级(仅管理员可见) + channel_id?: number | null + billing_tier?: string | null // 用户请求 IP(仅管理员可见) ip_address?: string | null @@ -1145,6 +1158,7 @@ export interface DashboardStats { total_tokens: number total_cost: number // 累计标准计费 total_actual_cost: number // 累计实际扣除 + total_account_cost: number // 累计账号成本 // 今日 Token 使用统计 today_requests: number @@ -1155,6 +1169,7 @@ export interface DashboardStats { today_tokens: number today_cost: number // 今日标准计费 today_actual_cost: number // 今日实际扣除 + today_account_cost: number // 今日账号成本 // 系统运行统计 average_duration_ms: number // 平均响应时间 @@ -1202,6 +1217,7 @@ export interface ModelStat { total_tokens: number cost: number // 标准计费 actual_cost: number // 实际扣除 + account_cost: number // 账号成本 } export interface EndpointStat { @@ -1219,6 +1235,7 @@ export interface GroupStat { total_tokens: number cost: number // 标准计费 actual_cost: number // 实际扣除 + account_cost: number // 账号成本 } export interface UserBreakdownItem { @@ -1228,6 +1245,7 @@ export interface UserBreakdownItem { total_tokens: number cost: number actual_cost: number + account_cost: number } export interface UserUsageTrendPoint { diff --git a/frontend/src/types/payment.ts b/frontend/src/types/payment.ts index ac209ad7..7ecbb9a9 100644 --- a/frontend/src/types/payment.ts +++ b/frontend/src/types/payment.ts @@ -32,6 +32,7 @@ export interface PaymentConfig { max_pending_orders: number order_timeout_minutes: number balance_disabled: boolean + balance_recharge_multiplier: number enabled_payment_types: PaymentType[] help_image_url: string help_text: string @@ -62,6 +63,8 @@ export interface CheckoutInfoResponse { global_max: number plans: SubscriptionPlan[] balance_disabled: boolean + balance_recharge_multiplier: number + recharge_fee_rate: number help_text: string help_image_url: string stripe_publishable_key: string @@ -89,6 +92,7 @@ export interface PaymentOrder { refund_requested_by?: number refund_request_reason?: string plan_id?: number + provider_instance_id?: string } // ==================== Plans & Channels ==================== @@ -138,6 +142,7 @@ export interface ProviderInstance { enabled: boolean payment_mode: string refund_enabled: boolean + allow_user_refund: boolean limits: string sort_order: number } @@ -153,10 +158,12 @@ export interface CreateOrderRequest { export interface CreateOrderResult { order_id: number + amount: number pay_url?: string qr_code?: string client_secret?: string pay_amount: number + fee_rate: number expires_at: string payment_mode?: string } diff --git a/frontend/src/utils/__tests__/usageLoadQueue.spec.ts b/frontend/src/utils/__tests__/usageLoadQueue.spec.ts new file mode 100644 index 00000000..e5509261 --- /dev/null +++ b/frontend/src/utils/__tests__/usageLoadQueue.spec.ts @@ -0,0 +1,205 @@ +import { describe, expect, it } from 'vitest' +import { enqueueUsageRequest } from '../usageLoadQueue' +import type { Account } from '@/types' + +/** Helper to create a minimal Account with proxy info */ +function makeAccount( + platform: string, + type: string = 'oauth', + proxy?: { host: string; port: number; username?: string | null } | null +): Account { + return { + id: Math.floor(Math.random() * 10000), + platform, + type, + name: 'test', + status: 'active', + proxy_id: proxy ? 1 : null, + proxy: proxy + ? { id: 1, name: 'p', protocol: 'http', host: proxy.host, port: proxy.port, username: proxy.username ?? null, status: 'active', created_at: '', updated_at: '' } + : undefined, + credentials: {}, + created_at: '', + updated_at: '' + } as unknown as Account +} + +describe('usageLoadQueue', () => { + // ─── Anthropic 账号:按代理出口排队 ─── + + it('Anthropic 同代理出口串行执行,间隔 >= 1s', async () => { + const timestamps: number[] = [] + const makeFn = () => async () => { + timestamps.push(Date.now()) + return 'ok' + } + + const acc = makeAccount('anthropic', 'oauth', { host: '1.2.3.4', port: 8080, username: 'u1' }) + + const p1 = enqueueUsageRequest(acc, makeFn()) + const p2 = enqueueUsageRequest(acc, makeFn()) + const p3 = enqueueUsageRequest(acc, makeFn()) + + await Promise.all([p1, p2, p3]) + + expect(timestamps).toHaveLength(3) + expect(timestamps[1] - timestamps[0]).toBeGreaterThanOrEqual(950) + expect(timestamps[1] - timestamps[0]).toBeLessThan(2100) + expect(timestamps[2] - timestamps[1]).toBeGreaterThanOrEqual(950) + expect(timestamps[2] - timestamps[1]).toBeLessThan(2100) + }) + + it('Anthropic 不同代理出口并行执行', async () => { + const timestamps: Record = {} + const makeTracked = (key: string) => async () => { + timestamps[key] = Date.now() + return key + } + + const acc1 = makeAccount('anthropic', 'oauth', { host: '1.2.3.4', port: 8080, username: 'u1' }) + const acc2 = makeAccount('anthropic', 'oauth', { host: '5.6.7.8', port: 3128, username: 'u2' }) + + const p1 = enqueueUsageRequest(acc1, makeTracked('proxy1')) + const p2 = enqueueUsageRequest(acc2, makeTracked('proxy2')) + + await Promise.all([p1, p2]) + + const spread = Math.abs(timestamps['proxy1'] - timestamps['proxy2']) + expect(spread).toBeLessThan(50) + }) + + it('Anthropic 相同代理连接信息的不同账号归为同一队列', async () => { + const timestamps: number[] = [] + const makeFn = () => async () => { + timestamps.push(Date.now()) + return 'ok' + } + + const acc1 = makeAccount('anthropic', 'oauth', { host: '10.0.0.1', port: 3128, username: 'admin' }) + const acc2 = makeAccount('anthropic', 'setup-token', { host: '10.0.0.1', port: 3128, username: 'admin' }) + + const p1 = enqueueUsageRequest(acc1, makeFn()) + const p2 = enqueueUsageRequest(acc2, makeFn()) + + await Promise.all([p1, p2]) + + expect(timestamps).toHaveLength(2) + expect(timestamps[1] - timestamps[0]).toBeGreaterThanOrEqual(950) + }) + + it('Anthropic 直连(无代理)的账号归为同一队列', async () => { + const order: number[] = [] + const makeFn = (n: number) => async () => { + order.push(n) + return n + } + + const acc1 = makeAccount('anthropic', 'oauth') + const acc2 = makeAccount('anthropic', 'setup-token') + + const p1 = enqueueUsageRequest(acc1, makeFn(1)) + const p2 = enqueueUsageRequest(acc2, makeFn(2)) + + await Promise.all([p1, p2]) + + expect(order).toEqual([1, 2]) + }) + + it('Anthropic 请求失败时 reject,后续任务继续执行', async () => { + const results: string[] = [] + const acc = makeAccount('anthropic', 'oauth', { host: '99.99.99.99', port: 1234 }) + + const p1 = enqueueUsageRequest(acc, async () => { + throw new Error('fail') + }) + const p2 = enqueueUsageRequest(acc, async () => { + results.push('second') + return 'ok' + }) + + await expect(p1).rejects.toThrow('fail') + await p2 + expect(results).toEqual(['second']) + }) + + // ─── 非 Anthropic 平台:直接执行,不排队 ─── + + it('非 Anthropic 平台直接执行,不排队', async () => { + const timestamps: number[] = [] + const makeFn = () => async () => { + timestamps.push(Date.now()) + return 'ok' + } + + // 同一代理的 Gemini 账号 — 应当并行,不排队 + const acc1 = makeAccount('gemini', 'oauth', { host: '1.2.3.4', port: 8080 }) + const acc2 = makeAccount('gemini', 'oauth', { host: '1.2.3.4', port: 8080 }) + + const p1 = enqueueUsageRequest(acc1, makeFn()) + const p2 = enqueueUsageRequest(acc2, makeFn()) + + await Promise.all([p1, p2]) + + expect(timestamps).toHaveLength(2) + // 并行执行,几乎同时完成 + expect(Math.abs(timestamps[1] - timestamps[0])).toBeLessThan(50) + }) + + it('OpenAI 平台直接执行,不排队', async () => { + const timestamps: number[] = [] + const makeFn = () => async () => { + timestamps.push(Date.now()) + return 'ok' + } + + const acc1 = makeAccount('openai', 'oauth', { host: '1.2.3.4', port: 8080 }) + const acc2 = makeAccount('openai', 'oauth', { host: '1.2.3.4', port: 8080 }) + + const p1 = enqueueUsageRequest(acc1, makeFn()) + const p2 = enqueueUsageRequest(acc2, makeFn()) + + await Promise.all([p1, p2]) + + expect(timestamps).toHaveLength(2) + expect(Math.abs(timestamps[1] - timestamps[0])).toBeLessThan(50) + }) + + // ─── Anthropic apikey 类型不排队 ─── + + it('Anthropic apikey 类型直接执行,不排队', async () => { + const timestamps: number[] = [] + const makeFn = () => async () => { + timestamps.push(Date.now()) + return 'ok' + } + + const acc1 = makeAccount('anthropic', 'apikey', { host: '1.2.3.4', port: 8080 }) + const acc2 = makeAccount('anthropic', 'apikey', { host: '1.2.3.4', port: 8080 }) + + const p1 = enqueueUsageRequest(acc1, makeFn()) + const p2 = enqueueUsageRequest(acc2, makeFn()) + + await Promise.all([p1, p2]) + + expect(timestamps).toHaveLength(2) + expect(Math.abs(timestamps[1] - timestamps[0])).toBeLessThan(50) + }) + + // ─── 返回值透传 ─── + + it('返回值正确透传', async () => { + const acc = makeAccount('anthropic', 'oauth') + const result = await enqueueUsageRequest(acc, async () => { + return { usage: 42 } + }) + expect(result).toEqual({ usage: 42 }) + }) + + it('非 Anthropic 返回值正确透传', async () => { + const acc = makeAccount('gemini', 'oauth') + const result = await enqueueUsageRequest(acc, async () => { + return { quota: 100 } + }) + expect(result).toEqual({ quota: 100 }) + }) +}) diff --git a/frontend/src/utils/billingMode.ts b/frontend/src/utils/billingMode.ts new file mode 100644 index 00000000..152dadc4 --- /dev/null +++ b/frontend/src/utils/billingMode.ts @@ -0,0 +1,19 @@ +export const BILLING_MODE_TOKEN = 'token' +export const BILLING_MODE_PER_REQUEST = 'per_request' +export const BILLING_MODE_IMAGE = 'image' + +export function getBillingModeLabel(mode: string | null | undefined, t: (key: string) => string): string { + switch (mode) { + case BILLING_MODE_PER_REQUEST: return t('admin.usage.billingModePerRequest') + case BILLING_MODE_IMAGE: return t('admin.usage.billingModeImage') + default: return t('admin.usage.billingModeToken') + } +} + +export function getBillingModeBadgeClass(mode: string | null | undefined): string { + switch (mode) { + case BILLING_MODE_PER_REQUEST: return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-300' + case BILLING_MODE_IMAGE: return 'bg-pink-100 text-pink-700 dark:bg-pink-900/30 dark:text-pink-300' + default: return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-300' + } +} diff --git a/frontend/src/utils/usageLoadQueue.ts b/frontend/src/utils/usageLoadQueue.ts new file mode 100644 index 00000000..7bea5679 --- /dev/null +++ b/frontend/src/utils/usageLoadQueue.ts @@ -0,0 +1,93 @@ +/** + * Usage request scheduler — throttles Anthropic API calls by proxy exit. + * + * Anthropic OAuth/setup-token accounts sharing the same proxy exit are placed + * into a serial queue with a random 1–2s delay between requests, preventing + * upstream 429 rate-limit errors. + * + * Proxy identity = host:port:username — two proxy records pointing to the + * same exit share a single queue. Accounts without a proxy go into a + * "direct" queue. + * + * All other platforms bypass the queue and execute immediately. + */ + +import type { Account } from '@/types' + +const GROUP_DELAY_MIN_MS = 1000 +const GROUP_DELAY_MAX_MS = 2000 + +type Task = { + fn: () => Promise + resolve: (value: T) => void + reject: (reason: unknown) => void +} + +const queues = new Map[]>() +const running = new Set() + +/** Whether this account needs throttled queuing. */ +function needsThrottle(account: Account): boolean { + return ( + account.platform === 'anthropic' && + (account.type === 'oauth' || account.type === 'setup-token') + ) +} + +/** Build a queue key from proxy connection details. */ +function buildGroupKey(account: Account): string { + const proxy = account.proxy + const proxyIdentity = proxy + ? `${proxy.host}:${proxy.port}:${proxy.username || ''}` + : 'direct' + return `anthropic:${proxyIdentity}` +} + +async function drain(groupKey: string) { + if (running.has(groupKey)) return + running.add(groupKey) + + const queue = queues.get(groupKey) + while (queue && queue.length > 0) { + const task = queue.shift()! + try { + const result = await task.fn() + task.resolve(result) + } catch (err) { + task.reject(err) + } + if (queue.length > 0) { + const jitter = GROUP_DELAY_MIN_MS + Math.random() * (GROUP_DELAY_MAX_MS - GROUP_DELAY_MIN_MS) + await new Promise((r) => setTimeout(r, jitter)) + } + } + + running.delete(groupKey) + queues.delete(groupKey) +} + +/** + * Schedule a usage fetch. Anthropic accounts are queued by proxy exit; + * all other platforms execute immediately. + */ +export function enqueueUsageRequest( + account: Account, + fn: () => Promise +): Promise { + // Non-Anthropic → fire immediately, no queuing + if (!needsThrottle(account)) { + return fn() + } + + const key = buildGroupKey(account) + + return new Promise((resolve, reject) => { + let queue = queues.get(key) + if (!queue) { + queue = [] + queues.set(key, queue) + } + queue.push({ fn, resolve, reject } as Task) + drain(key) + }) +} diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index d7fae112..4fec956b 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -144,6 +144,7 @@
diff --git a/frontend/src/views/user/PaymentQRCodeView.vue b/frontend/src/views/user/PaymentQRCodeView.vue index 50173a1a..0965947a 100644 --- a/frontend/src/views/user/PaymentQRCodeView.vue +++ b/frontend/src/views/user/PaymentQRCodeView.vue @@ -94,12 +94,12 @@ async function renderQR() { await nextTick() if (!qrCanvas.value || !qrUrl.value) return - // Use high error correction to support logo overlay + // Use medium error correction to support logo overlay while keeping QR code scannable const logoSrc = getLogoForType() await QRCode.toCanvas(qrCanvas.value, qrUrl.value, { width: 256, margin: 2, - errorCorrectionLevel: logoSrc ? 'H' : 'M', + errorCorrectionLevel: logoSrc ? 'M' : 'L', }) if (!logoSrc) return diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue index bc16918c..6431ddf6 100644 --- a/frontend/src/views/user/PaymentResultView.vue +++ b/frontend/src/views/user/PaymentResultView.vue @@ -37,8 +37,20 @@ {{ order.out_trade_no }}
- {{ t('payment.orders.amount') }} - ¥{{ order.pay_amount.toFixed(2) }} + {{ t('payment.orders.baseAmount') }} + ¥{{ baseAmount.toFixed(2) }} +
+
+ {{ t('payment.orders.fee') }} ({{ order.fee_rate }}%) + ¥{{ feeAmount.toFixed(2) }} +
+
+ {{ t('payment.orders.payAmount') }} + ¥{{ order.pay_amount.toFixed(2) }} +
+
+ {{ t('payment.orders.creditedAmount') }} + {{ order.order_type === 'balance' ? '$' : '¥' }}{{ order.amount.toFixed(2) }}
{{ t('payment.orders.paymentMethod') }} @@ -58,7 +70,7 @@ {{ returnInfo.outTradeNo }}
- {{ t('payment.orders.amount') }} + {{ t('payment.orders.payAmount') }} ¥{{ returnInfo.money }}
@@ -104,6 +116,18 @@ const returnInfo = ref(null) const SUCCESS_STATUSES = new Set(['COMPLETED', 'PAID', 'RECHARGING']) +/** 充值金额 = pay_amount / (1 + fee_rate/100),fee_rate=0 时等于 pay_amount */ +const baseAmount = computed(() => { + if (!order.value || order.value.fee_rate <= 0) return order.value?.pay_amount ?? 0 + return Math.round((order.value.pay_amount / (1 + order.value.fee_rate / 100)) * 100) / 100 +}) + +/** 手续费 = pay_amount - baseAmount */ +const feeAmount = computed(() => { + if (!order.value || order.value.fee_rate <= 0) return 0 + return Math.round((order.value.pay_amount - baseAmount.value) * 100) / 100 +}) + const isSuccess = computed(() => { // Always prioritize actual order status from backend if (order.value) { diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue index 5a958097..e91df5da 100644 --- a/frontend/src/views/user/PaymentView.vue +++ b/frontend/src/views/user/PaymentView.vue @@ -28,7 +28,9 @@