chore: merge upstream v0.1.113, keep Antigravity customizations
This commit is contained in:
commit
b6e1c64c25
2
.github/workflows/backend-ci.yml
vendored
2
.github/workflows/backend-ci.yml
vendored
@ -17,6 +17,7 @@ jobs:
|
|||||||
go-version-file: backend/go.mod
|
go-version-file: backend/go.mod
|
||||||
check-latest: false
|
check-latest: false
|
||||||
cache: true
|
cache: true
|
||||||
|
cache-dependency-path: backend/go.sum
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.26.2'
|
go version | grep -q 'go1.26.2'
|
||||||
@ -36,6 +37,7 @@ jobs:
|
|||||||
go-version-file: backend/go.mod
|
go-version-file: backend/go.mod
|
||||||
check-latest: false
|
check-latest: false
|
||||||
cache: true
|
cache: true
|
||||||
|
cache-dependency-path: backend/go.sum
|
||||||
- name: Verify Go version
|
- name: Verify Go version
|
||||||
run: |
|
run: |
|
||||||
go version | grep -q 'go1.26.2'
|
go version | grep -q 'go1.26.2'
|
||||||
|
|||||||
10
README.md
10
README.md
@ -86,6 +86,16 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
|||||||
<td>Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for sub2api users: register via <a href="https://www.aicodemirror.com/register?invitecode=KMVZQM">this link</a> to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off!</td>
|
<td>Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for sub2api users: register via <a href="https://www.aicodemirror.com/register?invitecode=KMVZQM">this link</a> to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off!</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
|
<tr>
|
||||||
|
<td width="180"><a href="https://aigocode.com/invite/SUB2API"><img src="assets/partners/logos/aigocode.png" alt="AIGoCode" width="150"></a></td>
|
||||||
|
<td>Thanks to AIGoCode for sponsoring this project! AIGoCode is an all-in-one platform that integrates Claude Code, Codex, and the latest Gemini models, providing you with stable, efficient, and highly cost-effective AI coding services. The platform offers flexible subscription plans, zero risk of account suspension, direct access with no VPN required, and lightning-fast responses. AIGoCode has prepared a special benefit for sub2api users: if you register via <a href="https://aigocode.com/invite/SUB2API">this link</a>, you'll receive an extra 10% bonus credit on your first top-up!</td>
|
||||||
|
</tr>
|
||||||
|
|
||||||
|
<tr>
|
||||||
|
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="assets/partners/logos/bmoplus.jpg" alt="bmoplus" width="150"></a></td>
|
||||||
|
<td>Huge thanks to BmoPlus for sponsoring this project! BmoPlus is a highly reliable AI account provider built strictly for heavy AI users and developers. They offer rock-solid, ready-to-use accounts and official top-up services for ChatGPT Plus / ChatGPT Pro (Full Warranty) / Claude Pro / Super Grok / Gemini Pro. By registering and ordering through <a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus - Premium AI Accounts & Top-ups</a>, users can unlock the mind-blowing rate of 10% of the official GPT subscription price (90% OFF)</td>
|
||||||
|
</tr>
|
||||||
|
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
## Ecosystem
|
## Ecosystem
|
||||||
|
|||||||
10
README_CN.md
10
README_CN.md
@ -85,6 +85,16 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
|
|||||||
<td>感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定性中转服务,企业级并发、快速开票、7×24 小时专属技术支持。Claude Code / Codex / Gemini 官方通道低至原价 38% / 2% / 9%,充值更享额外折扣!AICodeMirror 为 sub2api 用户提供专属福利:通过<a href="https://www.aicodemirror.com/register?invitecode=KMVZQM">此链接</a>注册,首次充值立享 8 折优惠,企业客户最高可享 75 折!</td>
|
<td>感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定性中转服务,企业级并发、快速开票、7×24 小时专属技术支持。Claude Code / Codex / Gemini 官方通道低至原价 38% / 2% / 9%,充值更享额外折扣!AICodeMirror 为 sub2api 用户提供专属福利:通过<a href="https://www.aicodemirror.com/register?invitecode=KMVZQM">此链接</a>注册,首次充值立享 8 折优惠,企业客户最高可享 75 折!</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
|
<tr>
|
||||||
|
<td width="180"><a href="https://aigocode.com/invite/SUB2API"><img src="assets/partners/logos/aigocode.png" alt="AIGoCode" width="150"></a></td>
|
||||||
|
<td>感谢 AIGoCode 赞助了本项目!AIGoCode 是一站式集成 Claude Code、Codex 以及最新 Gemini 模型的综合平台,为您提供稳定、高效、高性价比的 AI 编程服务。平台提供灵活的订阅方案,零封号风险,免 VPN 直连,响应极速。AIGoCode 为 sub2api 用户准备了专属福利:通过<a href="https://aigocode.com/invite/SUB2API">此链接</a>注册,首次充值可额外获得 10% 赠送额度!</td>
|
||||||
|
</tr>
|
||||||
|
|
||||||
|
<tr>
|
||||||
|
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="assets/partners/logos/bmoplus.jpg" alt="bmoplus" width="150"></a></td>
|
||||||
|
<td>感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AI成品号专卖/代充</a>注册下单的用户,可享GPT 官网订阅一折 的震撼价格!</td>
|
||||||
|
</tr>
|
||||||
|
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
## 生态项目
|
## 生态项目
|
||||||
|
|||||||
10
README_JA.md
10
README_JA.md
@ -85,6 +85,16 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
|
|||||||
<td>AICodeMirror のご支援に感謝します!AICodeMirror は Claude Code / Codex / Gemini CLI の公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時実行、迅速な請求書発行、24時間年中無休の専属テクニカルサポートを備えています。Claude Code / Codex / Gemini の公式チャネルを定価の 38% / 2% / 9% で利用可能、チャージ時にはさらに追加割引!AICodeMirror は sub2api ユーザー向けに特別特典を提供中:<a href="https://www.aicodemirror.com/register?invitecode=KMVZQM">こちらのリンク</a>から登録すると、初回チャージが 20% オフ、法人のお客様は最大 25% オフ!</td>
|
<td>AICodeMirror のご支援に感謝します!AICodeMirror は Claude Code / Codex / Gemini CLI の公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時実行、迅速な請求書発行、24時間年中無休の専属テクニカルサポートを備えています。Claude Code / Codex / Gemini の公式チャネルを定価の 38% / 2% / 9% で利用可能、チャージ時にはさらに追加割引!AICodeMirror は sub2api ユーザー向けに特別特典を提供中:<a href="https://www.aicodemirror.com/register?invitecode=KMVZQM">こちらのリンク</a>から登録すると、初回チャージが 20% オフ、法人のお客様は最大 25% オフ!</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
|
<tr>
|
||||||
|
<td width="180"><a href="https://aigocode.com/invite/SUB2API"><img src="assets/partners/logos/aigocode.png" alt="AIGoCode" width="150"></a></td>
|
||||||
|
<td>AIGoCode のご支援に感謝します!AIGoCode は Claude Code、Codex、最新の Gemini モデルを統合したオールインワンプラットフォームで、安定的かつ効率的でコストパフォーマンスに優れた AI コーディングサービスを提供します。柔軟なサブスクリプションプラン、アカウント停止リスクゼロ、VPN 不要の直接アクセス、超高速レスポンスが特長です。AIGoCode は sub2api ユーザー向けに特別特典を用意しています:<a href="https://aigocode.com/invite/SUB2API">こちらのリンク</a>から登録すると、初回チャージ時に 10% のボーナスクレジットを追加プレゼント!</td>
|
||||||
|
</tr>
|
||||||
|
|
||||||
|
<tr>
|
||||||
|
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="assets/partners/logos/bmoplus.jpg" alt="bmoplus" width="150"></a></td>
|
||||||
|
<td>本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらの<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AIアカウント専門店/代行チャージ</a>経由でご登録・ご注文いただいたユーザー様は、GPTを 公式サイト価格の約1割(90% OFF) という驚異的な価格でご利用いただけます!</td>
|
||||||
|
</tr>
|
||||||
|
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
## エコシステム
|
## エコシステム
|
||||||
|
|||||||
BIN
assets/partners/logos/aigocode.png
Normal file
BIN
assets/partners/logos/aigocode.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 38 KiB |
BIN
assets/partners/logos/bmoplus.jpg
Normal file
BIN
assets/partners/logos/bmoplus.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.8 KiB |
@ -1 +1 @@
|
|||||||
0.1.111
|
0.1.113
|
||||||
|
|||||||
@ -36,15 +36,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
// Business layer ProviderSets
|
// Business layer ProviderSets
|
||||||
repository.ProviderSet,
|
repository.ProviderSet,
|
||||||
service.ProviderSet,
|
service.ProviderSet,
|
||||||
|
payment.ProviderSet,
|
||||||
middleware.ProviderSet,
|
middleware.ProviderSet,
|
||||||
handler.ProviderSet,
|
handler.ProviderSet,
|
||||||
|
|
||||||
// Server layer ProviderSet
|
// Server layer ProviderSet
|
||||||
server.ProviderSet,
|
server.ProviderSet,
|
||||||
|
|
||||||
// Payment providers
|
|
||||||
payment.ProviderSet,
|
|
||||||
|
|
||||||
// Privacy client factory for OpenAI training opt-out
|
// Privacy client factory for OpenAI training opt-out
|
||||||
providePrivacyClientFactory,
|
providePrivacyClientFactory,
|
||||||
|
|
||||||
|
|||||||
@ -50,7 +50,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
|
refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
|
||||||
settingRepository := repository.NewSettingRepository(client)
|
settingRepository := repository.NewSettingRepository(client)
|
||||||
groupRepository := repository.NewGroupRepository(client, db)
|
groupRepository := repository.NewGroupRepository(client, db)
|
||||||
settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig)
|
proxyRepository := repository.NewProxyRepository(client, db)
|
||||||
|
settingService := service.ProvideSettingService(settingRepository, groupRepository, proxyRepository, configConfig)
|
||||||
emailCache := repository.NewEmailCache(redisClient)
|
emailCache := repository.NewEmailCache(redisClient)
|
||||||
emailService := service.NewEmailService(settingRepository, emailCache)
|
emailService := service.NewEmailService(settingRepository, emailCache)
|
||||||
turnstileVerifier := repository.NewTurnstileVerifier()
|
turnstileVerifier := repository.NewTurnstileVerifier()
|
||||||
@ -68,7 +69,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
|
||||||
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
|
||||||
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
|
userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache)
|
||||||
redeemCache := repository.NewRedeemCache(redisClient)
|
redeemCache := repository.NewRedeemCache(redisClient)
|
||||||
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
|
||||||
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
|
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
|
||||||
@ -78,7 +79,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
totpCache := repository.NewTotpCache(redisClient)
|
totpCache := repository.NewTotpCache(redisClient)
|
||||||
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
|
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
|
||||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
|
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
|
||||||
userHandler := handler.NewUserHandler(userService)
|
userHandler := handler.NewUserHandler(userService, emailService, emailCache)
|
||||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
|
||||||
@ -100,7 +101,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
|
||||||
schedulerCache := repository.ProvideSchedulerCache(redisClient, configConfig)
|
schedulerCache := repository.ProvideSchedulerCache(redisClient, configConfig)
|
||||||
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
|
accountRepository := repository.NewAccountRepository(client, db, schedulerCache)
|
||||||
proxyRepository := repository.NewProxyRepository(client, db)
|
|
||||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||||
privacyClientFactory := providePrivacyClientFactory()
|
privacyClientFactory := providePrivacyClientFactory()
|
||||||
@ -136,7 +136,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
||||||
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
||||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
||||||
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
oAuthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
||||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||||
@ -176,21 +176,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
channelRepository := repository.NewChannelRepository(db)
|
channelRepository := repository.NewChannelRepository(db)
|
||||||
channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
|
channelService := service.NewChannelService(channelRepository, apiKeyAuthCacheInvalidator)
|
||||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver)
|
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository)
|
||||||
|
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService)
|
||||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService)
|
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService)
|
||||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||||
encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
|
|
||||||
registry := payment.ProvideRegistry()
|
|
||||||
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
|
|
||||||
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
|
|
||||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
|
|
||||||
opsHandler := admin.NewOpsHandler(opsService)
|
opsHandler := admin.NewOpsHandler(opsService)
|
||||||
updateCache := repository.NewUpdateCache(redisClient)
|
updateCache := repository.NewUpdateCache(redisClient)
|
||||||
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
|
||||||
@ -218,6 +210,16 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||||
channelHandler := admin.NewChannelHandler(channelService, billingService)
|
channelHandler := admin.NewChannelHandler(channelService, billingService)
|
||||||
|
registry := payment.ProvideRegistry()
|
||||||
|
encryptionKey, err := payment.ProvideEncryptionKey(configConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
|
||||||
|
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
|
||||||
|
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
|
||||||
|
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
|
||||||
|
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
||||||
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
|
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
|
||||||
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler)
|
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler)
|
||||||
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
|
||||||
@ -235,8 +237,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService)
|
||||||
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService)
|
||||||
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig)
|
||||||
languageServerService := service.ProvideLanguageServerService(httpUpstream)
|
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient)
|
||||||
engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, redisClient, languageServerService)
|
|
||||||
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
httpServer := server.ProvideHTTPServer(configConfig, engine)
|
||||||
opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig)
|
opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig)
|
||||||
opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig)
|
opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig)
|
||||||
@ -247,7 +248,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
|||||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository)
|
||||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||||
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
|
||||||
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService)
|
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService)
|
||||||
application := &Application{
|
application := &Application{
|
||||||
Server: httpServer,
|
Server: httpServer,
|
||||||
|
|||||||
@ -616,6 +616,7 @@ var (
|
|||||||
{Name: "sort_order", Type: field.TypeInt, Default: 0},
|
{Name: "sort_order", Type: field.TypeInt, Default: 0},
|
||||||
{Name: "limits", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
|
{Name: "limits", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}},
|
||||||
{Name: "refund_enabled", Type: field.TypeBool, Default: false},
|
{Name: "refund_enabled", Type: field.TypeBool, Default: false},
|
||||||
|
{Name: "allow_user_refund", Type: field.TypeBool, Default: false},
|
||||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||||
}
|
}
|
||||||
@ -1078,6 +1079,11 @@ var (
|
|||||||
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
{Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
||||||
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
|
{Name: "totp_enabled", Type: field.TypeBool, Default: false},
|
||||||
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
|
{Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true},
|
||||||
|
{Name: "balance_notify_enabled", Type: field.TypeBool, Default: true},
|
||||||
|
{Name: "balance_notify_threshold_type", Type: field.TypeString, Default: "fixed"},
|
||||||
|
{Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
|
{Name: "balance_notify_extra_emails", Type: field.TypeString, Default: "[]", SchemaType: map[string]string{"postgres": "text"}},
|
||||||
|
{Name: "total_recharged", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
|
||||||
}
|
}
|
||||||
// UsersTable holds the schema information for the "users" table.
|
// UsersTable holds the schema information for the "users" table.
|
||||||
UsersTable = &schema.Table{
|
UsersTable = &schema.Table{
|
||||||
|
|||||||
@ -15642,25 +15642,26 @@ func (m *PaymentOrderMutation) ResetEdge(name string) error {
|
|||||||
// PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph.
|
// PaymentProviderInstanceMutation represents an operation that mutates the PaymentProviderInstance nodes in the graph.
|
||||||
type PaymentProviderInstanceMutation struct {
|
type PaymentProviderInstanceMutation struct {
|
||||||
config
|
config
|
||||||
op Op
|
op Op
|
||||||
typ string
|
typ string
|
||||||
id *int64
|
id *int64
|
||||||
provider_key *string
|
provider_key *string
|
||||||
name *string
|
name *string
|
||||||
_config *string
|
_config *string
|
||||||
supported_types *string
|
supported_types *string
|
||||||
enabled *bool
|
enabled *bool
|
||||||
payment_mode *string
|
payment_mode *string
|
||||||
sort_order *int
|
sort_order *int
|
||||||
addsort_order *int
|
addsort_order *int
|
||||||
limits *string
|
limits *string
|
||||||
refund_enabled *bool
|
refund_enabled *bool
|
||||||
created_at *time.Time
|
allow_user_refund *bool
|
||||||
updated_at *time.Time
|
created_at *time.Time
|
||||||
clearedFields map[string]struct{}
|
updated_at *time.Time
|
||||||
done bool
|
clearedFields map[string]struct{}
|
||||||
oldValue func(context.Context) (*PaymentProviderInstance, error)
|
done bool
|
||||||
predicates []predicate.PaymentProviderInstance
|
oldValue func(context.Context) (*PaymentProviderInstance, error)
|
||||||
|
predicates []predicate.PaymentProviderInstance
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil)
|
var _ ent.Mutation = (*PaymentProviderInstanceMutation)(nil)
|
||||||
@ -16105,6 +16106,42 @@ func (m *PaymentProviderInstanceMutation) ResetRefundEnabled() {
|
|||||||
m.refund_enabled = nil
|
m.refund_enabled = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAllowUserRefund sets the "allow_user_refund" field.
|
||||||
|
func (m *PaymentProviderInstanceMutation) SetAllowUserRefund(b bool) {
|
||||||
|
m.allow_user_refund = &b
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowUserRefund returns the value of the "allow_user_refund" field in the mutation.
|
||||||
|
func (m *PaymentProviderInstanceMutation) AllowUserRefund() (r bool, exists bool) {
|
||||||
|
v := m.allow_user_refund
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldAllowUserRefund returns the old "allow_user_refund" field's value of the PaymentProviderInstance entity.
|
||||||
|
// If the PaymentProviderInstance object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *PaymentProviderInstanceMutation) OldAllowUserRefund(ctx context.Context) (v bool, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldAllowUserRefund is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldAllowUserRefund requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldAllowUserRefund: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.AllowUserRefund, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetAllowUserRefund resets all changes to the "allow_user_refund" field.
|
||||||
|
func (m *PaymentProviderInstanceMutation) ResetAllowUserRefund() {
|
||||||
|
m.allow_user_refund = nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetCreatedAt sets the "created_at" field.
|
// SetCreatedAt sets the "created_at" field.
|
||||||
func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) {
|
func (m *PaymentProviderInstanceMutation) SetCreatedAt(t time.Time) {
|
||||||
m.created_at = &t
|
m.created_at = &t
|
||||||
@ -16211,7 +16248,7 @@ func (m *PaymentProviderInstanceMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *PaymentProviderInstanceMutation) Fields() []string {
|
func (m *PaymentProviderInstanceMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 11)
|
fields := make([]string, 0, 12)
|
||||||
if m.provider_key != nil {
|
if m.provider_key != nil {
|
||||||
fields = append(fields, paymentproviderinstance.FieldProviderKey)
|
fields = append(fields, paymentproviderinstance.FieldProviderKey)
|
||||||
}
|
}
|
||||||
@ -16239,6 +16276,9 @@ func (m *PaymentProviderInstanceMutation) Fields() []string {
|
|||||||
if m.refund_enabled != nil {
|
if m.refund_enabled != nil {
|
||||||
fields = append(fields, paymentproviderinstance.FieldRefundEnabled)
|
fields = append(fields, paymentproviderinstance.FieldRefundEnabled)
|
||||||
}
|
}
|
||||||
|
if m.allow_user_refund != nil {
|
||||||
|
fields = append(fields, paymentproviderinstance.FieldAllowUserRefund)
|
||||||
|
}
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, paymentproviderinstance.FieldCreatedAt)
|
fields = append(fields, paymentproviderinstance.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@ -16271,6 +16311,8 @@ func (m *PaymentProviderInstanceMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.Limits()
|
return m.Limits()
|
||||||
case paymentproviderinstance.FieldRefundEnabled:
|
case paymentproviderinstance.FieldRefundEnabled:
|
||||||
return m.RefundEnabled()
|
return m.RefundEnabled()
|
||||||
|
case paymentproviderinstance.FieldAllowUserRefund:
|
||||||
|
return m.AllowUserRefund()
|
||||||
case paymentproviderinstance.FieldCreatedAt:
|
case paymentproviderinstance.FieldCreatedAt:
|
||||||
return m.CreatedAt()
|
return m.CreatedAt()
|
||||||
case paymentproviderinstance.FieldUpdatedAt:
|
case paymentproviderinstance.FieldUpdatedAt:
|
||||||
@ -16302,6 +16344,8 @@ func (m *PaymentProviderInstanceMutation) OldField(ctx context.Context, name str
|
|||||||
return m.OldLimits(ctx)
|
return m.OldLimits(ctx)
|
||||||
case paymentproviderinstance.FieldRefundEnabled:
|
case paymentproviderinstance.FieldRefundEnabled:
|
||||||
return m.OldRefundEnabled(ctx)
|
return m.OldRefundEnabled(ctx)
|
||||||
|
case paymentproviderinstance.FieldAllowUserRefund:
|
||||||
|
return m.OldAllowUserRefund(ctx)
|
||||||
case paymentproviderinstance.FieldCreatedAt:
|
case paymentproviderinstance.FieldCreatedAt:
|
||||||
return m.OldCreatedAt(ctx)
|
return m.OldCreatedAt(ctx)
|
||||||
case paymentproviderinstance.FieldUpdatedAt:
|
case paymentproviderinstance.FieldUpdatedAt:
|
||||||
@ -16378,6 +16422,13 @@ func (m *PaymentProviderInstanceMutation) SetField(name string, value ent.Value)
|
|||||||
}
|
}
|
||||||
m.SetRefundEnabled(v)
|
m.SetRefundEnabled(v)
|
||||||
return nil
|
return nil
|
||||||
|
case paymentproviderinstance.FieldAllowUserRefund:
|
||||||
|
v, ok := value.(bool)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetAllowUserRefund(v)
|
||||||
|
return nil
|
||||||
case paymentproviderinstance.FieldCreatedAt:
|
case paymentproviderinstance.FieldCreatedAt:
|
||||||
v, ok := value.(time.Time)
|
v, ok := value.(time.Time)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -16483,6 +16534,9 @@ func (m *PaymentProviderInstanceMutation) ResetField(name string) error {
|
|||||||
case paymentproviderinstance.FieldRefundEnabled:
|
case paymentproviderinstance.FieldRefundEnabled:
|
||||||
m.ResetRefundEnabled()
|
m.ResetRefundEnabled()
|
||||||
return nil
|
return nil
|
||||||
|
case paymentproviderinstance.FieldAllowUserRefund:
|
||||||
|
m.ResetAllowUserRefund()
|
||||||
|
return nil
|
||||||
case paymentproviderinstance.FieldCreatedAt:
|
case paymentproviderinstance.FieldCreatedAt:
|
||||||
m.ResetCreatedAt()
|
m.ResetCreatedAt()
|
||||||
return nil
|
return nil
|
||||||
@ -28210,6 +28264,13 @@ type UserMutation struct {
|
|||||||
totp_secret_encrypted *string
|
totp_secret_encrypted *string
|
||||||
totp_enabled *bool
|
totp_enabled *bool
|
||||||
totp_enabled_at *time.Time
|
totp_enabled_at *time.Time
|
||||||
|
balance_notify_enabled *bool
|
||||||
|
balance_notify_threshold_type *string
|
||||||
|
balance_notify_threshold *float64
|
||||||
|
addbalance_notify_threshold *float64
|
||||||
|
balance_notify_extra_emails *string
|
||||||
|
total_recharged *float64
|
||||||
|
addtotal_recharged *float64
|
||||||
clearedFields map[string]struct{}
|
clearedFields map[string]struct{}
|
||||||
api_keys map[int64]struct{}
|
api_keys map[int64]struct{}
|
||||||
removedapi_keys map[int64]struct{}
|
removedapi_keys map[int64]struct{}
|
||||||
@ -28927,6 +28988,240 @@ func (m *UserMutation) ResetTotpEnabledAt() {
|
|||||||
delete(m.clearedFields, user.FieldTotpEnabledAt)
|
delete(m.clearedFields, user.FieldTotpEnabledAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
|
||||||
|
func (m *UserMutation) SetBalanceNotifyEnabled(b bool) {
|
||||||
|
m.balance_notify_enabled = &b
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyEnabled returns the value of the "balance_notify_enabled" field in the mutation.
|
||||||
|
func (m *UserMutation) BalanceNotifyEnabled() (r bool, exists bool) {
|
||||||
|
v := m.balance_notify_enabled
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldBalanceNotifyEnabled returns the old "balance_notify_enabled" field's value of the User entity.
|
||||||
|
// If the User object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *UserMutation) OldBalanceNotifyEnabled(ctx context.Context) (v bool, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldBalanceNotifyEnabled is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldBalanceNotifyEnabled requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldBalanceNotifyEnabled: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.BalanceNotifyEnabled, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetBalanceNotifyEnabled resets all changes to the "balance_notify_enabled" field.
|
||||||
|
func (m *UserMutation) ResetBalanceNotifyEnabled() {
|
||||||
|
m.balance_notify_enabled = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
|
||||||
|
func (m *UserMutation) SetBalanceNotifyThresholdType(s string) {
|
||||||
|
m.balance_notify_threshold_type = &s
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdType returns the value of the "balance_notify_threshold_type" field in the mutation.
|
||||||
|
func (m *UserMutation) BalanceNotifyThresholdType() (r string, exists bool) {
|
||||||
|
v := m.balance_notify_threshold_type
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldBalanceNotifyThresholdType returns the old "balance_notify_threshold_type" field's value of the User entity.
|
||||||
|
// If the User object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *UserMutation) OldBalanceNotifyThresholdType(ctx context.Context) (v string, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldBalanceNotifyThresholdType is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldBalanceNotifyThresholdType requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldBalanceNotifyThresholdType: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.BalanceNotifyThresholdType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetBalanceNotifyThresholdType resets all changes to the "balance_notify_threshold_type" field.
|
||||||
|
func (m *UserMutation) ResetBalanceNotifyThresholdType() {
|
||||||
|
m.balance_notify_threshold_type = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
|
||||||
|
func (m *UserMutation) SetBalanceNotifyThreshold(f float64) {
|
||||||
|
m.balance_notify_threshold = &f
|
||||||
|
m.addbalance_notify_threshold = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThreshold returns the value of the "balance_notify_threshold" field in the mutation.
|
||||||
|
func (m *UserMutation) BalanceNotifyThreshold() (r float64, exists bool) {
|
||||||
|
v := m.balance_notify_threshold
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldBalanceNotifyThreshold returns the old "balance_notify_threshold" field's value of the User entity.
|
||||||
|
// If the User object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *UserMutation) OldBalanceNotifyThreshold(ctx context.Context) (v *float64, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldBalanceNotifyThreshold is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldBalanceNotifyThreshold requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldBalanceNotifyThreshold: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.BalanceNotifyThreshold, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddBalanceNotifyThreshold adds f to the "balance_notify_threshold" field.
|
||||||
|
func (m *UserMutation) AddBalanceNotifyThreshold(f float64) {
|
||||||
|
if m.addbalance_notify_threshold != nil {
|
||||||
|
*m.addbalance_notify_threshold += f
|
||||||
|
} else {
|
||||||
|
m.addbalance_notify_threshold = &f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddedBalanceNotifyThreshold returns the value that was added to the "balance_notify_threshold" field in this mutation.
|
||||||
|
func (m *UserMutation) AddedBalanceNotifyThreshold() (r float64, exists bool) {
|
||||||
|
v := m.addbalance_notify_threshold
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
|
||||||
|
func (m *UserMutation) ClearBalanceNotifyThreshold() {
|
||||||
|
m.balance_notify_threshold = nil
|
||||||
|
m.addbalance_notify_threshold = nil
|
||||||
|
m.clearedFields[user.FieldBalanceNotifyThreshold] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdCleared returns if the "balance_notify_threshold" field was cleared in this mutation.
|
||||||
|
func (m *UserMutation) BalanceNotifyThresholdCleared() bool {
|
||||||
|
_, ok := m.clearedFields[user.FieldBalanceNotifyThreshold]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetBalanceNotifyThreshold resets all changes to the "balance_notify_threshold" field.
|
||||||
|
func (m *UserMutation) ResetBalanceNotifyThreshold() {
|
||||||
|
m.balance_notify_threshold = nil
|
||||||
|
m.addbalance_notify_threshold = nil
|
||||||
|
delete(m.clearedFields, user.FieldBalanceNotifyThreshold)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
|
||||||
|
func (m *UserMutation) SetBalanceNotifyExtraEmails(s string) {
|
||||||
|
m.balance_notify_extra_emails = &s
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyExtraEmails returns the value of the "balance_notify_extra_emails" field in the mutation.
|
||||||
|
func (m *UserMutation) BalanceNotifyExtraEmails() (r string, exists bool) {
|
||||||
|
v := m.balance_notify_extra_emails
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldBalanceNotifyExtraEmails returns the old "balance_notify_extra_emails" field's value of the User entity.
|
||||||
|
// If the User object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *UserMutation) OldBalanceNotifyExtraEmails(ctx context.Context) (v string, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldBalanceNotifyExtraEmails is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldBalanceNotifyExtraEmails requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldBalanceNotifyExtraEmails: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.BalanceNotifyExtraEmails, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetBalanceNotifyExtraEmails resets all changes to the "balance_notify_extra_emails" field.
|
||||||
|
func (m *UserMutation) ResetBalanceNotifyExtraEmails() {
|
||||||
|
m.balance_notify_extra_emails = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotalRecharged sets the "total_recharged" field.
|
||||||
|
func (m *UserMutation) SetTotalRecharged(f float64) {
|
||||||
|
m.total_recharged = &f
|
||||||
|
m.addtotal_recharged = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotalRecharged returns the value of the "total_recharged" field in the mutation.
|
||||||
|
func (m *UserMutation) TotalRecharged() (r float64, exists bool) {
|
||||||
|
v := m.total_recharged
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OldTotalRecharged returns the old "total_recharged" field's value of the User entity.
|
||||||
|
// If the User object wasn't provided to the builder, the object is fetched from the database.
|
||||||
|
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||||
|
func (m *UserMutation) OldTotalRecharged(ctx context.Context) (v float64, err error) {
|
||||||
|
if !m.op.Is(OpUpdateOne) {
|
||||||
|
return v, errors.New("OldTotalRecharged is only allowed on UpdateOne operations")
|
||||||
|
}
|
||||||
|
if m.id == nil || m.oldValue == nil {
|
||||||
|
return v, errors.New("OldTotalRecharged requires an ID field in the mutation")
|
||||||
|
}
|
||||||
|
oldValue, err := m.oldValue(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("querying old value for OldTotalRecharged: %w", err)
|
||||||
|
}
|
||||||
|
return oldValue.TotalRecharged, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddTotalRecharged adds f to the "total_recharged" field.
|
||||||
|
func (m *UserMutation) AddTotalRecharged(f float64) {
|
||||||
|
if m.addtotal_recharged != nil {
|
||||||
|
*m.addtotal_recharged += f
|
||||||
|
} else {
|
||||||
|
m.addtotal_recharged = &f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddedTotalRecharged returns the value that was added to the "total_recharged" field in this mutation.
|
||||||
|
func (m *UserMutation) AddedTotalRecharged() (r float64, exists bool) {
|
||||||
|
v := m.addtotal_recharged
|
||||||
|
if v == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return *v, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetTotalRecharged resets all changes to the "total_recharged" field.
|
||||||
|
func (m *UserMutation) ResetTotalRecharged() {
|
||||||
|
m.total_recharged = nil
|
||||||
|
m.addtotal_recharged = nil
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
|
||||||
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
|
func (m *UserMutation) AddAPIKeyIDs(ids ...int64) {
|
||||||
if m.api_keys == nil {
|
if m.api_keys == nil {
|
||||||
@ -29501,7 +29796,7 @@ func (m *UserMutation) Type() string {
|
|||||||
// order to get all numeric fields that were incremented/decremented, call
|
// order to get all numeric fields that were incremented/decremented, call
|
||||||
// AddedFields().
|
// AddedFields().
|
||||||
func (m *UserMutation) Fields() []string {
|
func (m *UserMutation) Fields() []string {
|
||||||
fields := make([]string, 0, 14)
|
fields := make([]string, 0, 19)
|
||||||
if m.created_at != nil {
|
if m.created_at != nil {
|
||||||
fields = append(fields, user.FieldCreatedAt)
|
fields = append(fields, user.FieldCreatedAt)
|
||||||
}
|
}
|
||||||
@ -29544,6 +29839,21 @@ func (m *UserMutation) Fields() []string {
|
|||||||
if m.totp_enabled_at != nil {
|
if m.totp_enabled_at != nil {
|
||||||
fields = append(fields, user.FieldTotpEnabledAt)
|
fields = append(fields, user.FieldTotpEnabledAt)
|
||||||
}
|
}
|
||||||
|
if m.balance_notify_enabled != nil {
|
||||||
|
fields = append(fields, user.FieldBalanceNotifyEnabled)
|
||||||
|
}
|
||||||
|
if m.balance_notify_threshold_type != nil {
|
||||||
|
fields = append(fields, user.FieldBalanceNotifyThresholdType)
|
||||||
|
}
|
||||||
|
if m.balance_notify_threshold != nil {
|
||||||
|
fields = append(fields, user.FieldBalanceNotifyThreshold)
|
||||||
|
}
|
||||||
|
if m.balance_notify_extra_emails != nil {
|
||||||
|
fields = append(fields, user.FieldBalanceNotifyExtraEmails)
|
||||||
|
}
|
||||||
|
if m.total_recharged != nil {
|
||||||
|
fields = append(fields, user.FieldTotalRecharged)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -29580,6 +29890,16 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) {
|
|||||||
return m.TotpEnabled()
|
return m.TotpEnabled()
|
||||||
case user.FieldTotpEnabledAt:
|
case user.FieldTotpEnabledAt:
|
||||||
return m.TotpEnabledAt()
|
return m.TotpEnabledAt()
|
||||||
|
case user.FieldBalanceNotifyEnabled:
|
||||||
|
return m.BalanceNotifyEnabled()
|
||||||
|
case user.FieldBalanceNotifyThresholdType:
|
||||||
|
return m.BalanceNotifyThresholdType()
|
||||||
|
case user.FieldBalanceNotifyThreshold:
|
||||||
|
return m.BalanceNotifyThreshold()
|
||||||
|
case user.FieldBalanceNotifyExtraEmails:
|
||||||
|
return m.BalanceNotifyExtraEmails()
|
||||||
|
case user.FieldTotalRecharged:
|
||||||
|
return m.TotalRecharged()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@ -29617,6 +29937,16 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er
|
|||||||
return m.OldTotpEnabled(ctx)
|
return m.OldTotpEnabled(ctx)
|
||||||
case user.FieldTotpEnabledAt:
|
case user.FieldTotpEnabledAt:
|
||||||
return m.OldTotpEnabledAt(ctx)
|
return m.OldTotpEnabledAt(ctx)
|
||||||
|
case user.FieldBalanceNotifyEnabled:
|
||||||
|
return m.OldBalanceNotifyEnabled(ctx)
|
||||||
|
case user.FieldBalanceNotifyThresholdType:
|
||||||
|
return m.OldBalanceNotifyThresholdType(ctx)
|
||||||
|
case user.FieldBalanceNotifyThreshold:
|
||||||
|
return m.OldBalanceNotifyThreshold(ctx)
|
||||||
|
case user.FieldBalanceNotifyExtraEmails:
|
||||||
|
return m.OldBalanceNotifyExtraEmails(ctx)
|
||||||
|
case user.FieldTotalRecharged:
|
||||||
|
return m.OldTotalRecharged(ctx)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown User field %s", name)
|
return nil, fmt.Errorf("unknown User field %s", name)
|
||||||
}
|
}
|
||||||
@ -29724,6 +30054,41 @@ func (m *UserMutation) SetField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.SetTotpEnabledAt(v)
|
m.SetTotpEnabledAt(v)
|
||||||
return nil
|
return nil
|
||||||
|
case user.FieldBalanceNotifyEnabled:
|
||||||
|
v, ok := value.(bool)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetBalanceNotifyEnabled(v)
|
||||||
|
return nil
|
||||||
|
case user.FieldBalanceNotifyThresholdType:
|
||||||
|
v, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetBalanceNotifyThresholdType(v)
|
||||||
|
return nil
|
||||||
|
case user.FieldBalanceNotifyThreshold:
|
||||||
|
v, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetBalanceNotifyThreshold(v)
|
||||||
|
return nil
|
||||||
|
case user.FieldBalanceNotifyExtraEmails:
|
||||||
|
v, ok := value.(string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetBalanceNotifyExtraEmails(v)
|
||||||
|
return nil
|
||||||
|
case user.FieldTotalRecharged:
|
||||||
|
v, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.SetTotalRecharged(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown User field %s", name)
|
return fmt.Errorf("unknown User field %s", name)
|
||||||
}
|
}
|
||||||
@ -29738,6 +30103,12 @@ func (m *UserMutation) AddedFields() []string {
|
|||||||
if m.addconcurrency != nil {
|
if m.addconcurrency != nil {
|
||||||
fields = append(fields, user.FieldConcurrency)
|
fields = append(fields, user.FieldConcurrency)
|
||||||
}
|
}
|
||||||
|
if m.addbalance_notify_threshold != nil {
|
||||||
|
fields = append(fields, user.FieldBalanceNotifyThreshold)
|
||||||
|
}
|
||||||
|
if m.addtotal_recharged != nil {
|
||||||
|
fields = append(fields, user.FieldTotalRecharged)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -29750,6 +30121,10 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) {
|
|||||||
return m.AddedBalance()
|
return m.AddedBalance()
|
||||||
case user.FieldConcurrency:
|
case user.FieldConcurrency:
|
||||||
return m.AddedConcurrency()
|
return m.AddedConcurrency()
|
||||||
|
case user.FieldBalanceNotifyThreshold:
|
||||||
|
return m.AddedBalanceNotifyThreshold()
|
||||||
|
case user.FieldTotalRecharged:
|
||||||
|
return m.AddedTotalRecharged()
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
@ -29773,6 +30148,20 @@ func (m *UserMutation) AddField(name string, value ent.Value) error {
|
|||||||
}
|
}
|
||||||
m.AddConcurrency(v)
|
m.AddConcurrency(v)
|
||||||
return nil
|
return nil
|
||||||
|
case user.FieldBalanceNotifyThreshold:
|
||||||
|
v, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.AddBalanceNotifyThreshold(v)
|
||||||
|
return nil
|
||||||
|
case user.FieldTotalRecharged:
|
||||||
|
v, ok := value.(float64)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||||
|
}
|
||||||
|
m.AddTotalRecharged(v)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown User numeric field %s", name)
|
return fmt.Errorf("unknown User numeric field %s", name)
|
||||||
}
|
}
|
||||||
@ -29790,6 +30179,9 @@ func (m *UserMutation) ClearedFields() []string {
|
|||||||
if m.FieldCleared(user.FieldTotpEnabledAt) {
|
if m.FieldCleared(user.FieldTotpEnabledAt) {
|
||||||
fields = append(fields, user.FieldTotpEnabledAt)
|
fields = append(fields, user.FieldTotpEnabledAt)
|
||||||
}
|
}
|
||||||
|
if m.FieldCleared(user.FieldBalanceNotifyThreshold) {
|
||||||
|
fields = append(fields, user.FieldBalanceNotifyThreshold)
|
||||||
|
}
|
||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -29813,6 +30205,9 @@ func (m *UserMutation) ClearField(name string) error {
|
|||||||
case user.FieldTotpEnabledAt:
|
case user.FieldTotpEnabledAt:
|
||||||
m.ClearTotpEnabledAt()
|
m.ClearTotpEnabledAt()
|
||||||
return nil
|
return nil
|
||||||
|
case user.FieldBalanceNotifyThreshold:
|
||||||
|
m.ClearBalanceNotifyThreshold()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown User nullable field %s", name)
|
return fmt.Errorf("unknown User nullable field %s", name)
|
||||||
}
|
}
|
||||||
@ -29863,6 +30258,21 @@ func (m *UserMutation) ResetField(name string) error {
|
|||||||
case user.FieldTotpEnabledAt:
|
case user.FieldTotpEnabledAt:
|
||||||
m.ResetTotpEnabledAt()
|
m.ResetTotpEnabledAt()
|
||||||
return nil
|
return nil
|
||||||
|
case user.FieldBalanceNotifyEnabled:
|
||||||
|
m.ResetBalanceNotifyEnabled()
|
||||||
|
return nil
|
||||||
|
case user.FieldBalanceNotifyThresholdType:
|
||||||
|
m.ResetBalanceNotifyThresholdType()
|
||||||
|
return nil
|
||||||
|
case user.FieldBalanceNotifyThreshold:
|
||||||
|
m.ResetBalanceNotifyThreshold()
|
||||||
|
return nil
|
||||||
|
case user.FieldBalanceNotifyExtraEmails:
|
||||||
|
m.ResetBalanceNotifyExtraEmails()
|
||||||
|
return nil
|
||||||
|
case user.FieldTotalRecharged:
|
||||||
|
m.ResetTotalRecharged()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unknown User field %s", name)
|
return fmt.Errorf("unknown User field %s", name)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -35,6 +35,8 @@ type PaymentProviderInstance struct {
|
|||||||
Limits string `json:"limits,omitempty"`
|
Limits string `json:"limits,omitempty"`
|
||||||
// RefundEnabled holds the value of the "refund_enabled" field.
|
// RefundEnabled holds the value of the "refund_enabled" field.
|
||||||
RefundEnabled bool `json:"refund_enabled,omitempty"`
|
RefundEnabled bool `json:"refund_enabled,omitempty"`
|
||||||
|
// AllowUserRefund holds the value of the "allow_user_refund" field.
|
||||||
|
AllowUserRefund bool `json:"allow_user_refund,omitempty"`
|
||||||
// CreatedAt holds the value of the "created_at" field.
|
// CreatedAt holds the value of the "created_at" field.
|
||||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||||
// UpdatedAt holds the value of the "updated_at" field.
|
// UpdatedAt holds the value of the "updated_at" field.
|
||||||
@ -47,7 +49,7 @@ func (*PaymentProviderInstance) scanValues(columns []string) ([]any, error) {
|
|||||||
values := make([]any, len(columns))
|
values := make([]any, len(columns))
|
||||||
for i := range columns {
|
for i := range columns {
|
||||||
switch columns[i] {
|
switch columns[i] {
|
||||||
case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled:
|
case paymentproviderinstance.FieldEnabled, paymentproviderinstance.FieldRefundEnabled, paymentproviderinstance.FieldAllowUserRefund:
|
||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case paymentproviderinstance.FieldID, paymentproviderinstance.FieldSortOrder:
|
case paymentproviderinstance.FieldID, paymentproviderinstance.FieldSortOrder:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
@ -130,6 +132,12 @@ func (_m *PaymentProviderInstance) assignValues(columns []string, values []any)
|
|||||||
} else if value.Valid {
|
} else if value.Valid {
|
||||||
_m.RefundEnabled = value.Bool
|
_m.RefundEnabled = value.Bool
|
||||||
}
|
}
|
||||||
|
case paymentproviderinstance.FieldAllowUserRefund:
|
||||||
|
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field allow_user_refund", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.AllowUserRefund = value.Bool
|
||||||
|
}
|
||||||
case paymentproviderinstance.FieldCreatedAt:
|
case paymentproviderinstance.FieldCreatedAt:
|
||||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||||
return fmt.Errorf("unexpected type %T for field created_at", values[i])
|
return fmt.Errorf("unexpected type %T for field created_at", values[i])
|
||||||
@ -205,6 +213,9 @@ func (_m *PaymentProviderInstance) String() string {
|
|||||||
builder.WriteString("refund_enabled=")
|
builder.WriteString("refund_enabled=")
|
||||||
builder.WriteString(fmt.Sprintf("%v", _m.RefundEnabled))
|
builder.WriteString(fmt.Sprintf("%v", _m.RefundEnabled))
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("allow_user_refund=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.AllowUserRefund))
|
||||||
|
builder.WriteString(", ")
|
||||||
builder.WriteString("created_at=")
|
builder.WriteString("created_at=")
|
||||||
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
|
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
|
|||||||
@ -31,6 +31,8 @@ const (
|
|||||||
FieldLimits = "limits"
|
FieldLimits = "limits"
|
||||||
// FieldRefundEnabled holds the string denoting the refund_enabled field in the database.
|
// FieldRefundEnabled holds the string denoting the refund_enabled field in the database.
|
||||||
FieldRefundEnabled = "refund_enabled"
|
FieldRefundEnabled = "refund_enabled"
|
||||||
|
// FieldAllowUserRefund holds the string denoting the allow_user_refund field in the database.
|
||||||
|
FieldAllowUserRefund = "allow_user_refund"
|
||||||
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||||
FieldCreatedAt = "created_at"
|
FieldCreatedAt = "created_at"
|
||||||
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
|
// FieldUpdatedAt holds the string denoting the updated_at field in the database.
|
||||||
@ -51,6 +53,7 @@ var Columns = []string{
|
|||||||
FieldSortOrder,
|
FieldSortOrder,
|
||||||
FieldLimits,
|
FieldLimits,
|
||||||
FieldRefundEnabled,
|
FieldRefundEnabled,
|
||||||
|
FieldAllowUserRefund,
|
||||||
FieldCreatedAt,
|
FieldCreatedAt,
|
||||||
FieldUpdatedAt,
|
FieldUpdatedAt,
|
||||||
}
|
}
|
||||||
@ -88,6 +91,8 @@ var (
|
|||||||
DefaultLimits string
|
DefaultLimits string
|
||||||
// DefaultRefundEnabled holds the default value on creation for the "refund_enabled" field.
|
// DefaultRefundEnabled holds the default value on creation for the "refund_enabled" field.
|
||||||
DefaultRefundEnabled bool
|
DefaultRefundEnabled bool
|
||||||
|
// DefaultAllowUserRefund holds the default value on creation for the "allow_user_refund" field.
|
||||||
|
DefaultAllowUserRefund bool
|
||||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||||
DefaultCreatedAt func() time.Time
|
DefaultCreatedAt func() time.Time
|
||||||
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
|
// DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
|
||||||
@ -149,6 +154,11 @@ func ByRefundEnabled(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldRefundEnabled, opts...).ToFunc()
|
return sql.OrderByField(FieldRefundEnabled, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByAllowUserRefund orders the results by the allow_user_refund field.
|
||||||
|
func ByAllowUserRefund(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldAllowUserRefund, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByCreatedAt orders the results by the created_at field.
|
// ByCreatedAt orders the results by the created_at field.
|
||||||
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
|
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
|
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
|
||||||
|
|||||||
@ -99,6 +99,11 @@ func RefundEnabled(v bool) predicate.PaymentProviderInstance {
|
|||||||
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldRefundEnabled, v))
|
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldRefundEnabled, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllowUserRefund applies equality check predicate on the "allow_user_refund" field. It's identical to AllowUserRefundEQ.
|
||||||
|
func AllowUserRefund(v bool) predicate.PaymentProviderInstance {
|
||||||
|
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v))
|
||||||
|
}
|
||||||
|
|
||||||
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
|
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
|
||||||
func CreatedAt(v time.Time) predicate.PaymentProviderInstance {
|
func CreatedAt(v time.Time) predicate.PaymentProviderInstance {
|
||||||
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v))
|
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
@ -559,6 +564,16 @@ func RefundEnabledNEQ(v bool) predicate.PaymentProviderInstance {
|
|||||||
return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldRefundEnabled, v))
|
return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldRefundEnabled, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllowUserRefundEQ applies the EQ predicate on the "allow_user_refund" field.
|
||||||
|
func AllowUserRefundEQ(v bool) predicate.PaymentProviderInstance {
|
||||||
|
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldAllowUserRefund, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowUserRefundNEQ applies the NEQ predicate on the "allow_user_refund" field.
|
||||||
|
func AllowUserRefundNEQ(v bool) predicate.PaymentProviderInstance {
|
||||||
|
return predicate.PaymentProviderInstance(sql.FieldNEQ(FieldAllowUserRefund, v))
|
||||||
|
}
|
||||||
|
|
||||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
func CreatedAtEQ(v time.Time) predicate.PaymentProviderInstance {
|
func CreatedAtEQ(v time.Time) predicate.PaymentProviderInstance {
|
||||||
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v))
|
return predicate.PaymentProviderInstance(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
|
|||||||
@ -132,6 +132,20 @@ func (_c *PaymentProviderInstanceCreate) SetNillableRefundEnabled(v *bool) *Paym
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAllowUserRefund sets the "allow_user_refund" field.
|
||||||
|
func (_c *PaymentProviderInstanceCreate) SetAllowUserRefund(v bool) *PaymentProviderInstanceCreate {
|
||||||
|
_c.mutation.SetAllowUserRefund(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
|
||||||
|
func (_c *PaymentProviderInstanceCreate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetAllowUserRefund(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// SetCreatedAt sets the "created_at" field.
|
// SetCreatedAt sets the "created_at" field.
|
||||||
func (_c *PaymentProviderInstanceCreate) SetCreatedAt(v time.Time) *PaymentProviderInstanceCreate {
|
func (_c *PaymentProviderInstanceCreate) SetCreatedAt(v time.Time) *PaymentProviderInstanceCreate {
|
||||||
_c.mutation.SetCreatedAt(v)
|
_c.mutation.SetCreatedAt(v)
|
||||||
@ -223,6 +237,10 @@ func (_c *PaymentProviderInstanceCreate) defaults() {
|
|||||||
v := paymentproviderinstance.DefaultRefundEnabled
|
v := paymentproviderinstance.DefaultRefundEnabled
|
||||||
_c.mutation.SetRefundEnabled(v)
|
_c.mutation.SetRefundEnabled(v)
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.AllowUserRefund(); !ok {
|
||||||
|
v := paymentproviderinstance.DefaultAllowUserRefund
|
||||||
|
_c.mutation.SetAllowUserRefund(v)
|
||||||
|
}
|
||||||
if _, ok := _c.mutation.CreatedAt(); !ok {
|
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||||
v := paymentproviderinstance.DefaultCreatedAt()
|
v := paymentproviderinstance.DefaultCreatedAt()
|
||||||
_c.mutation.SetCreatedAt(v)
|
_c.mutation.SetCreatedAt(v)
|
||||||
@ -282,6 +300,9 @@ func (_c *PaymentProviderInstanceCreate) check() error {
|
|||||||
if _, ok := _c.mutation.RefundEnabled(); !ok {
|
if _, ok := _c.mutation.RefundEnabled(); !ok {
|
||||||
return &ValidationError{Name: "refund_enabled", err: errors.New(`ent: missing required field "PaymentProviderInstance.refund_enabled"`)}
|
return &ValidationError{Name: "refund_enabled", err: errors.New(`ent: missing required field "PaymentProviderInstance.refund_enabled"`)}
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.AllowUserRefund(); !ok {
|
||||||
|
return &ValidationError{Name: "allow_user_refund", err: errors.New(`ent: missing required field "PaymentProviderInstance.allow_user_refund"`)}
|
||||||
|
}
|
||||||
if _, ok := _c.mutation.CreatedAt(); !ok {
|
if _, ok := _c.mutation.CreatedAt(); !ok {
|
||||||
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PaymentProviderInstance.created_at"`)}
|
return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "PaymentProviderInstance.created_at"`)}
|
||||||
}
|
}
|
||||||
@ -351,6 +372,10 @@ func (_c *PaymentProviderInstanceCreate) createSpec() (*PaymentProviderInstance,
|
|||||||
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
|
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
|
||||||
_node.RefundEnabled = value
|
_node.RefundEnabled = value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.AllowUserRefund(); ok {
|
||||||
|
_spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
|
||||||
|
_node.AllowUserRefund = value
|
||||||
|
}
|
||||||
if value, ok := _c.mutation.CreatedAt(); ok {
|
if value, ok := _c.mutation.CreatedAt(); ok {
|
||||||
_spec.SetField(paymentproviderinstance.FieldCreatedAt, field.TypeTime, value)
|
_spec.SetField(paymentproviderinstance.FieldCreatedAt, field.TypeTime, value)
|
||||||
_node.CreatedAt = value
|
_node.CreatedAt = value
|
||||||
@ -525,6 +550,18 @@ func (u *PaymentProviderInstanceUpsert) UpdateRefundEnabled() *PaymentProviderIn
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAllowUserRefund sets the "allow_user_refund" field.
|
||||||
|
func (u *PaymentProviderInstanceUpsert) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsert {
|
||||||
|
u.Set(paymentproviderinstance.FieldAllowUserRefund, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
|
||||||
|
func (u *PaymentProviderInstanceUpsert) UpdateAllowUserRefund() *PaymentProviderInstanceUpsert {
|
||||||
|
u.SetExcluded(paymentproviderinstance.FieldAllowUserRefund)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// SetUpdatedAt sets the "updated_at" field.
|
// SetUpdatedAt sets the "updated_at" field.
|
||||||
func (u *PaymentProviderInstanceUpsert) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsert {
|
func (u *PaymentProviderInstanceUpsert) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsert {
|
||||||
u.Set(paymentproviderinstance.FieldUpdatedAt, v)
|
u.Set(paymentproviderinstance.FieldUpdatedAt, v)
|
||||||
@ -715,6 +752,20 @@ func (u *PaymentProviderInstanceUpsertOne) UpdateRefundEnabled() *PaymentProvide
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAllowUserRefund sets the "allow_user_refund" field.
|
||||||
|
func (u *PaymentProviderInstanceUpsertOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertOne {
|
||||||
|
return u.Update(func(s *PaymentProviderInstanceUpsert) {
|
||||||
|
s.SetAllowUserRefund(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
|
||||||
|
func (u *PaymentProviderInstanceUpsertOne) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertOne {
|
||||||
|
return u.Update(func(s *PaymentProviderInstanceUpsert) {
|
||||||
|
s.UpdateAllowUserRefund()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetUpdatedAt sets the "updated_at" field.
|
// SetUpdatedAt sets the "updated_at" field.
|
||||||
func (u *PaymentProviderInstanceUpsertOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertOne {
|
func (u *PaymentProviderInstanceUpsertOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertOne {
|
||||||
return u.Update(func(s *PaymentProviderInstanceUpsert) {
|
return u.Update(func(s *PaymentProviderInstanceUpsert) {
|
||||||
@ -1073,6 +1124,20 @@ func (u *PaymentProviderInstanceUpsertBulk) UpdateRefundEnabled() *PaymentProvid
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAllowUserRefund sets the "allow_user_refund" field.
|
||||||
|
func (u *PaymentProviderInstanceUpsertBulk) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpsertBulk {
|
||||||
|
return u.Update(func(s *PaymentProviderInstanceUpsert) {
|
||||||
|
s.SetAllowUserRefund(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAllowUserRefund sets the "allow_user_refund" field to the value that was provided on create.
|
||||||
|
func (u *PaymentProviderInstanceUpsertBulk) UpdateAllowUserRefund() *PaymentProviderInstanceUpsertBulk {
|
||||||
|
return u.Update(func(s *PaymentProviderInstanceUpsert) {
|
||||||
|
s.UpdateAllowUserRefund()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetUpdatedAt sets the "updated_at" field.
|
// SetUpdatedAt sets the "updated_at" field.
|
||||||
func (u *PaymentProviderInstanceUpsertBulk) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertBulk {
|
func (u *PaymentProviderInstanceUpsertBulk) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpsertBulk {
|
||||||
return u.Update(func(s *PaymentProviderInstanceUpsert) {
|
return u.Update(func(s *PaymentProviderInstanceUpsert) {
|
||||||
|
|||||||
@ -161,6 +161,20 @@ func (_u *PaymentProviderInstanceUpdate) SetNillableRefundEnabled(v *bool) *Paym
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAllowUserRefund sets the "allow_user_refund" field.
|
||||||
|
func (_u *PaymentProviderInstanceUpdate) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdate {
|
||||||
|
_u.mutation.SetAllowUserRefund(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
|
||||||
|
func (_u *PaymentProviderInstanceUpdate) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetAllowUserRefund(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetUpdatedAt sets the "updated_at" field.
|
// SetUpdatedAt sets the "updated_at" field.
|
||||||
func (_u *PaymentProviderInstanceUpdate) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdate {
|
func (_u *PaymentProviderInstanceUpdate) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdate {
|
||||||
_u.mutation.SetUpdatedAt(v)
|
_u.mutation.SetUpdatedAt(v)
|
||||||
@ -275,6 +289,9 @@ func (_u *PaymentProviderInstanceUpdate) sqlSave(ctx context.Context) (_node int
|
|||||||
if value, ok := _u.mutation.RefundEnabled(); ok {
|
if value, ok := _u.mutation.RefundEnabled(); ok {
|
||||||
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
|
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.AllowUserRefund(); ok {
|
||||||
|
_spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.UpdatedAt(); ok {
|
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||||
_spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value)
|
_spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value)
|
||||||
}
|
}
|
||||||
@ -431,6 +448,20 @@ func (_u *PaymentProviderInstanceUpdateOne) SetNillableRefundEnabled(v *bool) *P
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAllowUserRefund sets the "allow_user_refund" field.
|
||||||
|
func (_u *PaymentProviderInstanceUpdateOne) SetAllowUserRefund(v bool) *PaymentProviderInstanceUpdateOne {
|
||||||
|
_u.mutation.SetAllowUserRefund(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableAllowUserRefund sets the "allow_user_refund" field if the given value is not nil.
|
||||||
|
func (_u *PaymentProviderInstanceUpdateOne) SetNillableAllowUserRefund(v *bool) *PaymentProviderInstanceUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetAllowUserRefund(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// SetUpdatedAt sets the "updated_at" field.
|
// SetUpdatedAt sets the "updated_at" field.
|
||||||
func (_u *PaymentProviderInstanceUpdateOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdateOne {
|
func (_u *PaymentProviderInstanceUpdateOne) SetUpdatedAt(v time.Time) *PaymentProviderInstanceUpdateOne {
|
||||||
_u.mutation.SetUpdatedAt(v)
|
_u.mutation.SetUpdatedAt(v)
|
||||||
@ -575,6 +606,9 @@ func (_u *PaymentProviderInstanceUpdateOne) sqlSave(ctx context.Context) (_node
|
|||||||
if value, ok := _u.mutation.RefundEnabled(); ok {
|
if value, ok := _u.mutation.RefundEnabled(); ok {
|
||||||
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
|
_spec.SetField(paymentproviderinstance.FieldRefundEnabled, field.TypeBool, value)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.AllowUserRefund(); ok {
|
||||||
|
_spec.SetField(paymentproviderinstance.FieldAllowUserRefund, field.TypeBool, value)
|
||||||
|
}
|
||||||
if value, ok := _u.mutation.UpdatedAt(); ok {
|
if value, ok := _u.mutation.UpdatedAt(); ok {
|
||||||
_spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value)
|
_spec.SetField(paymentproviderinstance.FieldUpdatedAt, field.TypeTime, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -668,12 +668,16 @@ func init() {
|
|||||||
paymentproviderinstanceDescRefundEnabled := paymentproviderinstanceFields[8].Descriptor()
|
paymentproviderinstanceDescRefundEnabled := paymentproviderinstanceFields[8].Descriptor()
|
||||||
// paymentproviderinstance.DefaultRefundEnabled holds the default value on creation for the refund_enabled field.
|
// paymentproviderinstance.DefaultRefundEnabled holds the default value on creation for the refund_enabled field.
|
||||||
paymentproviderinstance.DefaultRefundEnabled = paymentproviderinstanceDescRefundEnabled.Default.(bool)
|
paymentproviderinstance.DefaultRefundEnabled = paymentproviderinstanceDescRefundEnabled.Default.(bool)
|
||||||
|
// paymentproviderinstanceDescAllowUserRefund is the schema descriptor for allow_user_refund field.
|
||||||
|
paymentproviderinstanceDescAllowUserRefund := paymentproviderinstanceFields[9].Descriptor()
|
||||||
|
// paymentproviderinstance.DefaultAllowUserRefund holds the default value on creation for the allow_user_refund field.
|
||||||
|
paymentproviderinstance.DefaultAllowUserRefund = paymentproviderinstanceDescAllowUserRefund.Default.(bool)
|
||||||
// paymentproviderinstanceDescCreatedAt is the schema descriptor for created_at field.
|
// paymentproviderinstanceDescCreatedAt is the schema descriptor for created_at field.
|
||||||
paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[9].Descriptor()
|
paymentproviderinstanceDescCreatedAt := paymentproviderinstanceFields[10].Descriptor()
|
||||||
// paymentproviderinstance.DefaultCreatedAt holds the default value on creation for the created_at field.
|
// paymentproviderinstance.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||||
paymentproviderinstance.DefaultCreatedAt = paymentproviderinstanceDescCreatedAt.Default.(func() time.Time)
|
paymentproviderinstance.DefaultCreatedAt = paymentproviderinstanceDescCreatedAt.Default.(func() time.Time)
|
||||||
// paymentproviderinstanceDescUpdatedAt is the schema descriptor for updated_at field.
|
// paymentproviderinstanceDescUpdatedAt is the schema descriptor for updated_at field.
|
||||||
paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[10].Descriptor()
|
paymentproviderinstanceDescUpdatedAt := paymentproviderinstanceFields[11].Descriptor()
|
||||||
// paymentproviderinstance.DefaultUpdatedAt holds the default value on creation for the updated_at field.
|
// paymentproviderinstance.DefaultUpdatedAt holds the default value on creation for the updated_at field.
|
||||||
paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time)
|
paymentproviderinstance.DefaultUpdatedAt = paymentproviderinstanceDescUpdatedAt.Default.(func() time.Time)
|
||||||
// paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
|
// paymentproviderinstance.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
|
||||||
@ -1293,6 +1297,22 @@ func init() {
|
|||||||
userDescTotpEnabled := userFields[9].Descriptor()
|
userDescTotpEnabled := userFields[9].Descriptor()
|
||||||
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
|
// user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field.
|
||||||
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
|
user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool)
|
||||||
|
// userDescBalanceNotifyEnabled is the schema descriptor for balance_notify_enabled field.
|
||||||
|
userDescBalanceNotifyEnabled := userFields[11].Descriptor()
|
||||||
|
// user.DefaultBalanceNotifyEnabled holds the default value on creation for the balance_notify_enabled field.
|
||||||
|
user.DefaultBalanceNotifyEnabled = userDescBalanceNotifyEnabled.Default.(bool)
|
||||||
|
// userDescBalanceNotifyThresholdType is the schema descriptor for balance_notify_threshold_type field.
|
||||||
|
userDescBalanceNotifyThresholdType := userFields[12].Descriptor()
|
||||||
|
// user.DefaultBalanceNotifyThresholdType holds the default value on creation for the balance_notify_threshold_type field.
|
||||||
|
user.DefaultBalanceNotifyThresholdType = userDescBalanceNotifyThresholdType.Default.(string)
|
||||||
|
// userDescBalanceNotifyExtraEmails is the schema descriptor for balance_notify_extra_emails field.
|
||||||
|
userDescBalanceNotifyExtraEmails := userFields[14].Descriptor()
|
||||||
|
// user.DefaultBalanceNotifyExtraEmails holds the default value on creation for the balance_notify_extra_emails field.
|
||||||
|
user.DefaultBalanceNotifyExtraEmails = userDescBalanceNotifyExtraEmails.Default.(string)
|
||||||
|
// userDescTotalRecharged is the schema descriptor for total_recharged field.
|
||||||
|
userDescTotalRecharged := userFields[15].Descriptor()
|
||||||
|
// user.DefaultTotalRecharged holds the default value on creation for the total_recharged field.
|
||||||
|
user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64)
|
||||||
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
|
userallowedgroupFields := schema.UserAllowedGroup{}.Fields()
|
||||||
_ = userallowedgroupFields
|
_ = userallowedgroupFields
|
||||||
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
|
// userallowedgroupDescCreatedAt is the schema descriptor for created_at field.
|
||||||
|
|||||||
@ -53,6 +53,8 @@ func (PaymentProviderInstance) Fields() []ent.Field {
|
|||||||
Default(""),
|
Default(""),
|
||||||
field.Bool("refund_enabled").
|
field.Bool("refund_enabled").
|
||||||
Default(false),
|
Default(false),
|
||||||
|
field.Bool("allow_user_refund").
|
||||||
|
Default(false),
|
||||||
field.Time("created_at").
|
field.Time("created_at").
|
||||||
Immutable().
|
Immutable().
|
||||||
Default(time.Now).
|
Default(time.Now).
|
||||||
|
|||||||
@ -72,6 +72,22 @@ func (User) Fields() []ent.Field {
|
|||||||
field.Time("totp_enabled_at").
|
field.Time("totp_enabled_at").
|
||||||
Optional().
|
Optional().
|
||||||
Nillable(),
|
Nillable(),
|
||||||
|
|
||||||
|
// 余额不足通知
|
||||||
|
field.Bool("balance_notify_enabled").
|
||||||
|
Default(true),
|
||||||
|
field.String("balance_notify_threshold_type").
|
||||||
|
Default("fixed"), // "fixed" | "percentage"
|
||||||
|
field.Float("balance_notify_threshold").
|
||||||
|
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
|
||||||
|
Optional().
|
||||||
|
Nillable(),
|
||||||
|
field.String("balance_notify_extra_emails").
|
||||||
|
SchemaType(map[string]string{dialect.Postgres: "text"}).
|
||||||
|
Default("[]"),
|
||||||
|
field.Float("total_recharged").
|
||||||
|
SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}).
|
||||||
|
Default(0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -45,6 +45,16 @@ type User struct {
|
|||||||
TotpEnabled bool `json:"totp_enabled,omitempty"`
|
TotpEnabled bool `json:"totp_enabled,omitempty"`
|
||||||
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
|
// TotpEnabledAt holds the value of the "totp_enabled_at" field.
|
||||||
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
|
TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"`
|
||||||
|
// BalanceNotifyEnabled holds the value of the "balance_notify_enabled" field.
|
||||||
|
BalanceNotifyEnabled bool `json:"balance_notify_enabled,omitempty"`
|
||||||
|
// BalanceNotifyThresholdType holds the value of the "balance_notify_threshold_type" field.
|
||||||
|
BalanceNotifyThresholdType string `json:"balance_notify_threshold_type,omitempty"`
|
||||||
|
// BalanceNotifyThreshold holds the value of the "balance_notify_threshold" field.
|
||||||
|
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
|
||||||
|
// BalanceNotifyExtraEmails holds the value of the "balance_notify_extra_emails" field.
|
||||||
|
BalanceNotifyExtraEmails string `json:"balance_notify_extra_emails,omitempty"`
|
||||||
|
// TotalRecharged holds the value of the "total_recharged" field.
|
||||||
|
TotalRecharged float64 `json:"total_recharged,omitempty"`
|
||||||
// Edges holds the relations/edges for other nodes in the graph.
|
// Edges holds the relations/edges for other nodes in the graph.
|
||||||
// The values are being populated by the UserQuery when eager-loading is set.
|
// The values are being populated by the UserQuery when eager-loading is set.
|
||||||
Edges UserEdges `json:"edges"`
|
Edges UserEdges `json:"edges"`
|
||||||
@ -184,13 +194,13 @@ func (*User) scanValues(columns []string) ([]any, error) {
|
|||||||
values := make([]any, len(columns))
|
values := make([]any, len(columns))
|
||||||
for i := range columns {
|
for i := range columns {
|
||||||
switch columns[i] {
|
switch columns[i] {
|
||||||
case user.FieldTotpEnabled:
|
case user.FieldTotpEnabled, user.FieldBalanceNotifyEnabled:
|
||||||
values[i] = new(sql.NullBool)
|
values[i] = new(sql.NullBool)
|
||||||
case user.FieldBalance:
|
case user.FieldBalance, user.FieldBalanceNotifyThreshold, user.FieldTotalRecharged:
|
||||||
values[i] = new(sql.NullFloat64)
|
values[i] = new(sql.NullFloat64)
|
||||||
case user.FieldID, user.FieldConcurrency:
|
case user.FieldID, user.FieldConcurrency:
|
||||||
values[i] = new(sql.NullInt64)
|
values[i] = new(sql.NullInt64)
|
||||||
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted:
|
case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails:
|
||||||
values[i] = new(sql.NullString)
|
values[i] = new(sql.NullString)
|
||||||
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
|
case user.FieldCreatedAt, user.FieldUpdatedAt, user.FieldDeletedAt, user.FieldTotpEnabledAt:
|
||||||
values[i] = new(sql.NullTime)
|
values[i] = new(sql.NullTime)
|
||||||
@ -302,6 +312,37 @@ func (_m *User) assignValues(columns []string, values []any) error {
|
|||||||
_m.TotpEnabledAt = new(time.Time)
|
_m.TotpEnabledAt = new(time.Time)
|
||||||
*_m.TotpEnabledAt = value.Time
|
*_m.TotpEnabledAt = value.Time
|
||||||
}
|
}
|
||||||
|
case user.FieldBalanceNotifyEnabled:
|
||||||
|
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field balance_notify_enabled", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.BalanceNotifyEnabled = value.Bool
|
||||||
|
}
|
||||||
|
case user.FieldBalanceNotifyThresholdType:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field balance_notify_threshold_type", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.BalanceNotifyThresholdType = value.String
|
||||||
|
}
|
||||||
|
case user.FieldBalanceNotifyThreshold:
|
||||||
|
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field balance_notify_threshold", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.BalanceNotifyThreshold = new(float64)
|
||||||
|
*_m.BalanceNotifyThreshold = value.Float64
|
||||||
|
}
|
||||||
|
case user.FieldBalanceNotifyExtraEmails:
|
||||||
|
if value, ok := values[i].(*sql.NullString); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field balance_notify_extra_emails", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.BalanceNotifyExtraEmails = value.String
|
||||||
|
}
|
||||||
|
case user.FieldTotalRecharged:
|
||||||
|
if value, ok := values[i].(*sql.NullFloat64); !ok {
|
||||||
|
return fmt.Errorf("unexpected type %T for field total_recharged", values[i])
|
||||||
|
} else if value.Valid {
|
||||||
|
_m.TotalRecharged = value.Float64
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
_m.selectValues.Set(columns[i], values[i])
|
_m.selectValues.Set(columns[i], values[i])
|
||||||
}
|
}
|
||||||
@ -440,6 +481,23 @@ func (_m *User) String() string {
|
|||||||
builder.WriteString("totp_enabled_at=")
|
builder.WriteString("totp_enabled_at=")
|
||||||
builder.WriteString(v.Format(time.ANSIC))
|
builder.WriteString(v.Format(time.ANSIC))
|
||||||
}
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("balance_notify_enabled=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.BalanceNotifyEnabled))
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("balance_notify_threshold_type=")
|
||||||
|
builder.WriteString(_m.BalanceNotifyThresholdType)
|
||||||
|
builder.WriteString(", ")
|
||||||
|
if v := _m.BalanceNotifyThreshold; v != nil {
|
||||||
|
builder.WriteString("balance_notify_threshold=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||||
|
}
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("balance_notify_extra_emails=")
|
||||||
|
builder.WriteString(_m.BalanceNotifyExtraEmails)
|
||||||
|
builder.WriteString(", ")
|
||||||
|
builder.WriteString("total_recharged=")
|
||||||
|
builder.WriteString(fmt.Sprintf("%v", _m.TotalRecharged))
|
||||||
builder.WriteByte(')')
|
builder.WriteByte(')')
|
||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -43,6 +43,16 @@ const (
|
|||||||
FieldTotpEnabled = "totp_enabled"
|
FieldTotpEnabled = "totp_enabled"
|
||||||
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
|
// FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database.
|
||||||
FieldTotpEnabledAt = "totp_enabled_at"
|
FieldTotpEnabledAt = "totp_enabled_at"
|
||||||
|
// FieldBalanceNotifyEnabled holds the string denoting the balance_notify_enabled field in the database.
|
||||||
|
FieldBalanceNotifyEnabled = "balance_notify_enabled"
|
||||||
|
// FieldBalanceNotifyThresholdType holds the string denoting the balance_notify_threshold_type field in the database.
|
||||||
|
FieldBalanceNotifyThresholdType = "balance_notify_threshold_type"
|
||||||
|
// FieldBalanceNotifyThreshold holds the string denoting the balance_notify_threshold field in the database.
|
||||||
|
FieldBalanceNotifyThreshold = "balance_notify_threshold"
|
||||||
|
// FieldBalanceNotifyExtraEmails holds the string denoting the balance_notify_extra_emails field in the database.
|
||||||
|
FieldBalanceNotifyExtraEmails = "balance_notify_extra_emails"
|
||||||
|
// FieldTotalRecharged holds the string denoting the total_recharged field in the database.
|
||||||
|
FieldTotalRecharged = "total_recharged"
|
||||||
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
|
||||||
EdgeAPIKeys = "api_keys"
|
EdgeAPIKeys = "api_keys"
|
||||||
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
|
||||||
@ -161,6 +171,11 @@ var Columns = []string{
|
|||||||
FieldTotpSecretEncrypted,
|
FieldTotpSecretEncrypted,
|
||||||
FieldTotpEnabled,
|
FieldTotpEnabled,
|
||||||
FieldTotpEnabledAt,
|
FieldTotpEnabledAt,
|
||||||
|
FieldBalanceNotifyEnabled,
|
||||||
|
FieldBalanceNotifyThresholdType,
|
||||||
|
FieldBalanceNotifyThreshold,
|
||||||
|
FieldBalanceNotifyExtraEmails,
|
||||||
|
FieldTotalRecharged,
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -217,6 +232,14 @@ var (
|
|||||||
DefaultNotes string
|
DefaultNotes string
|
||||||
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
|
// DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field.
|
||||||
DefaultTotpEnabled bool
|
DefaultTotpEnabled bool
|
||||||
|
// DefaultBalanceNotifyEnabled holds the default value on creation for the "balance_notify_enabled" field.
|
||||||
|
DefaultBalanceNotifyEnabled bool
|
||||||
|
// DefaultBalanceNotifyThresholdType holds the default value on creation for the "balance_notify_threshold_type" field.
|
||||||
|
DefaultBalanceNotifyThresholdType string
|
||||||
|
// DefaultBalanceNotifyExtraEmails holds the default value on creation for the "balance_notify_extra_emails" field.
|
||||||
|
DefaultBalanceNotifyExtraEmails string
|
||||||
|
// DefaultTotalRecharged holds the default value on creation for the "total_recharged" field.
|
||||||
|
DefaultTotalRecharged float64
|
||||||
)
|
)
|
||||||
|
|
||||||
// OrderOption defines the ordering options for the User queries.
|
// OrderOption defines the ordering options for the User queries.
|
||||||
@ -297,6 +320,31 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption {
|
|||||||
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
|
return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ByBalanceNotifyEnabled orders the results by the balance_notify_enabled field.
|
||||||
|
func ByBalanceNotifyEnabled(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldBalanceNotifyEnabled, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByBalanceNotifyThresholdType orders the results by the balance_notify_threshold_type field.
|
||||||
|
func ByBalanceNotifyThresholdType(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldBalanceNotifyThresholdType, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByBalanceNotifyThreshold orders the results by the balance_notify_threshold field.
|
||||||
|
func ByBalanceNotifyThreshold(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldBalanceNotifyThreshold, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByBalanceNotifyExtraEmails orders the results by the balance_notify_extra_emails field.
|
||||||
|
func ByBalanceNotifyExtraEmails(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldBalanceNotifyExtraEmails, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ByTotalRecharged orders the results by the total_recharged field.
|
||||||
|
func ByTotalRecharged(opts ...sql.OrderTermOption) OrderOption {
|
||||||
|
return sql.OrderByField(FieldTotalRecharged, opts...).ToFunc()
|
||||||
|
}
|
||||||
|
|
||||||
// ByAPIKeysCount orders the results by api_keys count.
|
// ByAPIKeysCount orders the results by api_keys count.
|
||||||
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
|
||||||
return func(s *sql.Selector) {
|
return func(s *sql.Selector) {
|
||||||
|
|||||||
@ -125,6 +125,31 @@ func TotpEnabledAt(v time.Time) predicate.User {
|
|||||||
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
|
return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyEnabled applies equality check predicate on the "balance_notify_enabled" field. It's identical to BalanceNotifyEnabledEQ.
|
||||||
|
func BalanceNotifyEnabled(v bool) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdType applies equality check predicate on the "balance_notify_threshold_type" field. It's identical to BalanceNotifyThresholdTypeEQ.
|
||||||
|
func BalanceNotifyThresholdType(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldBalanceNotifyThresholdType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThreshold applies equality check predicate on the "balance_notify_threshold" field. It's identical to BalanceNotifyThresholdEQ.
|
||||||
|
func BalanceNotifyThreshold(v float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyExtraEmails applies equality check predicate on the "balance_notify_extra_emails" field. It's identical to BalanceNotifyExtraEmailsEQ.
|
||||||
|
func BalanceNotifyExtraEmails(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldBalanceNotifyExtraEmails, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotalRecharged applies equality check predicate on the "total_recharged" field. It's identical to TotalRechargedEQ.
|
||||||
|
func TotalRecharged(v float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldTotalRecharged, v))
|
||||||
|
}
|
||||||
|
|
||||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||||
func CreatedAtEQ(v time.Time) predicate.User {
|
func CreatedAtEQ(v time.Time) predicate.User {
|
||||||
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
|
return predicate.User(sql.FieldEQ(FieldCreatedAt, v))
|
||||||
@ -860,6 +885,236 @@ func TotpEnabledAtNotNil() predicate.User {
|
|||||||
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
|
return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyEnabledEQ applies the EQ predicate on the "balance_notify_enabled" field.
|
||||||
|
func BalanceNotifyEnabledEQ(v bool) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldBalanceNotifyEnabled, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyEnabledNEQ applies the NEQ predicate on the "balance_notify_enabled" field.
|
||||||
|
func BalanceNotifyEnabledNEQ(v bool) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNEQ(FieldBalanceNotifyEnabled, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdTypeEQ applies the EQ predicate on the "balance_notify_threshold_type" field.
|
||||||
|
func BalanceNotifyThresholdTypeEQ(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldBalanceNotifyThresholdType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdTypeNEQ applies the NEQ predicate on the "balance_notify_threshold_type" field.
|
||||||
|
func BalanceNotifyThresholdTypeNEQ(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNEQ(FieldBalanceNotifyThresholdType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdTypeIn applies the In predicate on the "balance_notify_threshold_type" field.
|
||||||
|
func BalanceNotifyThresholdTypeIn(vs ...string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldIn(FieldBalanceNotifyThresholdType, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdTypeNotIn applies the NotIn predicate on the "balance_notify_threshold_type" field.
|
||||||
|
func BalanceNotifyThresholdTypeNotIn(vs ...string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNotIn(FieldBalanceNotifyThresholdType, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdTypeGT applies the GT predicate on the "balance_notify_threshold_type" field.
|
||||||
|
func BalanceNotifyThresholdTypeGT(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGT(FieldBalanceNotifyThresholdType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdTypeGTE applies the GTE predicate on the "balance_notify_threshold_type" field.
|
||||||
|
func BalanceNotifyThresholdTypeGTE(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGTE(FieldBalanceNotifyThresholdType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdTypeLT applies the LT predicate on the "balance_notify_threshold_type" field.
|
||||||
|
func BalanceNotifyThresholdTypeLT(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLT(FieldBalanceNotifyThresholdType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdTypeLTE applies the LTE predicate on the "balance_notify_threshold_type" field.
|
||||||
|
func BalanceNotifyThresholdTypeLTE(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLTE(FieldBalanceNotifyThresholdType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdTypeContains applies the Contains predicate on the "balance_notify_threshold_type" field.
|
||||||
|
func BalanceNotifyThresholdTypeContains(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldContains(FieldBalanceNotifyThresholdType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdTypeHasPrefix applies the HasPrefix predicate on the "balance_notify_threshold_type" field.
|
||||||
|
func BalanceNotifyThresholdTypeHasPrefix(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldHasPrefix(FieldBalanceNotifyThresholdType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdTypeHasSuffix applies the HasSuffix predicate on the "balance_notify_threshold_type" field.
|
||||||
|
func BalanceNotifyThresholdTypeHasSuffix(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldHasSuffix(FieldBalanceNotifyThresholdType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdTypeEqualFold applies the EqualFold predicate on the "balance_notify_threshold_type" field.
|
||||||
|
func BalanceNotifyThresholdTypeEqualFold(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEqualFold(FieldBalanceNotifyThresholdType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdTypeContainsFold applies the ContainsFold predicate on the "balance_notify_threshold_type" field.
|
||||||
|
func BalanceNotifyThresholdTypeContainsFold(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldContainsFold(FieldBalanceNotifyThresholdType, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdEQ applies the EQ predicate on the "balance_notify_threshold" field.
|
||||||
|
func BalanceNotifyThresholdEQ(v float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldBalanceNotifyThreshold, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdNEQ applies the NEQ predicate on the "balance_notify_threshold" field.
|
||||||
|
func BalanceNotifyThresholdNEQ(v float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNEQ(FieldBalanceNotifyThreshold, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdIn applies the In predicate on the "balance_notify_threshold" field.
|
||||||
|
func BalanceNotifyThresholdIn(vs ...float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldIn(FieldBalanceNotifyThreshold, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdNotIn applies the NotIn predicate on the "balance_notify_threshold" field.
|
||||||
|
func BalanceNotifyThresholdNotIn(vs ...float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNotIn(FieldBalanceNotifyThreshold, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdGT applies the GT predicate on the "balance_notify_threshold" field.
|
||||||
|
func BalanceNotifyThresholdGT(v float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGT(FieldBalanceNotifyThreshold, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdGTE applies the GTE predicate on the "balance_notify_threshold" field.
|
||||||
|
func BalanceNotifyThresholdGTE(v float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGTE(FieldBalanceNotifyThreshold, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdLT applies the LT predicate on the "balance_notify_threshold" field.
|
||||||
|
func BalanceNotifyThresholdLT(v float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLT(FieldBalanceNotifyThreshold, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdLTE applies the LTE predicate on the "balance_notify_threshold" field.
|
||||||
|
func BalanceNotifyThresholdLTE(v float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLTE(FieldBalanceNotifyThreshold, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdIsNil applies the IsNil predicate on the "balance_notify_threshold" field.
|
||||||
|
func BalanceNotifyThresholdIsNil() predicate.User {
|
||||||
|
return predicate.User(sql.FieldIsNull(FieldBalanceNotifyThreshold))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyThresholdNotNil applies the NotNil predicate on the "balance_notify_threshold" field.
|
||||||
|
func BalanceNotifyThresholdNotNil() predicate.User {
|
||||||
|
return predicate.User(sql.FieldNotNull(FieldBalanceNotifyThreshold))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyExtraEmailsEQ applies the EQ predicate on the "balance_notify_extra_emails" field.
|
||||||
|
func BalanceNotifyExtraEmailsEQ(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldBalanceNotifyExtraEmails, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyExtraEmailsNEQ applies the NEQ predicate on the "balance_notify_extra_emails" field.
|
||||||
|
func BalanceNotifyExtraEmailsNEQ(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNEQ(FieldBalanceNotifyExtraEmails, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyExtraEmailsIn applies the In predicate on the "balance_notify_extra_emails" field.
|
||||||
|
func BalanceNotifyExtraEmailsIn(vs ...string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldIn(FieldBalanceNotifyExtraEmails, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyExtraEmailsNotIn applies the NotIn predicate on the "balance_notify_extra_emails" field.
|
||||||
|
func BalanceNotifyExtraEmailsNotIn(vs ...string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNotIn(FieldBalanceNotifyExtraEmails, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyExtraEmailsGT applies the GT predicate on the "balance_notify_extra_emails" field.
|
||||||
|
func BalanceNotifyExtraEmailsGT(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGT(FieldBalanceNotifyExtraEmails, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyExtraEmailsGTE applies the GTE predicate on the "balance_notify_extra_emails" field.
|
||||||
|
func BalanceNotifyExtraEmailsGTE(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGTE(FieldBalanceNotifyExtraEmails, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyExtraEmailsLT applies the LT predicate on the "balance_notify_extra_emails" field.
|
||||||
|
func BalanceNotifyExtraEmailsLT(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLT(FieldBalanceNotifyExtraEmails, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyExtraEmailsLTE applies the LTE predicate on the "balance_notify_extra_emails" field.
|
||||||
|
func BalanceNotifyExtraEmailsLTE(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLTE(FieldBalanceNotifyExtraEmails, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyExtraEmailsContains applies the Contains predicate on the "balance_notify_extra_emails" field.
|
||||||
|
func BalanceNotifyExtraEmailsContains(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldContains(FieldBalanceNotifyExtraEmails, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyExtraEmailsHasPrefix applies the HasPrefix predicate on the "balance_notify_extra_emails" field.
|
||||||
|
func BalanceNotifyExtraEmailsHasPrefix(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldHasPrefix(FieldBalanceNotifyExtraEmails, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyExtraEmailsHasSuffix applies the HasSuffix predicate on the "balance_notify_extra_emails" field.
|
||||||
|
func BalanceNotifyExtraEmailsHasSuffix(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldHasSuffix(FieldBalanceNotifyExtraEmails, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyExtraEmailsEqualFold applies the EqualFold predicate on the "balance_notify_extra_emails" field.
|
||||||
|
func BalanceNotifyExtraEmailsEqualFold(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEqualFold(FieldBalanceNotifyExtraEmails, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyExtraEmailsContainsFold applies the ContainsFold predicate on the "balance_notify_extra_emails" field.
|
||||||
|
func BalanceNotifyExtraEmailsContainsFold(v string) predicate.User {
|
||||||
|
return predicate.User(sql.FieldContainsFold(FieldBalanceNotifyExtraEmails, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotalRechargedEQ applies the EQ predicate on the "total_recharged" field.
|
||||||
|
func TotalRechargedEQ(v float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldEQ(FieldTotalRecharged, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotalRechargedNEQ applies the NEQ predicate on the "total_recharged" field.
|
||||||
|
func TotalRechargedNEQ(v float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNEQ(FieldTotalRecharged, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotalRechargedIn applies the In predicate on the "total_recharged" field.
|
||||||
|
func TotalRechargedIn(vs ...float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldIn(FieldTotalRecharged, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotalRechargedNotIn applies the NotIn predicate on the "total_recharged" field.
|
||||||
|
func TotalRechargedNotIn(vs ...float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldNotIn(FieldTotalRecharged, vs...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotalRechargedGT applies the GT predicate on the "total_recharged" field.
|
||||||
|
func TotalRechargedGT(v float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGT(FieldTotalRecharged, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotalRechargedGTE applies the GTE predicate on the "total_recharged" field.
|
||||||
|
func TotalRechargedGTE(v float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldGTE(FieldTotalRecharged, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotalRechargedLT applies the LT predicate on the "total_recharged" field.
|
||||||
|
func TotalRechargedLT(v float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLT(FieldTotalRecharged, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TotalRechargedLTE applies the LTE predicate on the "total_recharged" field.
|
||||||
|
func TotalRechargedLTE(v float64) predicate.User {
|
||||||
|
return predicate.User(sql.FieldLTE(FieldTotalRecharged, v))
|
||||||
|
}
|
||||||
|
|
||||||
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
|
||||||
func HasAPIKeys() predicate.User {
|
func HasAPIKeys() predicate.User {
|
||||||
return predicate.User(func(s *sql.Selector) {
|
return predicate.User(func(s *sql.Selector) {
|
||||||
|
|||||||
@ -211,6 +211,76 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate {
|
|||||||
return _c
|
return _c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
|
||||||
|
func (_c *UserCreate) SetBalanceNotifyEnabled(v bool) *UserCreate {
|
||||||
|
_c.mutation.SetBalanceNotifyEnabled(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
|
||||||
|
func (_c *UserCreate) SetNillableBalanceNotifyEnabled(v *bool) *UserCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetBalanceNotifyEnabled(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
|
||||||
|
func (_c *UserCreate) SetBalanceNotifyThresholdType(v string) *UserCreate {
|
||||||
|
_c.mutation.SetBalanceNotifyThresholdType(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
|
||||||
|
func (_c *UserCreate) SetNillableBalanceNotifyThresholdType(v *string) *UserCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetBalanceNotifyThresholdType(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
|
||||||
|
func (_c *UserCreate) SetBalanceNotifyThreshold(v float64) *UserCreate {
|
||||||
|
_c.mutation.SetBalanceNotifyThreshold(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
|
||||||
|
func (_c *UserCreate) SetNillableBalanceNotifyThreshold(v *float64) *UserCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetBalanceNotifyThreshold(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
|
||||||
|
func (_c *UserCreate) SetBalanceNotifyExtraEmails(v string) *UserCreate {
|
||||||
|
_c.mutation.SetBalanceNotifyExtraEmails(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
|
||||||
|
func (_c *UserCreate) SetNillableBalanceNotifyExtraEmails(v *string) *UserCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetBalanceNotifyExtraEmails(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotalRecharged sets the "total_recharged" field.
|
||||||
|
func (_c *UserCreate) SetTotalRecharged(v float64) *UserCreate {
|
||||||
|
_c.mutation.SetTotalRecharged(v)
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
|
||||||
|
func (_c *UserCreate) SetNillableTotalRecharged(v *float64) *UserCreate {
|
||||||
|
if v != nil {
|
||||||
|
_c.SetTotalRecharged(*v)
|
||||||
|
}
|
||||||
|
return _c
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
|
func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate {
|
||||||
_c.mutation.AddAPIKeyIDs(ids...)
|
_c.mutation.AddAPIKeyIDs(ids...)
|
||||||
@ -440,6 +510,22 @@ func (_c *UserCreate) defaults() error {
|
|||||||
v := user.DefaultTotpEnabled
|
v := user.DefaultTotpEnabled
|
||||||
_c.mutation.SetTotpEnabled(v)
|
_c.mutation.SetTotpEnabled(v)
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
|
||||||
|
v := user.DefaultBalanceNotifyEnabled
|
||||||
|
_c.mutation.SetBalanceNotifyEnabled(v)
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.BalanceNotifyThresholdType(); !ok {
|
||||||
|
v := user.DefaultBalanceNotifyThresholdType
|
||||||
|
_c.mutation.SetBalanceNotifyThresholdType(v)
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok {
|
||||||
|
v := user.DefaultBalanceNotifyExtraEmails
|
||||||
|
_c.mutation.SetBalanceNotifyExtraEmails(v)
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.TotalRecharged(); !ok {
|
||||||
|
v := user.DefaultTotalRecharged
|
||||||
|
_c.mutation.SetTotalRecharged(v)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -503,6 +589,18 @@ func (_c *UserCreate) check() error {
|
|||||||
if _, ok := _c.mutation.TotpEnabled(); !ok {
|
if _, ok := _c.mutation.TotpEnabled(); !ok {
|
||||||
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
|
return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)}
|
||||||
}
|
}
|
||||||
|
if _, ok := _c.mutation.BalanceNotifyEnabled(); !ok {
|
||||||
|
return &ValidationError{Name: "balance_notify_enabled", err: errors.New(`ent: missing required field "User.balance_notify_enabled"`)}
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.BalanceNotifyThresholdType(); !ok {
|
||||||
|
return &ValidationError{Name: "balance_notify_threshold_type", err: errors.New(`ent: missing required field "User.balance_notify_threshold_type"`)}
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.BalanceNotifyExtraEmails(); !ok {
|
||||||
|
return &ValidationError{Name: "balance_notify_extra_emails", err: errors.New(`ent: missing required field "User.balance_notify_extra_emails"`)}
|
||||||
|
}
|
||||||
|
if _, ok := _c.mutation.TotalRecharged(); !ok {
|
||||||
|
return &ValidationError{Name: "total_recharged", err: errors.New(`ent: missing required field "User.total_recharged"`)}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -586,6 +684,26 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) {
|
|||||||
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
|
_spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value)
|
||||||
_node.TotpEnabledAt = &value
|
_node.TotpEnabledAt = &value
|
||||||
}
|
}
|
||||||
|
if value, ok := _c.mutation.BalanceNotifyEnabled(); ok {
|
||||||
|
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
|
||||||
|
_node.BalanceNotifyEnabled = value
|
||||||
|
}
|
||||||
|
if value, ok := _c.mutation.BalanceNotifyThresholdType(); ok {
|
||||||
|
_spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
|
||||||
|
_node.BalanceNotifyThresholdType = value
|
||||||
|
}
|
||||||
|
if value, ok := _c.mutation.BalanceNotifyThreshold(); ok {
|
||||||
|
_spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
|
||||||
|
_node.BalanceNotifyThreshold = &value
|
||||||
|
}
|
||||||
|
if value, ok := _c.mutation.BalanceNotifyExtraEmails(); ok {
|
||||||
|
_spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
|
||||||
|
_node.BalanceNotifyExtraEmails = value
|
||||||
|
}
|
||||||
|
if value, ok := _c.mutation.TotalRecharged(); ok {
|
||||||
|
_spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
|
||||||
|
_node.TotalRecharged = value
|
||||||
|
}
|
||||||
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@ -988,6 +1106,84 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert {
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
|
||||||
|
func (u *UserUpsert) SetBalanceNotifyEnabled(v bool) *UserUpsert {
|
||||||
|
u.Set(user.FieldBalanceNotifyEnabled, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsert) UpdateBalanceNotifyEnabled() *UserUpsert {
|
||||||
|
u.SetExcluded(user.FieldBalanceNotifyEnabled)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
|
||||||
|
func (u *UserUpsert) SetBalanceNotifyThresholdType(v string) *UserUpsert {
|
||||||
|
u.Set(user.FieldBalanceNotifyThresholdType, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsert) UpdateBalanceNotifyThresholdType() *UserUpsert {
|
||||||
|
u.SetExcluded(user.FieldBalanceNotifyThresholdType)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
|
||||||
|
func (u *UserUpsert) SetBalanceNotifyThreshold(v float64) *UserUpsert {
|
||||||
|
u.Set(user.FieldBalanceNotifyThreshold, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsert) UpdateBalanceNotifyThreshold() *UserUpsert {
|
||||||
|
u.SetExcluded(user.FieldBalanceNotifyThreshold)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
|
||||||
|
func (u *UserUpsert) AddBalanceNotifyThreshold(v float64) *UserUpsert {
|
||||||
|
u.Add(user.FieldBalanceNotifyThreshold, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
|
||||||
|
func (u *UserUpsert) ClearBalanceNotifyThreshold() *UserUpsert {
|
||||||
|
u.SetNull(user.FieldBalanceNotifyThreshold)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
|
||||||
|
func (u *UserUpsert) SetBalanceNotifyExtraEmails(v string) *UserUpsert {
|
||||||
|
u.Set(user.FieldBalanceNotifyExtraEmails, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsert) UpdateBalanceNotifyExtraEmails() *UserUpsert {
|
||||||
|
u.SetExcluded(user.FieldBalanceNotifyExtraEmails)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotalRecharged sets the "total_recharged" field.
|
||||||
|
func (u *UserUpsert) SetTotalRecharged(v float64) *UserUpsert {
|
||||||
|
u.Set(user.FieldTotalRecharged, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsert) UpdateTotalRecharged() *UserUpsert {
|
||||||
|
u.SetExcluded(user.FieldTotalRecharged)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddTotalRecharged adds v to the "total_recharged" field.
|
||||||
|
func (u *UserUpsert) AddTotalRecharged(v float64) *UserUpsert {
|
||||||
|
u.Add(user.FieldTotalRecharged, v)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
// UpdateNewValues updates the mutable fields using the new values that were set on create.
|
||||||
// Using this option is equivalent to using:
|
// Using this option is equivalent to using:
|
||||||
//
|
//
|
||||||
@ -1250,6 +1446,97 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
|
||||||
|
func (u *UserUpsertOne) SetBalanceNotifyEnabled(v bool) *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetBalanceNotifyEnabled(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertOne) UpdateBalanceNotifyEnabled() *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateBalanceNotifyEnabled()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
|
||||||
|
func (u *UserUpsertOne) SetBalanceNotifyThresholdType(v string) *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetBalanceNotifyThresholdType(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertOne) UpdateBalanceNotifyThresholdType() *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateBalanceNotifyThresholdType()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
|
||||||
|
func (u *UserUpsertOne) SetBalanceNotifyThreshold(v float64) *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetBalanceNotifyThreshold(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
|
||||||
|
func (u *UserUpsertOne) AddBalanceNotifyThreshold(v float64) *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.AddBalanceNotifyThreshold(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertOne) UpdateBalanceNotifyThreshold() *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateBalanceNotifyThreshold()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
|
||||||
|
func (u *UserUpsertOne) ClearBalanceNotifyThreshold() *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.ClearBalanceNotifyThreshold()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
|
||||||
|
func (u *UserUpsertOne) SetBalanceNotifyExtraEmails(v string) *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetBalanceNotifyExtraEmails(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertOne) UpdateBalanceNotifyExtraEmails() *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateBalanceNotifyExtraEmails()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotalRecharged sets the "total_recharged" field.
|
||||||
|
func (u *UserUpsertOne) SetTotalRecharged(v float64) *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetTotalRecharged(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddTotalRecharged adds v to the "total_recharged" field.
|
||||||
|
func (u *UserUpsertOne) AddTotalRecharged(v float64) *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.AddTotalRecharged(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertOne) UpdateTotalRecharged() *UserUpsertOne {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateTotalRecharged()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *UserUpsertOne) Exec(ctx context.Context) error {
|
func (u *UserUpsertOne) Exec(ctx context.Context) error {
|
||||||
if len(u.create.conflict) == 0 {
|
if len(u.create.conflict) == 0 {
|
||||||
@ -1678,6 +1965,97 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
|
||||||
|
func (u *UserUpsertBulk) SetBalanceNotifyEnabled(v bool) *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetBalanceNotifyEnabled(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBalanceNotifyEnabled sets the "balance_notify_enabled" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertBulk) UpdateBalanceNotifyEnabled() *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateBalanceNotifyEnabled()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
|
||||||
|
func (u *UserUpsertBulk) SetBalanceNotifyThresholdType(v string) *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetBalanceNotifyThresholdType(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertBulk) UpdateBalanceNotifyThresholdType() *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateBalanceNotifyThresholdType()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
|
||||||
|
func (u *UserUpsertBulk) SetBalanceNotifyThreshold(v float64) *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetBalanceNotifyThreshold(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddBalanceNotifyThreshold adds v to the "balance_notify_threshold" field.
|
||||||
|
func (u *UserUpsertBulk) AddBalanceNotifyThreshold(v float64) *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.AddBalanceNotifyThreshold(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBalanceNotifyThreshold sets the "balance_notify_threshold" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertBulk) UpdateBalanceNotifyThreshold() *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateBalanceNotifyThreshold()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
|
||||||
|
func (u *UserUpsertBulk) ClearBalanceNotifyThreshold() *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.ClearBalanceNotifyThreshold()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
|
||||||
|
func (u *UserUpsertBulk) SetBalanceNotifyExtraEmails(v string) *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetBalanceNotifyExtraEmails(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertBulk) UpdateBalanceNotifyExtraEmails() *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateBalanceNotifyExtraEmails()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotalRecharged sets the "total_recharged" field.
|
||||||
|
func (u *UserUpsertBulk) SetTotalRecharged(v float64) *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.SetTotalRecharged(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddTotalRecharged adds v to the "total_recharged" field.
|
||||||
|
func (u *UserUpsertBulk) AddTotalRecharged(v float64) *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.AddTotalRecharged(v)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTotalRecharged sets the "total_recharged" field to the value that was provided on create.
|
||||||
|
func (u *UserUpsertBulk) UpdateTotalRecharged() *UserUpsertBulk {
|
||||||
|
return u.Update(func(s *UserUpsert) {
|
||||||
|
s.UpdateTotalRecharged()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Exec executes the query.
|
// Exec executes the query.
|
||||||
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
|
func (u *UserUpsertBulk) Exec(ctx context.Context) error {
|
||||||
if u.create.err != nil {
|
if u.create.err != nil {
|
||||||
|
|||||||
@ -243,6 +243,96 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
|
||||||
|
func (_u *UserUpdate) SetBalanceNotifyEnabled(v bool) *UserUpdate {
|
||||||
|
_u.mutation.SetBalanceNotifyEnabled(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdate) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetBalanceNotifyEnabled(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
|
||||||
|
func (_u *UserUpdate) SetBalanceNotifyThresholdType(v string) *UserUpdate {
|
||||||
|
_u.mutation.SetBalanceNotifyThresholdType(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdate) SetNillableBalanceNotifyThresholdType(v *string) *UserUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetBalanceNotifyThresholdType(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
|
||||||
|
func (_u *UserUpdate) SetBalanceNotifyThreshold(v float64) *UserUpdate {
|
||||||
|
_u.mutation.ResetBalanceNotifyThreshold()
|
||||||
|
_u.mutation.SetBalanceNotifyThreshold(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdate) SetNillableBalanceNotifyThreshold(v *float64) *UserUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetBalanceNotifyThreshold(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddBalanceNotifyThreshold adds value to the "balance_notify_threshold" field.
|
||||||
|
func (_u *UserUpdate) AddBalanceNotifyThreshold(v float64) *UserUpdate {
|
||||||
|
_u.mutation.AddBalanceNotifyThreshold(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
|
||||||
|
func (_u *UserUpdate) ClearBalanceNotifyThreshold() *UserUpdate {
|
||||||
|
_u.mutation.ClearBalanceNotifyThreshold()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
|
||||||
|
func (_u *UserUpdate) SetBalanceNotifyExtraEmails(v string) *UserUpdate {
|
||||||
|
_u.mutation.SetBalanceNotifyExtraEmails(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdate) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetBalanceNotifyExtraEmails(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotalRecharged sets the "total_recharged" field.
|
||||||
|
func (_u *UserUpdate) SetTotalRecharged(v float64) *UserUpdate {
|
||||||
|
_u.mutation.ResetTotalRecharged()
|
||||||
|
_u.mutation.SetTotalRecharged(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdate) SetNillableTotalRecharged(v *float64) *UserUpdate {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetTotalRecharged(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddTotalRecharged adds value to the "total_recharged" field.
|
||||||
|
func (_u *UserUpdate) AddTotalRecharged(v float64) *UserUpdate {
|
||||||
|
_u.mutation.AddTotalRecharged(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
|
func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@ -746,6 +836,30 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
|||||||
if _u.mutation.TotpEnabledAtCleared() {
|
if _u.mutation.TotpEnabledAtCleared() {
|
||||||
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
|
||||||
|
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.BalanceNotifyThresholdType(); ok {
|
||||||
|
_spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.BalanceNotifyThreshold(); ok {
|
||||||
|
_spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedBalanceNotifyThreshold(); ok {
|
||||||
|
_spec.AddField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.BalanceNotifyThresholdCleared() {
|
||||||
|
_spec.ClearField(user.FieldBalanceNotifyThreshold, field.TypeFloat64)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok {
|
||||||
|
_spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.TotalRecharged(); ok {
|
||||||
|
_spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedTotalRecharged(); ok {
|
||||||
|
_spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
@ -1434,6 +1548,96 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne {
|
|||||||
return _u
|
return _u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyEnabled sets the "balance_notify_enabled" field.
|
||||||
|
func (_u *UserUpdateOne) SetBalanceNotifyEnabled(v bool) *UserUpdateOne {
|
||||||
|
_u.mutation.SetBalanceNotifyEnabled(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableBalanceNotifyEnabled sets the "balance_notify_enabled" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdateOne) SetNillableBalanceNotifyEnabled(v *bool) *UserUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetBalanceNotifyEnabled(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field.
|
||||||
|
func (_u *UserUpdateOne) SetBalanceNotifyThresholdType(v string) *UserUpdateOne {
|
||||||
|
_u.mutation.SetBalanceNotifyThresholdType(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableBalanceNotifyThresholdType sets the "balance_notify_threshold_type" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdateOne) SetNillableBalanceNotifyThresholdType(v *string) *UserUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetBalanceNotifyThresholdType(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyThreshold sets the "balance_notify_threshold" field.
|
||||||
|
func (_u *UserUpdateOne) SetBalanceNotifyThreshold(v float64) *UserUpdateOne {
|
||||||
|
_u.mutation.ResetBalanceNotifyThreshold()
|
||||||
|
_u.mutation.SetBalanceNotifyThreshold(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableBalanceNotifyThreshold sets the "balance_notify_threshold" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdateOne) SetNillableBalanceNotifyThreshold(v *float64) *UserUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetBalanceNotifyThreshold(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddBalanceNotifyThreshold adds value to the "balance_notify_threshold" field.
|
||||||
|
func (_u *UserUpdateOne) AddBalanceNotifyThreshold(v float64) *UserUpdateOne {
|
||||||
|
_u.mutation.AddBalanceNotifyThreshold(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearBalanceNotifyThreshold clears the value of the "balance_notify_threshold" field.
|
||||||
|
func (_u *UserUpdateOne) ClearBalanceNotifyThreshold() *UserUpdateOne {
|
||||||
|
_u.mutation.ClearBalanceNotifyThreshold()
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field.
|
||||||
|
func (_u *UserUpdateOne) SetBalanceNotifyExtraEmails(v string) *UserUpdateOne {
|
||||||
|
_u.mutation.SetBalanceNotifyExtraEmails(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableBalanceNotifyExtraEmails sets the "balance_notify_extra_emails" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdateOne) SetNillableBalanceNotifyExtraEmails(v *string) *UserUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetBalanceNotifyExtraEmails(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTotalRecharged sets the "total_recharged" field.
|
||||||
|
func (_u *UserUpdateOne) SetTotalRecharged(v float64) *UserUpdateOne {
|
||||||
|
_u.mutation.ResetTotalRecharged()
|
||||||
|
_u.mutation.SetTotalRecharged(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNillableTotalRecharged sets the "total_recharged" field if the given value is not nil.
|
||||||
|
func (_u *UserUpdateOne) SetNillableTotalRecharged(v *float64) *UserUpdateOne {
|
||||||
|
if v != nil {
|
||||||
|
_u.SetTotalRecharged(*v)
|
||||||
|
}
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddTotalRecharged adds value to the "total_recharged" field.
|
||||||
|
func (_u *UserUpdateOne) AddTotalRecharged(v float64) *UserUpdateOne {
|
||||||
|
_u.mutation.AddTotalRecharged(v)
|
||||||
|
return _u
|
||||||
|
}
|
||||||
|
|
||||||
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
|
||||||
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
|
func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne {
|
||||||
_u.mutation.AddAPIKeyIDs(ids...)
|
_u.mutation.AddAPIKeyIDs(ids...)
|
||||||
@ -1967,6 +2171,30 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) {
|
|||||||
if _u.mutation.TotpEnabledAtCleared() {
|
if _u.mutation.TotpEnabledAtCleared() {
|
||||||
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
_spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime)
|
||||||
}
|
}
|
||||||
|
if value, ok := _u.mutation.BalanceNotifyEnabled(); ok {
|
||||||
|
_spec.SetField(user.FieldBalanceNotifyEnabled, field.TypeBool, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.BalanceNotifyThresholdType(); ok {
|
||||||
|
_spec.SetField(user.FieldBalanceNotifyThresholdType, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.BalanceNotifyThreshold(); ok {
|
||||||
|
_spec.SetField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedBalanceNotifyThreshold(); ok {
|
||||||
|
_spec.AddField(user.FieldBalanceNotifyThreshold, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if _u.mutation.BalanceNotifyThresholdCleared() {
|
||||||
|
_spec.ClearField(user.FieldBalanceNotifyThreshold, field.TypeFloat64)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.BalanceNotifyExtraEmails(); ok {
|
||||||
|
_spec.SetField(user.FieldBalanceNotifyExtraEmails, field.TypeString, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.TotalRecharged(); ok {
|
||||||
|
_spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
|
if value, ok := _u.mutation.AddedTotalRecharged(); ok {
|
||||||
|
_spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value)
|
||||||
|
}
|
||||||
if _u.mutation.APIKeysCleared() {
|
if _u.mutation.APIKeysCleared() {
|
||||||
edge := &sqlgraph.EdgeSpec{
|
edge := &sqlgraph.EdgeSpec{
|
||||||
Rel: sqlgraph.O2M,
|
Rel: sqlgraph.O2M,
|
||||||
|
|||||||
@ -185,6 +185,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
|
|||||||
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
|
||||||
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
|
||||||
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
|
||||||
|
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||||
|
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
@ -220,6 +222,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
|||||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||||
|
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||||
@ -253,6 +257,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
|||||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
|
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||||
|
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
@ -282,6 +288,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
|
|||||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
|
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||||
|
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||||
@ -314,6 +322,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
|||||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||||
|
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||||
|
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||||
|
|||||||
@ -28,7 +28,7 @@ const (
|
|||||||
|
|
||||||
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
|
// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
|
||||||
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
|
// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
|
||||||
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com https://*.stripe.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com https://*.stripe.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
|
||||||
|
|
||||||
// UMQ(用户消息队列)模式常量
|
// UMQ(用户消息队列)模式常量
|
||||||
const (
|
const (
|
||||||
|
|||||||
@ -233,12 +233,13 @@ func TestLoadForcedCodexInstructionsTemplate(t *testing.T) {
|
|||||||
configPath := filepath.Join(tempDir, "config.yaml")
|
configPath := filepath.Join(tempDir, "config.yaml")
|
||||||
|
|
||||||
require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644))
|
require.NoError(t, os.WriteFile(templatePath, []byte("server-prefix\n\n{{ .ExistingInstructions }}"), 0o644))
|
||||||
require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+templatePath+"\"\n"), 0o644))
|
yamlSafePath := filepath.ToSlash(templatePath)
|
||||||
|
require.NoError(t, os.WriteFile(configPath, []byte("gateway:\n forced_codex_instructions_template_file: \""+yamlSafePath+"\"\n"), 0o644))
|
||||||
t.Setenv("DATA_DIR", tempDir)
|
t.Setenv("DATA_DIR", tempDir)
|
||||||
|
|
||||||
cfg, err := Load()
|
cfg, err := Load()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, templatePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile)
|
require.Equal(t, yamlSafePath, cfg.Gateway.ForcedCodexInstructionsTemplateFile)
|
||||||
require.Equal(t, "server-prefix\n\n{{ .ExistingInstructions }}", cfg.Gateway.ForcedCodexInstructionsTemplate)
|
require.Equal(t, "server-prefix\n\n{{ .ExistingInstructions }}", cfg.Gateway.ForcedCodexInstructionsTemplate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1412,6 +1412,12 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
|
|||||||
c.JSON(409, gin.H{
|
c.JSON(409, gin.H{
|
||||||
"error": "mixed_channel_warning",
|
"error": "mixed_channel_warning",
|
||||||
"message": mixedErr.Error(),
|
"message": mixedErr.Error(),
|
||||||
|
"details": gin.H{
|
||||||
|
"group_id": mixedErr.GroupID,
|
||||||
|
"group_name": mixedErr.GroupName,
|
||||||
|
"current_platform": mixedErr.CurrentPlatform,
|
||||||
|
"other_platform": mixedErr.OtherPlatform,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -26,24 +27,32 @@ func NewChannelHandler(channelService *service.ChannelService, billingService *s
|
|||||||
// --- Request / Response types ---
|
// --- Request / Response types ---
|
||||||
|
|
||||||
type createChannelRequest struct {
|
type createChannelRequest struct {
|
||||||
Name string `json:"name" binding:"required,max=100"`
|
Name string `json:"name" binding:"required,max=100"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
GroupIDs []int64 `json:"group_ids"`
|
GroupIDs []int64 `json:"group_ids"`
|
||||||
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
|
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
|
||||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||||
RestrictModels bool `json:"restrict_models"`
|
RestrictModels bool `json:"restrict_models"`
|
||||||
|
Features string `json:"features"`
|
||||||
|
FeaturesConfig map[string]any `json:"features_config"`
|
||||||
|
ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
|
||||||
|
AccountStatsPricingRules []accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type updateChannelRequest struct {
|
type updateChannelRequest struct {
|
||||||
Name string `json:"name" binding:"omitempty,max=100"`
|
Name string `json:"name" binding:"omitempty,max=100"`
|
||||||
Description *string `json:"description"`
|
Description *string `json:"description"`
|
||||||
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||||
GroupIDs *[]int64 `json:"group_ids"`
|
GroupIDs *[]int64 `json:"group_ids"`
|
||||||
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
|
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
|
||||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||||
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
||||||
RestrictModels *bool `json:"restrict_models"`
|
RestrictModels *bool `json:"restrict_models"`
|
||||||
|
Features *string `json:"features"`
|
||||||
|
FeaturesConfig map[string]any `json:"features_config"`
|
||||||
|
ApplyPricingToAccountStats *bool `json:"apply_pricing_to_account_stats"`
|
||||||
|
AccountStatsPricingRules *[]accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type channelModelPricingRequest struct {
|
type channelModelPricingRequest struct {
|
||||||
@ -71,18 +80,29 @@ type pricingIntervalRequest struct {
|
|||||||
SortOrder int `json:"sort_order"`
|
SortOrder int `json:"sort_order"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type accountStatsPricingRuleRequest struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
GroupIDs []int64 `json:"group_ids"`
|
||||||
|
AccountIDs []int64 `json:"account_ids"`
|
||||||
|
Pricing []channelModelPricingRequest `json:"pricing"`
|
||||||
|
}
|
||||||
|
|
||||||
type channelResponse struct {
|
type channelResponse struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
BillingModelSource string `json:"billing_model_source"`
|
BillingModelSource string `json:"billing_model_source"`
|
||||||
RestrictModels bool `json:"restrict_models"`
|
RestrictModels bool `json:"restrict_models"`
|
||||||
GroupIDs []int64 `json:"group_ids"`
|
Features string `json:"features"`
|
||||||
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
FeaturesConfig map[string]any `json:"features_config"`
|
||||||
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
GroupIDs []int64 `json:"group_ids"`
|
||||||
CreatedAt string `json:"created_at"`
|
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
||||||
UpdatedAt string `json:"updated_at"`
|
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
||||||
|
ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"`
|
||||||
|
AccountStatsPricingRules []accountStatsPricingRuleResponse `json:"account_stats_pricing_rules"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
UpdatedAt string `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type channelModelPricingResponse struct {
|
type channelModelPricingResponse struct {
|
||||||
@ -112,6 +132,14 @@ type pricingIntervalResponse struct {
|
|||||||
SortOrder int `json:"sort_order"`
|
SortOrder int `json:"sort_order"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type accountStatsPricingRuleResponse struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
GroupIDs []int64 `json:"group_ids"`
|
||||||
|
AccountIDs []int64 `json:"account_ids"`
|
||||||
|
Pricing []channelModelPricingResponse `json:"pricing"`
|
||||||
|
}
|
||||||
|
|
||||||
func channelToResponse(ch *service.Channel) *channelResponse {
|
func channelToResponse(ch *service.Channel) *channelResponse {
|
||||||
if ch == nil {
|
if ch == nil {
|
||||||
return nil
|
return nil
|
||||||
@ -122,6 +150,8 @@ func channelToResponse(ch *service.Channel) *channelResponse {
|
|||||||
Description: ch.Description,
|
Description: ch.Description,
|
||||||
Status: ch.Status,
|
Status: ch.Status,
|
||||||
RestrictModels: ch.RestrictModels,
|
RestrictModels: ch.RestrictModels,
|
||||||
|
Features: ch.Features,
|
||||||
|
FeaturesConfig: ch.FeaturesConfig,
|
||||||
GroupIDs: ch.GroupIDs,
|
GroupIDs: ch.GroupIDs,
|
||||||
ModelMapping: ch.ModelMapping,
|
ModelMapping: ch.ModelMapping,
|
||||||
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
||||||
@ -142,6 +172,29 @@ func channelToResponse(ch *service.Channel) *channelResponse {
|
|||||||
for _, p := range ch.ModelPricing {
|
for _, p := range ch.ModelPricing {
|
||||||
resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
|
resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resp.ApplyPricingToAccountStats = ch.ApplyPricingToAccountStats
|
||||||
|
resp.AccountStatsPricingRules = make([]accountStatsPricingRuleResponse, 0, len(ch.AccountStatsPricingRules))
|
||||||
|
for _, rule := range ch.AccountStatsPricingRules {
|
||||||
|
ruleResp := accountStatsPricingRuleResponse{
|
||||||
|
ID: rule.ID,
|
||||||
|
Name: rule.Name,
|
||||||
|
GroupIDs: rule.GroupIDs,
|
||||||
|
AccountIDs: rule.AccountIDs,
|
||||||
|
Pricing: make([]channelModelPricingResponse, 0, len(rule.Pricing)),
|
||||||
|
}
|
||||||
|
if ruleResp.GroupIDs == nil {
|
||||||
|
ruleResp.GroupIDs = []int64{}
|
||||||
|
}
|
||||||
|
if ruleResp.AccountIDs == nil {
|
||||||
|
ruleResp.AccountIDs = []int64{}
|
||||||
|
}
|
||||||
|
for i := range rule.Pricing {
|
||||||
|
ruleResp.Pricing = append(ruleResp.Pricing, pricingToResponse(&rule.Pricing[i]))
|
||||||
|
}
|
||||||
|
resp.AccountStatsPricingRules = append(resp.AccountStatsPricingRules, ruleResp)
|
||||||
|
}
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -200,9 +253,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
|
|||||||
billingMode = service.BillingModeToken
|
billingMode = service.BillingModeToken
|
||||||
}
|
}
|
||||||
platform := r.Platform
|
platform := r.Platform
|
||||||
if platform == "" {
|
|
||||||
platform = service.PlatformAnthropic
|
|
||||||
}
|
|
||||||
intervals := make([]service.PricingInterval, 0, len(r.Intervals))
|
intervals := make([]service.PricingInterval, 0, len(r.Intervals))
|
||||||
for _, iv := range r.Intervals {
|
for _, iv := range r.Intervals {
|
||||||
intervals = append(intervals, service.PricingInterval{
|
intervals = append(intervals, service.PricingInterval{
|
||||||
@ -233,6 +283,15 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func accountStatsPricingRuleRequestToService(r accountStatsPricingRuleRequest) service.AccountStatsPricingRule {
|
||||||
|
return service.AccountStatsPricingRule{
|
||||||
|
Name: r.Name,
|
||||||
|
GroupIDs: r.GroupIDs,
|
||||||
|
AccountIDs: r.AccountIDs,
|
||||||
|
Pricing: pricingRequestToService(r.Pricing),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// --- Handlers ---
|
// --- Handlers ---
|
||||||
|
|
||||||
// List handles listing channels with pagination
|
// List handles listing channels with pagination
|
||||||
@ -291,15 +350,42 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pricing := pricingRequestToService(req.ModelPricing)
|
pricing := pricingRequestToService(req.ModelPricing)
|
||||||
|
// Main model_pricing requires a platform; default to anthropic for backward compatibility.
|
||||||
|
for i := range pricing {
|
||||||
|
if pricing[i].Platform == "" {
|
||||||
|
pricing[i].Platform = service.PlatformAnthropic
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var statsRules []service.AccountStatsPricingRule
|
||||||
|
for i, r := range req.AccountStatsPricingRules {
|
||||||
|
if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 {
|
||||||
|
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE",
|
||||||
|
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(r.Pricing) == 0 {
|
||||||
|
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
|
||||||
|
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rule := accountStatsPricingRuleRequestToService(r)
|
||||||
|
rule.SortOrder = i
|
||||||
|
statsRules = append(statsRules, rule)
|
||||||
|
}
|
||||||
|
|
||||||
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
GroupIDs: req.GroupIDs,
|
GroupIDs: req.GroupIDs,
|
||||||
ModelPricing: pricing,
|
ModelPricing: pricing,
|
||||||
ModelMapping: req.ModelMapping,
|
ModelMapping: req.ModelMapping,
|
||||||
BillingModelSource: req.BillingModelSource,
|
BillingModelSource: req.BillingModelSource,
|
||||||
RestrictModels: req.RestrictModels,
|
RestrictModels: req.RestrictModels,
|
||||||
|
Features: req.Features,
|
||||||
|
FeaturesConfig: req.FeaturesConfig,
|
||||||
|
ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
|
||||||
|
AccountStatsPricingRules: statsRules,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
@ -325,18 +411,45 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
input := &service.UpdateChannelInput{
|
input := &service.UpdateChannelInput{
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
Status: req.Status,
|
Status: req.Status,
|
||||||
GroupIDs: req.GroupIDs,
|
GroupIDs: req.GroupIDs,
|
||||||
ModelMapping: req.ModelMapping,
|
ModelMapping: req.ModelMapping,
|
||||||
BillingModelSource: req.BillingModelSource,
|
BillingModelSource: req.BillingModelSource,
|
||||||
RestrictModels: req.RestrictModels,
|
RestrictModels: req.RestrictModels,
|
||||||
|
Features: req.Features,
|
||||||
|
FeaturesConfig: req.FeaturesConfig,
|
||||||
|
ApplyPricingToAccountStats: req.ApplyPricingToAccountStats,
|
||||||
}
|
}
|
||||||
if req.ModelPricing != nil {
|
if req.ModelPricing != nil {
|
||||||
pricing := pricingRequestToService(*req.ModelPricing)
|
pricing := pricingRequestToService(*req.ModelPricing)
|
||||||
|
for i := range pricing {
|
||||||
|
if pricing[i].Platform == "" {
|
||||||
|
pricing[i].Platform = service.PlatformAnthropic
|
||||||
|
}
|
||||||
|
}
|
||||||
input.ModelPricing = &pricing
|
input.ModelPricing = &pricing
|
||||||
}
|
}
|
||||||
|
if req.AccountStatsPricingRules != nil {
|
||||||
|
statsRules := make([]service.AccountStatsPricingRule, 0, len(*req.AccountStatsPricingRules))
|
||||||
|
for i, r := range *req.AccountStatsPricingRules {
|
||||||
|
if len(r.GroupIDs) == 0 && len(r.AccountIDs) == 0 {
|
||||||
|
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_SCOPE",
|
||||||
|
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(r.Pricing) == 0 {
|
||||||
|
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
|
||||||
|
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rule := accountStatsPricingRuleRequestToService(r)
|
||||||
|
rule.SortOrder = i
|
||||||
|
statsRules = append(statsRules, rule)
|
||||||
|
}
|
||||||
|
input.AccountStatsPricingRules = &statsRules
|
||||||
|
}
|
||||||
|
|
||||||
channel, err := h.channelService.Update(c.Request.Context(), id, input)
|
channel, err := h.channelService.Update(c.Request.Context(), id, input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -273,13 +273,13 @@ func TestPricingRequestToService_Defaults(t *testing.T) {
|
|||||||
wantValue: string(service.BillingModeToken),
|
wantValue: string(service.BillingModeToken),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty platform defaults to anthropic",
|
name: "empty platform stays empty",
|
||||||
req: channelModelPricingRequest{
|
req: channelModelPricingRequest{
|
||||||
Models: []string{"m1"},
|
Models: []string{"m1"},
|
||||||
Platform: "",
|
Platform: "",
|
||||||
},
|
},
|
||||||
wantField: "Platform",
|
wantField: "Platform",
|
||||||
wantValue: "anthropic",
|
wantValue: "",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -5,11 +5,10 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||||
@ -175,6 +174,12 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
EnableFingerprintUnification: settings.EnableFingerprintUnification,
|
EnableFingerprintUnification: settings.EnableFingerprintUnification,
|
||||||
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
|
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
|
||||||
EnableCCHSigning: settings.EnableCCHSigning,
|
EnableCCHSigning: settings.EnableCCHSigning,
|
||||||
|
WebSearchEmulationEnabled: settings.WebSearchEmulationEnabled,
|
||||||
|
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
|
||||||
|
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
|
||||||
|
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
|
||||||
|
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
|
||||||
|
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails),
|
||||||
PaymentEnabled: paymentCfg.Enabled,
|
PaymentEnabled: paymentCfg.Enabled,
|
||||||
PaymentMinAmount: paymentCfg.MinAmount,
|
PaymentMinAmount: paymentCfg.MinAmount,
|
||||||
PaymentMaxAmount: paymentCfg.MaxAmount,
|
PaymentMaxAmount: paymentCfg.MaxAmount,
|
||||||
@ -183,6 +188,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
|||||||
PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders,
|
PaymentMaxPendingOrders: paymentCfg.MaxPendingOrders,
|
||||||
PaymentEnabledTypes: paymentCfg.EnabledTypes,
|
PaymentEnabledTypes: paymentCfg.EnabledTypes,
|
||||||
PaymentBalanceDisabled: paymentCfg.BalanceDisabled,
|
PaymentBalanceDisabled: paymentCfg.BalanceDisabled,
|
||||||
|
PaymentBalanceRechargeMultiplier: paymentCfg.BalanceRechargeMultiplier,
|
||||||
|
PaymentRechargeFeeRate: paymentCfg.RechargeFeeRate,
|
||||||
PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy,
|
PaymentLoadBalanceStrat: paymentCfg.LoadBalanceStrategy,
|
||||||
PaymentProductNamePrefix: paymentCfg.ProductNamePrefix,
|
PaymentProductNamePrefix: paymentCfg.ProductNamePrefix,
|
||||||
PaymentProductNameSuffix: paymentCfg.ProductNameSuffix,
|
PaymentProductNameSuffix: paymentCfg.ProductNameSuffix,
|
||||||
@ -304,20 +311,29 @@ type UpdateSettingsRequest struct {
|
|||||||
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
|
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
|
||||||
EnableCCHSigning *bool `json:"enable_cch_signing"`
|
EnableCCHSigning *bool `json:"enable_cch_signing"`
|
||||||
|
|
||||||
|
// Balance low notification
|
||||||
|
BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
|
||||||
|
BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
|
||||||
|
BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"`
|
||||||
|
AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"`
|
||||||
|
AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"`
|
||||||
|
|
||||||
// Payment configuration (integrated into settings, full replace)
|
// Payment configuration (integrated into settings, full replace)
|
||||||
PaymentEnabled *bool `json:"payment_enabled"`
|
PaymentEnabled *bool `json:"payment_enabled"`
|
||||||
PaymentMinAmount *float64 `json:"payment_min_amount"`
|
PaymentMinAmount *float64 `json:"payment_min_amount"`
|
||||||
PaymentMaxAmount *float64 `json:"payment_max_amount"`
|
PaymentMaxAmount *float64 `json:"payment_max_amount"`
|
||||||
PaymentDailyLimit *float64 `json:"payment_daily_limit"`
|
PaymentDailyLimit *float64 `json:"payment_daily_limit"`
|
||||||
PaymentOrderTimeoutMin *int `json:"payment_order_timeout_minutes"`
|
PaymentOrderTimeoutMin *int `json:"payment_order_timeout_minutes"`
|
||||||
PaymentMaxPendingOrders *int `json:"payment_max_pending_orders"`
|
PaymentMaxPendingOrders *int `json:"payment_max_pending_orders"`
|
||||||
PaymentEnabledTypes []string `json:"payment_enabled_types"`
|
PaymentEnabledTypes []string `json:"payment_enabled_types"`
|
||||||
PaymentBalanceDisabled *bool `json:"payment_balance_disabled"`
|
PaymentBalanceDisabled *bool `json:"payment_balance_disabled"`
|
||||||
PaymentLoadBalanceStrat *string `json:"payment_load_balance_strategy"`
|
PaymentBalanceRechargeMultiplier *float64 `json:"payment_balance_recharge_multiplier"`
|
||||||
PaymentProductNamePrefix *string `json:"payment_product_name_prefix"`
|
PaymentRechargeFeeRate *float64 `json:"payment_recharge_fee_rate"`
|
||||||
PaymentProductNameSuffix *string `json:"payment_product_name_suffix"`
|
PaymentLoadBalanceStrat *string `json:"payment_load_balance_strategy"`
|
||||||
PaymentHelpImageURL *string `json:"payment_help_image_url"`
|
PaymentProductNamePrefix *string `json:"payment_product_name_prefix"`
|
||||||
PaymentHelpText *string `json:"payment_help_text"`
|
PaymentProductNameSuffix *string `json:"payment_product_name_suffix"`
|
||||||
|
PaymentHelpImageURL *string `json:"payment_help_image_url"`
|
||||||
|
PaymentHelpText *string `json:"payment_help_text"`
|
||||||
|
|
||||||
// Cancel rate limit
|
// Cancel rate limit
|
||||||
PaymentCancelRateLimitEnabled *bool `json:"payment_cancel_rate_limit_enabled"`
|
PaymentCancelRateLimitEnabled *bool `json:"payment_cancel_rate_limit_enabled"`
|
||||||
@ -881,6 +897,36 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
return previousSettings.EnableCCHSigning
|
return previousSettings.EnableCCHSigning
|
||||||
}(),
|
}(),
|
||||||
|
BalanceLowNotifyEnabled: func() bool {
|
||||||
|
if req.BalanceLowNotifyEnabled != nil {
|
||||||
|
return *req.BalanceLowNotifyEnabled
|
||||||
|
}
|
||||||
|
return previousSettings.BalanceLowNotifyEnabled
|
||||||
|
}(),
|
||||||
|
BalanceLowNotifyThreshold: func() float64 {
|
||||||
|
if req.BalanceLowNotifyThreshold != nil {
|
||||||
|
return *req.BalanceLowNotifyThreshold
|
||||||
|
}
|
||||||
|
return previousSettings.BalanceLowNotifyThreshold
|
||||||
|
}(),
|
||||||
|
BalanceLowNotifyRechargeURL: func() string {
|
||||||
|
if req.BalanceLowNotifyRechargeURL != nil {
|
||||||
|
return *req.BalanceLowNotifyRechargeURL
|
||||||
|
}
|
||||||
|
return previousSettings.BalanceLowNotifyRechargeURL
|
||||||
|
}(),
|
||||||
|
AccountQuotaNotifyEnabled: func() bool {
|
||||||
|
if req.AccountQuotaNotifyEnabled != nil {
|
||||||
|
return *req.AccountQuotaNotifyEnabled
|
||||||
|
}
|
||||||
|
return previousSettings.AccountQuotaNotifyEnabled
|
||||||
|
}(),
|
||||||
|
AccountQuotaNotifyEmails: func() []service.NotifyEmailEntry {
|
||||||
|
if req.AccountQuotaNotifyEmails != nil {
|
||||||
|
return dto.NotifyEmailEntriesToService(*req.AccountQuotaNotifyEmails)
|
||||||
|
}
|
||||||
|
return previousSettings.AccountQuotaNotifyEmails
|
||||||
|
}(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
|
||||||
@ -892,24 +938,26 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
// Skip if no payment fields were provided (prevents accidental wipe).
|
// Skip if no payment fields were provided (prevents accidental wipe).
|
||||||
if h.paymentConfigService != nil && hasPaymentFields(req) {
|
if h.paymentConfigService != nil && hasPaymentFields(req) {
|
||||||
paymentReq := service.UpdatePaymentConfigRequest{
|
paymentReq := service.UpdatePaymentConfigRequest{
|
||||||
Enabled: req.PaymentEnabled,
|
Enabled: req.PaymentEnabled,
|
||||||
MinAmount: req.PaymentMinAmount,
|
MinAmount: req.PaymentMinAmount,
|
||||||
MaxAmount: req.PaymentMaxAmount,
|
MaxAmount: req.PaymentMaxAmount,
|
||||||
DailyLimit: req.PaymentDailyLimit,
|
DailyLimit: req.PaymentDailyLimit,
|
||||||
OrderTimeoutMin: req.PaymentOrderTimeoutMin,
|
OrderTimeoutMin: req.PaymentOrderTimeoutMin,
|
||||||
MaxPendingOrders: req.PaymentMaxPendingOrders,
|
MaxPendingOrders: req.PaymentMaxPendingOrders,
|
||||||
EnabledTypes: req.PaymentEnabledTypes,
|
EnabledTypes: req.PaymentEnabledTypes,
|
||||||
BalanceDisabled: req.PaymentBalanceDisabled,
|
BalanceDisabled: req.PaymentBalanceDisabled,
|
||||||
LoadBalanceStrategy: req.PaymentLoadBalanceStrat,
|
BalanceRechargeMultiplier: req.PaymentBalanceRechargeMultiplier,
|
||||||
ProductNamePrefix: req.PaymentProductNamePrefix,
|
RechargeFeeRate: req.PaymentRechargeFeeRate,
|
||||||
ProductNameSuffix: req.PaymentProductNameSuffix,
|
LoadBalanceStrategy: req.PaymentLoadBalanceStrat,
|
||||||
HelpImageURL: req.PaymentHelpImageURL,
|
ProductNamePrefix: req.PaymentProductNamePrefix,
|
||||||
HelpText: req.PaymentHelpText,
|
ProductNameSuffix: req.PaymentProductNameSuffix,
|
||||||
CancelRateLimitEnabled: req.PaymentCancelRateLimitEnabled,
|
HelpImageURL: req.PaymentHelpImageURL,
|
||||||
CancelRateLimitMax: req.PaymentCancelRateLimitMax,
|
HelpText: req.PaymentHelpText,
|
||||||
CancelRateLimitWindow: req.PaymentCancelRateLimitWindow,
|
CancelRateLimitEnabled: req.PaymentCancelRateLimitEnabled,
|
||||||
CancelRateLimitUnit: req.PaymentCancelRateLimitUnit,
|
CancelRateLimitMax: req.PaymentCancelRateLimitMax,
|
||||||
CancelRateLimitMode: req.PaymentCancelRateLimitMode,
|
CancelRateLimitWindow: req.PaymentCancelRateLimitWindow,
|
||||||
|
CancelRateLimitUnit: req.PaymentCancelRateLimitUnit,
|
||||||
|
CancelRateLimitMode: req.PaymentCancelRateLimitMode,
|
||||||
}
|
}
|
||||||
if err := h.paymentConfigService.UpdatePaymentConfig(c.Request.Context(), paymentReq); err != nil {
|
if err := h.paymentConfigService.UpdatePaymentConfig(c.Request.Context(), paymentReq); err != nil {
|
||||||
response.ErrorFrom(c, err)
|
response.ErrorFrom(c, err)
|
||||||
@ -1027,6 +1075,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
|
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
|
||||||
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
|
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
|
||||||
EnableCCHSigning: updatedSettings.EnableCCHSigning,
|
EnableCCHSigning: updatedSettings.EnableCCHSigning,
|
||||||
|
BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
|
||||||
|
BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
|
||||||
|
BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL,
|
||||||
|
AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled,
|
||||||
|
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails),
|
||||||
PaymentEnabled: updatedPaymentCfg.Enabled,
|
PaymentEnabled: updatedPaymentCfg.Enabled,
|
||||||
PaymentMinAmount: updatedPaymentCfg.MinAmount,
|
PaymentMinAmount: updatedPaymentCfg.MinAmount,
|
||||||
PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
|
PaymentMaxAmount: updatedPaymentCfg.MaxAmount,
|
||||||
@ -1035,6 +1088,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
|||||||
PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders,
|
PaymentMaxPendingOrders: updatedPaymentCfg.MaxPendingOrders,
|
||||||
PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes,
|
PaymentEnabledTypes: updatedPaymentCfg.EnabledTypes,
|
||||||
PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled,
|
PaymentBalanceDisabled: updatedPaymentCfg.BalanceDisabled,
|
||||||
|
PaymentBalanceRechargeMultiplier: updatedPaymentCfg.BalanceRechargeMultiplier,
|
||||||
|
PaymentRechargeFeeRate: updatedPaymentCfg.RechargeFeeRate,
|
||||||
PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy,
|
PaymentLoadBalanceStrat: updatedPaymentCfg.LoadBalanceStrategy,
|
||||||
PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix,
|
PaymentProductNamePrefix: updatedPaymentCfg.ProductNamePrefix,
|
||||||
PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix,
|
PaymentProductNameSuffix: updatedPaymentCfg.ProductNameSuffix,
|
||||||
@ -1054,6 +1109,7 @@ func hasPaymentFields(req UpdateSettingsRequest) bool {
|
|||||||
req.PaymentMaxAmount != nil || req.PaymentDailyLimit != nil ||
|
req.PaymentMaxAmount != nil || req.PaymentDailyLimit != nil ||
|
||||||
req.PaymentOrderTimeoutMin != nil || req.PaymentMaxPendingOrders != nil ||
|
req.PaymentOrderTimeoutMin != nil || req.PaymentMaxPendingOrders != nil ||
|
||||||
req.PaymentEnabledTypes != nil || req.PaymentBalanceDisabled != nil ||
|
req.PaymentEnabledTypes != nil || req.PaymentBalanceDisabled != nil ||
|
||||||
|
req.PaymentBalanceRechargeMultiplier != nil || req.PaymentRechargeFeeRate != nil ||
|
||||||
req.PaymentLoadBalanceStrat != nil || req.PaymentProductNamePrefix != nil ||
|
req.PaymentLoadBalanceStrat != nil || req.PaymentProductNamePrefix != nil ||
|
||||||
req.PaymentProductNameSuffix != nil || req.PaymentHelpImageURL != nil ||
|
req.PaymentProductNameSuffix != nil || req.PaymentHelpImageURL != nil ||
|
||||||
req.PaymentHelpText != nil || req.PaymentCancelRateLimitEnabled != nil ||
|
req.PaymentHelpText != nil || req.PaymentCancelRateLimitEnabled != nil ||
|
||||||
@ -1073,11 +1129,11 @@ func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.Sys
|
|||||||
|
|
||||||
subject, _ := middleware.GetAuthSubjectFromContext(c)
|
subject, _ := middleware.GetAuthSubjectFromContext(c)
|
||||||
role, _ := middleware.GetUserRoleFromContext(c)
|
role, _ := middleware.GetUserRoleFromContext(c)
|
||||||
log.Printf("AUDIT: settings updated at=%s user_id=%d role=%s changed=%v",
|
slog.Info("settings updated",
|
||||||
time.Now().UTC().Format(time.RFC3339),
|
"audit", true,
|
||||||
subject.UserID,
|
"user_id", subject.UserID,
|
||||||
role,
|
"role", role,
|
||||||
changed,
|
"changed", changed,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1092,6 +1148,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
|
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
|
||||||
changed = append(changed, "registration_email_suffix_whitelist")
|
changed = append(changed, "registration_email_suffix_whitelist")
|
||||||
}
|
}
|
||||||
|
if before.PromoCodeEnabled != after.PromoCodeEnabled {
|
||||||
|
changed = append(changed, "promo_code_enabled")
|
||||||
|
}
|
||||||
|
if before.InvitationCodeEnabled != after.InvitationCodeEnabled {
|
||||||
|
changed = append(changed, "invitation_code_enabled")
|
||||||
|
}
|
||||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||||
changed = append(changed, "password_reset_enabled")
|
changed = append(changed, "password_reset_enabled")
|
||||||
}
|
}
|
||||||
@ -1302,6 +1364,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.CustomMenuItems != after.CustomMenuItems {
|
if before.CustomMenuItems != after.CustomMenuItems {
|
||||||
changed = append(changed, "custom_menu_items")
|
changed = append(changed, "custom_menu_items")
|
||||||
}
|
}
|
||||||
|
if before.CustomEndpoints != after.CustomEndpoints {
|
||||||
|
changed = append(changed, "custom_endpoints")
|
||||||
|
}
|
||||||
if before.EnableFingerprintUnification != after.EnableFingerprintUnification {
|
if before.EnableFingerprintUnification != after.EnableFingerprintUnification {
|
||||||
changed = append(changed, "enable_fingerprint_unification")
|
changed = append(changed, "enable_fingerprint_unification")
|
||||||
}
|
}
|
||||||
@ -1311,6 +1376,22 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.EnableCCHSigning != after.EnableCCHSigning {
|
if before.EnableCCHSigning != after.EnableCCHSigning {
|
||||||
changed = append(changed, "enable_cch_signing")
|
changed = append(changed, "enable_cch_signing")
|
||||||
}
|
}
|
||||||
|
// Balance & quota notification
|
||||||
|
if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled {
|
||||||
|
changed = append(changed, "balance_low_notify_enabled")
|
||||||
|
}
|
||||||
|
if before.BalanceLowNotifyThreshold != after.BalanceLowNotifyThreshold {
|
||||||
|
changed = append(changed, "balance_low_notify_threshold")
|
||||||
|
}
|
||||||
|
if before.BalanceLowNotifyRechargeURL != after.BalanceLowNotifyRechargeURL {
|
||||||
|
changed = append(changed, "balance_low_notify_recharge_url")
|
||||||
|
}
|
||||||
|
if before.AccountQuotaNotifyEnabled != after.AccountQuotaNotifyEnabled {
|
||||||
|
changed = append(changed, "account_quota_notify_enabled")
|
||||||
|
}
|
||||||
|
if !equalNotifyEmailEntries(before.AccountQuotaNotifyEmails, after.AccountQuotaNotifyEmails) {
|
||||||
|
changed = append(changed, "account_quota_notify_emails")
|
||||||
|
}
|
||||||
return changed
|
return changed
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1367,6 +1448,18 @@ func equalIntSlice(a, b []int) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func equalNotifyEmailEntries(a, b []service.NotifyEmailEntry) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i].Email != b[i].Email || a[i].Verified != b[i].Verified || a[i].Disabled != b[i].Disabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// TestSMTPRequest 测试SMTP连接请求
|
// TestSMTPRequest 测试SMTP连接请求
|
||||||
type TestSMTPRequest struct {
|
type TestSMTPRequest struct {
|
||||||
SMTPHost string `json:"smtp_host"`
|
SMTPHost string `json:"smtp_host"`
|
||||||
@ -1847,3 +1940,80 @@ func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) {
|
|||||||
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
|
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetWebSearchEmulationConfig 获取 Web Search 模拟配置
|
||||||
|
// GET /api/v1/admin/settings/web-search-emulation
|
||||||
|
func (h *SettingHandler) GetWebSearchEmulationConfig(c *gin.Context) {
|
||||||
|
cfg, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), cfg))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateWebSearchEmulationConfig 更新 Web Search 模拟配置
|
||||||
|
// PUT /api/v1/admin/settings/web-search-emulation
|
||||||
|
func (h *SettingHandler) UpdateWebSearchEmulationConfig(c *gin.Context) {
|
||||||
|
var cfg service.WebSearchEmulationConfig
|
||||||
|
if err := c.ShouldBindJSON(&cfg); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.settingService.SaveWebSearchEmulationConfig(c.Request.Context(), &cfg); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-read (with sanitized api keys) to return current state
|
||||||
|
updated, err := h.settingService.GetWebSearchEmulationConfig(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, service.PopulateWebSearchUsage(c.Request.Context(), updated))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetWebSearchUsage 重置指定 provider 的配额用量
|
||||||
|
// POST /api/v1/admin/settings/web-search-emulation/reset-usage
|
||||||
|
func (h *SettingHandler) ResetWebSearchUsage(c *gin.Context) {
|
||||||
|
var req struct {
|
||||||
|
ProviderType string `json:"provider_type"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.ProviderType == "" {
|
||||||
|
response.BadRequest(c, "provider_type is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := service.ResetWebSearchUsage(c.Request.Context(), req.ProviderType); err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWebSearchEmulation 测试 Web Search 搜索
|
||||||
|
// POST /api/v1/admin/settings/web-search-emulation/test
|
||||||
|
func (h *SettingHandler) TestWebSearchEmulation(c *gin.Context) {
|
||||||
|
var req struct {
|
||||||
|
Query string `json:"query"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(req.Query) == "" {
|
||||||
|
req.Query = "搜索今年世界大事件"
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := service.TestWebSearch(c.Request.Context(), req.Query)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, result)
|
||||||
|
}
|
||||||
|
|||||||
@ -13,16 +13,21 @@ func UserFromServiceShallow(u *service.User) *User {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &User{
|
return &User{
|
||||||
ID: u.ID,
|
ID: u.ID,
|
||||||
Email: u.Email,
|
Email: u.Email,
|
||||||
Username: u.Username,
|
Username: u.Username,
|
||||||
Role: u.Role,
|
Role: u.Role,
|
||||||
Balance: u.Balance,
|
Balance: u.Balance,
|
||||||
Concurrency: u.Concurrency,
|
Concurrency: u.Concurrency,
|
||||||
Status: u.Status,
|
Status: u.Status,
|
||||||
AllowedGroups: u.AllowedGroups,
|
AllowedGroups: u.AllowedGroups,
|
||||||
CreatedAt: u.CreatedAt,
|
CreatedAt: u.CreatedAt,
|
||||||
UpdatedAt: u.UpdatedAt,
|
UpdatedAt: u.UpdatedAt,
|
||||||
|
BalanceNotifyEnabled: u.BalanceNotifyEnabled,
|
||||||
|
BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
|
||||||
|
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
|
||||||
|
BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails),
|
||||||
|
TotalRecharged: u.TotalRecharged,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -322,6 +327,26 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
|||||||
out.QuotaWeeklyResetAt = &v
|
out.QuotaWeeklyResetAt = &v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 配额通知配置
|
||||||
|
if enabled := a.GetQuotaNotifyDailyEnabled(); enabled {
|
||||||
|
out.QuotaNotifyDailyEnabled = &enabled
|
||||||
|
}
|
||||||
|
if threshold := a.GetQuotaNotifyDailyThreshold(); threshold > 0 {
|
||||||
|
out.QuotaNotifyDailyThreshold = &threshold
|
||||||
|
}
|
||||||
|
if enabled := a.GetQuotaNotifyWeeklyEnabled(); enabled {
|
||||||
|
out.QuotaNotifyWeeklyEnabled = &enabled
|
||||||
|
}
|
||||||
|
if threshold := a.GetQuotaNotifyWeeklyThreshold(); threshold > 0 {
|
||||||
|
out.QuotaNotifyWeeklyThreshold = &threshold
|
||||||
|
}
|
||||||
|
if enabled := a.GetQuotaNotifyTotalEnabled(); enabled {
|
||||||
|
out.QuotaNotifyTotalEnabled = &enabled
|
||||||
|
}
|
||||||
|
if threshold := a.GetQuotaNotifyTotalThreshold(); threshold > 0 {
|
||||||
|
out.QuotaNotifyTotalThreshold = &threshold
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return out
|
return out
|
||||||
@ -603,6 +628,7 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog {
|
|||||||
ModelMappingChain: l.ModelMappingChain,
|
ModelMappingChain: l.ModelMappingChain,
|
||||||
BillingTier: l.BillingTier,
|
BillingTier: l.BillingTier,
|
||||||
AccountRateMultiplier: l.AccountRateMultiplier,
|
AccountRateMultiplier: l.AccountRateMultiplier,
|
||||||
|
AccountStatsCost: l.AccountStatsCost,
|
||||||
IPAddress: l.IPAddress,
|
IPAddress: l.IPAddress,
|
||||||
Account: AccountSummaryFromService(l.Account),
|
Account: AccountSummaryFromService(l.Account),
|
||||||
}
|
}
|
||||||
|
|||||||
43
backend/internal/handler/dto/notify_email_entry.go
Normal file
43
backend/internal/handler/dto/notify_email_entry.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import "github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
|
// NotifyEmailEntry represents a notification email with enable/disable and verification state.
|
||||||
|
// All emails are user-managed; maximum 3 entries per user.
|
||||||
|
type NotifyEmailEntry struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Disabled bool `json:"disabled"`
|
||||||
|
Verified bool `json:"verified"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotifyEmailEntriesFromService converts service entries to DTO entries.
|
||||||
|
func NotifyEmailEntriesFromService(entries []service.NotifyEmailEntry) []NotifyEmailEntry {
|
||||||
|
if entries == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make([]NotifyEmailEntry, len(entries))
|
||||||
|
for i, e := range entries {
|
||||||
|
result[i] = NotifyEmailEntry{
|
||||||
|
Email: e.Email,
|
||||||
|
Disabled: e.Disabled,
|
||||||
|
Verified: e.Verified,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotifyEmailEntriesToService converts DTO entries to service entries.
|
||||||
|
func NotifyEmailEntriesToService(entries []NotifyEmailEntry) []service.NotifyEmailEntry {
|
||||||
|
if entries == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make([]service.NotifyEmailEntry, len(entries))
|
||||||
|
for i, e := range entries {
|
||||||
|
result[i] = service.NotifyEmailEntry{
|
||||||
|
Email: e.Email,
|
||||||
|
Disabled: e.Disabled,
|
||||||
|
Verified: e.Verified,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
@ -124,20 +124,25 @@ type SystemSettings struct {
|
|||||||
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
|
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
|
||||||
EnableCCHSigning bool `json:"enable_cch_signing"`
|
EnableCCHSigning bool `json:"enable_cch_signing"`
|
||||||
|
|
||||||
|
// Web Search Emulation
|
||||||
|
WebSearchEmulationEnabled bool `json:"web_search_emulation_enabled"`
|
||||||
|
|
||||||
// Payment configuration
|
// Payment configuration
|
||||||
PaymentEnabled bool `json:"payment_enabled"`
|
PaymentEnabled bool `json:"payment_enabled"`
|
||||||
PaymentMinAmount float64 `json:"payment_min_amount"`
|
PaymentMinAmount float64 `json:"payment_min_amount"`
|
||||||
PaymentMaxAmount float64 `json:"payment_max_amount"`
|
PaymentMaxAmount float64 `json:"payment_max_amount"`
|
||||||
PaymentDailyLimit float64 `json:"payment_daily_limit"`
|
PaymentDailyLimit float64 `json:"payment_daily_limit"`
|
||||||
PaymentOrderTimeoutMin int `json:"payment_order_timeout_minutes"`
|
PaymentOrderTimeoutMin int `json:"payment_order_timeout_minutes"`
|
||||||
PaymentMaxPendingOrders int `json:"payment_max_pending_orders"`
|
PaymentMaxPendingOrders int `json:"payment_max_pending_orders"`
|
||||||
PaymentEnabledTypes []string `json:"payment_enabled_types"`
|
PaymentEnabledTypes []string `json:"payment_enabled_types"`
|
||||||
PaymentBalanceDisabled bool `json:"payment_balance_disabled"`
|
PaymentBalanceDisabled bool `json:"payment_balance_disabled"`
|
||||||
PaymentLoadBalanceStrat string `json:"payment_load_balance_strategy"`
|
PaymentBalanceRechargeMultiplier float64 `json:"payment_balance_recharge_multiplier"`
|
||||||
PaymentProductNamePrefix string `json:"payment_product_name_prefix"`
|
PaymentRechargeFeeRate float64 `json:"payment_recharge_fee_rate"`
|
||||||
PaymentProductNameSuffix string `json:"payment_product_name_suffix"`
|
PaymentLoadBalanceStrat string `json:"payment_load_balance_strategy"`
|
||||||
PaymentHelpImageURL string `json:"payment_help_image_url"`
|
PaymentProductNamePrefix string `json:"payment_product_name_prefix"`
|
||||||
PaymentHelpText string `json:"payment_help_text"`
|
PaymentProductNameSuffix string `json:"payment_product_name_suffix"`
|
||||||
|
PaymentHelpImageURL string `json:"payment_help_image_url"`
|
||||||
|
PaymentHelpText string `json:"payment_help_text"`
|
||||||
|
|
||||||
// Cancel rate limit
|
// Cancel rate limit
|
||||||
PaymentCancelRateLimitEnabled bool `json:"payment_cancel_rate_limit_enabled"`
|
PaymentCancelRateLimitEnabled bool `json:"payment_cancel_rate_limit_enabled"`
|
||||||
@ -145,6 +150,13 @@ type SystemSettings struct {
|
|||||||
PaymentCancelRateLimitWindow int `json:"payment_cancel_rate_limit_window"`
|
PaymentCancelRateLimitWindow int `json:"payment_cancel_rate_limit_window"`
|
||||||
PaymentCancelRateLimitUnit string `json:"payment_cancel_rate_limit_unit"`
|
PaymentCancelRateLimitUnit string `json:"payment_cancel_rate_limit_unit"`
|
||||||
PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"`
|
PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"`
|
||||||
|
|
||||||
|
// Balance low notification
|
||||||
|
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
|
||||||
|
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
|
||||||
|
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
|
||||||
|
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
|
||||||
|
AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultSubscriptionSetting struct {
|
type DefaultSubscriptionSetting struct {
|
||||||
@ -183,6 +195,10 @@ type PublicSettings struct {
|
|||||||
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
BackendModeEnabled bool `json:"backend_mode_enabled"`
|
||||||
PaymentEnabled bool `json:"payment_enabled"`
|
PaymentEnabled bool `json:"payment_enabled"`
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
|
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
|
||||||
|
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
|
||||||
|
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
|
||||||
|
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// OverloadCooldownSettings 529过载冷却配置 DTO
|
// OverloadCooldownSettings 529过载冷却配置 DTO
|
||||||
|
|||||||
@ -18,6 +18,13 @@ type User struct {
|
|||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
|
||||||
|
// 余额不足通知
|
||||||
|
BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
|
||||||
|
BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
|
||||||
|
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
|
||||||
|
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"`
|
||||||
|
TotalRecharged float64 `json:"total_recharged"`
|
||||||
|
|
||||||
APIKeys []APIKey `json:"api_keys,omitempty"`
|
APIKeys []APIKey `json:"api_keys,omitempty"`
|
||||||
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
|
Subscriptions []UserSubscription `json:"subscriptions,omitempty"`
|
||||||
}
|
}
|
||||||
@ -218,6 +225,14 @@ type Account struct {
|
|||||||
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
|
QuotaDailyResetAt *string `json:"quota_daily_reset_at,omitempty"`
|
||||||
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
|
QuotaWeeklyResetAt *string `json:"quota_weekly_reset_at,omitempty"`
|
||||||
|
|
||||||
|
// 配额通知配置
|
||||||
|
QuotaNotifyDailyEnabled *bool `json:"quota_notify_daily_enabled,omitempty"`
|
||||||
|
QuotaNotifyDailyThreshold *float64 `json:"quota_notify_daily_threshold,omitempty"`
|
||||||
|
QuotaNotifyWeeklyEnabled *bool `json:"quota_notify_weekly_enabled,omitempty"`
|
||||||
|
QuotaNotifyWeeklyThreshold *float64 `json:"quota_notify_weekly_threshold,omitempty"`
|
||||||
|
QuotaNotifyTotalEnabled *bool `json:"quota_notify_total_enabled,omitempty"`
|
||||||
|
QuotaNotifyTotalThreshold *float64 `json:"quota_notify_total_threshold,omitempty"`
|
||||||
|
|
||||||
Proxy *Proxy `json:"proxy,omitempty"`
|
Proxy *Proxy `json:"proxy,omitempty"`
|
||||||
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
|
||||||
|
|
||||||
@ -412,6 +427,8 @@ type AdminUsageLog struct {
|
|||||||
|
|
||||||
// AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理)
|
// AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理)
|
||||||
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
|
AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
|
||||||
|
// AccountStatsCost 自定义定价规则计算的账号统计费用(nil 表示使用默认公式)
|
||||||
|
AccountStatsCost *float64 `json:"account_stats_cost,omitempty"`
|
||||||
|
|
||||||
// IPAddress 用户请求 IP(仅管理员可见)
|
// IPAddress 用户请求 IP(仅管理员可见)
|
||||||
IPAddress *string `json:"ip_address,omitempty"`
|
IPAddress *string `json:"ip_address,omitempty"`
|
||||||
|
|||||||
@ -248,6 +248,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 设置请求所属分组 ID(用于渠道级功能判断,如 WebSearch 模拟)
|
||||||
|
parsedReq.GroupID = apiKey.GroupID
|
||||||
|
|
||||||
// 计算粘性会话hash
|
// 计算粘性会话hash
|
||||||
parsedReq.SessionContext = &service.SessionContext{
|
parsedReq.SessionContext = &service.SessionContext{
|
||||||
ClientIP: ip.GetClientIP(c),
|
ClientIP: ip.GetClientIP(c),
|
||||||
@ -470,6 +473,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
|
ParsedRequest: parsedReq,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
@ -518,7 +522,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
// 选择支持该模型的账号
|
// 选择支持该模型的账号
|
||||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, int64(0))
|
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if len(fs.FailedAccountIDs) == 0 {
|
if len(fs.FailedAccountIDs) == 0 {
|
||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||||
@ -672,6 +676,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 转发请求 - 根据账号平台分流
|
// 转发请求 - 根据账号平台分流
|
||||||
|
c.Set("parsed_request", parsedReq)
|
||||||
var result *service.ForwardResult
|
var result *service.ForwardResult
|
||||||
requestCtx := c.Request.Context()
|
requestCtx := c.Request.Context()
|
||||||
if fs.SwitchCount > 0 {
|
if fs.SwitchCount > 0 {
|
||||||
@ -810,6 +815,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
|
ParsedRequest: parsedReq,
|
||||||
APIKey: currentAPIKey,
|
APIKey: currentAPIKey,
|
||||||
User: currentAPIKey.User,
|
User: currentAPIKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
|
|||||||
@ -168,6 +168,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
|
|||||||
nil, // tlsFPProfileService
|
nil, // tlsFPProfileService
|
||||||
nil, // channelService
|
nil, // channelService
|
||||||
nil, // resolver
|
nil, // resolver
|
||||||
|
nil, // balanceNotifyService
|
||||||
)
|
)
|
||||||
|
|
||||||
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
// RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。
|
||||||
|
|||||||
@ -126,26 +126,30 @@ func (h *PaymentHandler) GetCheckoutInfo(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
response.Success(c, checkoutInfoResponse{
|
response.Success(c, checkoutInfoResponse{
|
||||||
Methods: limitsResp.Methods,
|
Methods: limitsResp.Methods,
|
||||||
GlobalMin: limitsResp.GlobalMin,
|
GlobalMin: limitsResp.GlobalMin,
|
||||||
GlobalMax: limitsResp.GlobalMax,
|
GlobalMax: limitsResp.GlobalMax,
|
||||||
Plans: planList,
|
Plans: planList,
|
||||||
BalanceDisabled: cfg.BalanceDisabled,
|
BalanceDisabled: cfg.BalanceDisabled,
|
||||||
HelpText: cfg.HelpText,
|
BalanceRechargeMultiplier: cfg.BalanceRechargeMultiplier,
|
||||||
HelpImageURL: cfg.HelpImageURL,
|
RechargeFeeRate: cfg.RechargeFeeRate,
|
||||||
StripePublishableKey: cfg.StripePublishableKey,
|
HelpText: cfg.HelpText,
|
||||||
|
HelpImageURL: cfg.HelpImageURL,
|
||||||
|
StripePublishableKey: cfg.StripePublishableKey,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type checkoutInfoResponse struct {
|
type checkoutInfoResponse struct {
|
||||||
Methods map[string]service.MethodLimits `json:"methods"`
|
Methods map[string]service.MethodLimits `json:"methods"`
|
||||||
GlobalMin float64 `json:"global_min"`
|
GlobalMin float64 `json:"global_min"`
|
||||||
GlobalMax float64 `json:"global_max"`
|
GlobalMax float64 `json:"global_max"`
|
||||||
Plans []checkoutPlan `json:"plans"`
|
Plans []checkoutPlan `json:"plans"`
|
||||||
BalanceDisabled bool `json:"balance_disabled"`
|
BalanceDisabled bool `json:"balance_disabled"`
|
||||||
HelpText string `json:"help_text"`
|
BalanceRechargeMultiplier float64 `json:"balance_recharge_multiplier"`
|
||||||
HelpImageURL string `json:"help_image_url"`
|
RechargeFeeRate float64 `json:"recharge_fee_rate"`
|
||||||
StripePublishableKey string `json:"stripe_publishable_key"`
|
HelpText string `json:"help_text"`
|
||||||
|
HelpImageURL string `json:"help_image_url"`
|
||||||
|
StripePublishableKey string `json:"stripe_publishable_key"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type checkoutPlan struct {
|
type checkoutPlan struct {
|
||||||
@ -335,6 +339,16 @@ func (h *PaymentHandler) RequestRefund(c *gin.Context) {
|
|||||||
response.Success(c, gin.H{"message": "refund requested"})
|
response.Success(c, gin.H{"message": "refund requested"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetRefundEligibleProviders returns provider instance IDs that allow user refund.
|
||||||
|
func (h *PaymentHandler) GetRefundEligibleProviders(c *gin.Context) {
|
||||||
|
ids, err := h.configService.GetUserRefundEligibleInstanceIDs(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"provider_instance_ids": ids})
|
||||||
|
}
|
||||||
|
|
||||||
// VerifyOrderRequest is the request body for verifying a payment order.
|
// VerifyOrderRequest is the request body for verifying a payment order.
|
||||||
type VerifyOrderRequest struct {
|
type VerifyOrderRequest struct {
|
||||||
OutTradeNo string `json:"out_trade_no" binding:"required"`
|
OutTradeNo string `json:"out_trade_no" binding:"required"`
|
||||||
@ -371,6 +385,7 @@ type PublicOrderResult struct {
|
|||||||
Amount float64 `json:"amount"`
|
Amount float64 `json:"amount"`
|
||||||
PayAmount float64 `json:"pay_amount"`
|
PayAmount float64 `json:"pay_amount"`
|
||||||
PaymentType string `json:"payment_type"`
|
PaymentType string `json:"payment_type"`
|
||||||
|
OrderType string `json:"order_type"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -394,6 +409,7 @@ func (h *PaymentHandler) VerifyOrderPublic(c *gin.Context) {
|
|||||||
Amount: order.Amount,
|
Amount: order.Amount,
|
||||||
PayAmount: order.PayAmount,
|
PayAmount: order.PayAmount,
|
||||||
PaymentType: order.PaymentType,
|
PaymentType: order.PaymentType,
|
||||||
|
OrderType: order.OrderType,
|
||||||
Status: order.Status,
|
Status: order.Status,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -61,5 +61,9 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
|||||||
BackendModeEnabled: settings.BackendModeEnabled,
|
BackendModeEnabled: settings.BackendModeEnabled,
|
||||||
PaymentEnabled: settings.PaymentEnabled,
|
PaymentEnabled: settings.PaymentEnabled,
|
||||||
Version: h.version,
|
Version: h.version,
|
||||||
|
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
|
||||||
|
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
|
||||||
|
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
|
||||||
|
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,13 +11,17 @@ import (
|
|||||||
|
|
||||||
// UserHandler handles user-related requests
|
// UserHandler handles user-related requests
|
||||||
type UserHandler struct {
|
type UserHandler struct {
|
||||||
userService *service.UserService
|
userService *service.UserService
|
||||||
|
emailService *service.EmailService
|
||||||
|
emailCache service.EmailCache
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUserHandler creates a new UserHandler
|
// NewUserHandler creates a new UserHandler
|
||||||
func NewUserHandler(userService *service.UserService) *UserHandler {
|
func NewUserHandler(userService *service.UserService, emailService *service.EmailService, emailCache service.EmailCache) *UserHandler {
|
||||||
return &UserHandler{
|
return &UserHandler{
|
||||||
userService: userService,
|
userService: userService,
|
||||||
|
emailService: emailService,
|
||||||
|
emailCache: emailCache,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -29,7 +33,9 @@ type ChangePasswordRequest struct {
|
|||||||
|
|
||||||
// UpdateProfileRequest represents the update profile request payload
|
// UpdateProfileRequest represents the update profile request payload
|
||||||
type UpdateProfileRequest struct {
|
type UpdateProfileRequest struct {
|
||||||
Username *string `json:"username"`
|
Username *string `json:"username"`
|
||||||
|
BalanceNotifyEnabled *bool `json:"balance_notify_enabled"`
|
||||||
|
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProfile handles getting user profile
|
// GetProfile handles getting user profile
|
||||||
@ -94,7 +100,9 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
svcReq := service.UpdateProfileRequest{
|
svcReq := service.UpdateProfileRequest{
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
|
BalanceNotifyEnabled: req.BalanceNotifyEnabled,
|
||||||
|
BalanceNotifyThreshold: req.BalanceNotifyThreshold,
|
||||||
}
|
}
|
||||||
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
|
updatedUser, err := h.userService.UpdateProfile(c.Request.Context(), subject.UserID, svcReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -104,3 +112,141 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
|||||||
|
|
||||||
response.Success(c, dto.UserFromService(updatedUser))
|
response.Success(c, dto.UserFromService(updatedUser))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SendNotifyEmailCodeRequest represents the request to send notify email verification code
|
||||||
|
type SendNotifyEmailCodeRequest struct {
|
||||||
|
Email string `json:"email" binding:"required,email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendNotifyEmailCode sends verification code to extra notification email
|
||||||
|
// POST /api/v1/user/notify-email/send-code
|
||||||
|
func (h *UserHandler) SendNotifyEmailCode(c *gin.Context) {
|
||||||
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
response.Unauthorized(c, "User not authenticated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req SendNotifyEmailCodeRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.userService.SendNotifyEmailCode(c.Request.Context(), subject.UserID, req.Email, h.emailService, h.emailCache)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, gin.H{"message": "Verification code sent successfully"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyNotifyEmailRequest represents the request to verify and add notify email
|
||||||
|
type VerifyNotifyEmailRequest struct {
|
||||||
|
Email string `json:"email" binding:"required,email"`
|
||||||
|
Code string `json:"code" binding:"required,len=6"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyNotifyEmail verifies code and adds email to notification list
|
||||||
|
// POST /api/v1/user/notify-email/verify
|
||||||
|
func (h *UserHandler) VerifyNotifyEmail(c *gin.Context) {
|
||||||
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
response.Unauthorized(c, "User not authenticated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req VerifyNotifyEmailRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.userService.VerifyAndAddNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Code, h.emailCache)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return updated user
|
||||||
|
updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.UserFromService(updatedUser))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveNotifyEmailRequest represents the request to remove a notify email
|
||||||
|
type RemoveNotifyEmailRequest struct {
|
||||||
|
Email string `json:"email" binding:"required,email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveNotifyEmail removes email from notification list
|
||||||
|
// DELETE /api/v1/user/notify-email
|
||||||
|
func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
|
||||||
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
response.Unauthorized(c, "User not authenticated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req RemoveNotifyEmailRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.userService.RemoveNotifyEmail(c.Request.Context(), subject.UserID, req.Email)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return updated user
|
||||||
|
updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.UserFromService(updatedUser))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
|
||||||
|
type ToggleNotifyEmailRequest struct {
|
||||||
|
Email string `json:"email" binding:"required,email"`
|
||||||
|
Disabled bool `json:"disabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToggleNotifyEmail toggles the disabled state of a notification email
|
||||||
|
// PUT /api/v1/user/notify-email/toggle
|
||||||
|
func (h *UserHandler) ToggleNotifyEmail(c *gin.Context) {
|
||||||
|
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||||
|
if !ok {
|
||||||
|
response.Unauthorized(c, "User not authenticated")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req ToggleNotifyEmailRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := h.userService.ToggleNotifyEmail(c.Request.Context(), subject.UserID, req.Email, req.Disabled)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedUser, err := h.userService.GetByID(c.Request.Context(), subject.UserID)
|
||||||
|
if err != nil {
|
||||||
|
response.ErrorFrom(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Success(c, dto.UserFromService(updatedUser))
|
||||||
|
}
|
||||||
|
|||||||
@ -94,17 +94,21 @@ func (lb *DefaultLoadBalancer) SelectInstance(
|
|||||||
return lb.buildSelection(selected.inst)
|
return lb.buildSelection(selected.inst)
|
||||||
}
|
}
|
||||||
|
|
||||||
// queryEnabledInstances returns enabled instances for providerKey that support paymentType.
|
// queryEnabledInstances returns enabled instances that support paymentType.
|
||||||
|
// When providerKey is non-empty, only instances with that provider key are considered.
|
||||||
|
// When providerKey is empty, instances across all providers are considered,
|
||||||
|
// enabling cross-provider load balancing (e.g. EasyPay + Alipay direct for "alipay").
|
||||||
func (lb *DefaultLoadBalancer) queryEnabledInstances(
|
func (lb *DefaultLoadBalancer) queryEnabledInstances(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
providerKey string,
|
providerKey string,
|
||||||
paymentType PaymentType,
|
paymentType PaymentType,
|
||||||
) ([]*dbent.PaymentProviderInstance, error) {
|
) ([]*dbent.PaymentProviderInstance, error) {
|
||||||
instances, err := lb.db.PaymentProviderInstance.Query().
|
query := lb.db.PaymentProviderInstance.Query().
|
||||||
Where(
|
Where(paymentproviderinstance.Enabled(true))
|
||||||
paymentproviderinstance.ProviderKey(providerKey),
|
if providerKey != "" {
|
||||||
paymentproviderinstance.Enabled(true),
|
query = query.Where(paymentproviderinstance.ProviderKey(providerKey))
|
||||||
).
|
}
|
||||||
|
instances, err := query.
|
||||||
Order(dbent.Asc(paymentproviderinstance.FieldSortOrder)).
|
Order(dbent.Asc(paymentproviderinstance.FieldSortOrder)).
|
||||||
All(ctx)
|
All(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -113,12 +117,18 @@ func (lb *DefaultLoadBalancer) queryEnabledInstances(
|
|||||||
|
|
||||||
var matched []*dbent.PaymentProviderInstance
|
var matched []*dbent.PaymentProviderInstance
|
||||||
for _, inst := range instances {
|
for _, inst := range instances {
|
||||||
if paymentType == providerKey || InstanceSupportsType(inst.SupportedTypes, paymentType) {
|
// Stripe: match by provider_key because supported_types lists sub-types (card,link,alipay,wxpay),
|
||||||
|
// not "stripe" itself. The checkout page aggregates all sub-types under "stripe".
|
||||||
|
if paymentType == TypeStripe {
|
||||||
|
if inst.ProviderKey == TypeStripe {
|
||||||
|
matched = append(matched, inst)
|
||||||
|
}
|
||||||
|
} else if InstanceSupportsType(inst.SupportedTypes, paymentType) {
|
||||||
matched = append(matched, inst)
|
matched = append(matched, inst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(matched) == 0 {
|
if len(matched) == 0 {
|
||||||
return nil, fmt.Errorf("no enabled instance for provider %s type %s", providerKey, paymentType)
|
return nil, fmt.Errorf("no enabled instance for payment type %s", paymentType)
|
||||||
}
|
}
|
||||||
return matched, nil
|
return matched, nil
|
||||||
}
|
}
|
||||||
@ -258,6 +268,7 @@ func (lb *DefaultLoadBalancer) buildSelection(selected *dbent.PaymentProviderIns
|
|||||||
|
|
||||||
return &InstanceSelection{
|
return &InstanceSelection{
|
||||||
InstanceID: fmt.Sprintf("%d", selected.ID),
|
InstanceID: fmt.Sprintf("%d", selected.ID),
|
||||||
|
ProviderKey: selected.ProviderKey,
|
||||||
Config: config,
|
Config: config,
|
||||||
SupportedTypes: selected.SupportedTypes,
|
SupportedTypes: selected.SupportedTypes,
|
||||||
PaymentMode: selected.PaymentMode,
|
PaymentMode: selected.PaymentMode,
|
||||||
|
|||||||
@ -242,7 +242,7 @@ func TestFilterByLimits(t *testing.T) {
|
|||||||
wantIDs: nil,
|
wantIDs: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty candidates returns empty",
|
name: "empty candidates returns empty",
|
||||||
candidates: nil,
|
candidates: nil,
|
||||||
paymentType: "alipay",
|
paymentType: "alipay",
|
||||||
orderAmount: 10,
|
orderAmount: 10,
|
||||||
|
|||||||
@ -76,7 +76,7 @@ func (a *Alipay) getClient() (*alipay.Client, error) {
|
|||||||
func (a *Alipay) Name() string { return "Alipay" }
|
func (a *Alipay) Name() string { return "Alipay" }
|
||||||
func (a *Alipay) ProviderKey() string { return payment.TypeAlipay }
|
func (a *Alipay) ProviderKey() string { return payment.TypeAlipay }
|
||||||
func (a *Alipay) SupportedTypes() []payment.PaymentType {
|
func (a *Alipay) SupportedTypes() []payment.PaymentType {
|
||||||
return []payment.PaymentType{payment.TypeAlipayDirect}
|
return []payment.PaymentType{payment.TypeAlipay}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreatePayment creates an Alipay payment page URL.
|
// CreatePayment creates an Alipay payment page URL.
|
||||||
|
|||||||
@ -98,9 +98,9 @@ func TestNewAlipay(t *testing.T) {
|
|||||||
errSubstr: "privateKey",
|
errSubstr: "privateKey",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "nil config map returns error for appId",
|
name: "nil config map returns error for appId",
|
||||||
config: map[string]string{},
|
config: map[string]string{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
errSubstr: "appId",
|
errSubstr: "appId",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -72,7 +72,7 @@ func NewWxpay(instanceID string, config map[string]string) (*Wxpay, error) {
|
|||||||
func (w *Wxpay) Name() string { return "Wxpay" }
|
func (w *Wxpay) Name() string { return "Wxpay" }
|
||||||
func (w *Wxpay) ProviderKey() string { return payment.TypeWxpay }
|
func (w *Wxpay) ProviderKey() string { return payment.TypeWxpay }
|
||||||
func (w *Wxpay) SupportedTypes() []payment.PaymentType {
|
func (w *Wxpay) SupportedTypes() []payment.PaymentType {
|
||||||
return []payment.PaymentType{payment.TypeWxpayDirect}
|
return []payment.PaymentType{payment.TypeWxpay}
|
||||||
}
|
}
|
||||||
|
|
||||||
func formatPEM(key, keyType string) string {
|
func formatPEM(key, keyType string) string {
|
||||||
|
|||||||
@ -148,6 +148,7 @@ type RefundResponse struct {
|
|||||||
// InstanceSelection holds the selected provider instance and its decrypted config.
|
// InstanceSelection holds the selected provider instance and its decrypted config.
|
||||||
type InstanceSelection struct {
|
type InstanceSelection struct {
|
||||||
InstanceID string
|
InstanceID string
|
||||||
|
ProviderKey string // Provider key of the selected instance (e.g. "alipay", "easypay")
|
||||||
Config map[string]string
|
Config map[string]string
|
||||||
SupportedTypes string // Comma-separated list of supported payment types from the instance
|
SupportedTypes string // Comma-separated list of supported payment types from the instance
|
||||||
PaymentMode string // Payment display mode: "qrcode", "redirect", "popup"
|
PaymentMode string // Payment display mode: "qrcode", "redirect", "popup"
|
||||||
|
|||||||
@ -18,6 +18,9 @@ const (
|
|||||||
BlockTypeFunction
|
BlockTypeFunction
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// UsageMapHook is a callback that can modify usage data before it's emitted in SSE events.
|
||||||
|
type UsageMapHook func(usageMap map[string]any)
|
||||||
|
|
||||||
// StreamingProcessor 流式响应处理器
|
// StreamingProcessor 流式响应处理器
|
||||||
type StreamingProcessor struct {
|
type StreamingProcessor struct {
|
||||||
blockType BlockType
|
blockType BlockType
|
||||||
@ -30,6 +33,7 @@ type StreamingProcessor struct {
|
|||||||
originalModel string
|
originalModel string
|
||||||
webSearchQueries []string
|
webSearchQueries []string
|
||||||
groundingChunks []GeminiGroundingChunk
|
groundingChunks []GeminiGroundingChunk
|
||||||
|
usageMapHook UsageMapHook
|
||||||
|
|
||||||
// 累计 usage
|
// 累计 usage
|
||||||
inputTokens int
|
inputTokens int
|
||||||
@ -46,6 +50,28 @@ func NewStreamingProcessor(originalModel string) *StreamingProcessor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetUsageMapHook sets an optional hook that modifies usage maps before they are emitted.
|
||||||
|
func (p *StreamingProcessor) SetUsageMapHook(fn UsageMapHook) {
|
||||||
|
p.usageMapHook = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
func usageToMap(u ClaudeUsage) map[string]any {
|
||||||
|
m := map[string]any{
|
||||||
|
"input_tokens": u.InputTokens,
|
||||||
|
"output_tokens": u.OutputTokens,
|
||||||
|
}
|
||||||
|
if u.CacheCreationInputTokens > 0 {
|
||||||
|
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
|
||||||
|
}
|
||||||
|
if u.CacheReadInputTokens > 0 {
|
||||||
|
m["cache_read_input_tokens"] = u.CacheReadInputTokens
|
||||||
|
}
|
||||||
|
if u.ImageOutputTokens > 0 {
|
||||||
|
m["image_output_tokens"] = u.ImageOutputTokens
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
|
// ProcessLine 处理 SSE 行,返回 Claude SSE 事件
|
||||||
func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
func (p *StreamingProcessor) ProcessLine(line string) []byte {
|
||||||
line = strings.TrimSpace(line)
|
line = strings.TrimSpace(line)
|
||||||
@ -172,6 +198,13 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
|||||||
responseID = "msg_" + generateRandomID()
|
responseID = "msg_" + generateRandomID()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var usageValue any = usage
|
||||||
|
if p.usageMapHook != nil {
|
||||||
|
usageMap := usageToMap(usage)
|
||||||
|
p.usageMapHook(usageMap)
|
||||||
|
usageValue = usageMap
|
||||||
|
}
|
||||||
|
|
||||||
message := map[string]any{
|
message := map[string]any{
|
||||||
"id": responseID,
|
"id": responseID,
|
||||||
"type": "message",
|
"type": "message",
|
||||||
@ -180,7 +213,7 @@ func (p *StreamingProcessor) emitMessageStart(v1Resp *V1InternalResponse) []byte
|
|||||||
"model": p.originalModel,
|
"model": p.originalModel,
|
||||||
"stop_reason": nil,
|
"stop_reason": nil,
|
||||||
"stop_sequence": nil,
|
"stop_sequence": nil,
|
||||||
"usage": usage,
|
"usage": usageValue,
|
||||||
}
|
}
|
||||||
|
|
||||||
event := map[string]any{
|
event := map[string]any{
|
||||||
@ -496,13 +529,20 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte {
|
|||||||
ImageOutputTokens: p.imageOutputTokens,
|
ImageOutputTokens: p.imageOutputTokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var usageValue any = usage
|
||||||
|
if p.usageMapHook != nil {
|
||||||
|
usageMap := usageToMap(usage)
|
||||||
|
p.usageMapHook(usageMap)
|
||||||
|
usageValue = usageMap
|
||||||
|
}
|
||||||
|
|
||||||
deltaEvent := map[string]any{
|
deltaEvent := map[string]any{
|
||||||
"type": "message_delta",
|
"type": "message_delta",
|
||||||
"delta": map[string]any{
|
"delta": map[string]any{
|
||||||
"stop_reason": stopReason,
|
"stop_reason": stopReason,
|
||||||
"stop_sequence": nil,
|
"stop_sequence": nil,
|
||||||
},
|
},
|
||||||
"usage": usage,
|
"usage": usageValue,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
|
_, _ = result.Write(p.formatSSE("message_delta", deltaEvent))
|
||||||
|
|||||||
@ -27,13 +27,14 @@ func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest,
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := &ResponsesRequest{
|
out := &ResponsesRequest{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
Input: inputJSON,
|
Instructions: req.Instructions,
|
||||||
Temperature: req.Temperature,
|
Input: inputJSON,
|
||||||
TopP: req.TopP,
|
Temperature: req.Temperature,
|
||||||
Stream: true, // upstream always streams
|
TopP: req.TopP,
|
||||||
Include: []string{"reasoning.encrypted_content"},
|
Stream: true, // upstream always streams
|
||||||
ServiceTier: req.ServiceTier,
|
Include: []string{"reasoning.encrypted_content"},
|
||||||
|
ServiceTier: req.ServiceTier,
|
||||||
}
|
}
|
||||||
|
|
||||||
storeFalse := false
|
storeFalse := false
|
||||||
|
|||||||
@ -152,6 +152,7 @@ type AnthropicDelta struct {
|
|||||||
// ResponsesRequest is the request body for POST /v1/responses.
|
// ResponsesRequest is the request body for POST /v1/responses.
|
||||||
type ResponsesRequest struct {
|
type ResponsesRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
Instructions string `json:"instructions,omitempty"`
|
||||||
Input json.RawMessage `json:"input"` // string or []ResponsesInputItem
|
Input json.RawMessage `json:"input"` // string or []ResponsesInputItem
|
||||||
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
|
||||||
Temperature *float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
@ -337,6 +338,7 @@ type ResponsesStreamEvent struct {
|
|||||||
type ChatCompletionsRequest struct {
|
type ChatCompletionsRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []ChatMessage `json:"messages"`
|
Messages []ChatMessage `json:"messages"`
|
||||||
|
Instructions string `json:"instructions,omitempty"` // OpenAI Responses API compat
|
||||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||||
Temperature *float64 `json:"temperature,omitempty"`
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
|
|||||||
@ -10,7 +10,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestInit_DualOutput(t *testing.T) {
|
func TestInit_DualOutput(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
// Use os.MkdirTemp instead of t.TempDir to avoid cleanup failures
|
||||||
|
// when lumberjack holds file handles on Windows.
|
||||||
|
tmpDir, err := os.MkdirTemp("", "logger-test-*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create temp dir: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = os.RemoveAll(tmpDir) })
|
||||||
logPath := filepath.Join(tmpDir, "logs", "sub2api.log")
|
logPath := filepath.Join(tmpDir, "logs", "sub2api.log")
|
||||||
|
|
||||||
origStdout := os.Stdout
|
origStdout := os.Stdout
|
||||||
@ -57,7 +63,9 @@ func TestInit_DualOutput(t *testing.T) {
|
|||||||
|
|
||||||
L().Info("dual-output-info")
|
L().Info("dual-output-info")
|
||||||
L().Warn("dual-output-warn")
|
L().Warn("dual-output-warn")
|
||||||
Sync()
|
|
||||||
|
// Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
|
||||||
|
// The log data is already in the pipe buffer; closing writers is sufficient.
|
||||||
|
|
||||||
_ = stdoutW.Close()
|
_ = stdoutW.Close()
|
||||||
_ = stderrW.Close()
|
_ = stderrW.Close()
|
||||||
@ -166,7 +174,9 @@ func TestInit_CallerShouldPointToCallsite(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
L().Info("caller-check")
|
L().Info("caller-check")
|
||||||
Sync()
|
// Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
|
||||||
|
os.Stdout = origStdout
|
||||||
|
os.Stderr = origStderr
|
||||||
_ = stdoutW.Close()
|
_ = stdoutW.Close()
|
||||||
logBytes, _ := io.ReadAll(stdoutR)
|
logBytes, _ := io.ReadAll(stdoutR)
|
||||||
|
|
||||||
|
|||||||
@ -77,7 +77,7 @@ func TestStdLogBridgeRoutesLevels(t *testing.T) {
|
|||||||
log.Printf("service started")
|
log.Printf("service started")
|
||||||
log.Printf("Warning: queue full")
|
log.Printf("Warning: queue full")
|
||||||
log.Printf("Forward request failed: timeout")
|
log.Printf("Forward request failed: timeout")
|
||||||
Sync()
|
// Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
|
||||||
|
|
||||||
_ = stdoutW.Close()
|
_ = stdoutW.Close()
|
||||||
_ = stderrW.Close()
|
_ = stderrW.Close()
|
||||||
@ -139,7 +139,7 @@ func TestLegacyPrintfRoutesLevels(t *testing.T) {
|
|||||||
LegacyPrintf("service.test", "request started")
|
LegacyPrintf("service.test", "request started")
|
||||||
LegacyPrintf("service.test", "Warning: queue full")
|
LegacyPrintf("service.test", "Warning: queue full")
|
||||||
LegacyPrintf("service.test", "forward failed: timeout")
|
LegacyPrintf("service.test", "forward failed: timeout")
|
||||||
Sync()
|
// Skip Sync() — on Windows, fsync on pipes deadlocks (FlushFileBuffers).
|
||||||
|
|
||||||
_ = stdoutW.Close()
|
_ = stdoutW.Close()
|
||||||
_ = stderrW.Close()
|
_ = stderrW.Close()
|
||||||
|
|||||||
@ -56,8 +56,9 @@ type DashboardStats struct {
|
|||||||
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
TotalCacheCreationTokens int64 `json:"total_cache_creation_tokens"`
|
||||||
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
TotalCacheReadTokens int64 `json:"total_cache_read_tokens"`
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
TotalCost float64 `json:"total_cost"` // 累计标准计费
|
TotalCost float64 `json:"total_cost"` // 累计标准计费
|
||||||
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
|
TotalActualCost float64 `json:"total_actual_cost"` // 累计实际扣除
|
||||||
|
TotalAccountCost float64 `json:"total_account_cost"` // 累计账号成本
|
||||||
|
|
||||||
// 今日 Token 使用统计
|
// 今日 Token 使用统计
|
||||||
TodayRequests int64 `json:"today_requests"`
|
TodayRequests int64 `json:"today_requests"`
|
||||||
@ -66,8 +67,9 @@ type DashboardStats struct {
|
|||||||
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
|
TodayCacheCreationTokens int64 `json:"today_cache_creation_tokens"`
|
||||||
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
|
TodayCacheReadTokens int64 `json:"today_cache_read_tokens"`
|
||||||
TodayTokens int64 `json:"today_tokens"`
|
TodayTokens int64 `json:"today_tokens"`
|
||||||
TodayCost float64 `json:"today_cost"` // 今日标准计费
|
TodayCost float64 `json:"today_cost"` // 今日标准计费
|
||||||
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
|
TodayActualCost float64 `json:"today_actual_cost"` // 今日实际扣除
|
||||||
|
TodayAccountCost float64 `json:"today_account_cost"` // 今日账号成本
|
||||||
|
|
||||||
// 系统运行统计
|
// 系统运行统计
|
||||||
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
|
AverageDurationMs float64 `json:"average_duration_ms"` // 平均响应时间
|
||||||
@ -99,8 +101,9 @@ type ModelStat struct {
|
|||||||
CacheCreationTokens int64 `json:"cache_creation_tokens"`
|
CacheCreationTokens int64 `json:"cache_creation_tokens"`
|
||||||
CacheReadTokens int64 `json:"cache_read_tokens"`
|
CacheReadTokens int64 `json:"cache_read_tokens"`
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
Cost float64 `json:"cost"` // 标准计费
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
AccountCost float64 `json:"account_cost"` // 账号成本
|
||||||
}
|
}
|
||||||
|
|
||||||
// EndpointStat represents usage statistics for a single request endpoint.
|
// EndpointStat represents usage statistics for a single request endpoint.
|
||||||
@ -125,8 +128,9 @@ type GroupStat struct {
|
|||||||
GroupName string `json:"group_name"`
|
GroupName string `json:"group_name"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
Cost float64 `json:"cost"` // 标准计费
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
AccountCost float64 `json:"account_cost"` // 账号成本
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserUsageTrendPoint represents user usage trend data point
|
// UserUsageTrendPoint represents user usage trend data point
|
||||||
@ -164,8 +168,9 @@ type UserBreakdownItem struct {
|
|||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Requests int64 `json:"requests"`
|
Requests int64 `json:"requests"`
|
||||||
TotalTokens int64 `json:"total_tokens"`
|
TotalTokens int64 `json:"total_tokens"`
|
||||||
Cost float64 `json:"cost"` // 标准计费
|
Cost float64 `json:"cost"` // 标准计费
|
||||||
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
ActualCost float64 `json:"actual_cost"` // 实际扣除
|
||||||
|
AccountCost float64 `json:"account_cost"` // 账号成本
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserBreakdownDimension specifies the dimension to filter for user breakdown.
|
// UserBreakdownDimension specifies the dimension to filter for user breakdown.
|
||||||
|
|||||||
106
backend/internal/pkg/websearch/brave.go
Normal file
106
backend/internal/pkg/websearch/brave.go
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
braveSearchEndpoint = "https://api.search.brave.com/res/v1/web/search"
|
||||||
|
braveMaxCount = 20
|
||||||
|
braveProviderName = "brave"
|
||||||
|
)
|
||||||
|
|
||||||
|
// braveSearchURL is pre-parsed at init time; url.Parse cannot fail on a constant literal.
|
||||||
|
var braveSearchURL, _ = url.Parse(braveSearchEndpoint) //nolint:errcheck
|
||||||
|
|
||||||
|
// BraveProvider implements web search via the Brave Search API.
|
||||||
|
type BraveProvider struct {
|
||||||
|
apiKey string
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBraveProvider creates a Brave Search provider.
|
||||||
|
// The caller is responsible for configuring the http.Client with proxy/timeouts.
|
||||||
|
func NewBraveProvider(apiKey string, httpClient *http.Client) *BraveProvider {
|
||||||
|
if httpClient == nil {
|
||||||
|
httpClient = http.DefaultClient
|
||||||
|
}
|
||||||
|
return &BraveProvider{apiKey: apiKey, httpClient: httpClient}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BraveProvider) Name() string { return braveProviderName }
|
||||||
|
|
||||||
|
func (b *BraveProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
|
||||||
|
count := req.MaxResults
|
||||||
|
if count <= 0 {
|
||||||
|
count = defaultMaxResults
|
||||||
|
}
|
||||||
|
if count > braveMaxCount {
|
||||||
|
count = braveMaxCount
|
||||||
|
}
|
||||||
|
|
||||||
|
u := *braveSearchURL // copy the pre-parsed URL
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("q", req.Query)
|
||||||
|
q.Set("count", strconv.Itoa(count))
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("brave: build request: %w", err)
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("X-Subscription-Token", b.apiKey)
|
||||||
|
httpReq.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := b.httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("brave: request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("brave: read body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("brave: status %d: %s", resp.StatusCode, truncateBody(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var raw braveResponse
|
||||||
|
if err := json.Unmarshal(body, &raw); err != nil {
|
||||||
|
return nil, fmt.Errorf("brave: decode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make([]SearchResult, 0, len(raw.Web.Results))
|
||||||
|
for _, r := range raw.Web.Results {
|
||||||
|
results = append(results, SearchResult{
|
||||||
|
URL: r.URL,
|
||||||
|
Title: r.Title,
|
||||||
|
Snippet: r.Description,
|
||||||
|
PageAge: r.Age,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SearchResponse{Results: results, Query: req.Query}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// braveResponse is the minimal structure of the Brave Search API response.
|
||||||
|
type braveResponse struct {
|
||||||
|
Web struct {
|
||||||
|
Results []braveResult `json:"results"`
|
||||||
|
} `json:"web"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type braveResult struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Age string `json:"age"`
|
||||||
|
}
|
||||||
119
backend/internal/pkg/websearch/brave_test.go
Normal file
119
backend/internal/pkg/websearch/brave_test.go
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBraveProvider_Name(t *testing.T) {
|
||||||
|
p := NewBraveProvider("key", nil)
|
||||||
|
require.Equal(t, "brave", p.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBraveProvider_Search_Success(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
require.Equal(t, "test-key", r.Header.Get("X-Subscription-Token"))
|
||||||
|
require.Equal(t, "application/json", r.Header.Get("Accept"))
|
||||||
|
require.Equal(t, "golang", r.URL.Query().Get("q"))
|
||||||
|
require.Equal(t, "3", r.URL.Query().Get("count"))
|
||||||
|
|
||||||
|
resp := braveResponse{}
|
||||||
|
resp.Web.Results = []braveResult{
|
||||||
|
{URL: "https://go.dev", Title: "Go", Description: "Go lang", Age: "1 day"},
|
||||||
|
{URL: "https://pkg.go.dev", Title: "Pkg", Description: "Packages"},
|
||||||
|
{URL: "https://tour.go.dev", Title: "Tour", Description: "A Tour of Go", Age: "3 days"},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
p := NewBraveProvider("test-key", srv.Client())
|
||||||
|
// Override the endpoint for testing
|
||||||
|
origURL := *braveSearchURL
|
||||||
|
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||||
|
*braveSearchURL = *u.URL
|
||||||
|
defer func() { *braveSearchURL = origURL }()
|
||||||
|
|
||||||
|
resp, err := p.Search(context.Background(), SearchRequest{Query: "golang", MaxResults: 3})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, resp.Results, 3)
|
||||||
|
require.Equal(t, "https://go.dev", resp.Results[0].URL)
|
||||||
|
require.Equal(t, "Go lang", resp.Results[0].Snippet)
|
||||||
|
require.Equal(t, "1 day", resp.Results[0].PageAge)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBraveProvider_Search_DefaultMaxResults(t *testing.T) {
|
||||||
|
var receivedCount string
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedCount = r.URL.Query().Get("count")
|
||||||
|
resp := braveResponse{}
|
||||||
|
_ = json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
p := NewBraveProvider("key", srv.Client())
|
||||||
|
origURL := *braveSearchURL
|
||||||
|
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||||
|
*braveSearchURL = *u.URL
|
||||||
|
defer func() { *braveSearchURL = origURL }()
|
||||||
|
|
||||||
|
_, _ = p.Search(context.Background(), SearchRequest{Query: "test", MaxResults: 0})
|
||||||
|
require.Equal(t, "5", receivedCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBraveProvider_Search_HTTPError(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(429)
|
||||||
|
_, _ = w.Write([]byte("rate limited"))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
p := NewBraveProvider("key", srv.Client())
|
||||||
|
origURL := *braveSearchURL
|
||||||
|
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||||
|
*braveSearchURL = *u.URL
|
||||||
|
defer func() { *braveSearchURL = origURL }()
|
||||||
|
|
||||||
|
_, err := p.Search(context.Background(), SearchRequest{Query: "test"})
|
||||||
|
require.ErrorContains(t, err, "brave: status 429")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBraveProvider_Search_InvalidJSON(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("not json"))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
p := NewBraveProvider("key", srv.Client())
|
||||||
|
origURL := *braveSearchURL
|
||||||
|
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||||
|
*braveSearchURL = *u.URL
|
||||||
|
defer func() { *braveSearchURL = origURL }()
|
||||||
|
|
||||||
|
_, err := p.Search(context.Background(), SearchRequest{Query: "test"})
|
||||||
|
require.ErrorContains(t, err, "brave: decode response")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBraveProvider_Search_EmptyResults(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
resp := braveResponse{}
|
||||||
|
_ = json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
p := NewBraveProvider("key", srv.Client())
|
||||||
|
origURL := *braveSearchURL
|
||||||
|
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||||
|
*braveSearchURL = *u.URL
|
||||||
|
defer func() { *braveSearchURL = origURL }()
|
||||||
|
|
||||||
|
resp, err := p.Search(context.Background(), SearchRequest{Query: "test"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Empty(t, resp.Results)
|
||||||
|
}
|
||||||
14
backend/internal/pkg/websearch/helpers.go
Normal file
14
backend/internal/pkg/websearch/helpers.go
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxResponseSize = 1 << 20 // 1 MB
|
||||||
|
errorBodyTruncLen = 200
|
||||||
|
)
|
||||||
|
|
||||||
|
// truncateBody returns a truncated string of body for error messages.
|
||||||
|
func truncateBody(body []byte) string {
|
||||||
|
if len(body) <= errorBodyTruncLen {
|
||||||
|
return string(body)
|
||||||
|
}
|
||||||
|
return string(body[:errorBodyTruncLen]) + "...(truncated)"
|
||||||
|
}
|
||||||
25
backend/internal/pkg/websearch/helpers_test.go
Normal file
25
backend/internal/pkg/websearch/helpers_test.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTruncateBody_Short(t *testing.T) {
|
||||||
|
body := []byte("short body")
|
||||||
|
require.Equal(t, "short body", truncateBody(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTruncateBody_Long(t *testing.T) {
|
||||||
|
body := []byte(strings.Repeat("x", 500))
|
||||||
|
result := truncateBody(body)
|
||||||
|
require.Len(t, result, errorBodyTruncLen+len("...(truncated)"))
|
||||||
|
require.True(t, strings.HasSuffix(result, "...(truncated)"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTruncateBody_ExactBoundary(t *testing.T) {
|
||||||
|
body := []byte(strings.Repeat("x", errorBodyTruncLen))
|
||||||
|
require.Equal(t, string(body), truncateBody(body))
|
||||||
|
}
|
||||||
528
backend/internal/pkg/websearch/manager.go
Normal file
528
backend/internal/pkg/websearch/manager.go
Normal file
@ -0,0 +1,528 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProviderConfig holds the configuration for a single search provider.
|
||||||
|
type ProviderConfig struct {
|
||||||
|
Type string `json:"type"` // ProviderTypeBrave | ProviderTypeTavily
|
||||||
|
APIKey string `json:"api_key"` // secret
|
||||||
|
QuotaLimit int64 `json:"quota_limit"` // 0 = unlimited
|
||||||
|
SubscribedAt *int64 `json:"subscribed_at,omitempty"` // subscription start (unix seconds); quota resets monthly from this date
|
||||||
|
ProxyURL string `json:"-"` // resolved proxy URL (not persisted)
|
||||||
|
ProxyID int64 `json:"-"` // resolved proxy ID for unavailability tracking
|
||||||
|
ExpiresAt *int64 `json:"expires_at,omitempty"` // optional expiration (unix seconds)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager selects providers by quota-weighted load balancing and tracks quota via Redis.
|
||||||
|
type Manager struct {
|
||||||
|
configs []ProviderConfig
|
||||||
|
redis *redis.Client
|
||||||
|
|
||||||
|
clientMu sync.Mutex
|
||||||
|
clientCache map[string]*http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timeout constants for proxy and search operations.
|
||||||
|
const (
|
||||||
|
proxyDialTimeout = 3 * time.Second // proxy TCP connection timeout
|
||||||
|
proxyTLSTimeout = 3 * time.Second // TLS handshake timeout
|
||||||
|
searchDataTimeout = 60 * time.Second // response data transfer timeout
|
||||||
|
searchRequestTimeout = searchDataTimeout + proxyDialTimeout
|
||||||
|
|
||||||
|
quotaKeyPrefix = "websearch:quota:"
|
||||||
|
proxyUnavailableKey = "websearch:proxy_unavailable:%d"
|
||||||
|
proxyUnavailableTTL = 5 * time.Minute
|
||||||
|
quotaTTLBuffer = 24 * time.Hour
|
||||||
|
defaultQuotaTTL = 31*24*time.Hour + quotaTTLBuffer // fallback when no subscription date
|
||||||
|
maxCachedClients = 100
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrProxyUnavailable indicates the search failed due to a proxy connectivity issue.
|
||||||
|
// Callers may use this to trigger account switching instead of direct fallback.
|
||||||
|
var ErrProxyUnavailable = errors.New("websearch: proxy unavailable")
|
||||||
|
|
||||||
|
// quotaIncrScript atomically increments the counter and sets TTL on first creation.
|
||||||
|
var quotaIncrScript = redis.NewScript(`
|
||||||
|
local val = redis.call('INCR', KEYS[1])
|
||||||
|
if val == 1 then
|
||||||
|
redis.call('EXPIRE', KEYS[1], ARGV[1])
|
||||||
|
else
|
||||||
|
local ttl = redis.call('TTL', KEYS[1])
|
||||||
|
if ttl == -1 then
|
||||||
|
redis.call('EXPIRE', KEYS[1], ARGV[1])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return val
|
||||||
|
`)
|
||||||
|
|
||||||
|
// NewManager creates a Manager with the given provider configs and Redis client.
|
||||||
|
// Provider order is preserved as-is; selectByQuotaWeight handles load balancing.
|
||||||
|
func NewManager(configs []ProviderConfig, redisClient *redis.Client) *Manager {
|
||||||
|
copied := make([]ProviderConfig, len(configs))
|
||||||
|
copy(copied, configs)
|
||||||
|
return &Manager{
|
||||||
|
configs: copied,
|
||||||
|
redis: redisClient,
|
||||||
|
clientCache: make(map[string]*http.Client),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchWithBestProvider selects a provider using quota-weighted load balancing,
|
||||||
|
// reserves quota, executes the search, and rolls back quota on failure.
|
||||||
|
// If the search fails due to a proxy error, the proxy is marked unavailable for 5 minutes.
|
||||||
|
func (m *Manager) SearchWithBestProvider(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
|
||||||
|
if strings.TrimSpace(req.Query) == "" {
|
||||||
|
return nil, "", fmt.Errorf("websearch: empty search query")
|
||||||
|
}
|
||||||
|
|
||||||
|
candidates := m.filterAvailableProviders(ctx, req.ProxyURL)
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil, "", fmt.Errorf("websearch: no available provider (all exhausted, expired, or proxy unavailable)")
|
||||||
|
}
|
||||||
|
|
||||||
|
selected := m.selectByQuotaWeight(ctx, candidates)
|
||||||
|
|
||||||
|
for _, cfg := range selected {
|
||||||
|
allowed, incremented := m.tryReserveQuota(ctx, cfg)
|
||||||
|
if !allowed {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
resp, err := m.executeSearch(ctx, cfg, req)
|
||||||
|
if err != nil {
|
||||||
|
if incremented {
|
||||||
|
m.rollbackQuota(ctx, cfg)
|
||||||
|
}
|
||||||
|
if isProxyError(err) {
|
||||||
|
m.markProxyUnavailable(ctx, cfg, req.ProxyURL)
|
||||||
|
if req.ProxyURL != "" {
|
||||||
|
// Account-level proxy is shared by all providers — no point
|
||||||
|
// trying others with the same broken proxy; signal account switch.
|
||||||
|
slog.Warn("websearch: account proxy error, aborting failover",
|
||||||
|
"provider", cfg.Type, "error", err)
|
||||||
|
return nil, "", fmt.Errorf("%w: %s", ErrProxyUnavailable, err.Error())
|
||||||
|
}
|
||||||
|
// Provider-specific proxy failed — try the next provider which
|
||||||
|
// may use a different (or no) proxy.
|
||||||
|
slog.Warn("websearch: provider proxy error, trying next provider",
|
||||||
|
"provider", cfg.Type, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
slog.Warn("websearch: provider search failed",
|
||||||
|
"provider", cfg.Type, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return resp, cfg.Type, nil
|
||||||
|
}
|
||||||
|
return nil, "", fmt.Errorf("websearch: no available provider (all exhausted or failed)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterAvailableProviders returns providers that have API keys, are not expired,
|
||||||
|
// and whose proxies are not marked unavailable.
|
||||||
|
func (m *Manager) filterAvailableProviders(ctx context.Context, accountProxyURL string) []ProviderConfig {
|
||||||
|
var out []ProviderConfig
|
||||||
|
for _, cfg := range m.configs {
|
||||||
|
if !m.isProviderAvailable(cfg) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
proxyID := resolveProxyID(cfg, accountProxyURL)
|
||||||
|
if proxyID > 0 && !m.isProxyAvailable(ctx, proxyID) {
|
||||||
|
slog.Debug("websearch: proxy marked unavailable, skipping",
|
||||||
|
"provider", cfg.Type, "proxy_id", proxyID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, cfg)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// weighted is a provider candidate with computed quota weight.
|
||||||
|
type weighted struct {
|
||||||
|
cfg ProviderConfig
|
||||||
|
weight int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectByQuotaWeight orders candidates by remaining quota weight.
|
||||||
|
// Providers with quota_limit=0 (no limit set) get weight 0 and are placed last.
|
||||||
|
// Among providers with quota, higher remaining quota = higher priority.
|
||||||
|
func (m *Manager) selectByQuotaWeight(ctx context.Context, candidates []ProviderConfig) []ProviderConfig {
|
||||||
|
items := m.computeWeights(ctx, candidates)
|
||||||
|
withQuota, withoutQuota := partitionByQuota(items)
|
||||||
|
sortByStableRandomWeight(withQuota)
|
||||||
|
return mergeWeightedResults(withQuota, withoutQuota, len(candidates))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) computeWeights(ctx context.Context, candidates []ProviderConfig) []weighted {
|
||||||
|
items := make([]weighted, 0, len(candidates))
|
||||||
|
for _, cfg := range candidates {
|
||||||
|
w := int64(0)
|
||||||
|
if cfg.QuotaLimit > 0 {
|
||||||
|
used, _ := m.GetUsage(ctx, cfg.Type)
|
||||||
|
if remaining := cfg.QuotaLimit - used; remaining > 0 {
|
||||||
|
w = remaining
|
||||||
|
}
|
||||||
|
}
|
||||||
|
items = append(items, weighted{cfg: cfg, weight: w})
|
||||||
|
}
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
|
||||||
|
func partitionByQuota(items []weighted) (withQuota, withoutQuota []weighted) {
|
||||||
|
for _, item := range items {
|
||||||
|
if item.weight > 0 {
|
||||||
|
withQuota = append(withQuota, item)
|
||||||
|
} else {
|
||||||
|
withoutQuota = append(withoutQuota, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// sortByStableRandomWeight assigns a fixed random factor to each item before sorting,
|
||||||
|
// ensuring deterministic sort behavior (transitivity) within a single call.
|
||||||
|
func sortByStableRandomWeight(items []weighted) {
|
||||||
|
if len(items) <= 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
type entry struct {
|
||||||
|
item weighted
|
||||||
|
factor float64
|
||||||
|
}
|
||||||
|
entries := make([]entry, len(items))
|
||||||
|
for i, item := range items {
|
||||||
|
entries[i] = entry{item: item, factor: float64(item.weight) * (0.5 + rand.Float64())}
|
||||||
|
}
|
||||||
|
sort.Slice(entries, func(i, j int) bool {
|
||||||
|
return entries[i].factor > entries[j].factor
|
||||||
|
})
|
||||||
|
for i, e := range entries {
|
||||||
|
items[i] = e.item
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeWeightedResults(withQuota, withoutQuota []weighted, capacity int) []ProviderConfig {
|
||||||
|
result := make([]ProviderConfig, 0, capacity)
|
||||||
|
for _, item := range withQuota {
|
||||||
|
result = append(result, item.cfg)
|
||||||
|
}
|
||||||
|
for _, item := range withoutQuota {
|
||||||
|
result = append(result, item.cfg)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) isProviderAvailable(cfg ProviderConfig) bool {
|
||||||
|
if cfg.APIKey == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if cfg.ExpiresAt != nil && time.Now().Unix() > *cfg.ExpiresAt {
|
||||||
|
slog.Info("websearch: provider expired, skipping",
|
||||||
|
"provider", cfg.Type, "expires_at", *cfg.ExpiresAt)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Proxy availability tracking ---
|
||||||
|
|
||||||
|
// markProxyUnavailable marks the effective proxy as unavailable for proxyUnavailableTTL.
|
||||||
|
func (m *Manager) markProxyUnavailable(ctx context.Context, cfg ProviderConfig, accountProxyURL string) {
|
||||||
|
proxyID := resolveProxyID(cfg, accountProxyURL)
|
||||||
|
if proxyID <= 0 || m.redis == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key := fmt.Sprintf(proxyUnavailableKey, proxyID)
|
||||||
|
if err := m.redis.Set(ctx, key, "1", proxyUnavailableTTL).Err(); err != nil {
|
||||||
|
slog.Warn("websearch: failed to mark proxy unavailable",
|
||||||
|
"proxy_id", proxyID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isProxyAvailable checks whether a proxy is currently marked as unavailable.
|
||||||
|
func (m *Manager) isProxyAvailable(ctx context.Context, proxyID int64) bool {
|
||||||
|
if m.redis == nil || proxyID <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
key := fmt.Sprintf(proxyUnavailableKey, proxyID)
|
||||||
|
val, err := m.redis.Get(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return true // Redis error → assume available
|
||||||
|
}
|
||||||
|
return val == ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveProxyID determines the effective proxy ID for a provider+account combination.
|
||||||
|
func resolveProxyID(cfg ProviderConfig, accountProxyURL string) int64 {
|
||||||
|
if accountProxyURL != "" {
|
||||||
|
return 0 // account proxy has no ID in provider config
|
||||||
|
}
|
||||||
|
return cfg.ProxyID
|
||||||
|
}
|
||||||
|
|
||||||
|
// isProxyError checks whether the error is likely caused by proxy or network connectivity
|
||||||
|
// (as opposed to an API-level error from the search provider).
|
||||||
|
func isProxyError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Network-level errors (timeout, connection refused, DNS failure)
|
||||||
|
var netErr net.Error
|
||||||
|
if errors.As(err, &netErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
var opErr *net.OpError
|
||||||
|
if errors.As(err, &opErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// TLS handshake failures (often caused by proxy intercepting/blocking)
|
||||||
|
var tlsErr *tls.RecordHeaderError
|
||||||
|
if errors.As(err, &tlsErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// String-based detection for wrapped errors
|
||||||
|
msg := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(msg, "proxy") ||
|
||||||
|
strings.Contains(msg, "socks") ||
|
||||||
|
strings.Contains(msg, "connection refused") ||
|
||||||
|
strings.Contains(msg, "no such host") ||
|
||||||
|
strings.Contains(msg, "i/o timeout") ||
|
||||||
|
strings.Contains(msg, "tls handshake") ||
|
||||||
|
strings.Contains(msg, "certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Quota management ---
|
||||||
|
|
||||||
|
func (m *Manager) tryReserveQuota(ctx context.Context, cfg ProviderConfig) (bool, bool) {
|
||||||
|
if cfg.QuotaLimit <= 0 {
|
||||||
|
return true, false
|
||||||
|
}
|
||||||
|
if m.redis == nil {
|
||||||
|
slog.Warn("websearch: Redis unavailable, quota check skipped", "provider", cfg.Type)
|
||||||
|
return true, false
|
||||||
|
}
|
||||||
|
key := quotaRedisKey(cfg.Type)
|
||||||
|
ttlSec := int(quotaTTLFromSubscription(cfg.SubscribedAt).Seconds())
|
||||||
|
newVal, err := quotaIncrScript.Run(ctx, m.redis, []string{key}, ttlSec).Int64()
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("websearch: quota Lua INCR failed, allowing request",
|
||||||
|
"provider", cfg.Type, "error", err)
|
||||||
|
return true, false
|
||||||
|
}
|
||||||
|
if newVal > cfg.QuotaLimit {
|
||||||
|
if decrErr := m.redis.Decr(ctx, key).Err(); decrErr != nil {
|
||||||
|
slog.Warn("websearch: quota over-limit DECR failed",
|
||||||
|
"provider", cfg.Type, "error", decrErr)
|
||||||
|
}
|
||||||
|
slog.Info("websearch: provider quota exhausted",
|
||||||
|
"provider", cfg.Type, "used", newVal, "limit", cfg.QuotaLimit)
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return true, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) rollbackQuota(ctx context.Context, cfg ProviderConfig) {
|
||||||
|
if cfg.QuotaLimit <= 0 || m.redis == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key := quotaRedisKey(cfg.Type)
|
||||||
|
if err := m.redis.Decr(ctx, key).Err(); err != nil {
|
||||||
|
slog.Warn("websearch: quota rollback DECR failed",
|
||||||
|
"provider", cfg.Type, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Search execution ---
|
||||||
|
|
||||||
|
// TestSearch executes a search using the first available provider without reserving quota.
|
||||||
|
// Intended for admin test functionality only.
|
||||||
|
func (m *Manager) TestSearch(ctx context.Context, req SearchRequest) (*SearchResponse, string, error) {
|
||||||
|
if strings.TrimSpace(req.Query) == "" {
|
||||||
|
return nil, "", fmt.Errorf("websearch: empty search query")
|
||||||
|
}
|
||||||
|
for _, cfg := range m.configs {
|
||||||
|
if !m.isProviderAvailable(cfg) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
resp, err := m.executeSearch(ctx, cfg, req)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return resp, cfg.Type, nil
|
||||||
|
}
|
||||||
|
return nil, "", fmt.Errorf("websearch: no available provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) executeSearch(ctx context.Context, cfg ProviderConfig, req SearchRequest) (*SearchResponse, error) {
|
||||||
|
proxyURL := cfg.ProxyURL
|
||||||
|
if req.ProxyURL != "" {
|
||||||
|
proxyURL = req.ProxyURL
|
||||||
|
}
|
||||||
|
client, err := m.getOrCreateHTTPClient(proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("websearch: %w", err)
|
||||||
|
}
|
||||||
|
provider := m.buildProvider(cfg, client)
|
||||||
|
return provider.Search(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- HTTP client cache ---
|
||||||
|
|
||||||
|
func (m *Manager) getOrCreateHTTPClient(proxyURL string) (*http.Client, error) {
|
||||||
|
m.clientMu.Lock()
|
||||||
|
defer m.clientMu.Unlock()
|
||||||
|
|
||||||
|
if c, ok := m.clientCache[proxyURL]; ok {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
if len(m.clientCache) >= maxCachedClients {
|
||||||
|
m.clientCache = make(map[string]*http.Client)
|
||||||
|
}
|
||||||
|
c, err := newHTTPClient(proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m.clientCache[proxyURL] = c
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newHTTPClient creates an HTTP client with proper timeout settings.
|
||||||
|
// Uses proxyutil.ConfigureTransportProxy for unified proxy protocol support
|
||||||
|
// (HTTP/HTTPS/SOCKS5/SOCKS5H).
|
||||||
|
// Returns error if proxyURL is invalid — never falls back to direct connection.
|
||||||
|
func newHTTPClient(proxyURL string) (*http.Client, error) {
|
||||||
|
transport := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
|
||||||
|
DialContext: (&net.Dialer{Timeout: proxyDialTimeout}).DialContext,
|
||||||
|
TLSHandshakeTimeout: proxyTLSTimeout,
|
||||||
|
ResponseHeaderTimeout: searchDataTimeout,
|
||||||
|
}
|
||||||
|
if proxyURL != "" {
|
||||||
|
parsed, err := url.Parse(proxyURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid proxy URL %q: %w", proxyURL, err)
|
||||||
|
}
|
||||||
|
if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
|
||||||
|
return nil, fmt.Errorf("configure proxy: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &http.Client{Transport: transport, Timeout: searchRequestTimeout}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUsage returns the current usage count for the given provider.
|
||||||
|
func (m *Manager) GetUsage(ctx context.Context, providerType string) (int64, error) {
|
||||||
|
if m.redis == nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
key := quotaRedisKey(providerType)
|
||||||
|
val, err := m.redis.Get(ctx, key).Int64()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return val, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllUsage returns usage for every configured provider.
|
||||||
|
func (m *Manager) GetAllUsage(ctx context.Context) map[string]int64 {
|
||||||
|
result := make(map[string]int64, len(m.configs))
|
||||||
|
for _, cfg := range m.configs {
|
||||||
|
used, _ := m.GetUsage(ctx, cfg.Type)
|
||||||
|
result[cfg.Type] = used
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetUsage deletes the Redis quota key for the given provider, resetting usage to 0.
|
||||||
|
func (m *Manager) ResetUsage(ctx context.Context, providerType string) error {
|
||||||
|
if m.redis == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
key := quotaRedisKey(providerType)
|
||||||
|
return m.redis.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Provider factory ---
|
||||||
|
|
||||||
|
func (m *Manager) buildProvider(cfg ProviderConfig, client *http.Client) Provider {
|
||||||
|
switch cfg.Type {
|
||||||
|
case braveProviderName:
|
||||||
|
return NewBraveProvider(cfg.APIKey, client)
|
||||||
|
case tavilyProviderName:
|
||||||
|
return NewTavilyProvider(cfg.APIKey, client)
|
||||||
|
default:
|
||||||
|
slog.Warn("websearch: unknown provider type, falling back to brave",
|
||||||
|
"type", cfg.Type)
|
||||||
|
return NewBraveProvider(cfg.APIKey, client)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Redis key helpers ---
|
||||||
|
|
||||||
|
func quotaRedisKey(providerType string) string {
|
||||||
|
return quotaKeyPrefix + providerType
|
||||||
|
}
|
||||||
|
|
||||||
|
// quotaTTLFromSubscription calculates the TTL for the quota counter based on
|
||||||
|
// the provider's subscription start date. Quota resets monthly from that date.
|
||||||
|
// When the Redis key expires naturally, the next INCR creates a fresh counter (lazy refresh).
|
||||||
|
func quotaTTLFromSubscription(subscribedAt *int64) time.Duration {
|
||||||
|
if subscribedAt == nil || *subscribedAt == 0 {
|
||||||
|
return defaultQuotaTTL
|
||||||
|
}
|
||||||
|
next := nextMonthlyReset(time.Unix(*subscribedAt, 0).UTC())
|
||||||
|
ttl := time.Until(next) + quotaTTLBuffer
|
||||||
|
if ttl <= quotaTTLBuffer {
|
||||||
|
// Already past the reset — next cycle
|
||||||
|
ttl = defaultQuotaTTL
|
||||||
|
}
|
||||||
|
return ttl
|
||||||
|
}
|
||||||
|
|
||||||
|
// nextMonthlyReset returns the next monthly reset time based on the subscription start date.
|
||||||
|
// E.g., subscribed on Jan 15 → resets on Feb 15, Mar 15, etc.
|
||||||
|
// Handles day-of-month overflow: Jan 31 → Feb 28 (not Mar 3).
|
||||||
|
func nextMonthlyReset(subscribedAt time.Time) time.Time {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
if subscribedAt.IsZero() {
|
||||||
|
return now.AddDate(0, 1, 0)
|
||||||
|
}
|
||||||
|
months := (now.Year()-subscribedAt.Year())*12 + int(now.Month()-subscribedAt.Month())
|
||||||
|
if months < 0 {
|
||||||
|
months = 0
|
||||||
|
}
|
||||||
|
candidate := addMonthsClamped(subscribedAt, months)
|
||||||
|
if candidate.After(now) {
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
|
return addMonthsClamped(subscribedAt, months+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// addMonthsClamped adds N months to a date, clamping the day to the last day of the target month.
|
||||||
|
// E.g., Jan 31 + 1 month = Feb 28 (not Mar 3).
|
||||||
|
func addMonthsClamped(t time.Time, months int) time.Time {
|
||||||
|
y, m, d := t.Date()
|
||||||
|
targetMonth := time.Month(int(m) + months)
|
||||||
|
targetYear := y + int(targetMonth-1)/12
|
||||||
|
targetMonth = (targetMonth-1)%12 + 1
|
||||||
|
// Last day of the target month
|
||||||
|
lastDay := time.Date(targetYear, targetMonth+1, 0, 0, 0, 0, 0, time.UTC).Day()
|
||||||
|
if d > lastDay {
|
||||||
|
d = lastDay
|
||||||
|
}
|
||||||
|
return time.Date(targetYear, targetMonth, d, 0, 0, 0, 0, time.UTC)
|
||||||
|
}
|
||||||
323
backend/internal/pkg/websearch/manager_test.go
Normal file
323
backend/internal/pkg/websearch/manager_test.go
Normal file
@ -0,0 +1,323 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewManager_PreservesOrder(t *testing.T) {
|
||||||
|
configs := []ProviderConfig{
|
||||||
|
{Type: "brave", APIKey: "k3"},
|
||||||
|
{Type: "tavily", APIKey: "k1"},
|
||||||
|
}
|
||||||
|
m := NewManager(configs, nil)
|
||||||
|
require.Equal(t, "brave", m.configs[0].Type)
|
||||||
|
require.Equal(t, "tavily", m.configs[1].Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_SearchWithBestProvider_EmptyQuery(t *testing.T) {
|
||||||
|
m := NewManager([]ProviderConfig{{Type: "brave", APIKey: "k"}}, nil)
|
||||||
|
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: ""})
|
||||||
|
require.ErrorContains(t, err, "empty search query")
|
||||||
|
|
||||||
|
_, _, err = m.SearchWithBestProvider(context.Background(), SearchRequest{Query: " "})
|
||||||
|
require.ErrorContains(t, err, "empty search query")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_SearchWithBestProvider_SkipEmptyAPIKey(t *testing.T) {
|
||||||
|
m := NewManager([]ProviderConfig{{Type: "brave", APIKey: ""}}, nil)
|
||||||
|
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||||
|
require.ErrorContains(t, err, "no available provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_SearchWithBestProvider_SkipExpired(t *testing.T) {
|
||||||
|
past := time.Now().Add(-1 * time.Hour).Unix()
|
||||||
|
m := NewManager([]ProviderConfig{
|
||||||
|
{Type: "brave", APIKey: "k", ExpiresAt: &past},
|
||||||
|
}, nil)
|
||||||
|
_, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||||
|
require.ErrorContains(t, err, "no available provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_SearchWithBestProvider_UsesFirstAvailable(t *testing.T) {
|
||||||
|
srvBrave := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
resp := braveResponse{}
|
||||||
|
resp.Web.Results = []braveResult{{URL: "https://brave.com", Title: "Brave", Description: "from brave"}}
|
||||||
|
_ = json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer srvBrave.Close()
|
||||||
|
|
||||||
|
origURL := *braveSearchURL
|
||||||
|
u, _ := http.NewRequest("GET", srvBrave.URL, nil)
|
||||||
|
*braveSearchURL = *u.URL
|
||||||
|
defer func() { *braveSearchURL = origURL }()
|
||||||
|
|
||||||
|
m := NewManager([]ProviderConfig{
|
||||||
|
{Type: "brave", APIKey: "k1"},
|
||||||
|
{Type: "tavily", APIKey: "k2"},
|
||||||
|
}, nil)
|
||||||
|
m.clientCache[srvBrave.URL] = srvBrave.Client()
|
||||||
|
m.clientCache[""] = srvBrave.Client()
|
||||||
|
|
||||||
|
resp, providerName, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "brave", providerName)
|
||||||
|
require.Len(t, resp.Results, 1)
|
||||||
|
require.Equal(t, "from brave", resp.Results[0].Snippet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_SearchWithBestProvider_NilRedis(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
resp := braveResponse{}
|
||||||
|
resp.Web.Results = []braveResult{{URL: "https://test.com", Title: "Test", Description: "result"}}
|
||||||
|
_ = json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
origURL := *braveSearchURL
|
||||||
|
u, _ := http.NewRequest("GET", srv.URL, nil)
|
||||||
|
*braveSearchURL = *u.URL
|
||||||
|
defer func() { *braveSearchURL = origURL }()
|
||||||
|
|
||||||
|
m := NewManager([]ProviderConfig{
|
||||||
|
{Type: "brave", APIKey: "k", QuotaLimit: 100},
|
||||||
|
}, nil)
|
||||||
|
m.clientCache[""] = srv.Client()
|
||||||
|
|
||||||
|
resp, _, err := m.SearchWithBestProvider(context.Background(), SearchRequest{Query: "test"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, resp.Results, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_GetUsage_NilRedis(t *testing.T) {
|
||||||
|
m := NewManager(nil, nil)
|
||||||
|
used, err := m.GetUsage(context.Background(), "brave")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, int64(0), used)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_GetAllUsage_NilRedis(t *testing.T) {
|
||||||
|
m := NewManager([]ProviderConfig{
|
||||||
|
{Type: "brave"},
|
||||||
|
}, nil)
|
||||||
|
usage := m.GetAllUsage(context.Background())
|
||||||
|
require.Equal(t, int64(0), usage["brave"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Quota TTL from subscription ---
|
||||||
|
|
||||||
|
func TestQuotaTTLFromSubscription_NilSubscription(t *testing.T) {
|
||||||
|
ttl := quotaTTLFromSubscription(nil)
|
||||||
|
require.Equal(t, defaultQuotaTTL, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuotaTTLFromSubscription_ZeroSubscription(t *testing.T) {
|
||||||
|
zero := int64(0)
|
||||||
|
ttl := quotaTTLFromSubscription(&zero)
|
||||||
|
require.Equal(t, defaultQuotaTTL, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuotaTTLFromSubscription_ValidSubscription(t *testing.T) {
|
||||||
|
// Subscribed 10 days ago — next reset in ~20 days
|
||||||
|
sub := time.Now().Add(-10 * 24 * time.Hour).Unix()
|
||||||
|
ttl := quotaTTLFromSubscription(&sub)
|
||||||
|
require.Greater(t, ttl, 15*24*time.Hour) // at least 15 days
|
||||||
|
require.Less(t, ttl, 25*24*time.Hour+quotaTTLBuffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextMonthlyReset_SubscribedRecentPast(t *testing.T) {
|
||||||
|
// Subscribed on the 10th of this month (always valid day)
|
||||||
|
now := time.Now().UTC()
|
||||||
|
sub := time.Date(now.Year(), now.Month(), 10, 0, 0, 0, 0, time.UTC)
|
||||||
|
next := nextMonthlyReset(sub)
|
||||||
|
require.True(t, next.After(now) || next.Equal(now), "next reset should be in the future or now")
|
||||||
|
require.True(t, next.Before(now.AddDate(0, 1, 1)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextMonthlyReset_SubscribedLongAgo(t *testing.T) {
|
||||||
|
// Subscribed 6 months ago on the 1st
|
||||||
|
sub := time.Now().UTC().AddDate(0, -6, 0)
|
||||||
|
sub = time.Date(sub.Year(), sub.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
next := nextMonthlyReset(sub)
|
||||||
|
require.True(t, next.After(time.Now().UTC()))
|
||||||
|
// Should be within the next 31 days
|
||||||
|
require.True(t, next.Before(time.Now().UTC().AddDate(0, 1, 1)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextMonthlyReset_FutureSubscription(t *testing.T) {
|
||||||
|
sub := time.Now().UTC().AddDate(0, 0, 5)
|
||||||
|
next := nextMonthlyReset(sub)
|
||||||
|
require.True(t, next.After(time.Now().UTC()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddMonthsClamped_Jan31ToFeb(t *testing.T) {
|
||||||
|
sub := time.Date(2026, 1, 31, 0, 0, 0, 0, time.UTC)
|
||||||
|
next := addMonthsClamped(sub, 1)
|
||||||
|
require.Equal(t, time.Month(2), next.Month())
|
||||||
|
require.Equal(t, 28, next.Day()) // Feb 28 (2026 is not a leap year)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddMonthsClamped_Jan31ToFebLeapYear(t *testing.T) {
|
||||||
|
sub := time.Date(2028, 1, 31, 0, 0, 0, 0, time.UTC)
|
||||||
|
next := addMonthsClamped(sub, 1)
|
||||||
|
require.Equal(t, time.Month(2), next.Month())
|
||||||
|
require.Equal(t, 29, next.Day()) // Feb 29 (2028 is a leap year)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddMonthsClamped_Mar31ToApr(t *testing.T) {
|
||||||
|
sub := time.Date(2026, 3, 31, 0, 0, 0, 0, time.UTC)
|
||||||
|
next := addMonthsClamped(sub, 1)
|
||||||
|
require.Equal(t, time.Month(4), next.Month())
|
||||||
|
require.Equal(t, 30, next.Day()) // Apr has 30 days
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddMonthsClamped_NormalDay(t *testing.T) {
|
||||||
|
sub := time.Date(2026, 1, 15, 0, 0, 0, 0, time.UTC)
|
||||||
|
next := addMonthsClamped(sub, 1)
|
||||||
|
require.Equal(t, time.Month(2), next.Month())
|
||||||
|
require.Equal(t, 15, next.Day()) // no clamping needed
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Redis key ---
|
||||||
|
|
||||||
|
func TestQuotaRedisKey_Format(t *testing.T) {
|
||||||
|
key := quotaRedisKey("brave")
|
||||||
|
require.Equal(t, "websearch:quota:brave", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- isProviderAvailable ---
|
||||||
|
|
||||||
|
func TestIsProviderAvailable_EmptyAPIKey(t *testing.T) {
|
||||||
|
m := NewManager(nil, nil)
|
||||||
|
require.False(t, m.isProviderAvailable(ProviderConfig{APIKey: ""}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsProviderAvailable_Expired(t *testing.T) {
|
||||||
|
m := NewManager(nil, nil)
|
||||||
|
past := time.Now().Add(-1 * time.Hour).Unix()
|
||||||
|
require.False(t, m.isProviderAvailable(ProviderConfig{APIKey: "k", ExpiresAt: &past}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsProviderAvailable_Valid(t *testing.T) {
|
||||||
|
m := NewManager(nil, nil)
|
||||||
|
future := time.Now().Add(1 * time.Hour).Unix()
|
||||||
|
require.True(t, m.isProviderAvailable(ProviderConfig{APIKey: "k", ExpiresAt: &future}))
|
||||||
|
require.True(t, m.isProviderAvailable(ProviderConfig{APIKey: "k"})) // no expiry
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- resolveProxyID ---
|
||||||
|
|
||||||
|
func TestResolveProxyID_AccountProxyOverrides(t *testing.T) {
|
||||||
|
cfg := ProviderConfig{ProxyID: 42}
|
||||||
|
require.Equal(t, int64(0), resolveProxyID(cfg, "http://account-proxy:8080"))
|
||||||
|
require.Equal(t, int64(42), resolveProxyID(cfg, ""))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- isProxyError ---
|
||||||
|
|
||||||
|
func TestIsProxyError_Nil(t *testing.T) {
|
||||||
|
require.False(t, isProxyError(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsProxyError_ConnectionRefused(t *testing.T) {
|
||||||
|
require.True(t, isProxyError(fmt.Errorf("dial tcp: connection refused")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsProxyError_Timeout(t *testing.T) {
|
||||||
|
require.True(t, isProxyError(fmt.Errorf("i/o timeout while connecting to proxy")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsProxyError_SOCKS(t *testing.T) {
|
||||||
|
require.True(t, isProxyError(fmt.Errorf("socks connect failed")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsProxyError_TLSHandshake(t *testing.T) {
|
||||||
|
require.True(t, isProxyError(fmt.Errorf("tls handshake timeout")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsProxyError_APIError_NotProxy(t *testing.T) {
|
||||||
|
require.False(t, isProxyError(fmt.Errorf("API rate limit exceeded")))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- isProxyAvailable (nil Redis) ---
|
||||||
|
|
||||||
|
func TestIsProxyAvailable_NilRedis(t *testing.T) {
|
||||||
|
m := NewManager(nil, nil)
|
||||||
|
require.True(t, m.isProxyAvailable(context.Background(), 42))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsProxyAvailable_ZeroID(t *testing.T) {
|
||||||
|
m := NewManager(nil, nil)
|
||||||
|
require.True(t, m.isProxyAvailable(context.Background(), 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- selectByQuotaWeight ---
|
||||||
|
|
||||||
|
func TestSelectByQuotaWeight_NoQuotaLast(t *testing.T) {
|
||||||
|
m := NewManager(nil, nil)
|
||||||
|
candidates := []ProviderConfig{
|
||||||
|
{Type: "brave", APIKey: "k1", QuotaLimit: 0},
|
||||||
|
{Type: "tavily", APIKey: "k2", QuotaLimit: 100},
|
||||||
|
}
|
||||||
|
result := m.selectByQuotaWeight(context.Background(), candidates)
|
||||||
|
require.Len(t, result, 2)
|
||||||
|
require.Equal(t, "tavily", result[0].Type)
|
||||||
|
require.Equal(t, "brave", result[1].Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectByQuotaWeight_AllNoQuota(t *testing.T) {
|
||||||
|
m := NewManager(nil, nil)
|
||||||
|
candidates := []ProviderConfig{
|
||||||
|
{Type: "brave", APIKey: "k1", QuotaLimit: 0},
|
||||||
|
{Type: "tavily", APIKey: "k2", QuotaLimit: 0},
|
||||||
|
}
|
||||||
|
result := m.selectByQuotaWeight(context.Background(), candidates)
|
||||||
|
require.Len(t, result, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectByQuotaWeight_Empty(t *testing.T) {
|
||||||
|
m := NewManager(nil, nil)
|
||||||
|
result := m.selectByQuotaWeight(context.Background(), nil)
|
||||||
|
require.Empty(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- newHTTPClient ---
|
||||||
|
|
||||||
|
func TestNewHTTPClient_NoProxy(t *testing.T) {
|
||||||
|
c, err := newHTTPClient("")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHTTPClient_InvalidProxy(t *testing.T) {
|
||||||
|
_, err := newHTTPClient("://bad-url")
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "invalid proxy URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHTTPClient_ValidHTTPProxy(t *testing.T) {
|
||||||
|
c, err := newHTTPClient("http://proxy.example.com:8080")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewHTTPClient_ValidSOCKS5Proxy(t *testing.T) {
|
||||||
|
c, err := newHTTPClient("socks5://proxy.example.com:1080")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ResetUsage ---
|
||||||
|
|
||||||
|
func TestManager_ResetUsage_NilRedis(t *testing.T) {
|
||||||
|
m := NewManager(nil, nil)
|
||||||
|
err := m.ResetUsage(context.Background(), "brave")
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
11
backend/internal/pkg/websearch/provider.go
Normal file
11
backend/internal/pkg/websearch/provider.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// Provider is the interface every search backend must implement.
|
||||||
|
type Provider interface {
|
||||||
|
// Name returns the provider identifier ("brave" or "tavily").
|
||||||
|
Name() string
|
||||||
|
// Search executes a web search and returns results.
|
||||||
|
Search(ctx context.Context, req SearchRequest) (*SearchResponse, error)
|
||||||
|
}
|
||||||
107
backend/internal/pkg/websearch/tavily.go
Normal file
107
backend/internal/pkg/websearch/tavily.go
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
tavilySearchEndpoint = "https://api.tavily.com/search"
|
||||||
|
tavilyProviderName = "tavily"
|
||||||
|
tavilySearchDepthBasic = "basic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TavilyProvider implements web search via the Tavily Search API.
|
||||||
|
type TavilyProvider struct {
|
||||||
|
apiKey string
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTavilyProvider creates a Tavily Search provider.
|
||||||
|
// The caller is responsible for configuring the http.Client with proxy/timeouts.
|
||||||
|
func NewTavilyProvider(apiKey string, httpClient *http.Client) *TavilyProvider {
|
||||||
|
if httpClient == nil {
|
||||||
|
httpClient = http.DefaultClient
|
||||||
|
}
|
||||||
|
return &TavilyProvider{apiKey: apiKey, httpClient: httpClient}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TavilyProvider) Name() string { return tavilyProviderName }
|
||||||
|
|
||||||
|
func (t *TavilyProvider) Search(ctx context.Context, req SearchRequest) (*SearchResponse, error) {
|
||||||
|
maxResults := req.MaxResults
|
||||||
|
if maxResults <= 0 {
|
||||||
|
maxResults = defaultMaxResults
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := tavilyRequest{
|
||||||
|
APIKey: t.apiKey,
|
||||||
|
Query: req.Query,
|
||||||
|
MaxResults: maxResults,
|
||||||
|
SearchDepth: tavilySearchDepthBasic,
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("tavily: encode request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tavilySearchEndpoint, bytes.NewReader(bodyBytes))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("tavily: build request: %w", err)
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := t.httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("tavily: request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("tavily: read body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("tavily: status %d: %s", resp.StatusCode, truncateBody(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var raw tavilyResponse
|
||||||
|
if err := json.Unmarshal(body, &raw); err != nil {
|
||||||
|
return nil, fmt.Errorf("tavily: decode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make([]SearchResult, 0, len(raw.Results))
|
||||||
|
for _, r := range raw.Results {
|
||||||
|
results = append(results, SearchResult{
|
||||||
|
URL: r.URL,
|
||||||
|
Title: r.Title,
|
||||||
|
Snippet: r.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SearchResponse{Results: results, Query: req.Query}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type tavilyRequest struct {
|
||||||
|
APIKey string `json:"api_key"`
|
||||||
|
Query string `json:"query"`
|
||||||
|
MaxResults int `json:"max_results"`
|
||||||
|
SearchDepth string `json:"search_depth"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type tavilyResponse struct {
|
||||||
|
Results []tavilyResult `json:"results"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type tavilyResult struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Score float64 `json:"score"`
|
||||||
|
}
|
||||||
63
backend/internal/pkg/websearch/tavily_test.go
Normal file
63
backend/internal/pkg/websearch/tavily_test.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTavilyProvider_Name(t *testing.T) {
|
||||||
|
p := NewTavilyProvider("key", nil)
|
||||||
|
require.Equal(t, "tavily", p.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTavilyProvider_Search_RequestConstruction(t *testing.T) {
|
||||||
|
// Verify tavilyRequest struct fields map correctly
|
||||||
|
req := tavilyRequest{
|
||||||
|
APIKey: "test-key",
|
||||||
|
Query: "golang",
|
||||||
|
MaxResults: 3,
|
||||||
|
SearchDepth: tavilySearchDepthBasic,
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(data, &parsed))
|
||||||
|
require.Equal(t, "test-key", parsed["api_key"])
|
||||||
|
require.Equal(t, "golang", parsed["query"])
|
||||||
|
require.Equal(t, float64(3), parsed["max_results"])
|
||||||
|
require.Equal(t, "basic", parsed["search_depth"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTavilyProvider_Search_ResponseParsing(t *testing.T) {
|
||||||
|
rawResp := `{"results":[{"url":"https://go.dev","title":"Go","content":"Go programming language","score":0.95}]}`
|
||||||
|
var resp tavilyResponse
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(rawResp), &resp))
|
||||||
|
require.Len(t, resp.Results, 1)
|
||||||
|
require.Equal(t, "https://go.dev", resp.Results[0].URL)
|
||||||
|
require.Equal(t, "Go programming language", resp.Results[0].Content)
|
||||||
|
require.InDelta(t, 0.95, resp.Results[0].Score, 0.001)
|
||||||
|
|
||||||
|
// Verify mapping to SearchResult
|
||||||
|
results := make([]SearchResult, 0, len(resp.Results))
|
||||||
|
for _, r := range resp.Results {
|
||||||
|
results = append(results, SearchResult{
|
||||||
|
URL: r.URL, Title: r.Title, Snippet: r.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
require.Equal(t, "Go programming language", results[0].Snippet)
|
||||||
|
require.Equal(t, "", results[0].PageAge)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTavilyProvider_Search_EmptyResults(t *testing.T) {
|
||||||
|
var resp tavilyResponse
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(`{"results":[]}`), &resp))
|
||||||
|
require.Empty(t, resp.Results)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTavilyProvider_Search_InvalidJSON(t *testing.T) {
|
||||||
|
var resp tavilyResponse
|
||||||
|
require.Error(t, json.Unmarshal([]byte("not json"), &resp))
|
||||||
|
}
|
||||||
30
backend/internal/pkg/websearch/types.go
Normal file
30
backend/internal/pkg/websearch/types.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package websearch
|
||||||
|
|
||||||
|
// SearchResult represents a single web search result.
|
||||||
|
type SearchResult struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Snippet string `json:"snippet"`
|
||||||
|
PageAge string `json:"page_age,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchRequest describes a web search to perform.
|
||||||
|
type SearchRequest struct {
|
||||||
|
Query string
|
||||||
|
MaxResults int // defaults to defaultMaxResults if <= 0
|
||||||
|
ProxyURL string // optional HTTP proxy URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchResponse holds the results of a web search.
|
||||||
|
type SearchResponse struct {
|
||||||
|
Results []SearchResult
|
||||||
|
Query string // the query that was actually executed
|
||||||
|
}
|
||||||
|
|
||||||
|
const defaultMaxResults = 5
|
||||||
|
|
||||||
|
// Provider type identifiers.
|
||||||
|
const (
|
||||||
|
ProviderTypeBrave = "brave"
|
||||||
|
ProviderTypeTavily = "tavily"
|
||||||
|
)
|
||||||
@ -138,10 +138,17 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
|||||||
WithUser(func(q *dbent.UserQuery) {
|
WithUser(func(q *dbent.UserQuery) {
|
||||||
q.Select(
|
q.Select(
|
||||||
user.FieldID,
|
user.FieldID,
|
||||||
|
user.FieldEmail,
|
||||||
|
user.FieldUsername,
|
||||||
user.FieldStatus,
|
user.FieldStatus,
|
||||||
user.FieldRole,
|
user.FieldRole,
|
||||||
user.FieldBalance,
|
user.FieldBalance,
|
||||||
user.FieldConcurrency,
|
user.FieldConcurrency,
|
||||||
|
user.FieldBalanceNotifyEnabled,
|
||||||
|
user.FieldBalanceNotifyThresholdType,
|
||||||
|
user.FieldBalanceNotifyThreshold,
|
||||||
|
user.FieldBalanceNotifyExtraEmails,
|
||||||
|
user.FieldTotalRecharged,
|
||||||
)
|
)
|
||||||
}).
|
}).
|
||||||
WithGroup(func(q *dbent.GroupQuery) {
|
WithGroup(func(q *dbent.GroupQuery) {
|
||||||
@ -639,22 +646,31 @@ func userEntityToService(u *dbent.User) *service.User {
|
|||||||
if u == nil {
|
if u == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &service.User{
|
out := &service.User{
|
||||||
ID: u.ID,
|
ID: u.ID,
|
||||||
Email: u.Email,
|
Email: u.Email,
|
||||||
Username: u.Username,
|
Username: u.Username,
|
||||||
Notes: u.Notes,
|
Notes: u.Notes,
|
||||||
PasswordHash: u.PasswordHash,
|
PasswordHash: u.PasswordHash,
|
||||||
Role: u.Role,
|
Role: u.Role,
|
||||||
Balance: u.Balance,
|
Balance: u.Balance,
|
||||||
Concurrency: u.Concurrency,
|
Concurrency: u.Concurrency,
|
||||||
Status: u.Status,
|
Status: u.Status,
|
||||||
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
||||||
TotpEnabled: u.TotpEnabled,
|
TotpEnabled: u.TotpEnabled,
|
||||||
TotpEnabledAt: u.TotpEnabledAt,
|
TotpEnabledAt: u.TotpEnabledAt,
|
||||||
CreatedAt: u.CreatedAt,
|
BalanceNotifyEnabled: u.BalanceNotifyEnabled,
|
||||||
UpdatedAt: u.UpdatedAt,
|
BalanceNotifyThresholdType: u.BalanceNotifyThresholdType,
|
||||||
|
BalanceNotifyThreshold: u.BalanceNotifyThreshold,
|
||||||
|
TotalRecharged: u.TotalRecharged,
|
||||||
|
CreatedAt: u.CreatedAt,
|
||||||
|
UpdatedAt: u.UpdatedAt,
|
||||||
}
|
}
|
||||||
|
// Parse extra emails JSON (supports both old []string and new []NotifyEmailEntry format)
|
||||||
|
if u.BalanceNotifyExtraEmails != "" && u.BalanceNotifyExtraEmails != "[]" {
|
||||||
|
out.BalanceNotifyExtraEmails = service.ParseNotifyEmails(u.BalanceNotifyExtraEmails)
|
||||||
|
}
|
||||||
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func groupEntityToService(g *dbent.Group) *service.Group {
|
func groupEntityToService(g *dbent.Group) *service.Group {
|
||||||
|
|||||||
@ -41,10 +41,14 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
err = tx.QueryRowContext(ctx,
|
err = tx.QueryRowContext(ctx,
|
||||||
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models) VALUES ($1, $2, $3, $4, $5, $6)
|
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||||
RETURNING id, created_at, updated_at`,
|
RETURNING id, created_at, updated_at`,
|
||||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels,
|
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats,
|
||||||
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
|
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isUniqueViolation(err) {
|
if isUniqueViolation(err) {
|
||||||
@ -67,17 +71,24 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 设置账号统计定价规则
|
||||||
|
if len(channel.AccountStatsPricingRules) > 0 {
|
||||||
|
if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
|
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
|
||||||
ch := &service.Channel{}
|
ch := &service.Channel{}
|
||||||
var modelMappingJSON []byte
|
var modelMappingJSON, featuresConfigJSON []byte
|
||||||
err := r.db.QueryRowContext(ctx,
|
err := r.db.QueryRowContext(ctx,
|
||||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at
|
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at
|
||||||
FROM channels WHERE id = $1`, id,
|
FROM channels WHERE id = $1`, id,
|
||||||
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt)
|
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, service.ErrChannelNotFound
|
return nil, service.ErrChannelNotFound
|
||||||
}
|
}
|
||||||
@ -85,6 +96,7 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
|
|||||||
return nil, fmt.Errorf("get channel: %w", err)
|
return nil, fmt.Errorf("get channel: %w", err)
|
||||||
}
|
}
|
||||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||||
|
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||||
|
|
||||||
groupIDs, err := r.GetGroupIDs(ctx, id)
|
groupIDs, err := r.GetGroupIDs(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -98,6 +110,12 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
|
|||||||
}
|
}
|
||||||
ch.ModelPricing = pricing
|
ch.ModelPricing = pricing
|
||||||
|
|
||||||
|
statsPricingRules, err := r.loadAccountStatsPricingRules(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ch.AccountStatsPricingRules = statsPricingRules
|
||||||
|
|
||||||
return ch, nil
|
return ch, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -107,10 +125,14 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
result, err := tx.ExecContext(ctx,
|
result, err := tx.ExecContext(ctx,
|
||||||
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, updated_at = NOW()
|
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, apply_pricing_to_account_stats = $9, updated_at = NOW()
|
||||||
WHERE id = $7`,
|
WHERE id = $10`,
|
||||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.ID,
|
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats, channel.ID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isUniqueViolation(err) {
|
if isUniqueViolation(err) {
|
||||||
@ -137,6 +159,13 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 更新账号统计定价规则
|
||||||
|
if channel.AccountStatsPricingRules != nil {
|
||||||
|
if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -187,7 +216,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
|||||||
|
|
||||||
// 查询 channel 列表
|
// 查询 channel 列表
|
||||||
dataQuery := fmt.Sprintf(
|
dataQuery := fmt.Sprintf(
|
||||||
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.created_at, c.updated_at
|
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.apply_pricing_to_account_stats, c.created_at, c.updated_at
|
||||||
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
|
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
|
||||||
whereClause, channelListOrderBy(params), argIdx, argIdx+1,
|
whereClause, channelListOrderBy(params), argIdx, argIdx+1,
|
||||||
)
|
)
|
||||||
@ -203,11 +232,12 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
|||||||
var channelIDs []int64
|
var channelIDs []int64
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var ch service.Channel
|
var ch service.Channel
|
||||||
var modelMappingJSON []byte
|
var modelMappingJSON, featuresConfigJSON []byte
|
||||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||||
return nil, nil, fmt.Errorf("scan channel: %w", err)
|
return nil, nil, fmt.Errorf("scan channel: %w", err)
|
||||||
}
|
}
|
||||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||||
|
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||||
channels = append(channels, ch)
|
channels = append(channels, ch)
|
||||||
channelIDs = append(channelIDs, ch.ID)
|
channelIDs = append(channelIDs, ch.ID)
|
||||||
}
|
}
|
||||||
@ -225,9 +255,14 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
for i := range channels {
|
for i := range channels {
|
||||||
channels[i].GroupIDs = groupMap[channels[i].ID]
|
channels[i].GroupIDs = groupMap[channels[i].ID]
|
||||||
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
||||||
|
channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -273,7 +308,7 @@ func channelListOrderBy(params pagination.PaginationParams) string {
|
|||||||
|
|
||||||
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
|
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||||
rows, err := r.db.QueryContext(ctx,
|
rows, err := r.db.QueryContext(ctx,
|
||||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, created_at, updated_at FROM channels ORDER BY id`,
|
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query all channels: %w", err)
|
return nil, fmt.Errorf("query all channels: %w", err)
|
||||||
@ -284,11 +319,12 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
|
|||||||
var channelIDs []int64
|
var channelIDs []int64
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var ch service.Channel
|
var ch service.Channel
|
||||||
var modelMappingJSON []byte
|
var modelMappingJSON, featuresConfigJSON []byte
|
||||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||||
return nil, fmt.Errorf("scan channel: %w", err)
|
return nil, fmt.Errorf("scan channel: %w", err)
|
||||||
}
|
}
|
||||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||||
|
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||||
channels = append(channels, ch)
|
channels = append(channels, ch)
|
||||||
channelIDs = append(channelIDs, ch.ID)
|
channelIDs = append(channelIDs, ch.ID)
|
||||||
}
|
}
|
||||||
@ -312,9 +348,16 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 批量加载账号统计定价规则
|
||||||
|
statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
for i := range channels {
|
for i := range channels {
|
||||||
channels[i].GroupIDs = groupMap[channels[i].ID]
|
channels[i].GroupIDs = groupMap[channels[i].ID]
|
||||||
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
||||||
|
channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID]
|
||||||
}
|
}
|
||||||
|
|
||||||
return channels, nil
|
return channels, nil
|
||||||
@ -456,6 +499,28 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string {
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func marshalFeaturesConfig(m map[string]any) ([]byte, error) {
|
||||||
|
if len(m) == 0 {
|
||||||
|
return []byte("{}"), nil
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshal features_config: %w", err)
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func unmarshalFeaturesConfig(data []byte) map[string]any {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var m map[string]any
|
||||||
|
if err := json.Unmarshal(data, &m); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
// GetGroupPlatforms 批量查询分组 ID 对应的平台
|
// GetGroupPlatforms 批量查询分组 ID 对应的平台
|
||||||
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||||||
if len(groupIDs) == 0 {
|
if len(groupIDs) == 0 {
|
||||||
|
|||||||
@ -0,0 +1,244 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/lib/pq"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- 账号统计定价规则 ---
|
||||||
|
|
||||||
|
// batchLoadAccountStatsPricingRules 批量加载多个渠道的账号统计定价规则(含模型定价)
|
||||||
|
func (r *channelRepository) batchLoadAccountStatsPricingRules(ctx context.Context, channelIDs []int64) (map[int64][]service.AccountStatsPricingRule, error) {
|
||||||
|
// 1. 查询规则
|
||||||
|
rows, err := r.db.QueryContext(ctx,
|
||||||
|
`SELECT id, channel_id, name, group_ids, account_ids, sort_order, created_at, updated_at
|
||||||
|
FROM channel_account_stats_pricing_rules WHERE channel_id = ANY($1) ORDER BY channel_id, sort_order, id`,
|
||||||
|
pq.Array(channelIDs),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("batch load account stats pricing rules: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var allRules []service.AccountStatsPricingRule
|
||||||
|
var ruleIDs []int64
|
||||||
|
for rows.Next() {
|
||||||
|
var rule service.AccountStatsPricingRule
|
||||||
|
if err := rows.Scan(
|
||||||
|
&rule.ID, &rule.ChannelID, &rule.Name,
|
||||||
|
pq.Array(&rule.GroupIDs), pq.Array(&rule.AccountIDs),
|
||||||
|
&rule.SortOrder, &rule.CreatedAt, &rule.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan account stats pricing rule: %w", err)
|
||||||
|
}
|
||||||
|
ruleIDs = append(ruleIDs, rule.ID)
|
||||||
|
allRules = append(allRules, rule)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("iterate account stats pricing rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 批量加载规则的模型定价
|
||||||
|
pricingMap, err := r.batchLoadAccountStatsModelPricing(ctx, ruleIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 按 channelID 分组并关联定价
|
||||||
|
result := make(map[int64][]service.AccountStatsPricingRule, len(channelIDs))
|
||||||
|
for i := range allRules {
|
||||||
|
allRules[i].Pricing = pricingMap[allRules[i].ID]
|
||||||
|
result[allRules[i].ChannelID] = append(result[allRules[i].ChannelID], allRules[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// batchLoadAccountStatsModelPricing 批量加载规则的模型定价
|
||||||
|
func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Context, ruleIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
|
||||||
|
if len(ruleIDs) == 0 {
|
||||||
|
return make(map[int64][]service.ChannelModelPricing), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := r.db.QueryContext(ctx,
|
||||||
|
`SELECT id, rule_id, platform, models, billing_mode, input_price, output_price,
|
||||||
|
cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at
|
||||||
|
FROM channel_account_stats_model_pricing WHERE rule_id = ANY($1) ORDER BY rule_id, id`,
|
||||||
|
pq.Array(ruleIDs),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("batch load account stats model pricing: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
pricingMap := make(map[int64][]service.ChannelModelPricing, len(ruleIDs))
|
||||||
|
for rows.Next() {
|
||||||
|
var p service.ChannelModelPricing
|
||||||
|
var ruleID int64
|
||||||
|
var modelsJSON []byte
|
||||||
|
if err := rows.Scan(
|
||||||
|
&p.ID, &ruleID, &p.Platform, &modelsJSON, &p.BillingMode,
|
||||||
|
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
|
||||||
|
&p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan account stats model pricing: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
|
||||||
|
p.Models = []string{}
|
||||||
|
}
|
||||||
|
pricingMap[ruleID] = append(pricingMap[ruleID], p)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("iterate account stats model pricing: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load intervals for all pricing entries.
|
||||||
|
var allPricingIDs []int64
|
||||||
|
for _, pricings := range pricingMap {
|
||||||
|
for _, p := range pricings {
|
||||||
|
allPricingIDs = append(allPricingIDs, p.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(allPricingIDs) > 0 {
|
||||||
|
intervalsMap, err := r.batchLoadAccountStatsIntervals(ctx, allPricingIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for ruleID, pricings := range pricingMap {
|
||||||
|
for i := range pricings {
|
||||||
|
pricings[i].Intervals = intervalsMap[pricings[i].ID]
|
||||||
|
}
|
||||||
|
pricingMap[ruleID] = pricings
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pricingMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadAccountStatsPricingRules 加载单个渠道的账号统计定价规则(供 GetByID 使用)
|
||||||
|
func (r *channelRepository) loadAccountStatsPricingRules(ctx context.Context, channelID int64) ([]service.AccountStatsPricingRule, error) {
|
||||||
|
result, err := r.batchLoadAccountStatsPricingRules(ctx, []int64{channelID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result[channelID], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// replaceAccountStatsPricingRulesTx 在事务中替换渠道的账号统计定价规则(删除旧的 + 插入新的)
|
||||||
|
func replaceAccountStatsPricingRulesTx(ctx context.Context, tx *sql.Tx, channelID int64, rules []service.AccountStatsPricingRule) error {
|
||||||
|
// CASCADE 会自动删除关联的 model_pricing
|
||||||
|
if _, err := tx.ExecContext(ctx,
|
||||||
|
`DELETE FROM channel_account_stats_pricing_rules WHERE channel_id = $1`, channelID,
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("delete old account stats pricing rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range rules {
|
||||||
|
rules[i].ChannelID = channelID
|
||||||
|
if err := createAccountStatsPricingRuleTx(ctx, tx, &rules[i]); err != nil {
|
||||||
|
return fmt.Errorf("insert account stats pricing rule: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createAccountStatsPricingRuleTx 在事务中创建单条账号统计定价规则及其模型定价
|
||||||
|
func createAccountStatsPricingRuleTx(ctx context.Context, tx *sql.Tx, rule *service.AccountStatsPricingRule) error {
|
||||||
|
err := tx.QueryRowContext(ctx,
|
||||||
|
`INSERT INTO channel_account_stats_pricing_rules (channel_id, name, group_ids, account_ids, sort_order)
|
||||||
|
VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at`,
|
||||||
|
rule.ChannelID, rule.Name, pq.Array(rule.GroupIDs), pq.Array(rule.AccountIDs), rule.SortOrder,
|
||||||
|
).Scan(&rule.ID, &rule.CreatedAt, &rule.UpdatedAt)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("insert account stats pricing rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for j := range rule.Pricing {
|
||||||
|
if err := createAccountStatsModelPricingTx(ctx, tx, rule.ID, &rule.Pricing[j]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createAccountStatsModelPricingTx 在事务中创建单条账号统计模型定价
|
||||||
|
func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID int64, pricing *service.ChannelModelPricing) error {
|
||||||
|
modelsJSON, err := json.Marshal(pricing.Models)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal models: %w", err)
|
||||||
|
}
|
||||||
|
billingMode := pricing.BillingMode
|
||||||
|
if billingMode == "" {
|
||||||
|
billingMode = service.BillingModeToken
|
||||||
|
}
|
||||||
|
platform := pricing.Platform
|
||||||
|
err = tx.QueryRowContext(ctx,
|
||||||
|
`INSERT INTO channel_account_stats_model_pricing (rule_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
|
||||||
|
ruleID, platform, modelsJSON, billingMode,
|
||||||
|
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
|
||||||
|
pricing.ImageOutputPrice, pricing.PerRequestPrice,
|
||||||
|
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("insert account stats model pricing: %w", err)
|
||||||
|
}
|
||||||
|
// Persist intervals (mirrors channel_pricing_intervals logic).
|
||||||
|
for i := range pricing.Intervals {
|
||||||
|
iv := &pricing.Intervals[i]
|
||||||
|
iv.PricingID = pricing.ID
|
||||||
|
if err := createAccountStatsIntervalTx(ctx, tx, iv); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createAccountStatsIntervalTx inserts a single interval for an account stats pricing entry.
|
||||||
|
func createAccountStatsIntervalTx(ctx context.Context, tx *sql.Tx, iv *service.PricingInterval) error {
|
||||||
|
return tx.QueryRowContext(ctx,
|
||||||
|
`INSERT INTO channel_account_stats_pricing_intervals
|
||||||
|
(pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
|
||||||
|
iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel,
|
||||||
|
iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice,
|
||||||
|
iv.PerRequestPrice, iv.SortOrder,
|
||||||
|
).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// batchLoadAccountStatsIntervals loads intervals for account stats pricing entries.
|
||||||
|
func (r *channelRepository) batchLoadAccountStatsIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) {
|
||||||
|
if len(pricingIDs) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
rows, err := r.db.QueryContext(ctx,
|
||||||
|
`SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
|
||||||
|
input_price, output_price, cache_write_price, cache_read_price,
|
||||||
|
per_request_price, sort_order, created_at, updated_at
|
||||||
|
FROM channel_account_stats_pricing_intervals
|
||||||
|
WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`,
|
||||||
|
pq.Array(pricingIDs),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("batch load account stats pricing intervals: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
result := make(map[int64][]service.PricingInterval)
|
||||||
|
for rows.Next() {
|
||||||
|
var iv service.PricingInterval
|
||||||
|
if err := rows.Scan(
|
||||||
|
&iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel,
|
||||||
|
&iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice,
|
||||||
|
&iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan account stats pricing interval: %w", err)
|
||||||
|
}
|
||||||
|
result[iv.PricingID] = append(result[iv.PricingID], iv)
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
@ -331,6 +331,7 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
|
|||||||
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
|
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
|
||||||
COALESCE(SUM(total_cost), 0) AS total_cost,
|
COALESCE(SUM(total_cost), 0) AS total_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) AS actual_cost,
|
COALESCE(SUM(actual_cost), 0) AS actual_cost,
|
||||||
|
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) AS account_cost,
|
||||||
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) AS total_duration_ms
|
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) AS total_duration_ms
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE created_at >= $1 AND created_at < $2
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
@ -351,6 +352,7 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
|
|||||||
cache_read_tokens,
|
cache_read_tokens,
|
||||||
total_cost,
|
total_cost,
|
||||||
actual_cost,
|
actual_cost,
|
||||||
|
account_cost,
|
||||||
total_duration_ms,
|
total_duration_ms,
|
||||||
active_users,
|
active_users,
|
||||||
computed_at
|
computed_at
|
||||||
@ -364,6 +366,7 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
|
|||||||
hourly.cache_read_tokens,
|
hourly.cache_read_tokens,
|
||||||
hourly.total_cost,
|
hourly.total_cost,
|
||||||
hourly.actual_cost,
|
hourly.actual_cost,
|
||||||
|
hourly.account_cost,
|
||||||
hourly.total_duration_ms,
|
hourly.total_duration_ms,
|
||||||
COALESCE(user_counts.active_users, 0) AS active_users,
|
COALESCE(user_counts.active_users, 0) AS active_users,
|
||||||
NOW()
|
NOW()
|
||||||
@ -378,6 +381,7 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
|
|||||||
cache_read_tokens = EXCLUDED.cache_read_tokens,
|
cache_read_tokens = EXCLUDED.cache_read_tokens,
|
||||||
total_cost = EXCLUDED.total_cost,
|
total_cost = EXCLUDED.total_cost,
|
||||||
actual_cost = EXCLUDED.actual_cost,
|
actual_cost = EXCLUDED.actual_cost,
|
||||||
|
account_cost = EXCLUDED.account_cost,
|
||||||
total_duration_ms = EXCLUDED.total_duration_ms,
|
total_duration_ms = EXCLUDED.total_duration_ms,
|
||||||
active_users = EXCLUDED.active_users,
|
active_users = EXCLUDED.active_users,
|
||||||
computed_at = EXCLUDED.computed_at
|
computed_at = EXCLUDED.computed_at
|
||||||
@ -399,6 +403,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
|
|||||||
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
|
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
|
||||||
COALESCE(SUM(total_cost), 0) AS total_cost,
|
COALESCE(SUM(total_cost), 0) AS total_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) AS actual_cost,
|
COALESCE(SUM(actual_cost), 0) AS actual_cost,
|
||||||
|
COALESCE(SUM(account_cost), 0) AS account_cost,
|
||||||
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
|
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
|
||||||
FROM usage_dashboard_hourly
|
FROM usage_dashboard_hourly
|
||||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||||
@ -419,6 +424,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
|
|||||||
cache_read_tokens,
|
cache_read_tokens,
|
||||||
total_cost,
|
total_cost,
|
||||||
actual_cost,
|
actual_cost,
|
||||||
|
account_cost,
|
||||||
total_duration_ms,
|
total_duration_ms,
|
||||||
active_users,
|
active_users,
|
||||||
computed_at
|
computed_at
|
||||||
@ -432,6 +438,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
|
|||||||
daily.cache_read_tokens,
|
daily.cache_read_tokens,
|
||||||
daily.total_cost,
|
daily.total_cost,
|
||||||
daily.actual_cost,
|
daily.actual_cost,
|
||||||
|
daily.account_cost,
|
||||||
daily.total_duration_ms,
|
daily.total_duration_ms,
|
||||||
COALESCE(user_counts.active_users, 0) AS active_users,
|
COALESCE(user_counts.active_users, 0) AS active_users,
|
||||||
NOW()
|
NOW()
|
||||||
@ -446,6 +453,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
|
|||||||
cache_read_tokens = EXCLUDED.cache_read_tokens,
|
cache_read_tokens = EXCLUDED.cache_read_tokens,
|
||||||
total_cost = EXCLUDED.total_cost,
|
total_cost = EXCLUDED.total_cost,
|
||||||
actual_cost = EXCLUDED.actual_cost,
|
actual_cost = EXCLUDED.actual_cost,
|
||||||
|
account_cost = EXCLUDED.account_cost,
|
||||||
total_duration_ms = EXCLUDED.total_duration_ms,
|
total_duration_ms = EXCLUDED.total_duration_ms,
|
||||||
active_users = EXCLUDED.active_users,
|
active_users = EXCLUDED.active_users,
|
||||||
computed_at = EXCLUDED.computed_at
|
computed_at = EXCLUDED.computed_at
|
||||||
|
|||||||
@ -3,6 +3,8 @@ package repository
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@ -11,23 +13,33 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
verifyCodeKeyPrefix = "verify_code:"
|
verifyCodeKeyPrefix = "verify_code:"
|
||||||
|
notifyVerifyKeyPrefix = "notify_verify:"
|
||||||
passwordResetKeyPrefix = "password_reset:"
|
passwordResetKeyPrefix = "password_reset:"
|
||||||
passwordResetSentAtKeyPrefix = "password_reset_sent:"
|
passwordResetSentAtKeyPrefix = "password_reset_sent:"
|
||||||
|
notifyCodeUserRateKeyPrefix = "notify_code_user_rate:"
|
||||||
)
|
)
|
||||||
|
|
||||||
// verifyCodeKey generates the Redis key for email verification code.
|
// verifyCodeKey generates the Redis key for email verification code.
|
||||||
|
// Email is lowercased for case-insensitive consistency.
|
||||||
func verifyCodeKey(email string) string {
|
func verifyCodeKey(email string) string {
|
||||||
return verifyCodeKeyPrefix + email
|
return verifyCodeKeyPrefix + strings.ToLower(email)
|
||||||
|
}
|
||||||
|
|
||||||
|
// notifyVerifyKey generates the Redis key for notify email verification code.
|
||||||
|
// Email is lowercased to prevent case-sensitive key mismatch (the business layer
|
||||||
|
// uses strings.EqualFold for comparison).
|
||||||
|
func notifyVerifyKey(email string) string {
|
||||||
|
return notifyVerifyKeyPrefix + strings.ToLower(email)
|
||||||
}
|
}
|
||||||
|
|
||||||
// passwordResetKey generates the Redis key for password reset token.
|
// passwordResetKey generates the Redis key for password reset token.
|
||||||
func passwordResetKey(email string) string {
|
func passwordResetKey(email string) string {
|
||||||
return passwordResetKeyPrefix + email
|
return passwordResetKeyPrefix + strings.ToLower(email)
|
||||||
}
|
}
|
||||||
|
|
||||||
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
|
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
|
||||||
func passwordResetSentAtKey(email string) string {
|
func passwordResetSentAtKey(email string) string {
|
||||||
return passwordResetSentAtKeyPrefix + email
|
return passwordResetSentAtKeyPrefix + strings.ToLower(email)
|
||||||
}
|
}
|
||||||
|
|
||||||
type emailCache struct {
|
type emailCache struct {
|
||||||
@ -106,3 +118,60 @@ func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email st
|
|||||||
key := passwordResetSentAtKey(email)
|
key := passwordResetSentAtKey(email)
|
||||||
return c.rdb.Set(ctx, key, "1", ttl).Err()
|
return c.rdb.Set(ctx, key, "1", ttl).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Notify email verification code methods
|
||||||
|
|
||||||
|
func (c *emailCache) GetNotifyVerifyCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
|
||||||
|
key := notifyVerifyKey(email)
|
||||||
|
val, err := c.rdb.Get(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var data service.VerificationCodeData
|
||||||
|
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *emailCache) SetNotifyVerifyCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
|
||||||
|
key := notifyVerifyKey(email)
|
||||||
|
val, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *emailCache) DeleteNotifyVerifyCode(ctx context.Context, email string) error {
|
||||||
|
key := notifyVerifyKey(email)
|
||||||
|
return c.rdb.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// User-level rate limiting for notify email verification codes
|
||||||
|
|
||||||
|
func notifyCodeUserRateKey(userID int64) string {
|
||||||
|
return notifyCodeUserRateKeyPrefix + fmt.Sprintf("%d", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *emailCache) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) {
|
||||||
|
key := notifyCodeUserRateKey(userID)
|
||||||
|
count, err := c.rdb.Incr(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
// Always set TTL (idempotent) to avoid orphan keys if process crashes between INCR and EXPIRE.
|
||||||
|
if err := c.rdb.Expire(ctx, key, window).Err(); err != nil {
|
||||||
|
return count, fmt.Errorf("expire notify code rate key: %w", err)
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *emailCache) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) {
|
||||||
|
key := notifyCodeUserRateKey(userID)
|
||||||
|
count, err := c.rdb.Get(ctx, key).Int64()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|||||||
@ -426,6 +426,13 @@ func filterSchedulerExtra(extra map[string]any) map[string]any {
|
|||||||
"window_cost_sticky_reserve",
|
"window_cost_sticky_reserve",
|
||||||
"max_sessions",
|
"max_sessions",
|
||||||
"session_idle_timeout_minutes",
|
"session_idle_timeout_minutes",
|
||||||
|
"openai_oauth_responses_websockets_v2_enabled",
|
||||||
|
"openai_oauth_responses_websockets_v2_mode",
|
||||||
|
"openai_apikey_responses_websockets_v2_enabled",
|
||||||
|
"openai_apikey_responses_websockets_v2_mode",
|
||||||
|
"responses_websockets_v2_enabled",
|
||||||
|
"openai_ws_enabled",
|
||||||
|
"openai_ws_force_http",
|
||||||
}
|
}
|
||||||
filtered := make(map[string]any)
|
filtered := make(map[string]any)
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
|
|||||||
33
backend/internal/repository/scheduler_cache_unit_test.go
Normal file
33
backend/internal/repository/scheduler_cache_unit_test.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) {
|
||||||
|
account := service.Account{
|
||||||
|
ID: 42,
|
||||||
|
Platform: service.PlatformOpenAI,
|
||||||
|
Type: service.AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"openai_oauth_responses_websockets_v2_enabled": true,
|
||||||
|
"openai_oauth_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
|
||||||
|
"openai_ws_force_http": true,
|
||||||
|
"mixed_scheduling": true,
|
||||||
|
"unused_large_field": "drop-me",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got := buildSchedulerMetadataAccount(account)
|
||||||
|
|
||||||
|
require.Equal(t, true, got.Extra["openai_oauth_responses_websockets_v2_enabled"])
|
||||||
|
require.Equal(t, service.OpenAIWSIngressModePassthrough, got.Extra["openai_oauth_responses_websockets_v2_mode"])
|
||||||
|
require.Equal(t, true, got.Extra["openai_ws_force_http"])
|
||||||
|
require.Equal(t, true, got.Extra["mixed_scheduling"])
|
||||||
|
require.Nil(t, got.Extra["unused_large_field"])
|
||||||
|
}
|
||||||
@ -113,9 +113,11 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
|
|||||||
}
|
}
|
||||||
|
|
||||||
if cmd.BalanceCost > 0 {
|
if cmd.BalanceCost > 0 {
|
||||||
if err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost); err != nil {
|
newBalance, err := deductUsageBillingBalance(ctx, tx, cmd.UserID, cmd.BalanceCost)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
result.NewBalance = &newBalance
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.APIKeyQuotaCost > 0 {
|
if cmd.APIKeyQuotaCost > 0 {
|
||||||
@ -133,9 +135,11 @@ func (r *usageBillingRepository) applyUsageBillingEffects(ctx context.Context, t
|
|||||||
}
|
}
|
||||||
|
|
||||||
if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) {
|
if cmd.AccountQuotaCost > 0 && (strings.EqualFold(cmd.AccountType, service.AccountTypeAPIKey) || strings.EqualFold(cmd.AccountType, service.AccountTypeBedrock)) {
|
||||||
if err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost); err != nil {
|
quotaState, err := incrementUsageBillingAccountQuota(ctx, tx, cmd.AccountID, cmd.AccountQuotaCost)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
result.QuotaState = quotaState
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -169,24 +173,22 @@ func incrementUsageBillingSubscription(ctx context.Context, tx *sql.Tx, subscrip
|
|||||||
return service.ErrSubscriptionNotFound
|
return service.ErrSubscriptionNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) error {
|
func deductUsageBillingBalance(ctx context.Context, tx *sql.Tx, userID int64, amount float64) (float64, error) {
|
||||||
res, err := tx.ExecContext(ctx, `
|
var newBalance float64
|
||||||
|
err := tx.QueryRowContext(ctx, `
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET balance = balance - $1,
|
SET balance = balance - $1,
|
||||||
updated_at = NOW()
|
updated_at = NOW()
|
||||||
WHERE id = $2 AND deleted_at IS NULL
|
WHERE id = $2 AND deleted_at IS NULL
|
||||||
`, amount, userID)
|
RETURNING balance
|
||||||
|
`, amount, userID).Scan(&newBalance)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return 0, service.ErrUserNotFound
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return 0, err
|
||||||
}
|
}
|
||||||
affected, err := res.RowsAffected()
|
return newBalance, nil
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if affected > 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return service.ErrUserNotFound
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
|
func incrementUsageBillingAPIKeyQuota(ctx context.Context, tx *sql.Tx, apiKeyID int64, amount float64) (bool, error) {
|
||||||
@ -240,7 +242,7 @@ func incrementUsageBillingAPIKeyRateLimit(ctx context.Context, tx *sql.Tx, apiKe
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) error {
|
func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountID int64, amount float64) (*service.AccountQuotaState, error) {
|
||||||
rows, err := tx.QueryContext(ctx,
|
rows, err := tx.QueryContext(ctx,
|
||||||
`UPDATE accounts SET extra = (
|
`UPDATE accounts SET extra = (
|
||||||
COALESCE(extra, '{}'::jsonb)
|
COALESCE(extra, '{}'::jsonb)
|
||||||
@ -248,61 +250,71 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI
|
|||||||
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
|| CASE WHEN COALESCE((extra->>'quota_daily_limit')::numeric, 0) > 0 THEN
|
||||||
jsonb_build_object(
|
jsonb_build_object(
|
||||||
'quota_daily_used',
|
'quota_daily_used',
|
||||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
CASE WHEN `+dailyExpiredExpr+`
|
||||||
+ '24 hours'::interval <= NOW()
|
|
||||||
THEN $1
|
THEN $1
|
||||||
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
ELSE COALESCE((extra->>'quota_daily_used')::numeric, 0) + $1 END,
|
||||||
'quota_daily_start',
|
'quota_daily_start',
|
||||||
CASE WHEN COALESCE((extra->>'quota_daily_start')::timestamptz, '1970-01-01'::timestamptz)
|
CASE WHEN `+dailyExpiredExpr+`
|
||||||
+ '24 hours'::interval <= NOW()
|
|
||||||
THEN `+nowUTC+`
|
THEN `+nowUTC+`
|
||||||
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
ELSE COALESCE(extra->>'quota_daily_start', `+nowUTC+`) END
|
||||||
)
|
)
|
||||||
|
|| CASE WHEN `+dailyExpiredExpr+` AND `+nextDailyResetAtExpr+` IS NOT NULL
|
||||||
|
THEN jsonb_build_object('quota_daily_reset_at', `+nextDailyResetAtExpr+`)
|
||||||
|
ELSE '{}'::jsonb END
|
||||||
ELSE '{}'::jsonb END
|
ELSE '{}'::jsonb END
|
||||||
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
|| CASE WHEN COALESCE((extra->>'quota_weekly_limit')::numeric, 0) > 0 THEN
|
||||||
jsonb_build_object(
|
jsonb_build_object(
|
||||||
'quota_weekly_used',
|
'quota_weekly_used',
|
||||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
CASE WHEN `+weeklyExpiredExpr+`
|
||||||
+ '168 hours'::interval <= NOW()
|
|
||||||
THEN $1
|
THEN $1
|
||||||
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
ELSE COALESCE((extra->>'quota_weekly_used')::numeric, 0) + $1 END,
|
||||||
'quota_weekly_start',
|
'quota_weekly_start',
|
||||||
CASE WHEN COALESCE((extra->>'quota_weekly_start')::timestamptz, '1970-01-01'::timestamptz)
|
CASE WHEN `+weeklyExpiredExpr+`
|
||||||
+ '168 hours'::interval <= NOW()
|
|
||||||
THEN `+nowUTC+`
|
THEN `+nowUTC+`
|
||||||
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
ELSE COALESCE(extra->>'quota_weekly_start', `+nowUTC+`) END
|
||||||
)
|
)
|
||||||
|
|| CASE WHEN `+weeklyExpiredExpr+` AND `+nextWeeklyResetAtExpr+` IS NOT NULL
|
||||||
|
THEN jsonb_build_object('quota_weekly_reset_at', `+nextWeeklyResetAtExpr+`)
|
||||||
|
ELSE '{}'::jsonb END
|
||||||
ELSE '{}'::jsonb END
|
ELSE '{}'::jsonb END
|
||||||
), updated_at = NOW()
|
), updated_at = NOW()
|
||||||
WHERE id = $2 AND deleted_at IS NULL
|
WHERE id = $2 AND deleted_at IS NULL
|
||||||
RETURNING
|
RETURNING
|
||||||
COALESCE((extra->>'quota_used')::numeric, 0),
|
COALESCE((extra->>'quota_used')::numeric, 0),
|
||||||
COALESCE((extra->>'quota_limit')::numeric, 0)`,
|
COALESCE((extra->>'quota_limit')::numeric, 0),
|
||||||
|
COALESCE((extra->>'quota_daily_used')::numeric, 0),
|
||||||
|
COALESCE((extra->>'quota_daily_limit')::numeric, 0),
|
||||||
|
COALESCE((extra->>'quota_weekly_used')::numeric, 0),
|
||||||
|
COALESCE((extra->>'quota_weekly_limit')::numeric, 0)`,
|
||||||
amount, accountID)
|
amount, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer func() { _ = rows.Close() }()
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
var newUsed, limit float64
|
var state service.AccountQuotaState
|
||||||
if rows.Next() {
|
if rows.Next() {
|
||||||
if err := rows.Scan(&newUsed, &limit); err != nil {
|
if err := rows.Scan(
|
||||||
return err
|
&state.TotalUsed, &state.TotalLimit,
|
||||||
|
&state.DailyUsed, &state.DailyLimit,
|
||||||
|
&state.WeeklyUsed, &state.WeeklyLimit,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
return service.ErrAccountNotFound
|
return nil, service.ErrAccountNotFound
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
if err := rows.Err(); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
if limit > 0 && newUsed >= limit && (newUsed-amount) < limit {
|
if state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit {
|
||||||
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
|
if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil {
|
||||||
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
|
logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return &state, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -28,7 +28,7 @@ import (
|
|||||||
gocache "github.com/patrickmn/go-cache"
|
gocache "github.com/patrickmn/go-cache"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at"
|
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at"
|
||||||
|
|
||||||
// usageLogInsertArgTypes must stay in the same order as:
|
// usageLogInsertArgTypes must stay in the same order as:
|
||||||
// 1. prepareUsageLogInsert().args
|
// 1. prepareUsageLogInsert().args
|
||||||
@ -82,6 +82,7 @@ var usageLogInsertArgTypes = [...]string{
|
|||||||
"text", // model_mapping_chain
|
"text", // model_mapping_chain
|
||||||
"text", // billing_tier
|
"text", // billing_tier
|
||||||
"text", // billing_mode
|
"text", // billing_mode
|
||||||
|
"numeric", // account_stats_cost
|
||||||
"timestamptz", // created_at
|
"timestamptz", // created_at
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -360,6 +361,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
|||||||
model_mapping_chain,
|
model_mapping_chain,
|
||||||
billing_tier,
|
billing_tier,
|
||||||
billing_mode,
|
billing_mode,
|
||||||
|
account_stats_cost,
|
||||||
created_at
|
created_at
|
||||||
) VALUES (
|
) VALUES (
|
||||||
$1, $2, $3, $4, $5, $6, $7,
|
$1, $2, $3, $4, $5, $6, $7,
|
||||||
@ -367,7 +369,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
|
|||||||
$10, $11, $12, $13,
|
$10, $11, $12, $13,
|
||||||
$14, $15, $16, $17,
|
$14, $15, $16, $17,
|
||||||
$18, $19, $20, $21, $22, $23,
|
$18, $19, $20, $21, $22, $23,
|
||||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
|
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
|
||||||
)
|
)
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
RETURNING id, created_at
|
RETURNING id, created_at
|
||||||
@ -797,6 +799,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
model_mapping_chain,
|
model_mapping_chain,
|
||||||
billing_tier,
|
billing_tier,
|
||||||
billing_mode,
|
billing_mode,
|
||||||
|
account_stats_cost,
|
||||||
created_at
|
created_at
|
||||||
) AS (VALUES `)
|
) AS (VALUES `)
|
||||||
|
|
||||||
@ -873,6 +876,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
model_mapping_chain,
|
model_mapping_chain,
|
||||||
billing_tier,
|
billing_tier,
|
||||||
billing_mode,
|
billing_mode,
|
||||||
|
account_stats_cost,
|
||||||
created_at
|
created_at
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
@ -920,6 +924,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
|
|||||||
model_mapping_chain,
|
model_mapping_chain,
|
||||||
billing_tier,
|
billing_tier,
|
||||||
billing_mode,
|
billing_mode,
|
||||||
|
account_stats_cost,
|
||||||
created_at
|
created_at
|
||||||
FROM input
|
FROM input
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
@ -1007,10 +1012,11 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
|||||||
model_mapping_chain,
|
model_mapping_chain,
|
||||||
billing_tier,
|
billing_tier,
|
||||||
billing_mode,
|
billing_mode,
|
||||||
|
account_stats_cost,
|
||||||
created_at
|
created_at
|
||||||
) AS (VALUES `)
|
) AS (VALUES `)
|
||||||
|
|
||||||
args := make([]any, 0, len(preparedList)*45)
|
args := make([]any, 0, len(preparedList)*46)
|
||||||
argPos := 1
|
argPos := 1
|
||||||
for idx, prepared := range preparedList {
|
for idx, prepared := range preparedList {
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
@ -1080,6 +1086,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
|||||||
model_mapping_chain,
|
model_mapping_chain,
|
||||||
billing_tier,
|
billing_tier,
|
||||||
billing_mode,
|
billing_mode,
|
||||||
|
account_stats_cost,
|
||||||
created_at
|
created_at
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
@ -1127,6 +1134,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
|
|||||||
model_mapping_chain,
|
model_mapping_chain,
|
||||||
billing_tier,
|
billing_tier,
|
||||||
billing_mode,
|
billing_mode,
|
||||||
|
account_stats_cost,
|
||||||
created_at
|
created_at
|
||||||
FROM input
|
FROM input
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
@ -1182,6 +1190,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
|||||||
model_mapping_chain,
|
model_mapping_chain,
|
||||||
billing_tier,
|
billing_tier,
|
||||||
billing_mode,
|
billing_mode,
|
||||||
|
account_stats_cost,
|
||||||
created_at
|
created_at
|
||||||
) VALUES (
|
) VALUES (
|
||||||
$1, $2, $3, $4, $5, $6, $7,
|
$1, $2, $3, $4, $5, $6, $7,
|
||||||
@ -1189,7 +1198,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
|
|||||||
$10, $11, $12, $13,
|
$10, $11, $12, $13,
|
||||||
$14, $15, $16, $17,
|
$14, $15, $16, $17,
|
||||||
$18, $19, $20, $21, $22, $23,
|
$18, $19, $20, $21, $22, $23,
|
||||||
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45
|
$24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
|
||||||
)
|
)
|
||||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||||
`, prepared.args...)
|
`, prepared.args...)
|
||||||
@ -1285,6 +1294,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
|
|||||||
modelMappingChain,
|
modelMappingChain,
|
||||||
billingTier,
|
billingTier,
|
||||||
billingMode,
|
billingMode,
|
||||||
|
log.AccountStatsCost, // account_stats_cost
|
||||||
createdAt,
|
createdAt,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -1518,6 +1528,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
|
|||||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||||
|
COALESCE(SUM(account_cost), 0) as total_account_cost,
|
||||||
COALESCE(SUM(total_duration_ms), 0) as total_duration_ms
|
COALESCE(SUM(total_duration_ms), 0) as total_duration_ms
|
||||||
FROM usage_dashboard_daily
|
FROM usage_dashboard_daily
|
||||||
`
|
`
|
||||||
@ -1534,6 +1545,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
|
|||||||
&stats.TotalCacheReadTokens,
|
&stats.TotalCacheReadTokens,
|
||||||
&stats.TotalCost,
|
&stats.TotalCost,
|
||||||
&stats.TotalActualCost,
|
&stats.TotalActualCost,
|
||||||
|
&stats.TotalAccountCost,
|
||||||
&totalDurationMs,
|
&totalDurationMs,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -1552,6 +1564,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
|
|||||||
cache_read_tokens as today_cache_read_tokens,
|
cache_read_tokens as today_cache_read_tokens,
|
||||||
total_cost as today_cost,
|
total_cost as today_cost,
|
||||||
actual_cost as today_actual_cost,
|
actual_cost as today_actual_cost,
|
||||||
|
account_cost as today_account_cost,
|
||||||
active_users as active_users
|
active_users as active_users
|
||||||
FROM usage_dashboard_daily
|
FROM usage_dashboard_daily
|
||||||
WHERE bucket_date = $1::date
|
WHERE bucket_date = $1::date
|
||||||
@ -1568,6 +1581,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
|
|||||||
&stats.TodayCacheReadTokens,
|
&stats.TodayCacheReadTokens,
|
||||||
&stats.TodayCost,
|
&stats.TodayCost,
|
||||||
&stats.TodayActualCost,
|
&stats.TodayActualCost,
|
||||||
|
&stats.TodayAccountCost,
|
||||||
&stats.ActiveUsers,
|
&stats.ActiveUsers,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
if err != sql.ErrNoRows {
|
if err != sql.ErrNoRows {
|
||||||
@ -1603,6 +1617,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
|
|||||||
cache_read_tokens,
|
cache_read_tokens,
|
||||||
total_cost,
|
total_cost,
|
||||||
actual_cost,
|
actual_cost,
|
||||||
|
COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1) AS account_cost,
|
||||||
COALESCE(duration_ms, 0) AS duration_ms
|
COALESCE(duration_ms, 0) AS duration_ms
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz)
|
WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz)
|
||||||
@ -1616,6 +1631,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
|
|||||||
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_read_tokens,
|
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_read_tokens,
|
||||||
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cost,
|
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cost,
|
||||||
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_actual_cost,
|
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_actual_cost,
|
||||||
|
COALESCE(SUM(account_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_account_cost,
|
||||||
COALESCE(SUM(duration_ms) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_duration_ms,
|
COALESCE(SUM(duration_ms) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_duration_ms,
|
||||||
COUNT(*) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz) AS today_requests,
|
COUNT(*) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz) AS today_requests,
|
||||||
COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_input_tokens,
|
COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_input_tokens,
|
||||||
@ -1623,7 +1639,8 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
|
|||||||
COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_creation_tokens,
|
COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_creation_tokens,
|
||||||
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_read_tokens,
|
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_read_tokens,
|
||||||
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cost,
|
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cost,
|
||||||
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost
|
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost,
|
||||||
|
COALESCE(SUM(account_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_account_cost
|
||||||
FROM scoped
|
FROM scoped
|
||||||
`
|
`
|
||||||
var totalDurationMs int64
|
var totalDurationMs int64
|
||||||
@ -1639,6 +1656,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
|
|||||||
&stats.TotalCacheReadTokens,
|
&stats.TotalCacheReadTokens,
|
||||||
&stats.TotalCost,
|
&stats.TotalCost,
|
||||||
&stats.TotalActualCost,
|
&stats.TotalActualCost,
|
||||||
|
&stats.TotalAccountCost,
|
||||||
&totalDurationMs,
|
&totalDurationMs,
|
||||||
&stats.TodayRequests,
|
&stats.TodayRequests,
|
||||||
&stats.TodayInputTokens,
|
&stats.TodayInputTokens,
|
||||||
@ -1647,6 +1665,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
|
|||||||
&stats.TodayCacheReadTokens,
|
&stats.TodayCacheReadTokens,
|
||||||
&stats.TodayCost,
|
&stats.TodayCost,
|
||||||
&stats.TodayActualCost,
|
&stats.TodayActualCost,
|
||||||
|
&stats.TodayAccountCost,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -1959,7 +1978,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
|||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||||
COALESCE(SUM(total_cost), 0) as standard_cost,
|
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
@ -1989,7 +2008,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
|||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||||
COALESCE(SUM(total_cost), 0) as standard_cost,
|
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
@ -2026,7 +2045,7 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc
|
|||||||
account_id,
|
account_id,
|
||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||||
COALESCE(SUM(total_cost), 0) as standard_cost,
|
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
@ -2585,7 +2604,8 @@ func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64
|
|||||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as cost,
|
COALESCE(SUM(total_cost), 0) as cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
COALESCE(SUM(actual_cost), 0) as actual_cost,
|
||||||
|
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as account_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
|
WHERE user_id = $1 AND created_at >= $2 AND created_at < $3
|
||||||
GROUP BY model
|
GROUP BY model
|
||||||
@ -2990,8 +3010,9 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex
|
|||||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||||
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
|
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
|
||||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||||
}
|
}
|
||||||
|
accountCostExpr := "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as account_cost"
|
||||||
modelExpr := resolveModelDimensionExpression(source)
|
modelExpr := resolveModelDimensionExpression(source)
|
||||||
|
|
||||||
query := fmt.Sprintf(`
|
query := fmt.Sprintf(`
|
||||||
@ -3004,10 +3025,11 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex
|
|||||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as cost,
|
COALESCE(SUM(total_cost), 0) as cost,
|
||||||
|
%s,
|
||||||
%s
|
%s
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE created_at >= $1 AND created_at < $2
|
WHERE created_at >= $1 AND created_at < $2
|
||||||
`, modelExpr, actualCostExpr)
|
`, modelExpr, actualCostExpr, accountCostExpr)
|
||||||
|
|
||||||
args := []any{startTime, endTime}
|
args := []any{startTime, endTime}
|
||||||
if userID > 0 {
|
if userID > 0 {
|
||||||
@ -3062,7 +3084,8 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
|
|||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
|
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
|
||||||
COALESCE(SUM(ul.total_cost), 0) as cost,
|
COALESCE(SUM(ul.total_cost), 0) as cost,
|
||||||
COALESCE(SUM(ul.actual_cost), 0) as actual_cost
|
COALESCE(SUM(ul.actual_cost), 0) as actual_cost,
|
||||||
|
COALESCE(SUM(COALESCE(ul.account_stats_cost, ul.total_cost) * COALESCE(ul.account_rate_multiplier, 1)), 0) as account_cost
|
||||||
FROM usage_logs ul
|
FROM usage_logs ul
|
||||||
LEFT JOIN groups g ON g.id = ul.group_id
|
LEFT JOIN groups g ON g.id = ul.group_id
|
||||||
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||||
@ -3113,6 +3136,7 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start
|
|||||||
&row.TotalTokens,
|
&row.TotalTokens,
|
||||||
&row.Cost,
|
&row.Cost,
|
||||||
&row.ActualCost,
|
&row.ActualCost,
|
||||||
|
&row.AccountCost,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -3133,7 +3157,8 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
|
|||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
|
COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
|
||||||
COALESCE(SUM(ul.total_cost), 0) as cost,
|
COALESCE(SUM(ul.total_cost), 0) as cost,
|
||||||
COALESCE(SUM(ul.actual_cost), 0) as actual_cost
|
COALESCE(SUM(ul.actual_cost), 0) as actual_cost,
|
||||||
|
COALESCE(SUM(COALESCE(ul.account_stats_cost, ul.total_cost) * COALESCE(ul.account_rate_multiplier, 1)), 0) as account_cost
|
||||||
FROM usage_logs ul
|
FROM usage_logs ul
|
||||||
LEFT JOIN users u ON u.id = ul.user_id
|
LEFT JOIN users u ON u.id = ul.user_id
|
||||||
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||||
@ -3204,6 +3229,7 @@ func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTim
|
|||||||
&row.TotalTokens,
|
&row.TotalTokens,
|
||||||
&row.Cost,
|
&row.Cost,
|
||||||
&row.ActualCost,
|
&row.ActualCost,
|
||||||
|
&row.AccountCost,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -3358,7 +3384,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
|||||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
|
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
|
||||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
%s
|
%s
|
||||||
@ -3382,9 +3408,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
|||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if filters.AccountID > 0 {
|
stats.TotalAccountCost = &totalAccountCost
|
||||||
stats.TotalAccountCost = &totalAccountCost
|
|
||||||
}
|
|
||||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||||
|
|
||||||
start := time.Unix(0, 0).UTC()
|
start := time.Unix(0, 0).UTC()
|
||||||
@ -3433,7 +3457,7 @@ type EndpointStat = usagestats.EndpointStat
|
|||||||
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
||||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||||
}
|
}
|
||||||
|
|
||||||
query := fmt.Sprintf(`
|
query := fmt.Sprintf(`
|
||||||
@ -3500,7 +3524,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con
|
|||||||
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) {
|
||||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||||
}
|
}
|
||||||
|
|
||||||
query := fmt.Sprintf(`
|
query := fmt.Sprintf(`
|
||||||
@ -3591,7 +3615,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
|||||||
COUNT(*) as requests,
|
COUNT(*) as requests,
|
||||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||||
COALESCE(SUM(total_cost), 0) as cost,
|
COALESCE(SUM(total_cost), 0) as cost,
|
||||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
|
COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
|
||||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||||
FROM usage_logs
|
FROM usage_logs
|
||||||
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
|
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
|
||||||
@ -4069,6 +4093,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
modelMappingChain sql.NullString
|
modelMappingChain sql.NullString
|
||||||
billingTier sql.NullString
|
billingTier sql.NullString
|
||||||
billingMode sql.NullString
|
billingMode sql.NullString
|
||||||
|
accountStatsCost sql.NullFloat64
|
||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -4118,6 +4143,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
&modelMappingChain,
|
&modelMappingChain,
|
||||||
&billingTier,
|
&billingTier,
|
||||||
&billingMode,
|
&billingMode,
|
||||||
|
&accountStatsCost,
|
||||||
&createdAt,
|
&createdAt,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -4214,6 +4240,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
|||||||
if billingMode.Valid {
|
if billingMode.Valid {
|
||||||
log.BillingMode = &billingMode.String
|
log.BillingMode = &billingMode.String
|
||||||
}
|
}
|
||||||
|
if accountStatsCost.Valid {
|
||||||
|
log.AccountStatsCost = &accountStatsCost.Float64
|
||||||
|
}
|
||||||
|
|
||||||
return log, nil
|
return log, nil
|
||||||
}
|
}
|
||||||
@ -4257,6 +4286,7 @@ func scanModelStatsRows(rows *sql.Rows) ([]ModelStat, error) {
|
|||||||
&row.TotalTokens,
|
&row.TotalTokens,
|
||||||
&row.Cost,
|
&row.Cost,
|
||||||
&row.ActualCost,
|
&row.ActualCost,
|
||||||
|
&row.AccountCost,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -753,8 +753,11 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
|||||||
s.Require().Equal(baseStats.TotalTokens+int64(51), stats.TotalTokens, "TotalTokens mismatch")
|
s.Require().Equal(baseStats.TotalTokens+int64(51), stats.TotalTokens, "TotalTokens mismatch")
|
||||||
s.Require().Equal(baseStats.TotalCost+2.3, stats.TotalCost, "TotalCost mismatch")
|
s.Require().Equal(baseStats.TotalCost+2.3, stats.TotalCost, "TotalCost mismatch")
|
||||||
s.Require().Equal(baseStats.TotalActualCost+2.0, stats.TotalActualCost, "TotalActualCost mismatch")
|
s.Require().Equal(baseStats.TotalActualCost+2.0, stats.TotalActualCost, "TotalActualCost mismatch")
|
||||||
|
// account_cost falls back to total_cost when account_stats_cost is NULL
|
||||||
|
s.Require().Equal(baseStats.TotalAccountCost+2.3, stats.TotalAccountCost, "TotalAccountCost mismatch")
|
||||||
s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1")
|
s.Require().GreaterOrEqual(stats.TodayRequests, int64(1), "expected TodayRequests >= 1")
|
||||||
s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0")
|
s.Require().GreaterOrEqual(stats.TodayCost, 0.0, "expected TodayCost >= 0")
|
||||||
|
s.Require().GreaterOrEqual(stats.TodayAccountCost, 0.0, "expected TodayAccountCost >= 0")
|
||||||
|
|
||||||
wantRpm, wantTpm, err := s.repo.getPerformanceStats(s.ctx, 0)
|
wantRpm, wantTpm, err := s.repo.getPerformanceStats(s.ctx, 0)
|
||||||
s.Require().NoError(err, "getPerformanceStats")
|
s.Require().NoError(err, "getPerformanceStats")
|
||||||
@ -833,6 +836,8 @@ func (s *UsageLogRepoSuite) TestDashboardStatsWithRange_Fallback() {
|
|||||||
s.Require().Equal(int64(45), stats.TotalTokens)
|
s.Require().Equal(int64(45), stats.TotalTokens)
|
||||||
s.Require().Equal(1.5, stats.TotalCost)
|
s.Require().Equal(1.5, stats.TotalCost)
|
||||||
s.Require().Equal(1.4, stats.TotalActualCost)
|
s.Require().Equal(1.4, stats.TotalActualCost)
|
||||||
|
// account_cost = COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1) = total_cost
|
||||||
|
s.Require().Equal(1.5, stats.TotalAccountCost)
|
||||||
s.Require().InEpsilon(150.0, stats.AverageDurationMs, 0.0001)
|
s.Require().InEpsilon(150.0, stats.AverageDurationMs, 0.0001)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -85,6 +85,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
|
|||||||
sqlmock.AnyArg(), // model_mapping_chain
|
sqlmock.AnyArg(), // model_mapping_chain
|
||||||
sqlmock.AnyArg(), // billing_tier
|
sqlmock.AnyArg(), // billing_tier
|
||||||
sqlmock.AnyArg(), // billing_mode
|
sqlmock.AnyArg(), // billing_mode
|
||||||
|
sqlmock.AnyArg(), // account_stats_cost
|
||||||
createdAt,
|
createdAt,
|
||||||
).
|
).
|
||||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
|
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
|
||||||
@ -163,6 +164,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
|
|||||||
sqlmock.AnyArg(), // model_mapping_chain
|
sqlmock.AnyArg(), // model_mapping_chain
|
||||||
sqlmock.AnyArg(), // billing_tier
|
sqlmock.AnyArg(), // billing_tier
|
||||||
sqlmock.AnyArg(), // billing_mode
|
sqlmock.AnyArg(), // billing_mode
|
||||||
|
sqlmock.AnyArg(), // account_stats_cost
|
||||||
createdAt,
|
createdAt,
|
||||||
).
|
).
|
||||||
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
|
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt))
|
||||||
@ -299,7 +301,7 @@ func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testin
|
|||||||
|
|
||||||
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
|
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
|
||||||
WithArgs(start, end, requestType).
|
WithArgs(start, end, requestType).
|
||||||
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost"}))
|
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "cache_creation_tokens", "cache_read_tokens", "total_tokens", "cost", "actual_cost", "account_cost"}))
|
||||||
|
|
||||||
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
|
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -344,6 +346,93 @@ func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T)
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, int64(1), stats.TotalRequests)
|
require.Equal(t, int64(1), stats.TotalRequests)
|
||||||
require.Equal(t, int64(9), stats.TotalTokens)
|
require.Equal(t, int64(9), stats.TotalTokens)
|
||||||
|
require.NotNil(t, stats.TotalAccountCost, "TotalAccountCost should always be returned")
|
||||||
|
require.Equal(t, 1.2, *stats.TotalAccountCost)
|
||||||
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryGetModelStatsAccountCostColumn(t *testing.T) {
|
||||||
|
db, mock := newSQLMock(t)
|
||||||
|
repo := &usageLogRepository{sql: db}
|
||||||
|
|
||||||
|
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
end := start.Add(24 * time.Hour)
|
||||||
|
|
||||||
|
mock.ExpectQuery("FROM usage_logs").
|
||||||
|
WithArgs(start, end).
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{
|
||||||
|
"model", "requests", "input_tokens", "output_tokens",
|
||||||
|
"cache_creation_tokens", "cache_read_tokens", "total_tokens",
|
||||||
|
"cost", "actual_cost", "account_cost",
|
||||||
|
}).
|
||||||
|
AddRow("claude-opus-4-6", int64(10), int64(100), int64(200), int64(5), int64(3), int64(308), 2.5, 2.0, 1.8).
|
||||||
|
AddRow("claude-sonnet-4-6", int64(5), int64(50), int64(100), int64(0), int64(0), int64(150), 1.0, 0.8, 0.7))
|
||||||
|
|
||||||
|
results, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, nil, nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, results, 2)
|
||||||
|
require.Equal(t, "claude-opus-4-6", results[0].Model)
|
||||||
|
require.Equal(t, 2.5, results[0].Cost)
|
||||||
|
require.Equal(t, 2.0, results[0].ActualCost)
|
||||||
|
require.Equal(t, 1.8, results[0].AccountCost)
|
||||||
|
require.Equal(t, "claude-sonnet-4-6", results[1].Model)
|
||||||
|
require.Equal(t, 0.7, results[1].AccountCost)
|
||||||
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryGetGroupStatsAccountCostColumn(t *testing.T) {
|
||||||
|
db, mock := newSQLMock(t)
|
||||||
|
repo := &usageLogRepository{sql: db}
|
||||||
|
|
||||||
|
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
end := start.Add(24 * time.Hour)
|
||||||
|
|
||||||
|
mock.ExpectQuery("FROM usage_logs").
|
||||||
|
WithArgs(start, end).
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{
|
||||||
|
"group_id", "group_name", "requests", "total_tokens",
|
||||||
|
"cost", "actual_cost", "account_cost",
|
||||||
|
}).
|
||||||
|
AddRow(int64(1), "azure-cc", int64(100), int64(5000), 10.0, 8.5, 7.2).
|
||||||
|
AddRow(int64(2), "max", int64(50), int64(2000), 5.0, 4.0, 3.5))
|
||||||
|
|
||||||
|
results, err := repo.GetGroupStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, nil, nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, results, 2)
|
||||||
|
require.Equal(t, int64(1), results[0].GroupID)
|
||||||
|
require.Equal(t, "azure-cc", results[0].GroupName)
|
||||||
|
require.Equal(t, 10.0, results[0].Cost)
|
||||||
|
require.Equal(t, 8.5, results[0].ActualCost)
|
||||||
|
require.Equal(t, 7.2, results[0].AccountCost)
|
||||||
|
require.Equal(t, int64(2), results[1].GroupID)
|
||||||
|
require.Equal(t, 3.5, results[1].AccountCost)
|
||||||
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageLogRepositoryGetStatsWithFiltersAlwaysReturnsAccountCost(t *testing.T) {
|
||||||
|
db, mock := newSQLMock(t)
|
||||||
|
repo := &usageLogRepository{sql: db}
|
||||||
|
|
||||||
|
// No AccountID filter set - TotalAccountCost should still be returned
|
||||||
|
filters := usagestats.UsageLogFilters{}
|
||||||
|
|
||||||
|
mock.ExpectQuery("FROM usage_logs").
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{
|
||||||
|
"total_requests", "total_input_tokens", "total_output_tokens",
|
||||||
|
"total_cache_tokens", "total_cost", "total_actual_cost",
|
||||||
|
"total_account_cost", "avg_duration_ms",
|
||||||
|
}).AddRow(int64(50), int64(1000), int64(2000), int64(100), 15.0, 12.5, 11.0, 100.0))
|
||||||
|
mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(inbound_endpoint\\)").
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
|
||||||
|
mock.ExpectQuery("SELECT COALESCE\\(NULLIF\\(TRIM\\(upstream_endpoint\\)").
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
|
||||||
|
mock.ExpectQuery("SELECT CONCAT\\(").
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"endpoint", "requests", "total_tokens", "cost", "actual_cost"}))
|
||||||
|
|
||||||
|
stats, err := repo.GetStatsWithFilters(context.Background(), filters)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, stats.TotalAccountCost, "TotalAccountCost must always be returned, even without AccountID filter")
|
||||||
|
require.Equal(t, 11.0, *stats.TotalAccountCost)
|
||||||
require.NoError(t, mock.ExpectationsWereMet())
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -483,10 +572,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
false,
|
false,
|
||||||
sql.NullInt64{}, // channel_id
|
sql.NullInt64{}, // channel_id
|
||||||
sql.NullString{}, // model_mapping_chain
|
sql.NullString{}, // model_mapping_chain
|
||||||
sql.NullString{}, // billing_tier
|
sql.NullString{}, // billing_tier
|
||||||
sql.NullString{}, // billing_mode
|
sql.NullString{}, // billing_mode
|
||||||
|
sql.NullFloat64{}, // account_stats_cost
|
||||||
now,
|
now,
|
||||||
}})
|
}})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -530,10 +620,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
false,
|
false,
|
||||||
sql.NullInt64{}, // channel_id
|
sql.NullInt64{}, // channel_id
|
||||||
sql.NullString{}, // model_mapping_chain
|
sql.NullString{}, // model_mapping_chain
|
||||||
sql.NullString{}, // billing_tier
|
sql.NullString{}, // billing_tier
|
||||||
sql.NullString{}, // billing_mode
|
sql.NullString{}, // billing_mode
|
||||||
|
sql.NullFloat64{}, // account_stats_cost
|
||||||
now,
|
now,
|
||||||
}})
|
}})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -577,10 +668,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
|
|||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
sql.NullString{},
|
sql.NullString{},
|
||||||
false,
|
false,
|
||||||
sql.NullInt64{}, // channel_id
|
sql.NullInt64{}, // channel_id
|
||||||
sql.NullString{}, // model_mapping_chain
|
sql.NullString{}, // model_mapping_chain
|
||||||
sql.NullString{}, // billing_tier
|
sql.NullString{}, // billing_tier
|
||||||
sql.NullString{}, // billing_mode
|
sql.NullString{}, // billing_mode
|
||||||
|
sql.NullFloat64{}, // account_stats_cost
|
||||||
now,
|
now,
|
||||||
}})
|
}})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@ -100,7 +100,7 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6
|
|||||||
query := `
|
query := `
|
||||||
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
|
SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier
|
||||||
FROM user_group_rate_multipliers ugr
|
FROM user_group_rate_multipliers ugr
|
||||||
JOIN users u ON u.id = ugr.user_id
|
JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL
|
||||||
WHERE ugr.group_id = $1
|
WHERE ugr.group_id = $1
|
||||||
ORDER BY ugr.user_id
|
ORDER BY ugr.user_id
|
||||||
`
|
`
|
||||||
|
|||||||
@ -137,7 +137,7 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
|||||||
txClient = r.client
|
txClient = r.client
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := txClient.User.UpdateOneID(userIn.ID).
|
updateOp := txClient.User.UpdateOneID(userIn.ID).
|
||||||
SetEmail(userIn.Email).
|
SetEmail(userIn.Email).
|
||||||
SetUsername(userIn.Username).
|
SetUsername(userIn.Username).
|
||||||
SetNotes(userIn.Notes).
|
SetNotes(userIn.Notes).
|
||||||
@ -146,7 +146,15 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
|
|||||||
SetBalance(userIn.Balance).
|
SetBalance(userIn.Balance).
|
||||||
SetConcurrency(userIn.Concurrency).
|
SetConcurrency(userIn.Concurrency).
|
||||||
SetStatus(userIn.Status).
|
SetStatus(userIn.Status).
|
||||||
Save(ctx)
|
SetBalanceNotifyEnabled(userIn.BalanceNotifyEnabled).
|
||||||
|
SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType).
|
||||||
|
SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold).
|
||||||
|
SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)).
|
||||||
|
SetTotalRecharged(userIn.TotalRecharged)
|
||||||
|
if userIn.BalanceNotifyThreshold == nil {
|
||||||
|
updateOp = updateOp.ClearBalanceNotifyThreshold()
|
||||||
|
}
|
||||||
|
updated, err := updateOp.Save(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
|
||||||
}
|
}
|
||||||
@ -382,7 +390,12 @@ func (r *userRepository) filterUsersByAttributes(ctx context.Context, attrs map[
|
|||||||
|
|
||||||
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
func (r *userRepository) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||||
client := clientFromContext(ctx, r.client)
|
client := clientFromContext(ctx, r.client)
|
||||||
n, err := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount).Save(ctx)
|
update := client.User.Update().Where(dbuser.IDEQ(id)).AddBalance(amount)
|
||||||
|
// Track cumulative recharge amount for percentage-based notifications
|
||||||
|
if amount > 0 {
|
||||||
|
update = update.AddTotalRecharged(amount)
|
||||||
|
}
|
||||||
|
n, err := update.Save(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
return translatePersistenceError(err, service.ErrUserNotFound, nil)
|
||||||
}
|
}
|
||||||
@ -549,6 +562,11 @@ func applyUserEntityToService(dst *service.User, src *dbent.User) {
|
|||||||
dst.UpdatedAt = src.UpdatedAt
|
dst.UpdatedAt = src.UpdatedAt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// marshalExtraEmails serializes notify email entries to JSON for storage.
|
||||||
|
func marshalExtraEmails(entries []service.NotifyEmailEntry) string {
|
||||||
|
return service.MarshalNotifyEmails(entries)
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
|
// UpdateTotpSecret 更新用户的 TOTP 加密密钥
|
||||||
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
func (r *userRepository) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||||
client := clientFromContext(ctx, r.client)
|
client := clientFromContext(ctx, r.client)
|
||||||
|
|||||||
@ -58,6 +58,11 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"allowed_groups": null,
|
"allowed_groups": null,
|
||||||
"created_at": "2025-01-02T03:04:05Z",
|
"created_at": "2025-01-02T03:04:05Z",
|
||||||
"updated_at": "2025-01-02T03:04:05Z",
|
"updated_at": "2025-01-02T03:04:05Z",
|
||||||
|
"balance_notify_enabled": false,
|
||||||
|
"balance_notify_threshold_type": "",
|
||||||
|
"balance_notify_threshold": null,
|
||||||
|
"balance_notify_extra_emails": null,
|
||||||
|
"total_recharged": 0,
|
||||||
"run_mode": "standard"
|
"run_mode": "standard"
|
||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
@ -204,11 +209,10 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"image_price_1k": null,
|
"image_price_1k": null,
|
||||||
"image_price_2k": null,
|
"image_price_2k": null,
|
||||||
"image_price_4k": null,
|
"image_price_4k": null,
|
||||||
"claude_code_only": false,
|
"claude_code_only": false,
|
||||||
"allow_messages_dispatch": false,
|
"allow_messages_dispatch": false,
|
||||||
"fallback_group_id": null,
|
"fallback_group_id": null,
|
||||||
"fallback_group_id_on_invalid_request": null,
|
"fallback_group_id_on_invalid_request": null,
|
||||||
"allow_messages_dispatch": false,
|
|
||||||
"require_oauth_only": false,
|
"require_oauth_only": false,
|
||||||
"require_privacy_set": false,
|
"require_privacy_set": false,
|
||||||
"created_at": "2025-01-02T03:04:05Z",
|
"created_at": "2025-01-02T03:04:05Z",
|
||||||
@ -587,26 +591,34 @@ func TestAPIContracts(t *testing.T) {
|
|||||||
"enable_cch_signing": false,
|
"enable_cch_signing": false,
|
||||||
"enable_fingerprint_unification": true,
|
"enable_fingerprint_unification": true,
|
||||||
"enable_metadata_passthrough": false,
|
"enable_metadata_passthrough": false,
|
||||||
|
"web_search_emulation_enabled": false,
|
||||||
|
"custom_menu_items": [],
|
||||||
|
"custom_endpoints": [],
|
||||||
"payment_enabled": false,
|
"payment_enabled": false,
|
||||||
"payment_min_amount": 0,
|
"payment_min_amount": 0,
|
||||||
"payment_max_amount": 0,
|
"payment_max_amount": 0,
|
||||||
"payment_daily_limit": 0,
|
"payment_daily_limit": 0,
|
||||||
"payment_order_timeout_minutes": 0,
|
"payment_order_timeout_minutes": 0,
|
||||||
"payment_max_pending_orders": 0,
|
"payment_max_pending_orders": 0,
|
||||||
"payment_enabled_types": null,
|
|
||||||
"payment_balance_disabled": false,
|
"payment_balance_disabled": false,
|
||||||
|
"payment_balance_recharge_multiplier": 0,
|
||||||
|
"payment_recharge_fee_rate": 0,
|
||||||
"payment_load_balance_strategy": "",
|
"payment_load_balance_strategy": "",
|
||||||
"payment_product_name_prefix": "",
|
"payment_product_name_prefix": "",
|
||||||
"payment_product_name_suffix": "",
|
"payment_product_name_suffix": "",
|
||||||
"payment_help_image_url": "",
|
"payment_help_image_url": "",
|
||||||
"payment_help_text": "",
|
"payment_help_text": "",
|
||||||
|
"payment_enabled_types": null,
|
||||||
"payment_cancel_rate_limit_enabled": false,
|
"payment_cancel_rate_limit_enabled": false,
|
||||||
"payment_cancel_rate_limit_max": 0,
|
"payment_cancel_rate_limit_max": 0,
|
||||||
"payment_cancel_rate_limit_window": 0,
|
"payment_cancel_rate_limit_window": 0,
|
||||||
"payment_cancel_rate_limit_unit": "",
|
"payment_cancel_rate_limit_unit": "",
|
||||||
"payment_cancel_rate_limit_window_mode": "",
|
"payment_cancel_rate_limit_window_mode": "",
|
||||||
"custom_menu_items": [],
|
"balance_low_notify_enabled": false,
|
||||||
"custom_endpoints": []
|
"account_quota_notify_enabled": false,
|
||||||
|
"balance_low_notify_threshold": 0,
|
||||||
|
"balance_low_notify_recharge_url": "",
|
||||||
|
"account_quota_notify_emails": []
|
||||||
}
|
}
|
||||||
}`,
|
}`,
|
||||||
},
|
},
|
||||||
@ -699,7 +711,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
|||||||
RunMode: config.RunModeStandard,
|
RunMode: config.RunModeStandard,
|
||||||
}
|
}
|
||||||
|
|
||||||
userService := service.NewUserService(userRepo, nil, nil)
|
userService := service.NewUserService(userRepo, nil, nil, nil)
|
||||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
|
apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg)
|
||||||
|
|
||||||
usageRepo := newStubUsageLogRepo()
|
usageRepo := newStubUsageLogRepo()
|
||||||
|
|||||||
@ -2,12 +2,15 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"log"
|
"log"
|
||||||
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/pkg/websearch"
|
||||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
@ -36,7 +39,6 @@ func ProvideRouter(
|
|||||||
opsService *service.OpsService,
|
opsService *service.OpsService,
|
||||||
settingService *service.SettingService,
|
settingService *service.SettingService,
|
||||||
redisClient *redis.Client,
|
redisClient *redis.Client,
|
||||||
langServerService *service.LanguageServerService,
|
|
||||||
) *gin.Engine {
|
) *gin.Engine {
|
||||||
if cfg.Server.Mode == "release" {
|
if cfg.Server.Mode == "release" {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
@ -57,7 +59,43 @@ func ProvideRouter(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient, langServerService)
|
// Wire up websearch Manager builder so it initializes on startup and rebuilds on config save.
|
||||||
|
settingService.SetWebSearchManagerBuilder(context.Background(), func(cfg *service.WebSearchEmulationConfig, proxyURLs map[int64]string) {
|
||||||
|
if cfg == nil || !cfg.Enabled || len(cfg.Providers) == 0 {
|
||||||
|
service.SetWebSearchManager(nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
configs := make([]websearch.ProviderConfig, 0, len(cfg.Providers))
|
||||||
|
for _, p := range cfg.Providers {
|
||||||
|
if p.APIKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pc := websearch.ProviderConfig{
|
||||||
|
Type: p.Type,
|
||||||
|
APIKey: p.APIKey,
|
||||||
|
QuotaLimit: derefInt64(p.QuotaLimit),
|
||||||
|
ExpiresAt: p.ExpiresAt,
|
||||||
|
}
|
||||||
|
if p.SubscribedAt != nil {
|
||||||
|
pc.SubscribedAt = p.SubscribedAt
|
||||||
|
}
|
||||||
|
if p.ProxyID != nil {
|
||||||
|
pc.ProxyID = *p.ProxyID
|
||||||
|
if u, ok := proxyURLs[*p.ProxyID]; ok {
|
||||||
|
pc.ProxyURL = u
|
||||||
|
} else {
|
||||||
|
// Proxy configured but not found — skip this provider to prevent direct connection.
|
||||||
|
slog.Warn("websearch: proxy not found for provider, skipping",
|
||||||
|
"provider", p.Type, "proxy_id", *p.ProxyID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
configs = append(configs, pc)
|
||||||
|
}
|
||||||
|
service.SetWebSearchManager(websearch.NewManager(configs, redisClient))
|
||||||
|
})
|
||||||
|
|
||||||
|
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProvideHTTPServer 提供 HTTP 服务器
|
// ProvideHTTPServer 提供 HTTP 服务器
|
||||||
@ -103,3 +141,10 @@ func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
|
|||||||
// 不设置 ReadTimeout,因为大请求体可能需要较长时间读取
|
// 不设置 ReadTimeout,因为大请求体可能需要较长时间读取
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func derefInt64(p *int64) int64 {
|
||||||
|
if p == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return *p
|
||||||
|
}
|
||||||
|
|||||||
@ -39,7 +39,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
|
|||||||
return &clone, nil
|
return &clone, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
userService := service.NewUserService(userRepo, nil, nil)
|
userService := service.NewUserService(userRepo, nil, nil, nil)
|
||||||
|
|
||||||
router := gin.New()
|
router := gin.New()
|
||||||
router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil)))
|
router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil)))
|
||||||
|
|||||||
@ -41,7 +41,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
|
|||||||
|
|
||||||
userRepo := &stubJWTUserRepo{users: users}
|
userRepo := &stubJWTUserRepo{users: users}
|
||||||
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
authSvc := service.NewAuthService(nil, userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||||
userSvc := service.NewUserService(userRepo, nil, nil)
|
userSvc := service.NewUserService(userRepo, nil, nil, nil)
|
||||||
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
mw := NewJWTAuthMiddleware(authSvc, userSvc)
|
||||||
|
|
||||||
r := gin.New()
|
r := gin.New()
|
||||||
|
|||||||
@ -18,6 +18,8 @@ const (
|
|||||||
NonceTemplate = "__CSP_NONCE__"
|
NonceTemplate = "__CSP_NONCE__"
|
||||||
// CloudflareInsightsDomain is the domain for Cloudflare Web Analytics
|
// CloudflareInsightsDomain is the domain for Cloudflare Web Analytics
|
||||||
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
|
CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
|
||||||
|
// StripeDomain is the domain for Stripe.js SDK
|
||||||
|
StripeDomain = "https://*.stripe.com"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GenerateNonce generates a cryptographically secure random nonce.
|
// GenerateNonce generates a cryptographically secure random nonce.
|
||||||
@ -97,8 +99,9 @@ func isAPIRoutePath(c *gin.Context) bool {
|
|||||||
strings.HasPrefix(path, "/responses")
|
strings.HasPrefix(path, "/responses")
|
||||||
}
|
}
|
||||||
|
|
||||||
// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
|
// enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights,
|
||||||
// This allows the application to work correctly even if the config file has an older CSP policy.
|
// and Stripe.js domains. This allows the application to work correctly even if the
|
||||||
|
// config file has an older CSP policy.
|
||||||
func enhanceCSPPolicy(policy string) string {
|
func enhanceCSPPolicy(policy string) string {
|
||||||
// Add nonce placeholder to script-src if not present
|
// Add nonce placeholder to script-src if not present
|
||||||
if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") {
|
if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") {
|
||||||
@ -110,6 +113,12 @@ func enhanceCSPPolicy(policy string) string {
|
|||||||
policy = addToDirective(policy, "script-src", CloudflareInsightsDomain)
|
policy = addToDirective(policy, "script-src", CloudflareInsightsDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add Stripe.js domain to script-src and frame-src if not present
|
||||||
|
if !strings.Contains(policy, "stripe.com") {
|
||||||
|
policy = addToDirective(policy, "script-src", StripeDomain)
|
||||||
|
policy = addToDirective(policy, "frame-src", StripeDomain)
|
||||||
|
}
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -407,6 +407,11 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
|||||||
// Beta 策略配置
|
// Beta 策略配置
|
||||||
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
|
adminSettings.GET("/beta-policy", h.Admin.Setting.GetBetaPolicySettings)
|
||||||
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
|
adminSettings.PUT("/beta-policy", h.Admin.Setting.UpdateBetaPolicySettings)
|
||||||
|
// Web Search 模拟配置
|
||||||
|
adminSettings.GET("/web-search-emulation", h.Admin.Setting.GetWebSearchEmulationConfig)
|
||||||
|
adminSettings.PUT("/web-search-emulation", h.Admin.Setting.UpdateWebSearchEmulationConfig)
|
||||||
|
adminSettings.POST("/web-search-emulation/test", h.Admin.Setting.TestWebSearchEmulation)
|
||||||
|
adminSettings.POST("/web-search-emulation/reset-usage", h.Admin.Setting.ResetWebSearchUsage)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -39,6 +39,7 @@ func RegisterPaymentRoutes(
|
|||||||
orders.GET("/:id", paymentHandler.GetOrder)
|
orders.GET("/:id", paymentHandler.GetOrder)
|
||||||
orders.POST("/:id/cancel", paymentHandler.CancelOrder)
|
orders.POST("/:id/cancel", paymentHandler.CancelOrder)
|
||||||
orders.POST("/:id/refund-request", paymentHandler.RequestRefund)
|
orders.POST("/:id/refund-request", paymentHandler.RequestRefund)
|
||||||
|
orders.GET("/refund-eligible-providers", paymentHandler.GetRefundEligibleProviders)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -26,6 +26,15 @@ func RegisterUserRoutes(
|
|||||||
user.PUT("/password", h.User.ChangePassword)
|
user.PUT("/password", h.User.ChangePassword)
|
||||||
user.PUT("", h.User.UpdateProfile)
|
user.PUT("", h.User.UpdateProfile)
|
||||||
|
|
||||||
|
// 通知邮箱管理
|
||||||
|
notifyEmail := user.Group("/notify-email")
|
||||||
|
{
|
||||||
|
notifyEmail.POST("/send-code", h.User.SendNotifyEmailCode)
|
||||||
|
notifyEmail.POST("/verify", h.User.VerifyNotifyEmail)
|
||||||
|
notifyEmail.PUT("/toggle", h.User.ToggleNotifyEmail)
|
||||||
|
notifyEmail.DELETE("", h.User.RemoveNotifyEmail)
|
||||||
|
}
|
||||||
|
|
||||||
// TOTP 双因素认证
|
// TOTP 双因素认证
|
||||||
totp := user.Group("/totp")
|
totp := user.Group("/totp")
|
||||||
{
|
{
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
|
"log/slog"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -969,7 +970,7 @@ func (a *Account) IsOveragesEnabled() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。
|
// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用"自动透传(仅替换认证)"。
|
||||||
//
|
//
|
||||||
// 新字段:accounts.extra.openai_passthrough。
|
// 新字段:accounts.extra.openai_passthrough。
|
||||||
// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。
|
// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。
|
||||||
@ -1133,7 +1134,7 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
|
|||||||
return resolvedDefault
|
return resolvedDefault
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。
|
// IsOpenAIWSForceHTTPEnabled 返回账号级"强制 HTTP"开关。
|
||||||
// 字段:accounts.extra.openai_ws_force_http。
|
// 字段:accounts.extra.openai_ws_force_http。
|
||||||
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
|
func (a *Account) IsOpenAIWSForceHTTPEnabled() bool {
|
||||||
if a == nil || !a.IsOpenAI() || a.Extra == nil {
|
if a == nil || !a.IsOpenAI() || a.Extra == nil {
|
||||||
@ -1158,7 +1159,7 @@ func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool {
|
|||||||
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
|
return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用“自动透传(仅替换认证)”。
|
// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用"自动透传(仅替换认证)"。
|
||||||
// 字段:accounts.extra.anthropic_passthrough。
|
// 字段:accounts.extra.anthropic_passthrough。
|
||||||
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
||||||
func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
|
func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
|
||||||
@ -1169,7 +1170,42 @@ func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool {
|
|||||||
return ok && enabled
|
return ok && enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。
|
// WebSearch 模拟三态常量
|
||||||
|
const (
|
||||||
|
WebSearchModeDefault = "default" // 跟随渠道配置
|
||||||
|
WebSearchModeEnabled = "enabled" // 强制开启
|
||||||
|
WebSearchModeDisabled = "disabled" // 强制关闭
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetWebSearchEmulationMode 返回账号的 WebSearch 模拟模式。
|
||||||
|
// 三态:default(跟随渠道)/ enabled(强制开启)/ disabled(强制关闭)。
|
||||||
|
// 兼容旧 bool 值:true→enabled, false→default(并记录 debug 日志)。
|
||||||
|
func (a *Account) GetWebSearchEmulationMode() string {
|
||||||
|
if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil {
|
||||||
|
return WebSearchModeDefault
|
||||||
|
}
|
||||||
|
raw := a.Extra[featureKeyWebSearchEmulation]
|
||||||
|
// Tolerant: legacy bool values (pre-migration or stale writes)
|
||||||
|
if b, ok := raw.(bool); ok {
|
||||||
|
slog.Debug("legacy bool web_search_emulation value", "account_id", a.ID, "value", b)
|
||||||
|
if b {
|
||||||
|
return WebSearchModeEnabled
|
||||||
|
}
|
||||||
|
return WebSearchModeDefault
|
||||||
|
}
|
||||||
|
mode, ok := raw.(string)
|
||||||
|
if !ok {
|
||||||
|
return WebSearchModeDefault
|
||||||
|
}
|
||||||
|
switch mode {
|
||||||
|
case WebSearchModeEnabled, WebSearchModeDisabled:
|
||||||
|
return mode
|
||||||
|
default:
|
||||||
|
return WebSearchModeDefault
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用"仅允许 Codex 官方客户端"。
|
||||||
// 字段:accounts.extra.codex_cli_only。
|
// 字段:accounts.extra.codex_cli_only。
|
||||||
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
// 字段缺失或类型不正确时,按 false(关闭)处理。
|
||||||
func (a *Account) IsCodexCLIOnlyEnabled() bool {
|
func (a *Account) IsCodexCLIOnlyEnabled() bool {
|
||||||
@ -1395,6 +1431,19 @@ func (a *Account) getExtraTime(key string) time.Time {
|
|||||||
return time.Time{}
|
return time.Time{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getExtraBool 从 Extra 中读取指定 key 的 bool 值
|
||||||
|
func (a *Account) getExtraBool(key string) bool {
|
||||||
|
if a.Extra == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if v, ok := a.Extra[key]; ok {
|
||||||
|
if b, ok := v.(bool); ok {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// getExtraString 从 Extra 中读取指定 key 的字符串值
|
// getExtraString 从 Extra 中读取指定 key 的字符串值
|
||||||
func (a *Account) getExtraString(key string) string {
|
func (a *Account) getExtraString(key string) string {
|
||||||
if a.Extra == nil {
|
if a.Extra == nil {
|
||||||
@ -1408,6 +1457,14 @@ func (a *Account) getExtraString(key string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getExtraStringDefault 从 Extra 中读取指定 key 的字符串值,不存在时返回 defaultVal
|
||||||
|
func (a *Account) getExtraStringDefault(key, defaultVal string) string {
|
||||||
|
if v := a.getExtraString(key); v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||
|
|
||||||
// getExtraInt 从 Extra 中读取指定 key 的 int 值
|
// getExtraInt 从 Extra 中读取指定 key 的 int 值
|
||||||
func (a *Account) getExtraInt(key string) int {
|
func (a *Account) getExtraInt(key string) int {
|
||||||
if a.Extra == nil {
|
if a.Extra == nil {
|
||||||
@ -1464,6 +1521,62 @@ func (a *Account) GetQuotaResetTimezone() string {
|
|||||||
return "UTC"
|
return "UTC"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Quota Notification Getters ---
|
||||||
|
|
||||||
|
// QuotaNotifyConfig returns the notify configuration for a given quota dimension.
|
||||||
|
// dim must be one of quotaDimDaily, quotaDimWeekly, quotaDimTotal.
|
||||||
|
func (a *Account) QuotaNotifyConfig(dim string) (enabled bool, threshold float64, thresholdType string) {
|
||||||
|
enabled = a.getExtraBool("quota_notify_" + dim + "_enabled")
|
||||||
|
threshold = a.getExtraFloat64("quota_notify_" + dim + "_threshold")
|
||||||
|
thresholdType = a.getExtraStringDefault("quota_notify_"+dim+"_threshold_type", thresholdTypeFixed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetQuotaNotifyDailyEnabled() bool {
|
||||||
|
e, _, _ := a.QuotaNotifyConfig(quotaDimDaily)
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetQuotaNotifyDailyThreshold() float64 {
|
||||||
|
_, t, _ := a.QuotaNotifyConfig(quotaDimDaily)
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetQuotaNotifyDailyThresholdType() string {
|
||||||
|
_, _, tt := a.QuotaNotifyConfig(quotaDimDaily)
|
||||||
|
return tt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetQuotaNotifyWeeklyEnabled() bool {
|
||||||
|
e, _, _ := a.QuotaNotifyConfig(quotaDimWeekly)
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetQuotaNotifyWeeklyThreshold() float64 {
|
||||||
|
_, t, _ := a.QuotaNotifyConfig(quotaDimWeekly)
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetQuotaNotifyWeeklyThresholdType() string {
|
||||||
|
_, _, tt := a.QuotaNotifyConfig(quotaDimWeekly)
|
||||||
|
return tt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetQuotaNotifyTotalEnabled() bool {
|
||||||
|
e, _, _ := a.QuotaNotifyConfig(quotaDimTotal)
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetQuotaNotifyTotalThreshold() float64 {
|
||||||
|
_, t, _ := a.QuotaNotifyConfig(quotaDimTotal)
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetQuotaNotifyTotalThresholdType() string {
|
||||||
|
_, _, tt := a.QuotaNotifyConfig(quotaDimTotal)
|
||||||
|
return tt
|
||||||
|
}
|
||||||
|
|
||||||
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
|
// nextFixedDailyReset 计算在 after 之后的下一个每日固定重置时间点
|
||||||
func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time {
|
func nextFixedDailyReset(hour int, tz *time.Location, after time.Time) time.Time {
|
||||||
t := after.In(tz)
|
t := after.In(tz)
|
||||||
|
|||||||
236
backend/internal/service/account_stats_pricing.go
Normal file
236
backend/internal/service/account_stats_pricing.go
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// resolveAccountStatsCost 计算账号统计定价费用。
|
||||||
|
// 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。
|
||||||
|
//
|
||||||
|
// 优先级(先命中为准):
|
||||||
|
// 1. 自定义规则(始终尝试,不依赖 ApplyPricingToAccountStats 开关)
|
||||||
|
// 2. ApplyPricingToAccountStats 启用时,直接使用本次请求的客户计费(倍率前的 totalCost)
|
||||||
|
// 3. 模型定价文件(LiteLLM)中上游模型的默认价格
|
||||||
|
// 4. nil → 走默认公式(total_cost × account_rate_multiplier)
|
||||||
|
//
|
||||||
|
// upstreamModel 是最终发往上游的模型 ID。
|
||||||
|
// totalCost 是本次请求的客户计费(倍率前),用于优先级 2。
|
||||||
|
func resolveAccountStatsCost(
|
||||||
|
ctx context.Context,
|
||||||
|
channelService *ChannelService,
|
||||||
|
billingService *BillingService,
|
||||||
|
accountID int64,
|
||||||
|
groupID int64,
|
||||||
|
upstreamModel string,
|
||||||
|
tokens UsageTokens,
|
||||||
|
requestCount int,
|
||||||
|
totalCost float64,
|
||||||
|
) *float64 {
|
||||||
|
if channelService == nil || upstreamModel == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
channel, err := channelService.GetChannelForGroup(ctx, groupID)
|
||||||
|
if err != nil || channel == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
platform := channelService.GetGroupPlatform(ctx, groupID)
|
||||||
|
|
||||||
|
// 优先级 1:自定义规则(始终尝试)
|
||||||
|
if cost := tryCustomRules(channel, accountID, groupID, platform, upstreamModel, tokens, requestCount); cost != nil {
|
||||||
|
return cost
|
||||||
|
}
|
||||||
|
|
||||||
|
// 优先级 2:渠道开启"应用模型定价到账号统计"时,直接使用客户计费(倍率前)
|
||||||
|
if channel.ApplyPricingToAccountStats {
|
||||||
|
cost := totalCost
|
||||||
|
if cost <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &cost
|
||||||
|
}
|
||||||
|
|
||||||
|
// 优先级 3:模型定价文件(LiteLLM)默认价格
|
||||||
|
if billingService != nil {
|
||||||
|
return tryModelFilePricing(billingService, upstreamModel, tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryModelFilePricing 使用模型定价文件(LiteLLM/fallback)中的标准价格计算费用。
|
||||||
|
func tryModelFilePricing(billingService *BillingService, model string, tokens UsageTokens) *float64 {
|
||||||
|
pricing, err := billingService.GetModelPricing(model)
|
||||||
|
if err != nil || pricing == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cost := float64(tokens.InputTokens)*pricing.InputPricePerToken +
|
||||||
|
float64(tokens.OutputTokens)*pricing.OutputPricePerToken +
|
||||||
|
float64(tokens.CacheCreationTokens)*pricing.CacheCreationPricePerToken +
|
||||||
|
float64(tokens.CacheReadTokens)*pricing.CacheReadPricePerToken +
|
||||||
|
float64(tokens.ImageOutputTokens)*pricing.ImageOutputPricePerToken
|
||||||
|
if cost <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &cost
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。
|
||||||
|
func tryCustomRules(
|
||||||
|
channel *Channel, accountID, groupID int64,
|
||||||
|
platform, model string, tokens UsageTokens, requestCount int,
|
||||||
|
) *float64 {
|
||||||
|
modelLower := strings.ToLower(model)
|
||||||
|
for _, rule := range channel.AccountStatsPricingRules {
|
||||||
|
if !matchAccountStatsRule(&rule, accountID, groupID) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pricing := findPricingForModel(rule.Pricing, platform, modelLower)
|
||||||
|
if pricing == nil {
|
||||||
|
continue // 规则匹配但模型不在规则定价中,继续下一条
|
||||||
|
}
|
||||||
|
return calculateStatsCost(pricing, tokens, requestCount)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。
|
||||||
|
// 匹配条件:accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。
|
||||||
|
// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。
|
||||||
|
func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int64) bool {
|
||||||
|
if len(rule.AccountIDs) == 0 && len(rule.GroupIDs) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, id := range rule.AccountIDs {
|
||||||
|
if id == accountID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, id := range rule.GroupIDs {
|
||||||
|
if id == groupID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// findPricingForModel 在定价列表中查找匹配的模型定价。
|
||||||
|
// 先精确匹配,再通配符匹配(按配置顺序,先匹配先使用)。
|
||||||
|
func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing {
|
||||||
|
// 精确匹配优先
|
||||||
|
for i := range pricingList {
|
||||||
|
p := &pricingList[i]
|
||||||
|
if !isPlatformMatch(platform, p.Platform) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, m := range p.Models {
|
||||||
|
if strings.ToLower(m) == modelLower {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 通配符匹配:按配置顺序,先匹配先使用
|
||||||
|
for i := range pricingList {
|
||||||
|
p := &pricingList[i]
|
||||||
|
if !isPlatformMatch(platform, p.Platform) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, m := range p.Models {
|
||||||
|
ml := strings.ToLower(m)
|
||||||
|
if !strings.HasSuffix(ml, "*") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
prefix := strings.TrimSuffix(ml, "*")
|
||||||
|
if strings.HasPrefix(modelLower, prefix) {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。
|
||||||
|
func isPlatformMatch(queryPlatform, pricingPlatform string) bool {
|
||||||
|
if queryPlatform == "" || pricingPlatform == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return queryPlatform == pricingPlatform
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。
|
||||||
|
func calculateStatsCost(pricing *ChannelModelPricing, tokens UsageTokens, requestCount int) *float64 {
|
||||||
|
if pricing == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
switch pricing.BillingMode {
|
||||||
|
case BillingModePerRequest, BillingModeImage:
|
||||||
|
return calculatePerRequestStatsCost(pricing, requestCount)
|
||||||
|
default:
|
||||||
|
return calculateTokenStatsCost(pricing, tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculatePerRequestStatsCost 按次/图片计费。
|
||||||
|
func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int) *float64 {
|
||||||
|
if pricing.PerRequestPrice == nil || *pricing.PerRequestPrice <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cost := *pricing.PerRequestPrice * float64(requestCount)
|
||||||
|
return &cost
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateTokenStatsCost Token 计费。
|
||||||
|
// If the pricing has intervals, find the matching interval by total token count
|
||||||
|
// and use its prices instead of the flat pricing fields.
|
||||||
|
func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 {
|
||||||
|
p := pricing
|
||||||
|
if len(pricing.Intervals) > 0 {
|
||||||
|
totalTokens := tokens.InputTokens + tokens.OutputTokens + tokens.CacheCreationTokens + tokens.CacheReadTokens
|
||||||
|
if iv := FindMatchingInterval(pricing.Intervals, totalTokens); iv != nil {
|
||||||
|
p = &ChannelModelPricing{
|
||||||
|
InputPrice: iv.InputPrice,
|
||||||
|
OutputPrice: iv.OutputPrice,
|
||||||
|
CacheWritePrice: iv.CacheWritePrice,
|
||||||
|
CacheReadPrice: iv.CacheReadPrice,
|
||||||
|
PerRequestPrice: iv.PerRequestPrice,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
deref := func(ptr *float64) float64 {
|
||||||
|
if ptr == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return *ptr
|
||||||
|
}
|
||||||
|
cost := float64(tokens.InputTokens)*deref(p.InputPrice) +
|
||||||
|
float64(tokens.OutputTokens)*deref(p.OutputPrice) +
|
||||||
|
float64(tokens.CacheCreationTokens)*deref(p.CacheWritePrice) +
|
||||||
|
float64(tokens.CacheReadTokens)*deref(p.CacheReadPrice) +
|
||||||
|
float64(tokens.ImageOutputTokens)*deref(p.ImageOutputPrice)
|
||||||
|
if cost <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &cost
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyAccountStatsCost resolves the account stats cost for a usage log entry.
|
||||||
|
// It resolves the upstream model (falling back to the requested model) and calls
|
||||||
|
// the 4-level priority chain via resolveAccountStatsCost.
|
||||||
|
func applyAccountStatsCost(
|
||||||
|
ctx context.Context,
|
||||||
|
usageLog *UsageLog,
|
||||||
|
cs *ChannelService, bs *BillingService,
|
||||||
|
accountID int64, groupID int64,
|
||||||
|
upstreamModel, requestedModel string,
|
||||||
|
tokens UsageTokens,
|
||||||
|
totalCost float64,
|
||||||
|
) {
|
||||||
|
model := upstreamModel
|
||||||
|
if model == "" {
|
||||||
|
model = requestedModel
|
||||||
|
}
|
||||||
|
usageLog.AccountStatsCost = resolveAccountStatsCost(
|
||||||
|
ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost,
|
||||||
|
)
|
||||||
|
}
|
||||||
771
backend/internal/service/account_stats_pricing_test.go
Normal file
771
backend/internal/service/account_stats_pricing_test.go
Normal file
@ -0,0 +1,771 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// matchAccountStatsRule
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestMatchAccountStatsRule_BothEmpty_NoMatch(t *testing.T) {
|
||||||
|
rule := &AccountStatsPricingRule{}
|
||||||
|
require.False(t, matchAccountStatsRule(rule, 1, 10))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchAccountStatsRule_AccountIDMatch(t *testing.T) {
|
||||||
|
rule := &AccountStatsPricingRule{AccountIDs: []int64{1, 2, 3}}
|
||||||
|
require.True(t, matchAccountStatsRule(rule, 2, 999))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchAccountStatsRule_GroupIDMatch(t *testing.T) {
|
||||||
|
rule := &AccountStatsPricingRule{GroupIDs: []int64{10, 20}}
|
||||||
|
require.True(t, matchAccountStatsRule(rule, 999, 20))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchAccountStatsRule_BothConfigured_AccountMatch(t *testing.T) {
|
||||||
|
rule := &AccountStatsPricingRule{
|
||||||
|
AccountIDs: []int64{1, 2},
|
||||||
|
GroupIDs: []int64{10, 20},
|
||||||
|
}
|
||||||
|
require.True(t, matchAccountStatsRule(rule, 2, 999))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchAccountStatsRule_BothConfigured_GroupMatch(t *testing.T) {
|
||||||
|
rule := &AccountStatsPricingRule{
|
||||||
|
AccountIDs: []int64{1, 2},
|
||||||
|
GroupIDs: []int64{10, 20},
|
||||||
|
}
|
||||||
|
require.True(t, matchAccountStatsRule(rule, 999, 10))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchAccountStatsRule_BothConfigured_NeitherMatch(t *testing.T) {
|
||||||
|
rule := &AccountStatsPricingRule{
|
||||||
|
AccountIDs: []int64{1, 2},
|
||||||
|
GroupIDs: []int64{10, 20},
|
||||||
|
}
|
||||||
|
require.False(t, matchAccountStatsRule(rule, 999, 999))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// findPricingForModel
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestFindPricingForModel(t *testing.T) {
|
||||||
|
exactPricing := ChannelModelPricing{
|
||||||
|
ID: 1,
|
||||||
|
Models: []string{"claude-opus-4"},
|
||||||
|
}
|
||||||
|
wildcardPricing := ChannelModelPricing{
|
||||||
|
ID: 2,
|
||||||
|
Models: []string{"claude-*"},
|
||||||
|
}
|
||||||
|
platformPricing := ChannelModelPricing{
|
||||||
|
ID: 3,
|
||||||
|
Platform: "openai",
|
||||||
|
Models: []string{"gpt-4o"},
|
||||||
|
}
|
||||||
|
emptyPlatformPricing := ChannelModelPricing{
|
||||||
|
ID: 4,
|
||||||
|
Models: []string{"gemini-2.5-pro"},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
list []ChannelModelPricing
|
||||||
|
platform string
|
||||||
|
model string
|
||||||
|
wantID int64
|
||||||
|
wantNil bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact match",
|
||||||
|
list: []ChannelModelPricing{exactPricing},
|
||||||
|
platform: "anthropic",
|
||||||
|
model: "claude-opus-4",
|
||||||
|
wantID: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exact match case insensitive",
|
||||||
|
list: []ChannelModelPricing{{ID: 5, Models: []string{"Claude-Opus-4"}}},
|
||||||
|
platform: "",
|
||||||
|
model: "claude-opus-4",
|
||||||
|
wantID: 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard match",
|
||||||
|
list: []ChannelModelPricing{wildcardPricing},
|
||||||
|
platform: "anthropic",
|
||||||
|
model: "claude-opus-4",
|
||||||
|
wantID: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exact match takes priority over wildcard",
|
||||||
|
list: []ChannelModelPricing{wildcardPricing, exactPricing},
|
||||||
|
platform: "anthropic",
|
||||||
|
model: "claude-opus-4",
|
||||||
|
wantID: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "platform mismatch skipped",
|
||||||
|
list: []ChannelModelPricing{platformPricing},
|
||||||
|
platform: "anthropic",
|
||||||
|
model: "gpt-4o",
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty platform in pricing matches any",
|
||||||
|
list: []ChannelModelPricing{emptyPlatformPricing},
|
||||||
|
platform: "gemini",
|
||||||
|
model: "gemini-2.5-pro",
|
||||||
|
wantID: 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty platform in query matches any pricing platform",
|
||||||
|
list: []ChannelModelPricing{platformPricing},
|
||||||
|
platform: "",
|
||||||
|
model: "gpt-4o",
|
||||||
|
wantID: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no match at all",
|
||||||
|
list: []ChannelModelPricing{exactPricing, wildcardPricing},
|
||||||
|
platform: "anthropic",
|
||||||
|
model: "gpt-4o",
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty list returns nil",
|
||||||
|
list: nil,
|
||||||
|
model: "claude-opus-4",
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wildcard matches by config order (first match wins)",
|
||||||
|
list: []ChannelModelPricing{
|
||||||
|
{ID: 10, Models: []string{"claude-*"}},
|
||||||
|
{ID: 11, Models: []string{"claude-opus-*"}},
|
||||||
|
},
|
||||||
|
platform: "",
|
||||||
|
model: "claude-opus-4",
|
||||||
|
wantID: 10, // config order: "claude-*" is first and matches, so it wins
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "shorter wildcard used when longer does not match",
|
||||||
|
list: []ChannelModelPricing{
|
||||||
|
{ID: 10, Models: []string{"claude-*"}},
|
||||||
|
{ID: 11, Models: []string{"claude-opus-*"}},
|
||||||
|
},
|
||||||
|
platform: "",
|
||||||
|
model: "claude-sonnet-4",
|
||||||
|
wantID: 10, // only "claude-*" matches
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := findPricingForModel(tt.list, tt.platform, tt.model)
|
||||||
|
if tt.wantNil {
|
||||||
|
require.Nil(t, result)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.Equal(t, tt.wantID, result.ID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// calculateStatsCost
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_NilPricing(t *testing.T) {
|
||||||
|
result := calculateStatsCost(nil, UsageTokens{}, 1)
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_TokenBilling(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModeToken,
|
||||||
|
InputPrice: testPtrFloat64(0.001),
|
||||||
|
OutputPrice: testPtrFloat64(0.002),
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, tokens, 1)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
|
||||||
|
require.InDelta(t, 0.2, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_TokenBilling_WithCache(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModeToken,
|
||||||
|
InputPrice: testPtrFloat64(0.001),
|
||||||
|
OutputPrice: testPtrFloat64(0.002),
|
||||||
|
CacheWritePrice: testPtrFloat64(0.003),
|
||||||
|
CacheReadPrice: testPtrFloat64(0.0005),
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
CacheCreationTokens: 200,
|
||||||
|
CacheReadTokens: 300,
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, tokens, 1)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
|
||||||
|
// = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
|
||||||
|
require.InDelta(t, 0.95, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_TokenBilling_WithImageOutput(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModeToken,
|
||||||
|
InputPrice: testPtrFloat64(0.001),
|
||||||
|
OutputPrice: testPtrFloat64(0.002),
|
||||||
|
ImageOutputPrice: testPtrFloat64(0.01),
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
ImageOutputTokens: 10,
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, tokens, 1)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
|
||||||
|
require.InDelta(t, 0.3, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_TokenBilling_PartialPricesNil(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModeToken,
|
||||||
|
InputPrice: testPtrFloat64(0.001),
|
||||||
|
// OutputPrice, CacheWritePrice, etc. are all nil → treated as 0
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
CacheCreationTokens: 200,
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, tokens, 1)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// Only input contributes: 100*0.001 = 0.1
|
||||||
|
require.InDelta(t, 0.1, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_TokenBilling_AllTokensZero(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModeToken,
|
||||||
|
InputPrice: testPtrFloat64(0.001),
|
||||||
|
OutputPrice: testPtrFloat64(0.002),
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{} // all zeros
|
||||||
|
result := calculateStatsCost(pricing, tokens, 1)
|
||||||
|
// totalCost == 0 → returns nil (does not override, falls back to default formula)
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_PerRequestBilling(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModePerRequest,
|
||||||
|
PerRequestPrice: testPtrFloat64(0.05),
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{InputTokens: 999, OutputTokens: 999}
|
||||||
|
result := calculateStatsCost(pricing, tokens, 3)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 0.05 * 3 = 0.15
|
||||||
|
require.InDelta(t, 0.15, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_PerRequestBilling_PriceNil(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModePerRequest,
|
||||||
|
// PerRequestPrice is nil
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, UsageTokens{}, 1)
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_PerRequestBilling_PriceZero(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModePerRequest,
|
||||||
|
PerRequestPrice: testPtrFloat64(0),
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, UsageTokens{}, 1)
|
||||||
|
// price == 0 → condition *pricing.PerRequestPrice > 0 is false → returns nil
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_ImageBilling(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModeImage,
|
||||||
|
PerRequestPrice: testPtrFloat64(0.10),
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, UsageTokens{}, 2)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 0.10 * 2 = 0.20
|
||||||
|
require.InDelta(t, 0.20, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_ImageBilling_PriceNil(t *testing.T) {
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
BillingMode: BillingModeImage,
|
||||||
|
// PerRequestPrice is nil
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, UsageTokens{}, 1)
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCalculateStatsCost_DefaultBillingMode_FallsToToken(t *testing.T) {
|
||||||
|
// BillingMode is empty string (default) → falls into token billing
|
||||||
|
pricing := &ChannelModelPricing{
|
||||||
|
InputPrice: testPtrFloat64(0.001),
|
||||||
|
OutputPrice: testPtrFloat64(0.002),
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
}
|
||||||
|
result := calculateStatsCost(pricing, tokens, 1)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.InDelta(t, 0.2, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// tryCustomRules — 多规则顺序测试
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestTryCustomRules_FirstMatchWins(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||||
|
{
|
||||||
|
GroupIDs: []int64{1},
|
||||||
|
Pricing: []ChannelModelPricing{
|
||||||
|
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01), OutputPrice: testPtrFloat64(0.02)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
GroupIDs: []int64{1},
|
||||||
|
Pricing: []ChannelModelPricing{
|
||||||
|
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99), OutputPrice: testPtrFloat64(0.99)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
|
||||||
|
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 应使用第一条规则的价格:100*0.01 + 50*0.02 = 2.0
|
||||||
|
require.InDelta(t, 2.0, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryCustomRules_SkipsNonMatchingRules(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||||
|
{
|
||||||
|
AccountIDs: []int64{888}, // 不匹配
|
||||||
|
Pricing: []ChannelModelPricing{
|
||||||
|
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
GroupIDs: []int64{1}, // 匹配
|
||||||
|
Pricing: []ChannelModelPricing{
|
||||||
|
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{InputTokens: 100}
|
||||||
|
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 跳过规则1(账号不匹配),使用规则2:100*0.05 = 5.0
|
||||||
|
require.InDelta(t, 5.0, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryCustomRules_NoMatch_ReturnsNil(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||||
|
{
|
||||||
|
AccountIDs: []int64{888},
|
||||||
|
Pricing: []ChannelModelPricing{
|
||||||
|
{ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{InputTokens: 100}
|
||||||
|
result := tryCustomRules(channel, 999, 2, "", "claude-opus-4", tokens, 1)
|
||||||
|
require.Nil(t, result) // 账号和分组都不匹配
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryCustomRules_RuleMatchesButModelNot_ContinuesToNext(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||||
|
{
|
||||||
|
GroupIDs: []int64{1},
|
||||||
|
Pricing: []ChannelModelPricing{
|
||||||
|
{ID: 100, Models: []string{"gpt-4o"}, InputPrice: testPtrFloat64(0.01)}, // 模型不匹配
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
GroupIDs: []int64{1},
|
||||||
|
Pricing: []ChannelModelPricing{
|
||||||
|
{ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)}, // 模型匹配
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tokens := UsageTokens{InputTokens: 100}
|
||||||
|
result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.InDelta(t, 5.0, *result, 1e-12) // 使用规则2
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// tryModelFilePricing
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// newTestBillingServiceWithPrices creates a BillingService with pre-populated
|
||||||
|
// fallback prices for testing. No config or pricing service is needed.
|
||||||
|
// The key must match what getFallbackPricing resolves to for a given model name.
|
||||||
|
// E.g., model "claude-sonnet-4" resolves to key "claude-sonnet-4".
|
||||||
|
func newTestBillingServiceWithPrices(prices map[string]*ModelPricing) *BillingService {
|
||||||
|
return &BillingService{
|
||||||
|
fallbackPrices: prices,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryModelFilePricing_Success(t *testing.T) {
|
||||||
|
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
|
||||||
|
"claude-sonnet-4": {
|
||||||
|
InputPricePerToken: 0.001,
|
||||||
|
OutputPricePerToken: 0.002,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
|
||||||
|
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
|
||||||
|
require.InDelta(t, 0.2, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryModelFilePricing_PricingNotFound(t *testing.T) {
|
||||||
|
// "nonexistent-model" does not match any fallback pattern
|
||||||
|
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{})
|
||||||
|
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
|
||||||
|
result := tryModelFilePricing(bs, "nonexistent-model", tokens)
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryModelFilePricing_NilFallback(t *testing.T) {
|
||||||
|
// getFallbackPricing returns nil when key maps to nil
|
||||||
|
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
|
||||||
|
"claude-sonnet-4": nil,
|
||||||
|
})
|
||||||
|
tokens := UsageTokens{InputTokens: 100}
|
||||||
|
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryModelFilePricing_ZeroCost(t *testing.T) {
|
||||||
|
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
|
||||||
|
"claude-sonnet-4": {
|
||||||
|
InputPricePerToken: 0.001,
|
||||||
|
OutputPricePerToken: 0.002,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
tokens := UsageTokens{} // all zero tokens → cost = 0 → nil
|
||||||
|
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryModelFilePricing_WithImageOutput(t *testing.T) {
|
||||||
|
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
|
||||||
|
"claude-sonnet-4": {
|
||||||
|
InputPricePerToken: 0.001,
|
||||||
|
OutputPricePerToken: 0.002,
|
||||||
|
ImageOutputPricePerToken: 0.01,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
tokens := UsageTokens{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
ImageOutputTokens: 10,
|
||||||
|
}
|
||||||
|
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3
|
||||||
|
require.InDelta(t, 0.3, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryModelFilePricing_WithCacheTokens(t *testing.T) {
|
||||||
|
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
|
||||||
|
"claude-sonnet-4": {
|
||||||
|
InputPricePerToken: 0.001,
|
||||||
|
OutputPricePerToken: 0.002,
|
||||||
|
CacheCreationPricePerToken: 0.003,
|
||||||
|
CacheReadPricePerToken: 0.0005,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
tokens := UsageTokens{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
CacheCreationTokens: 200,
|
||||||
|
CacheReadTokens: 300,
|
||||||
|
}
|
||||||
|
result := tryModelFilePricing(bs, "claude-sonnet-4", tokens)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005
|
||||||
|
// = 0.1 + 0.1 + 0.6 + 0.15 = 0.95
|
||||||
|
require.InDelta(t, 0.95, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// resolveAccountStatsCost — integration tests covering the 4-level priority chain
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestResolveAccountStatsCost_NilChannelService(t *testing.T) {
|
||||||
|
result := resolveAccountStatsCost(
|
||||||
|
context.Background(),
|
||||||
|
nil, // channelService is nil
|
||||||
|
newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
|
||||||
|
1, 1, "claude-sonnet-4",
|
||||||
|
UsageTokens{InputTokens: 100}, 1, 0.5,
|
||||||
|
)
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAccountStatsCost_EmptyUpstreamModel(t *testing.T) {
|
||||||
|
cs := newTestChannelServiceForStats(t, &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
}, 1, "")
|
||||||
|
|
||||||
|
result := resolveAccountStatsCost(
|
||||||
|
context.Background(),
|
||||||
|
cs,
|
||||||
|
newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
|
||||||
|
1, 1, "", // empty upstream model
|
||||||
|
UsageTokens{InputTokens: 100}, 1, 0.5,
|
||||||
|
)
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAccountStatsCost_GetChannelForGroupReturnsNil(t *testing.T) {
|
||||||
|
// Group 99 is NOT in the cache, so GetChannelForGroup returns nil
|
||||||
|
cs := newTestChannelServiceForStats(t, &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
}, 1, "")
|
||||||
|
|
||||||
|
result := resolveAccountStatsCost(
|
||||||
|
context.Background(),
|
||||||
|
cs,
|
||||||
|
newTestBillingServiceWithPrices(map[string]*ModelPricing{}),
|
||||||
|
1, 99, "claude-sonnet-4", // groupID 99 has no channel
|
||||||
|
UsageTokens{InputTokens: 100}, 1, 0.5,
|
||||||
|
)
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAccountStatsCost_HitsCustomRule(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||||
|
{
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
Pricing: []ChannelModelPricing{
|
||||||
|
{
|
||||||
|
ID: 100,
|
||||||
|
Models: []string{"claude-sonnet-4"},
|
||||||
|
InputPrice: testPtrFloat64(0.01),
|
||||||
|
OutputPrice: testPtrFloat64(0.02),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
|
||||||
|
|
||||||
|
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
|
||||||
|
|
||||||
|
result := resolveAccountStatsCost(
|
||||||
|
context.Background(),
|
||||||
|
cs, nil, // billingService not needed when custom rule hits
|
||||||
|
1, 10, "claude-sonnet-4",
|
||||||
|
tokens, 1, 999.0, // totalCost ignored because custom rule hits
|
||||||
|
)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 100*0.01 + 50*0.02 = 1.0 + 1.0 = 2.0
|
||||||
|
require.InDelta(t, 2.0, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAccountStatsCost_ApplyPricingToAccountStats_UsesTotalCost(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
ApplyPricingToAccountStats: true,
|
||||||
|
// No custom rules
|
||||||
|
}
|
||||||
|
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
|
||||||
|
|
||||||
|
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
|
||||||
|
|
||||||
|
result := resolveAccountStatsCost(
|
||||||
|
context.Background(),
|
||||||
|
cs, nil,
|
||||||
|
1, 10, "claude-sonnet-4",
|
||||||
|
tokens, 1, 0.75, // totalCost = 0.75
|
||||||
|
)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
require.InDelta(t, 0.75, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAccountStatsCost_ApplyPricingToAccountStats_ZeroTotalCost_ReturnsNil(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
ApplyPricingToAccountStats: true,
|
||||||
|
}
|
||||||
|
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
|
||||||
|
|
||||||
|
result := resolveAccountStatsCost(
|
||||||
|
context.Background(),
|
||||||
|
cs, nil,
|
||||||
|
1, 10, "claude-sonnet-4",
|
||||||
|
UsageTokens{}, 1, 0.0, // totalCost = 0
|
||||||
|
)
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAccountStatsCost_FallsBackToLiteLLM(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
ApplyPricingToAccountStats: false, // not enabled
|
||||||
|
// No custom rules
|
||||||
|
}
|
||||||
|
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
|
||||||
|
|
||||||
|
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{
|
||||||
|
"claude-sonnet-4": {
|
||||||
|
InputPricePerToken: 0.001,
|
||||||
|
OutputPricePerToken: 0.002,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
|
||||||
|
|
||||||
|
result := resolveAccountStatsCost(
|
||||||
|
context.Background(),
|
||||||
|
cs, bs,
|
||||||
|
1, 10, "claude-sonnet-4",
|
||||||
|
tokens, 1, 999.0, // totalCost ignored
|
||||||
|
)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2
|
||||||
|
require.InDelta(t, 0.2, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAccountStatsCost_AllMiss_ReturnsNil(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
ApplyPricingToAccountStats: false,
|
||||||
|
// No custom rules
|
||||||
|
}
|
||||||
|
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
|
||||||
|
|
||||||
|
// BillingService with no pricing for the model
|
||||||
|
bs := newTestBillingServiceWithPrices(map[string]*ModelPricing{})
|
||||||
|
|
||||||
|
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50}
|
||||||
|
|
||||||
|
result := resolveAccountStatsCost(
|
||||||
|
context.Background(),
|
||||||
|
cs, bs,
|
||||||
|
1, 10, "totally-unknown-model",
|
||||||
|
tokens, 1, 0.0,
|
||||||
|
)
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAccountStatsCost_NilBillingService_SkipsLiteLLM(t *testing.T) {
|
||||||
|
channel := &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
ApplyPricingToAccountStats: false,
|
||||||
|
}
|
||||||
|
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
|
||||||
|
|
||||||
|
result := resolveAccountStatsCost(
|
||||||
|
context.Background(),
|
||||||
|
cs, nil, // billingService is nil
|
||||||
|
1, 10, "claude-sonnet-4",
|
||||||
|
UsageTokens{InputTokens: 100}, 1, 0.0,
|
||||||
|
)
|
||||||
|
require.Nil(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAccountStatsCost_CustomRulePriorityOverApplyPricing(t *testing.T) {
|
||||||
|
// Both custom rule and ApplyPricingToAccountStats are configured;
|
||||||
|
// custom rule should take precedence.
|
||||||
|
channel := &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
ApplyPricingToAccountStats: true,
|
||||||
|
AccountStatsPricingRules: []AccountStatsPricingRule{
|
||||||
|
{
|
||||||
|
GroupIDs: []int64{10},
|
||||||
|
Pricing: []ChannelModelPricing{
|
||||||
|
{
|
||||||
|
ID: 100,
|
||||||
|
Models: []string{"claude-sonnet-4"},
|
||||||
|
InputPrice: testPtrFloat64(0.05),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cs := newTestChannelServiceForStats(t, channel, 10, "anthropic")
|
||||||
|
|
||||||
|
tokens := UsageTokens{InputTokens: 100}
|
||||||
|
|
||||||
|
result := resolveAccountStatsCost(
|
||||||
|
context.Background(),
|
||||||
|
cs, nil,
|
||||||
|
1, 10, "claude-sonnet-4",
|
||||||
|
tokens, 1, 99.0, // totalCost = 99.0 (would be used if ApplyPricing wins)
|
||||||
|
)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
// Custom rule: 100*0.05 = 5.0 (NOT 99.0 from totalCost)
|
||||||
|
require.InDelta(t, 5.0, *result, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// helpers for resolveAccountStatsCost tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// newTestChannelServiceForStats creates a ChannelService with a single channel
|
||||||
|
// mapped to the given groupID, suitable for resolveAccountStatsCost tests.
|
||||||
|
func newTestChannelServiceForStats(t *testing.T, channel *Channel, groupID int64, platform string) *ChannelService {
|
||||||
|
t.Helper()
|
||||||
|
cache := newEmptyChannelCache()
|
||||||
|
cache.channelByGroupID[groupID] = channel
|
||||||
|
cache.groupPlatform[groupID] = platform
|
||||||
|
cs := &ChannelService{}
|
||||||
|
cache.loadedAt = time.Now()
|
||||||
|
cs.cache.Store(cache)
|
||||||
|
return cs
|
||||||
|
}
|
||||||
@ -515,22 +515,10 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
|||||||
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
|
_ = s.accountRepo.UpdateExtra(ctx, account.ID, updates)
|
||||||
mergeAccountExtra(account, updates)
|
mergeAccountExtra(account, updates)
|
||||||
}
|
}
|
||||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
|
||||||
if resetAt := codexRateLimitResetAtFromSnapshot(snapshot, time.Now()); resetAt != nil {
|
|
||||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
|
|
||||||
account.RateLimitResetAt = resetAt
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
if isOAuth && s.accountRepo != nil {
|
|
||||||
if resetAt := (&RateLimitService{}).calculateOpenAI429ResetTime(resp.Header); resetAt != nil {
|
|
||||||
_ = s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt)
|
|
||||||
account.RateLimitResetAt = resetAt
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 401 Unauthorized: 标记账号为永久错误
|
// 401 Unauthorized: 标记账号为永久错误
|
||||||
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
|
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
|
||||||
errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
|
errMsg := fmt.Sprintf("Authentication failed (401): %s", string(body))
|
||||||
|
|||||||
@ -111,7 +111,7 @@ func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.
|
|||||||
require.Contains(t, recorder.Body.String(), "test_complete")
|
require.Contains(t, recorder.Body.String(), "test_complete")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
|
func TestAccountTestService_OpenAI429PersistsSnapshotWithoutRateLimit(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
ctx, _ := newTestContext()
|
ctx, _ := newTestContext()
|
||||||
|
|
||||||
@ -138,10 +138,7 @@ func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T)
|
|||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.NotEmpty(t, repo.updatedExtra)
|
require.NotEmpty(t, repo.updatedExtra)
|
||||||
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
|
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
|
||||||
require.Equal(t, int64(88), repo.rateLimitedID)
|
require.Zero(t, repo.rateLimitedID)
|
||||||
require.NotNil(t, repo.rateLimitedAt)
|
require.Nil(t, repo.rateLimitedAt)
|
||||||
require.NotNil(t, account.RateLimitResetAt)
|
require.Nil(t, account.RateLimitResetAt)
|
||||||
if account.RateLimitResetAt != nil && repo.rateLimitedAt != nil {
|
|
||||||
require.WithinDuration(t, *repo.rateLimitedAt, *account.RateLimitResetAt, time.Second)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -499,7 +499,6 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
|||||||
if account == nil {
|
if account == nil {
|
||||||
return usage, nil
|
return usage, nil
|
||||||
}
|
}
|
||||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, account, now)
|
|
||||||
|
|
||||||
if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil {
|
if progress := buildCodexUsageProgressFromExtra(account.Extra, "5h", now); progress != nil {
|
||||||
usage.FiveHour = progress
|
usage.FiveHour = progress
|
||||||
@ -509,11 +508,8 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
|
|
||||||
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
|
||||||
if updates, resetAt, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && (len(updates) > 0 || resetAt != nil) {
|
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
|
||||||
mergeAccountExtra(account, updates)
|
mergeAccountExtra(account, updates)
|
||||||
if resetAt != nil {
|
|
||||||
account.RateLimitResetAt = resetAt
|
|
||||||
}
|
|
||||||
if usage.UpdatedAt == nil {
|
if usage.UpdatedAt == nil {
|
||||||
usage.UpdatedAt = &now
|
usage.UpdatedAt = &now
|
||||||
}
|
}
|
||||||
@ -594,26 +590,26 @@ func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, no
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, *time.Time, error) {
|
func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, account *Account) (map[string]any, error) {
|
||||||
if account == nil || !account.IsOAuth() {
|
if account == nil || !account.IsOAuth() {
|
||||||
return nil, nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
accessToken := account.GetOpenAIAccessToken()
|
accessToken := account.GetOpenAIAccessToken()
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return nil, nil, fmt.Errorf("no access token available")
|
return nil, fmt.Errorf("no access token available")
|
||||||
}
|
}
|
||||||
modelID := openaipkg.DefaultTestModel
|
modelID := openaipkg.DefaultTestModel
|
||||||
payload := createOpenAITestPayload(modelID, true)
|
payload := createOpenAITestPayload(modelID, true)
|
||||||
payloadBytes, err := json.Marshal(payload)
|
payloadBytes, err := json.Marshal(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("marshal openai probe payload: %w", err)
|
return nil, fmt.Errorf("marshal openai probe payload: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
reqCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
|
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, chatgptCodexURL, bytes.NewReader(payloadBytes))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("create openai probe request: %w", err)
|
return nil, fmt.Errorf("create openai probe request: %w", err)
|
||||||
}
|
}
|
||||||
req.Host = "chatgpt.com"
|
req.Host = "chatgpt.com"
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@ -642,67 +638,51 @@ func (s *AccountUsageService) probeOpenAICodexSnapshot(ctx context.Context, acco
|
|||||||
ResponseHeaderTimeout: 10 * time.Second,
|
ResponseHeaderTimeout: 10 * time.Second,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("build openai probe client: %w", err)
|
return nil, fmt.Errorf("build openai probe client: %w", err)
|
||||||
}
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("openai codex probe request failed: %w", err)
|
return nil, fmt.Errorf("openai codex probe request failed: %w", err)
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
updates, resetAt, err := extractOpenAICodexProbeSnapshot(resp)
|
updates, err := extractOpenAICodexProbeUpdates(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(updates) > 0 || resetAt != nil {
|
if len(updates) > 0 {
|
||||||
s.persistOpenAICodexProbeSnapshot(account.ID, updates, resetAt)
|
s.persistOpenAICodexProbeSnapshot(account.ID, updates)
|
||||||
return updates, resetAt, nil
|
return updates, nil
|
||||||
}
|
}
|
||||||
return nil, nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any, resetAt *time.Time) {
|
func (s *AccountUsageService) persistOpenAICodexProbeSnapshot(accountID int64, updates map[string]any) {
|
||||||
if s == nil || s.accountRepo == nil || accountID <= 0 {
|
if s == nil || s.accountRepo == nil || accountID <= 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(updates) == 0 && resetAt == nil {
|
if len(updates) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
updateCtx, updateCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer updateCancel()
|
defer updateCancel()
|
||||||
if len(updates) > 0 {
|
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
||||||
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
|
|
||||||
}
|
|
||||||
if resetAt != nil {
|
|
||||||
_ = s.accountRepo.SetRateLimited(updateCtx, accountID, *resetAt)
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractOpenAICodexProbeSnapshot(resp *http.Response) (map[string]any, *time.Time, error) {
|
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
return nil, nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||||
baseTime := time.Now()
|
return buildCodexUsageExtraUpdates(snapshot, time.Now()), nil
|
||||||
updates := buildCodexUsageExtraUpdates(snapshot, baseTime)
|
|
||||||
resetAt := codexRateLimitResetAtFromSnapshot(snapshot, baseTime)
|
|
||||||
if len(updates) > 0 {
|
|
||||||
return updates, resetAt, nil
|
|
||||||
}
|
|
||||||
return nil, resetAt, nil
|
|
||||||
}
|
}
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
return nil, nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
return nil, fmt.Errorf("openai codex probe returned status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
return nil, nil, nil
|
return nil, nil
|
||||||
}
|
|
||||||
|
|
||||||
func extractOpenAICodexProbeUpdates(resp *http.Response) (map[string]any, error) {
|
|
||||||
updates, _, err := extractOpenAICodexProbeSnapshot(resp)
|
|
||||||
return updates, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func mergeAccountExtra(account *Account, updates map[string]any) {
|
func mergeAccountExtra(account *Account, updates map[string]any) {
|
||||||
|
|||||||
@ -92,30 +92,7 @@ func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) {
|
func TestAccountUsageService_PersistOpenAICodexProbeSnapshotOnlyUpdatesExtra(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
headers := make(http.Header)
|
|
||||||
headers.Set("x-codex-primary-used-percent", "100")
|
|
||||||
headers.Set("x-codex-primary-reset-after-seconds", "604800")
|
|
||||||
headers.Set("x-codex-primary-window-minutes", "10080")
|
|
||||||
headers.Set("x-codex-secondary-used-percent", "100")
|
|
||||||
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
|
|
||||||
headers.Set("x-codex-secondary-window-minutes", "300")
|
|
||||||
|
|
||||||
updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err)
|
|
||||||
}
|
|
||||||
if len(updates) == 0 {
|
|
||||||
t.Fatal("expected codex probe updates from 429 headers")
|
|
||||||
}
|
|
||||||
if resetAt == nil {
|
|
||||||
t.Fatal("expected resetAt from exhausted codex headers")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) {
|
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
repo := &accountUsageCodexProbeRepo{
|
repo := &accountUsageCodexProbeRepo{
|
||||||
@ -123,12 +100,10 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes
|
|||||||
rateLimitCh: make(chan time.Time, 1),
|
rateLimitCh: make(chan time.Time, 1),
|
||||||
}
|
}
|
||||||
svc := &AccountUsageService{accountRepo: repo}
|
svc := &AccountUsageService{accountRepo: repo}
|
||||||
resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
|
|
||||||
|
|
||||||
svc.persistOpenAICodexProbeSnapshot(321, map[string]any{
|
svc.persistOpenAICodexProbeSnapshot(321, map[string]any{
|
||||||
"codex_7d_used_percent": 100.0,
|
"codex_7d_used_percent": 100.0,
|
||||||
"codex_7d_reset_at": resetAt.Format(time.RFC3339),
|
"codex_7d_reset_at": time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second).Format(time.RFC3339),
|
||||||
}, &resetAt)
|
})
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case updates := <-repo.updateExtraCh:
|
case updates := <-repo.updateExtraCh:
|
||||||
@ -136,16 +111,49 @@ func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *tes
|
|||||||
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||||
}
|
}
|
||||||
case <-time.After(2 * time.Second):
|
case <-time.After(2 * time.Second):
|
||||||
t.Fatal("waiting for codex probe extra persistence timed out")
|
t.Fatal("等待 codex 探测快照写入 extra 超时")
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case got := <-repo.rateLimitCh:
|
case got := <-repo.rateLimitCh:
|
||||||
if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) {
|
t.Fatalf("不应将探测快照写入运行时限流状态: %v", got)
|
||||||
t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt)
|
case <-time.After(200 * time.Millisecond):
|
||||||
}
|
}
|
||||||
case <-time.After(2 * time.Second):
|
}
|
||||||
t.Fatal("waiting for codex probe rate limit persistence timed out")
|
|
||||||
|
func TestAccountUsageService_GetOpenAIUsage_DoesNotPromoteCodexExtraToRateLimit(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
resetAt := time.Now().Add(6 * 24 * time.Hour).UTC().Truncate(time.Second)
|
||||||
|
repo := &accountUsageCodexProbeRepo{
|
||||||
|
rateLimitCh: make(chan time.Time, 1),
|
||||||
|
}
|
||||||
|
svc := &AccountUsageService{accountRepo: repo}
|
||||||
|
account := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"codex_5h_used_percent": 1.0,
|
||||||
|
"codex_5h_reset_at": time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second).Format(time.RFC3339),
|
||||||
|
"codex_7d_used_percent": 100.0,
|
||||||
|
"codex_7d_reset_at": resetAt.Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, err := svc.getOpenAIUsage(context.Background(), account)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getOpenAIUsage() error = %v", err)
|
||||||
|
}
|
||||||
|
if usage.SevenDay == nil || usage.SevenDay.Utilization != 100.0 {
|
||||||
|
t.Fatalf("预期 7 天用量仍然可见,实际为 %#v", usage.SevenDay)
|
||||||
|
}
|
||||||
|
if account.RateLimitResetAt != nil {
|
||||||
|
t.Fatalf("不应让已耗尽的 codex extra 改写运行时限流状态: %v", account.RateLimitResetAt)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case got := <-repo.rateLimitCh:
|
||||||
|
t.Fatalf("不应将已耗尽的 codex extra 持久化为运行时限流状态: %v", got)
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
105
backend/internal/service/account_websearch_test.go
Normal file
105
backend/internal/service/account_websearch_test.go
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetWebSearchEmulationMode_Enabled(t *testing.T) {
|
||||||
|
a := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
|
||||||
|
}
|
||||||
|
require.Equal(t, WebSearchModeEnabled, a.GetWebSearchEmulationMode())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetWebSearchEmulationMode_Disabled(t *testing.T) {
|
||||||
|
a := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{featureKeyWebSearchEmulation: "disabled"},
|
||||||
|
}
|
||||||
|
require.Equal(t, WebSearchModeDisabled, a.GetWebSearchEmulationMode())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetWebSearchEmulationMode_Default(t *testing.T) {
|
||||||
|
a := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{featureKeyWebSearchEmulation: "default"},
|
||||||
|
}
|
||||||
|
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetWebSearchEmulationMode_UnknownString(t *testing.T) {
|
||||||
|
a := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{featureKeyWebSearchEmulation: "unknown"},
|
||||||
|
}
|
||||||
|
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetWebSearchEmulationMode_OldBoolTrue(t *testing.T) {
|
||||||
|
a := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{featureKeyWebSearchEmulation: true},
|
||||||
|
}
|
||||||
|
// bool true → tolerant fallback → enabled (not default)
|
||||||
|
require.Equal(t, WebSearchModeEnabled, a.GetWebSearchEmulationMode())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetWebSearchEmulationMode_OldBoolFalse(t *testing.T) {
|
||||||
|
a := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{featureKeyWebSearchEmulation: false},
|
||||||
|
}
|
||||||
|
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetWebSearchEmulationMode_NilAccount(t *testing.T) {
|
||||||
|
var a *Account
|
||||||
|
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetWebSearchEmulationMode_NilExtra(t *testing.T) {
|
||||||
|
a := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Extra: nil,
|
||||||
|
}
|
||||||
|
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetWebSearchEmulationMode_MissingField(t *testing.T) {
|
||||||
|
a := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{},
|
||||||
|
}
|
||||||
|
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetWebSearchEmulationMode_NonAnthropicPlatform(t *testing.T) {
|
||||||
|
a := &Account{
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
|
||||||
|
}
|
||||||
|
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetWebSearchEmulationMode_NonAPIKeyType(t *testing.T) {
|
||||||
|
a := &Account{
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Extra: map[string]any{featureKeyWebSearchEmulation: "enabled"},
|
||||||
|
}
|
||||||
|
require.Equal(t, WebSearchModeDefault, a.GetWebSearchEmulationMode())
|
||||||
|
}
|
||||||
@ -1470,10 +1470,6 @@ func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
now := time.Now()
|
|
||||||
for i := range accounts {
|
|
||||||
syncOpenAICodexRateLimitFromExtra(ctx, s.accountRepo, &accounts[i], now)
|
|
||||||
}
|
|
||||||
return accounts, result.Total, nil
|
return accounts, result.Total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -65,14 +65,14 @@ func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (boo
|
|||||||
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
|
|
||||||
panic("unexpected")
|
|
||||||
}
|
|
||||||
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error {
|
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
|
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
|
||||||
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
|
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
|
||||||
|
func (s *userRepoStubForGroupUpdate) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error {
|
||||||
|
panic("unexpected")
|
||||||
|
}
|
||||||
|
|
||||||
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
|
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
|
||||||
type apiKeyRepoStubForGroupUpdate struct {
|
type apiKeyRepoStubForGroupUpdate struct {
|
||||||
@ -131,9 +131,6 @@ func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, str
|
|||||||
func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
|
func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) {
|
|
||||||
panic("unexpected")
|
|
||||||
}
|
|
||||||
func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) {
|
func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
@ -158,6 +155,9 @@ func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, in
|
|||||||
func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
|
func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
|
func (s *apiKeyRepoStubForGroupUpdate) UpdateGroupIDByUserAndGroup(context.Context, int64, int64, int64) (int64, error) {
|
||||||
|
panic("unexpected")
|
||||||
|
}
|
||||||
|
|
||||||
// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
|
// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
|
||||||
type groupRepoStubForGroupUpdate struct {
|
type groupRepoStubForGroupUpdate struct {
|
||||||
|
|||||||
@ -12,12 +12,12 @@ import (
|
|||||||
|
|
||||||
type accountRepoStubForClearAccountError struct {
|
type accountRepoStubForClearAccountError struct {
|
||||||
mockAccountRepoForGemini
|
mockAccountRepoForGemini
|
||||||
account *Account
|
account *Account
|
||||||
clearErrorCalls int
|
clearErrorCalls int
|
||||||
clearRateLimitCalls int
|
clearRateLimitCalls int
|
||||||
clearAntigravityCalls int
|
clearAntigravityCalls int
|
||||||
clearModelRateLimitCalls int
|
clearModelRateLimitCalls int
|
||||||
clearTempUnschedCalls int
|
clearTempUnschedCalls int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *accountRepoStubForClearAccountError) GetByID(ctx context.Context, id int64) (*Account, error) {
|
func (r *accountRepoStubForClearAccountError) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||||
@ -60,13 +60,13 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
|
|||||||
resetAt := time.Now().Add(5 * time.Minute)
|
resetAt := time.Now().Add(5 * time.Minute)
|
||||||
repo := &accountRepoStubForClearAccountError{
|
repo := &accountRepoStubForClearAccountError{
|
||||||
account: &Account{
|
account: &Account{
|
||||||
ID: 31,
|
ID: 31,
|
||||||
Platform: PlatformOpenAI,
|
Platform: PlatformOpenAI,
|
||||||
Type: AccountTypeOAuth,
|
Type: AccountTypeOAuth,
|
||||||
Status: StatusError,
|
Status: StatusError,
|
||||||
ErrorMessage: "refresh failed",
|
ErrorMessage: "refresh failed",
|
||||||
RateLimitResetAt: &resetAt,
|
RateLimitResetAt: &resetAt,
|
||||||
TempUnschedulableUntil: &until,
|
TempUnschedulableUntil: &until,
|
||||||
TempUnschedulableReason: "missing refresh token",
|
TempUnschedulableReason: "missing refresh token",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -34,6 +34,15 @@ type APIKeyAuthUserSnapshot struct {
|
|||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Balance float64 `json:"balance"`
|
Balance float64 `json:"balance"`
|
||||||
Concurrency int `json:"concurrency"`
|
Concurrency int `json:"concurrency"`
|
||||||
|
|
||||||
|
// Balance notification fields (required for CheckBalanceAfterDeduction)
|
||||||
|
Email string `json:"email"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
BalanceNotifyEnabled bool `json:"balance_notify_enabled"`
|
||||||
|
BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
|
||||||
|
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
|
||||||
|
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
|
||||||
|
TotalRecharged float64 `json:"total_recharged"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyAuthGroupSnapshot 分组快照
|
// APIKeyAuthGroupSnapshot 分组快照
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -13,7 +14,7 @@ import (
|
|||||||
"github.com/dgraph-io/ristretto"
|
"github.com/dgraph-io/ristretto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const apiKeyAuthSnapshotVersion = 3
|
const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold
|
||||||
|
|
||||||
type apiKeyAuthCacheConfig struct {
|
type apiKeyAuthCacheConfig struct {
|
||||||
l1Size int
|
l1Size int
|
||||||
@ -99,7 +100,7 @@ func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context
|
|||||||
s.authCacheL1.Del(cacheKey)
|
s.authCacheL1.Del(cacheKey)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
|
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
|
||||||
println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error())
|
slog.Warn("failed to start auth cache invalidation subscriber", "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -219,11 +220,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
|||||||
RateLimit1d: apiKey.RateLimit1d,
|
RateLimit1d: apiKey.RateLimit1d,
|
||||||
RateLimit7d: apiKey.RateLimit7d,
|
RateLimit7d: apiKey.RateLimit7d,
|
||||||
User: APIKeyAuthUserSnapshot{
|
User: APIKeyAuthUserSnapshot{
|
||||||
ID: apiKey.User.ID,
|
ID: apiKey.User.ID,
|
||||||
Status: apiKey.User.Status,
|
Status: apiKey.User.Status,
|
||||||
Role: apiKey.User.Role,
|
Role: apiKey.User.Role,
|
||||||
Balance: apiKey.User.Balance,
|
Balance: apiKey.User.Balance,
|
||||||
Concurrency: apiKey.User.Concurrency,
|
Concurrency: apiKey.User.Concurrency,
|
||||||
|
Email: apiKey.User.Email,
|
||||||
|
Username: apiKey.User.Username,
|
||||||
|
BalanceNotifyEnabled: apiKey.User.BalanceNotifyEnabled,
|
||||||
|
BalanceNotifyThresholdType: apiKey.User.BalanceNotifyThresholdType,
|
||||||
|
BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
|
||||||
|
BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
|
||||||
|
TotalRecharged: apiKey.User.TotalRecharged,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if apiKey.Group != nil {
|
if apiKey.Group != nil {
|
||||||
@ -274,11 +282,18 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
|||||||
RateLimit1d: snapshot.RateLimit1d,
|
RateLimit1d: snapshot.RateLimit1d,
|
||||||
RateLimit7d: snapshot.RateLimit7d,
|
RateLimit7d: snapshot.RateLimit7d,
|
||||||
User: &User{
|
User: &User{
|
||||||
ID: snapshot.User.ID,
|
ID: snapshot.User.ID,
|
||||||
Status: snapshot.User.Status,
|
Status: snapshot.User.Status,
|
||||||
Role: snapshot.User.Role,
|
Role: snapshot.User.Role,
|
||||||
Balance: snapshot.User.Balance,
|
Balance: snapshot.User.Balance,
|
||||||
Concurrency: snapshot.User.Concurrency,
|
Concurrency: snapshot.User.Concurrency,
|
||||||
|
Email: snapshot.User.Email,
|
||||||
|
Username: snapshot.User.Username,
|
||||||
|
BalanceNotifyEnabled: snapshot.User.BalanceNotifyEnabled,
|
||||||
|
BalanceNotifyThresholdType: snapshot.User.BalanceNotifyThresholdType,
|
||||||
|
BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
|
||||||
|
BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
|
||||||
|
TotalRecharged: snapshot.User.TotalRecharged,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if snapshot.Group != nil {
|
if snapshot.Group != nil {
|
||||||
|
|||||||
@ -87,6 +87,18 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *emailCacheStub) GetNotifyVerifyCode(ctx context.Context, email string) (*VerificationCodeData, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *emailCacheStub) SetNotifyVerifyCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *emailCacheStub) DeleteNotifyVerifyCode(ctx context.Context, email string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
|
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@ -107,6 +119,14 @@ func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, emai
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *emailCacheStub) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
|
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
JWT: config.JWTConfig{
|
JWT: config.JWTConfig{
|
||||||
|
|||||||
404
backend/internal/service/balance_notify_check_test.go
Normal file
404
backend/internal/service/balance_notify_check_test.go
Normal file
@ -0,0 +1,404 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newBalanceNotifyServiceForTest constructs a BalanceNotifyService with an
|
||||||
|
// in-memory settings repo and a non-nil emailService so that the guard-clause
|
||||||
|
// nil-checks pass. The emailService is intentionally minimal — tests must
|
||||||
|
// avoid crossing scenarios that would actually dispatch emails.
|
||||||
|
func newBalanceNotifyServiceForTest() (*BalanceNotifyService, *mockSettingRepo) {
|
||||||
|
repo := newMockSettingRepo()
|
||||||
|
// EmailService is a concrete type; construct with the same repo so that
|
||||||
|
// any accidental fallback reads still succeed. Tests should not trigger a
|
||||||
|
// crossing that reaches SendEmail.
|
||||||
|
email := NewEmailService(repo, nil)
|
||||||
|
return NewBalanceNotifyService(email, repo, nil), repo
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- guard clauses ----------
|
||||||
|
|
||||||
|
func TestCheckBalanceAfterDeduction_NilUser(t *testing.T) {
|
||||||
|
s, _ := newBalanceNotifyServiceForTest()
|
||||||
|
// Should not panic.
|
||||||
|
s.CheckBalanceAfterDeduction(context.Background(), nil, 100, 50)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckBalanceAfterDeduction_UserNotifyDisabled(t *testing.T) {
|
||||||
|
s, repo := newBalanceNotifyServiceForTest()
|
||||||
|
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
|
||||||
|
repo.data[SettingKeyBalanceLowNotifyThreshold] = "10"
|
||||||
|
u := &User{ID: 1, BalanceNotifyEnabled: false}
|
||||||
|
// Even with a crossing, disabled flag short-circuits.
|
||||||
|
s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckBalanceAfterDeduction_GlobalDisabled(t *testing.T) {
|
||||||
|
s, repo := newBalanceNotifyServiceForTest()
|
||||||
|
repo.data[SettingKeyBalanceLowNotifyEnabled] = "false"
|
||||||
|
u := &User{ID: 1, BalanceNotifyEnabled: true}
|
||||||
|
s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckBalanceAfterDeduction_ThresholdZero(t *testing.T) {
|
||||||
|
s, repo := newBalanceNotifyServiceForTest()
|
||||||
|
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
|
||||||
|
repo.data[SettingKeyBalanceLowNotifyThreshold] = "0"
|
||||||
|
u := &User{ID: 1, BalanceNotifyEnabled: true}
|
||||||
|
s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckBalanceAfterDeduction_UserThresholdOverride(t *testing.T) {
|
||||||
|
s, repo := newBalanceNotifyServiceForTest()
|
||||||
|
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
|
||||||
|
repo.data[SettingKeyBalanceLowNotifyThreshold] = "100" // global default
|
||||||
|
customThreshold := 5.0
|
||||||
|
u := &User{
|
||||||
|
ID: 1,
|
||||||
|
BalanceNotifyEnabled: true,
|
||||||
|
BalanceNotifyThreshold: &customThreshold,
|
||||||
|
}
|
||||||
|
// User's 5.0 threshold takes precedence over global 100. 20 -> 15 does not
|
||||||
|
// cross 5, so nothing fires (verified by absence of panic).
|
||||||
|
s.CheckBalanceAfterDeduction(context.Background(), u, 20, 15)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckBalanceAfterDeduction_NoCrossingNotFired(t *testing.T) {
|
||||||
|
s, repo := newBalanceNotifyServiceForTest()
|
||||||
|
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
|
||||||
|
repo.data[SettingKeyBalanceLowNotifyThreshold] = "10"
|
||||||
|
u := &User{ID: 1, BalanceNotifyEnabled: true}
|
||||||
|
|
||||||
|
// 100 -> 95, both remain above threshold=10, no crossing.
|
||||||
|
s.CheckBalanceAfterDeduction(context.Background(), u, 100, 5)
|
||||||
|
// 5 -> 3, both already below threshold, no crossing (only fires on first
|
||||||
|
// cross from above-to-below).
|
||||||
|
s.CheckBalanceAfterDeduction(context.Background(), u, 5, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- nil-service guards on CheckAccountQuotaAfterIncrement ----------
|
||||||
|
|
||||||
|
func TestCheckAccountQuotaAfterIncrement_NilAccount(t *testing.T) {
|
||||||
|
s, _ := newBalanceNotifyServiceForTest()
|
||||||
|
// Should not panic.
|
||||||
|
s.CheckAccountQuotaAfterIncrement(context.Background(), nil, 10, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckAccountQuotaAfterIncrement_ZeroCost(t *testing.T) {
|
||||||
|
s, _ := newBalanceNotifyServiceForTest()
|
||||||
|
a := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
|
||||||
|
s.CheckAccountQuotaAfterIncrement(context.Background(), a, 0, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckAccountQuotaAfterIncrement_NegativeCost(t *testing.T) {
|
||||||
|
s, _ := newBalanceNotifyServiceForTest()
|
||||||
|
a := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
|
||||||
|
s.CheckAccountQuotaAfterIncrement(context.Background(), a, -5, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckAccountQuotaAfterIncrement_GlobalDisabled(t *testing.T) {
|
||||||
|
s, repo := newBalanceNotifyServiceForTest()
|
||||||
|
repo.data[SettingKeyAccountQuotaNotifyEnabled] = "false"
|
||||||
|
a := &Account{
|
||||||
|
ID: 1,
|
||||||
|
Platform: PlatformAnthropic,
|
||||||
|
Type: AccountTypeAPIKey,
|
||||||
|
Extra: map[string]any{
|
||||||
|
"quota_notify_daily_enabled": true,
|
||||||
|
"quota_notify_daily_threshold": 100.0,
|
||||||
|
"quota_daily_limit": 1000.0,
|
||||||
|
"quota_daily_used": 950.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// Global disabled → no processing even if a dim would cross.
|
||||||
|
s.CheckAccountQuotaAfterIncrement(context.Background(), a, 100, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- sanity: internal helpers still work ----------
|
||||||
|
|
||||||
|
func TestGetBalanceNotifyConfig_AllFields(t *testing.T) {
|
||||||
|
s, repo := newBalanceNotifyServiceForTest()
|
||||||
|
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
|
||||||
|
repo.data[SettingKeyBalanceLowNotifyThreshold] = "12.5"
|
||||||
|
repo.data[SettingKeyBalanceLowNotifyRechargeURL] = "https://example.com/pay"
|
||||||
|
|
||||||
|
enabled, threshold, url := s.getBalanceNotifyConfig(context.Background())
|
||||||
|
require.True(t, enabled)
|
||||||
|
require.Equal(t, 12.5, threshold)
|
||||||
|
require.Equal(t, "https://example.com/pay", url)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetBalanceNotifyConfig_Disabled(t *testing.T) {
|
||||||
|
s, repo := newBalanceNotifyServiceForTest()
|
||||||
|
repo.data[SettingKeyBalanceLowNotifyEnabled] = "false"
|
||||||
|
|
||||||
|
enabled, _, _ := s.getBalanceNotifyConfig(context.Background())
|
||||||
|
require.False(t, enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetBalanceNotifyConfig_InvalidThreshold(t *testing.T) {
|
||||||
|
s, repo := newBalanceNotifyServiceForTest()
|
||||||
|
repo.data[SettingKeyBalanceLowNotifyEnabled] = "true"
|
||||||
|
repo.data[SettingKeyBalanceLowNotifyThreshold] = "not-a-number"
|
||||||
|
|
||||||
|
enabled, threshold, _ := s.getBalanceNotifyConfig(context.Background())
|
||||||
|
require.True(t, enabled)
|
||||||
|
require.Equal(t, 0.0, threshold)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAccountQuotaNotifyEnabled(t *testing.T) {
|
||||||
|
s, repo := newBalanceNotifyServiceForTest()
|
||||||
|
|
||||||
|
// Missing key → false
|
||||||
|
require.False(t, s.isAccountQuotaNotifyEnabled(context.Background()))
|
||||||
|
|
||||||
|
// Explicit "false"
|
||||||
|
repo.data[SettingKeyAccountQuotaNotifyEnabled] = "false"
|
||||||
|
require.False(t, s.isAccountQuotaNotifyEnabled(context.Background()))
|
||||||
|
|
||||||
|
// Explicit "true"
|
||||||
|
repo.data[SettingKeyAccountQuotaNotifyEnabled] = "true"
|
||||||
|
require.True(t, s.isAccountQuotaNotifyEnabled(context.Background()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSiteName_FallsBackToDefault(t *testing.T) {
|
||||||
|
s, _ := newBalanceNotifyServiceForTest()
|
||||||
|
name := s.getSiteName(context.Background())
|
||||||
|
require.Equal(t, defaultSiteName, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSiteName_Configured(t *testing.T) {
|
||||||
|
s, repo := newBalanceNotifyServiceForTest()
|
||||||
|
repo.data[SettingKeySiteName] = "My Site"
|
||||||
|
require.Equal(t, "My Site", s.getSiteName(context.Background()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- crossedDownward ----------
|
||||||
|
|
||||||
|
func TestCrossedDownward_CrossesBelow(t *testing.T) {
|
||||||
|
// oldBalance > threshold, newBalance < threshold → true
|
||||||
|
require.True(t, crossedDownward(100, 5, 10))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossedDownward_ExactlyAtThreshold(t *testing.T) {
|
||||||
|
// oldBalance > threshold, newBalance == threshold → false (not below)
|
||||||
|
require.False(t, crossedDownward(100, 10, 10))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossedDownward_OldExactlyAtThreshold_NewBelow(t *testing.T) {
|
||||||
|
// oldBalance == threshold, newBalance < threshold → true
|
||||||
|
// (at-or-above → below counts as a crossing)
|
||||||
|
require.True(t, crossedDownward(10, 5, 10))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossedDownward_AlreadyBelow(t *testing.T) {
|
||||||
|
// oldBalance < threshold → false (already below, no new crossing)
|
||||||
|
require.False(t, crossedDownward(5, 3, 10))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossedDownward_BothAbove(t *testing.T) {
|
||||||
|
// oldBalance > threshold, newBalance > threshold → false (no crossing)
|
||||||
|
require.False(t, crossedDownward(100, 50, 10))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossedDownward_ZeroThreshold(t *testing.T) {
|
||||||
|
// threshold == 0 → oldV >= 0 is always true, but newV < 0 only for negatives
|
||||||
|
// Typical case: positive balances should not fire when threshold is 0.
|
||||||
|
require.False(t, crossedDownward(10, 5, 0))
|
||||||
|
require.False(t, crossedDownward(0, 0, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossedDownward_ZeroThreshold_NegativeNew(t *testing.T) {
|
||||||
|
// Edge case: newBalance goes negative with threshold=0.
|
||||||
|
require.True(t, crossedDownward(5, -1, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossedDownward_NegativeValues(t *testing.T) {
|
||||||
|
// Both already negative, threshold is positive → no crossing (already below).
|
||||||
|
require.False(t, crossedDownward(-5, -10, 10))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossedDownward_LargeDecrement(t *testing.T) {
|
||||||
|
// A single large deduction crosses the threshold.
|
||||||
|
require.True(t, crossedDownward(1000, 0.5, 100))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCrossedDownward_SmallDecrement_NoCrossing(t *testing.T) {
|
||||||
|
// A tiny deduction stays above threshold.
|
||||||
|
require.False(t, crossedDownward(100, 99.99, 10))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- checkQuotaDimCrossings ----------
|
||||||
|
|
||||||
|
func TestCheckQuotaDimCrossings_NoDimensions(t *testing.T) {
|
||||||
|
s, _ := newBalanceNotifyServiceForTest()
|
||||||
|
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||||
|
// Empty dims → no crossing, no panic.
|
||||||
|
s.checkQuotaDimCrossings(account, nil, 10, []string{"admin@example.com"}, "TestSite")
|
||||||
|
s.checkQuotaDimCrossings(account, []quotaDim{}, 10, []string{"admin@example.com"}, "TestSite")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckQuotaDimCrossings_DisabledDimension(t *testing.T) {
|
||||||
|
s, _ := newBalanceNotifyServiceForTest()
|
||||||
|
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||||
|
dims := []quotaDim{
|
||||||
|
{
|
||||||
|
name: quotaDimDaily,
|
||||||
|
enabled: false, // disabled
|
||||||
|
threshold: 100,
|
||||||
|
thresholdType: thresholdTypeFixed,
|
||||||
|
currentUsed: 950,
|
||||||
|
limit: 1000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// Disabled dimension should be skipped even if crossing would occur.
|
||||||
|
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckQuotaDimCrossings_ZeroThresholdSkipped(t *testing.T) {
|
||||||
|
s, _ := newBalanceNotifyServiceForTest()
|
||||||
|
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||||
|
dims := []quotaDim{
|
||||||
|
{
|
||||||
|
name: quotaDimDaily,
|
||||||
|
enabled: true,
|
||||||
|
threshold: 0, // zero threshold
|
||||||
|
thresholdType: thresholdTypeFixed,
|
||||||
|
currentUsed: 950,
|
||||||
|
limit: 1000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// Zero threshold → skipped.
|
||||||
|
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckQuotaDimCrossings_NoCrossing_BothBelowThreshold(t *testing.T) {
|
||||||
|
s, _ := newBalanceNotifyServiceForTest()
|
||||||
|
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||||
|
// threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger)
|
||||||
|
// currentUsed=300 (after), oldUsed=300-50=250 (before). Both < 600, no crossing.
|
||||||
|
dims := []quotaDim{
|
||||||
|
{
|
||||||
|
name: quotaDimDaily,
|
||||||
|
enabled: true,
|
||||||
|
threshold: 400,
|
||||||
|
thresholdType: thresholdTypeFixed,
|
||||||
|
currentUsed: 300,
|
||||||
|
limit: 1000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckQuotaDimCrossings_NoCrossing_BothAboveThreshold(t *testing.T) {
|
||||||
|
s, _ := newBalanceNotifyServiceForTest()
|
||||||
|
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||||
|
// threshold=400 remaining, limit=1000 → effectiveThreshold = 600 (usage trigger)
|
||||||
|
// currentUsed=800 (after), oldUsed=800-50=750 (before). Both >= 600, no crossing.
|
||||||
|
dims := []quotaDim{
|
||||||
|
{
|
||||||
|
name: quotaDimDaily,
|
||||||
|
enabled: true,
|
||||||
|
threshold: 400,
|
||||||
|
thresholdType: thresholdTypeFixed,
|
||||||
|
currentUsed: 800,
|
||||||
|
limit: 1000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckQuotaDimCrossings_NegativeResolvedThreshold_Skipped(t *testing.T) {
|
||||||
|
s, _ := newBalanceNotifyServiceForTest()
|
||||||
|
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||||
|
// threshold=1200 remaining, limit=1000 → effectiveThreshold = 1000-1200 = -200
|
||||||
|
// Negative resolved threshold → skipped.
|
||||||
|
dims := []quotaDim{
|
||||||
|
{
|
||||||
|
name: quotaDimDaily,
|
||||||
|
enabled: true,
|
||||||
|
threshold: 1200,
|
||||||
|
thresholdType: thresholdTypeFixed,
|
||||||
|
currentUsed: 950,
|
||||||
|
limit: 1000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckQuotaDimCrossings_PercentageThreshold_NoCrossing(t *testing.T) {
|
||||||
|
s, _ := newBalanceNotifyServiceForTest()
|
||||||
|
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||||
|
// threshold=30%, limit=1000 → effectiveThreshold = 1000 * (1 - 0.30) = 700
|
||||||
|
// currentUsed=500, oldUsed=500-50=450. Both < 700, no crossing.
|
||||||
|
dims := []quotaDim{
|
||||||
|
{
|
||||||
|
name: quotaDimWeekly,
|
||||||
|
enabled: true,
|
||||||
|
threshold: 30,
|
||||||
|
thresholdType: thresholdTypePercentage,
|
||||||
|
currentUsed: 500,
|
||||||
|
limit: 1000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckQuotaDimCrossings_ZeroLimit_Skipped(t *testing.T) {
|
||||||
|
s, _ := newBalanceNotifyServiceForTest()
|
||||||
|
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||||
|
// limit=0 → resolvedThreshold returns 0 → skipped.
|
||||||
|
dims := []quotaDim{
|
||||||
|
{
|
||||||
|
name: quotaDimTotal,
|
||||||
|
enabled: true,
|
||||||
|
threshold: 100,
|
||||||
|
thresholdType: thresholdTypeFixed,
|
||||||
|
currentUsed: 50,
|
||||||
|
limit: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckQuotaDimCrossings_MultipleDims_MixedResults(t *testing.T) {
|
||||||
|
s, _ := newBalanceNotifyServiceForTest()
|
||||||
|
account := &Account{ID: 1, Name: "test", Platform: PlatformAnthropic}
|
||||||
|
// dim1: no crossing (both below effective threshold)
|
||||||
|
// dim2: disabled (skipped)
|
||||||
|
// dim3: zero threshold (skipped)
|
||||||
|
dims := []quotaDim{
|
||||||
|
{
|
||||||
|
name: quotaDimDaily,
|
||||||
|
enabled: true,
|
||||||
|
threshold: 400,
|
||||||
|
thresholdType: thresholdTypeFixed,
|
||||||
|
currentUsed: 300, // oldUsed=250, effectiveThreshold=600, both below
|
||||||
|
limit: 1000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: quotaDimWeekly,
|
||||||
|
enabled: false,
|
||||||
|
threshold: 100,
|
||||||
|
thresholdType: thresholdTypeFixed,
|
||||||
|
currentUsed: 900,
|
||||||
|
limit: 1000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: quotaDimTotal,
|
||||||
|
enabled: true,
|
||||||
|
threshold: 0,
|
||||||
|
thresholdType: thresholdTypeFixed,
|
||||||
|
currentUsed: 500,
|
||||||
|
limit: 1000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// None should trigger. No panic expected.
|
||||||
|
s.checkQuotaDimCrossings(account, dims, 50, []string{"admin@example.com"}, "TestSite")
|
||||||
|
}
|
||||||
147
backend/internal/service/balance_notify_email_body_test.go
Normal file
147
backend/internal/service/balance_notify_email_body_test.go
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// These tests guard against fmt.Sprintf arg-count mismatches in the email
|
||||||
|
// templates. A mismatch would produce "%!(EXTRA ...)" or "%!v(MISSING)" in
|
||||||
|
// the output, which these assertions will catch.
|
||||||
|
|
||||||
|
// ---------- buildBalanceLowEmailBody ----------
|
||||||
|
|
||||||
|
func TestBuildBalanceLowEmailBody_ContainsRequiredFields(t *testing.T) {
|
||||||
|
s := &BalanceNotifyService{}
|
||||||
|
body := s.buildBalanceLowEmailBody("Alice", 3.14, 10.0, "MySite", "")
|
||||||
|
|
||||||
|
// All substituted values should appear in the output.
|
||||||
|
require.Contains(t, body, "MySite")
|
||||||
|
require.Contains(t, body, "Alice")
|
||||||
|
require.Contains(t, body, "$3.14")
|
||||||
|
require.Contains(t, body, "$10.00")
|
||||||
|
|
||||||
|
// No fmt.Sprintf format error markers.
|
||||||
|
require.NotContains(t, body, "%!")
|
||||||
|
require.NotContains(t, body, "MISSING")
|
||||||
|
require.NotContains(t, body, "EXTRA")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildBalanceLowEmailBody_WithRechargeURL(t *testing.T) {
|
||||||
|
s := &BalanceNotifyService{}
|
||||||
|
body := s.buildBalanceLowEmailBody("Bob", 5.0, 20.0, "Site", "https://example.com/pay")
|
||||||
|
|
||||||
|
// The recharge anchor element should appear with the URL.
|
||||||
|
require.Contains(t, body, `href="https://example.com/pay"`)
|
||||||
|
require.Contains(t, body, "立即充值")
|
||||||
|
require.NotContains(t, body, "%!")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildBalanceLowEmailBody_RechargeURLEscaped(t *testing.T) {
|
||||||
|
s := &BalanceNotifyService{}
|
||||||
|
// Try a URL with characters that need HTML escaping.
|
||||||
|
body := s.buildBalanceLowEmailBody("u", 1.0, 5.0, "Site", `https://example.com/?a=1&b=<script>`)
|
||||||
|
|
||||||
|
// `&` and `<` should be escaped in the href.
|
||||||
|
require.Contains(t, body, "&")
|
||||||
|
require.Contains(t, body, "<script>")
|
||||||
|
require.NotContains(t, body, "<script>")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildBalanceLowEmailBody_NoRechargeURLOmitsButton(t *testing.T) {
|
||||||
|
s := &BalanceNotifyService{}
|
||||||
|
body := s.buildBalanceLowEmailBody("u", 1.0, 5.0, "Site", "")
|
||||||
|
// The anchor element should not be rendered (style class may still appear).
|
||||||
|
require.NotContains(t, body, `<a href`)
|
||||||
|
require.NotContains(t, body, "立即充值")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- buildQuotaAlertEmailBody ----------
|
||||||
|
|
||||||
|
func TestBuildQuotaAlertEmailBody_AllFieldsPresent(t *testing.T) {
|
||||||
|
s := &BalanceNotifyService{}
|
||||||
|
body := s.buildQuotaAlertEmailBody(
|
||||||
|
42, // accountID
|
||||||
|
"acc-foo", // accountName
|
||||||
|
"anthropic", // platform
|
||||||
|
"日限额 / Daily", // dimLabel
|
||||||
|
750.50, // used
|
||||||
|
1000.0, // limit
|
||||||
|
249.50, // remaining
|
||||||
|
"$249.50", // thresholdDisplay
|
||||||
|
"MySite", // siteName
|
||||||
|
)
|
||||||
|
|
||||||
|
require.Contains(t, body, "MySite")
|
||||||
|
require.Contains(t, body, "#42")
|
||||||
|
require.Contains(t, body, "acc-foo")
|
||||||
|
require.Contains(t, body, "anthropic")
|
||||||
|
require.Contains(t, body, "Daily")
|
||||||
|
require.Contains(t, body, "$750.50")
|
||||||
|
require.Contains(t, body, "$1000.00")
|
||||||
|
require.Contains(t, body, "$249.50")
|
||||||
|
|
||||||
|
// No format error markers.
|
||||||
|
require.NotContains(t, body, "%!")
|
||||||
|
require.NotContains(t, body, "MISSING")
|
||||||
|
require.NotContains(t, body, "EXTRA")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildQuotaAlertEmailBody_UnlimitedDisplay(t *testing.T) {
|
||||||
|
s := &BalanceNotifyService{}
|
||||||
|
body := s.buildQuotaAlertEmailBody(
|
||||||
|
1, "n", "p", "dim",
|
||||||
|
100.0, 0.0, // limit=0 triggers unlimited branch
|
||||||
|
0.0, "30%", "Site",
|
||||||
|
)
|
||||||
|
require.Contains(t, body, "无限制")
|
||||||
|
require.Contains(t, body, "Unlimited")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildQuotaAlertEmailBody_PercentageThresholdDisplay(t *testing.T) {
|
||||||
|
s := &BalanceNotifyService{}
|
||||||
|
body := s.buildQuotaAlertEmailBody(
|
||||||
|
1, "n", "p", "dim",
|
||||||
|
700.0, 1000.0, 300.0,
|
||||||
|
"30%", // percentage-formatted threshold
|
||||||
|
"Site",
|
||||||
|
)
|
||||||
|
require.Contains(t, body, "30%")
|
||||||
|
require.NotContains(t, body, "%!")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildQuotaAlertEmailBody_RemainingClampedAtZero(t *testing.T) {
|
||||||
|
// Even though caller is responsible for clamping, this test documents the
|
||||||
|
// display behavior with remaining=0.
|
||||||
|
s := &BalanceNotifyService{}
|
||||||
|
body := s.buildQuotaAlertEmailBody(
|
||||||
|
1, "n", "p", "dim",
|
||||||
|
1500.0, 1000.0, 0.0, // used > limit (over-quota)
|
||||||
|
"$100.00", "Site",
|
||||||
|
)
|
||||||
|
require.Contains(t, body, "$0.00")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- sanity checks on the CSS `%%` escape ----------
|
||||||
|
|
||||||
|
func TestBuildBalanceLowEmailBody_NoCSSFormatError(t *testing.T) {
|
||||||
|
s := &BalanceNotifyService{}
|
||||||
|
body := s.buildBalanceLowEmailBody("u", 1.0, 5.0, "Site", "")
|
||||||
|
// CSS `linear-gradient(135deg, #f59e0b 0%, #d97706 100%)` should appear with
|
||||||
|
// literal percent signs (from the %% escape in the template).
|
||||||
|
require.True(t,
|
||||||
|
strings.Contains(body, "0%") && strings.Contains(body, "100%"),
|
||||||
|
"CSS gradient percentages not rendered; got: %s", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildQuotaAlertEmailBody_NoCSSFormatError(t *testing.T) {
|
||||||
|
s := &BalanceNotifyService{}
|
||||||
|
body := s.buildQuotaAlertEmailBody(1, "n", "p", "d", 0, 0, 0, "$0.00", "Site")
|
||||||
|
require.True(t,
|
||||||
|
strings.Contains(body, "0%") && strings.Contains(body, "100%"),
|
||||||
|
"CSS gradient percentages not rendered; got: %s", body)
|
||||||
|
}
|
||||||
479
backend/internal/service/balance_notify_service.go
Normal file
479
backend/internal/service/balance_notify_service.go
Normal file
@ -0,0 +1,479 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"html"
|
||||||
|
"log/slog"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
emailSendTimeout = 30 * time.Second
|
||||||
|
|
||||||
|
// Threshold type values
|
||||||
|
thresholdTypeFixed = "fixed"
|
||||||
|
thresholdTypePercentage = "percentage"
|
||||||
|
|
||||||
|
// Quota dimension labels
|
||||||
|
quotaDimDaily = "daily"
|
||||||
|
quotaDimWeekly = "weekly"
|
||||||
|
quotaDimTotal = "total"
|
||||||
|
|
||||||
|
defaultSiteName = "Sub2API"
|
||||||
|
)
|
||||||
|
|
||||||
|
// quotaDimLabels maps dimension names to display labels.
|
||||||
|
var quotaDimLabels = map[string]string{
|
||||||
|
quotaDimDaily: "日限额 / Daily",
|
||||||
|
quotaDimWeekly: "周限额 / Weekly",
|
||||||
|
quotaDimTotal: "总限额 / Total",
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountQuotaReader provides read access to account quota data.
|
||||||
|
type AccountQuotaReader interface {
|
||||||
|
GetByID(ctx context.Context, id int64) (*Account, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BalanceNotifyService handles balance and quota threshold notifications.
|
||||||
|
type BalanceNotifyService struct {
|
||||||
|
emailService *EmailService
|
||||||
|
settingRepo SettingRepository
|
||||||
|
accountRepo AccountQuotaReader
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBalanceNotifyService creates a new BalanceNotifyService.
|
||||||
|
func NewBalanceNotifyService(emailService *EmailService, settingRepo SettingRepository, accountRepo AccountQuotaReader) *BalanceNotifyService {
|
||||||
|
return &BalanceNotifyService{
|
||||||
|
emailService: emailService,
|
||||||
|
settingRepo: settingRepo,
|
||||||
|
accountRepo: accountRepo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveBalanceThreshold returns the effective balance threshold.
|
||||||
|
// For percentage type, it computes threshold = totalRecharged * percentage / 100.
|
||||||
|
func resolveBalanceThreshold(threshold float64, thresholdType string, totalRecharged float64) float64 {
|
||||||
|
if thresholdType == thresholdTypePercentage && totalRecharged > 0 {
|
||||||
|
return totalRecharged * threshold / 100
|
||||||
|
}
|
||||||
|
return threshold
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckBalanceAfterDeduction checks if balance crossed below threshold after deduction.
|
||||||
|
// Notification is sent only on first crossing: oldBalance >= threshold && newBalance < threshold.
|
||||||
|
func (s *BalanceNotifyService) CheckBalanceAfterDeduction(ctx context.Context, user *User, oldBalance, cost float64) {
|
||||||
|
if !s.canNotifyBalance(user) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
effectiveThreshold, rechargeURL, ok := s.resolveUserEffectiveThreshold(ctx, user)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
newBalance := oldBalance - cost
|
||||||
|
if !crossedDownward(oldBalance, newBalance, effectiveThreshold) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.dispatchBalanceLowEmail(ctx, user, newBalance, effectiveThreshold, rechargeURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// canNotifyBalance checks nil guards and user-level toggle.
|
||||||
|
func (s *BalanceNotifyService) canNotifyBalance(user *User) bool {
|
||||||
|
if user == nil || s.emailService == nil || s.settingRepo == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return user.BalanceNotifyEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveUserEffectiveThreshold reads global + user config, returns the effective threshold.
|
||||||
|
// Returns ok=false when notifications should be skipped.
|
||||||
|
func (s *BalanceNotifyService) resolveUserEffectiveThreshold(ctx context.Context, user *User) (effectiveThreshold float64, rechargeURL string, ok bool) {
|
||||||
|
globalEnabled, globalThreshold, rechargeURL := s.getBalanceNotifyConfig(ctx)
|
||||||
|
if !globalEnabled {
|
||||||
|
return 0, "", false
|
||||||
|
}
|
||||||
|
threshold := globalThreshold
|
||||||
|
if user.BalanceNotifyThreshold != nil {
|
||||||
|
threshold = *user.BalanceNotifyThreshold
|
||||||
|
}
|
||||||
|
if threshold <= 0 {
|
||||||
|
return 0, "", false
|
||||||
|
}
|
||||||
|
effectiveThreshold = resolveBalanceThreshold(threshold, user.BalanceNotifyThresholdType, user.TotalRecharged)
|
||||||
|
if effectiveThreshold <= 0 {
|
||||||
|
return 0, "", false
|
||||||
|
}
|
||||||
|
return effectiveThreshold, rechargeURL, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// crossedDownward returns true when oldV was at-or-above threshold but newV dropped below it.
|
||||||
|
func crossedDownward(oldV, newV, threshold float64) bool {
|
||||||
|
return oldV >= threshold && newV < threshold
|
||||||
|
}
|
||||||
|
|
||||||
|
// dispatchBalanceLowEmail collects recipients and sends the alert in a goroutine.
|
||||||
|
func (s *BalanceNotifyService) dispatchBalanceLowEmail(ctx context.Context, user *User, newBalance, threshold float64, rechargeURL string) {
|
||||||
|
siteName := s.getSiteName(ctx)
|
||||||
|
recipients := s.collectBalanceNotifyRecipients(user)
|
||||||
|
slog.Info("CheckBalanceAfterDeduction: sending notification",
|
||||||
|
"user_id", user.ID, "recipients", recipients, "new_balance", newBalance, "threshold", threshold)
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
slog.Error("panic in balance notification", "recover", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
s.sendBalanceLowEmails(recipients, user.Username, user.Email, newBalance, threshold, siteName, rechargeURL)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// quotaDim describes one quota dimension for notification checking.
|
||||||
|
type quotaDim struct {
|
||||||
|
name string
|
||||||
|
enabled bool
|
||||||
|
threshold float64
|
||||||
|
thresholdType string // "fixed" (default) or "percentage"
|
||||||
|
currentUsed float64
|
||||||
|
limit float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolvedThreshold converts the user-facing "remaining" threshold into a usage-based trigger point.
|
||||||
|
// The threshold represents how much quota REMAINS when the alert fires:
|
||||||
|
// - Fixed ($): threshold=400, limit=1000 → fires when usage reaches 600 (remaining drops to 400)
|
||||||
|
// - Percentage (%): threshold=30, limit=1000 → fires when usage reaches 700 (remaining drops to 30%)
|
||||||
|
func (d quotaDim) resolvedThreshold() float64 {
|
||||||
|
if d.limit <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if d.thresholdType == thresholdTypePercentage {
|
||||||
|
return d.limit * (1 - d.threshold/100)
|
||||||
|
}
|
||||||
|
return d.limit - d.threshold
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildQuotaDims returns the three quota dimensions for notification checking.
|
||||||
|
func buildQuotaDims(account *Account) []quotaDim {
|
||||||
|
return []quotaDim{
|
||||||
|
{quotaDimDaily, account.GetQuotaNotifyDailyEnabled(), account.GetQuotaNotifyDailyThreshold(), account.GetQuotaNotifyDailyThresholdType(), account.GetQuotaDailyUsed(), account.GetQuotaDailyLimit()},
|
||||||
|
{quotaDimWeekly, account.GetQuotaNotifyWeeklyEnabled(), account.GetQuotaNotifyWeeklyThreshold(), account.GetQuotaNotifyWeeklyThresholdType(), account.GetQuotaWeeklyUsed(), account.GetQuotaWeeklyLimit()},
|
||||||
|
{quotaDimTotal, account.GetQuotaNotifyTotalEnabled(), account.GetQuotaNotifyTotalThreshold(), account.GetQuotaNotifyTotalThresholdType(), account.GetQuotaUsed(), account.GetQuotaLimit()},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildQuotaDimsFromState builds quota dimensions using DB transaction state instead of account snapshot.
|
||||||
|
// Notification settings (enabled, threshold, thresholdType) come from the account; usage values from quotaState.
|
||||||
|
func buildQuotaDimsFromState(account *Account, state *AccountQuotaState) []quotaDim {
|
||||||
|
return []quotaDim{
|
||||||
|
{quotaDimDaily, account.GetQuotaNotifyDailyEnabled(), account.GetQuotaNotifyDailyThreshold(), account.GetQuotaNotifyDailyThresholdType(), state.DailyUsed, state.DailyLimit},
|
||||||
|
{quotaDimWeekly, account.GetQuotaNotifyWeeklyEnabled(), account.GetQuotaNotifyWeeklyThreshold(), account.GetQuotaNotifyWeeklyThresholdType(), state.WeeklyUsed, state.WeeklyLimit},
|
||||||
|
{quotaDimTotal, account.GetQuotaNotifyTotalEnabled(), account.GetQuotaNotifyTotalThreshold(), account.GetQuotaNotifyTotalThresholdType(), state.TotalUsed, state.TotalLimit},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckAccountQuotaAfterIncrement checks if any quota dimension crossed above its notify threshold.
|
||||||
|
// When quotaState is non-nil (from DB transaction RETURNING), it is used directly for threshold
|
||||||
|
// checking, avoiding a separate DB read. Otherwise it falls back to fetching fresh account data.
|
||||||
|
func (s *BalanceNotifyService) CheckAccountQuotaAfterIncrement(ctx context.Context, account *Account, cost float64, quotaState *AccountQuotaState) {
|
||||||
|
if account == nil || s.emailService == nil || s.settingRepo == nil || cost <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !s.isAccountQuotaNotifyEnabled(ctx) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
adminEmails := s.getAccountQuotaNotifyEmails(ctx)
|
||||||
|
if len(adminEmails) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
siteName := s.getSiteName(ctx)
|
||||||
|
var dims []quotaDim
|
||||||
|
if quotaState != nil {
|
||||||
|
dims = buildQuotaDimsFromState(account, quotaState)
|
||||||
|
} else {
|
||||||
|
freshAccount := s.fetchFreshAccount(ctx, account)
|
||||||
|
dims = buildQuotaDims(freshAccount)
|
||||||
|
account = freshAccount // use fresh data for alert metadata
|
||||||
|
}
|
||||||
|
s.checkQuotaDimCrossings(account, dims, cost, adminEmails, siteName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchFreshAccount loads the latest account from DB; falls back to the snapshot on error.
|
||||||
|
func (s *BalanceNotifyService) fetchFreshAccount(ctx context.Context, snapshot *Account) *Account {
|
||||||
|
if s.accountRepo == nil {
|
||||||
|
return snapshot
|
||||||
|
}
|
||||||
|
fresh, err := s.accountRepo.GetByID(ctx, snapshot.ID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to fetch fresh account for quota notify, using snapshot",
|
||||||
|
"account_id", snapshot.ID, "error", err)
|
||||||
|
return snapshot
|
||||||
|
}
|
||||||
|
return fresh
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkQuotaDimCrossings iterates pre-built quota dimensions and sends alerts for threshold crossings.
|
||||||
|
// Pre-increment value is reconstructed as currentUsed - cost to detect the crossing moment.
|
||||||
|
func (s *BalanceNotifyService) checkQuotaDimCrossings(account *Account, dims []quotaDim, cost float64, adminEmails []string, siteName string) {
|
||||||
|
for _, dim := range dims {
|
||||||
|
if !dim.enabled || dim.threshold <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
effectiveThreshold := dim.resolvedThreshold()
|
||||||
|
if effectiveThreshold <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newUsed := dim.currentUsed
|
||||||
|
oldUsed := dim.currentUsed - cost
|
||||||
|
if oldUsed < effectiveThreshold && newUsed >= effectiveThreshold {
|
||||||
|
s.asyncSendQuotaAlert(adminEmails, account.ID, account.Name, account.Platform, dim, newUsed, effectiveThreshold, siteName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// asyncSendQuotaAlert sends quota alert email in a goroutine with panic recovery.
|
||||||
|
func (s *BalanceNotifyService) asyncSendQuotaAlert(adminEmails []string, accountID int64, accountName, platform string, dim quotaDim, newUsed, effectiveThreshold float64, siteName string) {
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
slog.Error("panic in quota notification", "recover", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
s.sendQuotaAlertEmails(adminEmails, accountID, accountName, platform, dim, newUsed, siteName)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getBalanceNotifyConfig reads global balance notification settings.
|
||||||
|
func (s *BalanceNotifyService) getBalanceNotifyConfig(ctx context.Context) (enabled bool, threshold float64, rechargeURL string) {
|
||||||
|
keys := []string{SettingKeyBalanceLowNotifyEnabled, SettingKeyBalanceLowNotifyThreshold, SettingKeyBalanceLowNotifyRechargeURL}
|
||||||
|
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||||
|
if err != nil {
|
||||||
|
return false, 0, ""
|
||||||
|
}
|
||||||
|
enabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
|
||||||
|
if v := settings[SettingKeyBalanceLowNotifyThreshold]; v != "" {
|
||||||
|
if f, err := strconv.ParseFloat(v, 64); err == nil {
|
||||||
|
threshold = f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rechargeURL = settings[SettingKeyBalanceLowNotifyRechargeURL]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAccountQuotaNotifyEnabled checks the global account quota notification toggle.
|
||||||
|
func (s *BalanceNotifyService) isAccountQuotaNotifyEnabled(ctx context.Context) bool {
|
||||||
|
val, err := s.settingRepo.GetValue(ctx, SettingKeyAccountQuotaNotifyEnabled)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return val == "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAccountQuotaNotifyEmails reads admin notification emails from settings,
|
||||||
|
// filtering out disabled and unverified entries.
|
||||||
|
func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context) []string {
|
||||||
|
raw, err := s.settingRepo.GetValue(ctx, SettingKeyAccountQuotaNotifyEmails)
|
||||||
|
if err != nil || strings.TrimSpace(raw) == "" || raw == "[]" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := ParseNotifyEmails(raw)
|
||||||
|
if len(entries) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return filterVerifiedEmails(entries)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSiteName reads site name from settings with fallback.
|
||||||
|
func (s *BalanceNotifyService) getSiteName(ctx context.Context) string {
|
||||||
|
name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
|
||||||
|
if err != nil || name == "" {
|
||||||
|
return defaultSiteName
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterVerifiedEmails returns deduplicated, non-disabled, verified emails.
|
||||||
|
func filterVerifiedEmails(entries []NotifyEmailEntry) []string {
|
||||||
|
var recipients []string
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.Disabled || !entry.Verified {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
email := strings.TrimSpace(entry.Email)
|
||||||
|
if email == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
lower := strings.ToLower(email)
|
||||||
|
if seen[lower] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[lower] = true
|
||||||
|
recipients = append(recipients, email)
|
||||||
|
}
|
||||||
|
return recipients
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectBalanceNotifyRecipients returns verified, non-disabled email recipients.
|
||||||
|
// Only emails with verified=true and disabled=false are included.
|
||||||
|
func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []string {
|
||||||
|
return filterVerifiedEmails(user.BalanceNotifyExtraEmails)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendEmails sends an email to all recipients with shared timeout and error logging.
|
||||||
|
func (s *BalanceNotifyService) sendEmails(recipients []string, subject, body string, logAttrs ...any) {
|
||||||
|
if len(recipients) == 0 {
|
||||||
|
slog.Warn("sendEmails: no recipients", "subject", subject)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, to := range recipients {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), emailSendTimeout)
|
||||||
|
if err := s.emailService.SendEmail(ctx, to, subject, body); err != nil {
|
||||||
|
attrs := append([]any{"to", to, "error", err}, logAttrs...)
|
||||||
|
slog.Error("failed to send notification", attrs...)
|
||||||
|
} else {
|
||||||
|
slog.Info("notification email sent successfully", "to", to, "subject", subject)
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendBalanceLowEmails sends balance low notification to all recipients.
|
||||||
|
func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userName, userEmail string, balance, threshold float64, siteName, rechargeURL string) {
|
||||||
|
displayName := userName
|
||||||
|
if displayName == "" {
|
||||||
|
displayName = userEmail
|
||||||
|
}
|
||||||
|
subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", sanitizeEmailHeader(siteName))
|
||||||
|
body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName), rechargeURL)
|
||||||
|
s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendQuotaAlertEmails sends quota alert notification to admin emails.
|
||||||
|
func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accountID int64, accountName, platform string, dim quotaDim, used float64, siteName string) {
|
||||||
|
dimLabel := quotaDimLabels[dim.name]
|
||||||
|
if dimLabel == "" {
|
||||||
|
dimLabel = dim.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format the remaining-based threshold for display
|
||||||
|
thresholdDisplay := fmt.Sprintf("$%.2f", dim.threshold)
|
||||||
|
if dim.thresholdType == thresholdTypePercentage {
|
||||||
|
thresholdDisplay = fmt.Sprintf("%.0f%%", dim.threshold)
|
||||||
|
}
|
||||||
|
remaining := dim.limit - used
|
||||||
|
if remaining < 0 {
|
||||||
|
remaining = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", sanitizeEmailHeader(siteName), sanitizeEmailHeader(accountName))
|
||||||
|
body := s.buildQuotaAlertEmailBody(accountID, html.EscapeString(accountName), html.EscapeString(platform), html.EscapeString(dimLabel), used, dim.limit, remaining, thresholdDisplay, html.EscapeString(siteName))
|
||||||
|
s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dim.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeEmailHeader removes CR/LF characters to prevent SMTP header injection.
|
||||||
|
func sanitizeEmailHeader(s string) string {
|
||||||
|
return strings.NewReplacer("\r", "", "\n", "").Replace(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// balanceLowEmailTemplate is the HTML template for balance low notifications.
|
||||||
|
// Format args: siteName, userName, userName, balance, threshold, threshold.
|
||||||
|
// The recharge button is appended dynamically when rechargeURL is set.
|
||||||
|
const balanceLowEmailTemplate = `<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<style>
|
||||||
|
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
|
||||||
|
.container { max-width: 600px; margin: 0 auto; background-color: #fff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
|
||||||
|
.header { background: linear-gradient(135deg, #f59e0b 0%%, #d97706 100%%); color: white; padding: 30px; text-align: center; }
|
||||||
|
.header h1 { margin: 0; font-size: 24px; }
|
||||||
|
.content { padding: 40px 30px; text-align: center; }
|
||||||
|
.balance { font-size: 36px; font-weight: bold; color: #dc2626; margin: 20px 0; }
|
||||||
|
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
|
||||||
|
.recharge-btn { display: inline-block; margin-top: 24px; padding: 12px 32px; background: linear-gradient(135deg, #f59e0b 0%%, #d97706 100%%); color: #fff; text-decoration: none; border-radius: 6px; font-size: 16px; font-weight: bold; }
|
||||||
|
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<div class="header"><h1>%s</h1></div>
|
||||||
|
<div class="content">
|
||||||
|
<p style="font-size: 18px; color: #333;">%s,您的余额不足</p>
|
||||||
|
<p style="color: #666;">Dear %s, your balance is running low</p>
|
||||||
|
<div class="balance">$%.2f</div>
|
||||||
|
<div class="info">
|
||||||
|
<p>您的账户余额已低于提醒阈值 <strong>$%.2f</strong>。</p>
|
||||||
|
<p>Your account balance has fallen below the alert threshold of <strong>$%.2f</strong>.</p>
|
||||||
|
<p>请及时充值以免服务中断。</p>
|
||||||
|
<p>Please top up to avoid service interruption.</p>
|
||||||
|
</div>
|
||||||
|
%s
|
||||||
|
</div>
|
||||||
|
<div class="footer"><p>此邮件由系统自动发送,请勿回复。</p></div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
|
||||||
|
// quotaAlertEmailTemplate is the HTML template for account quota alert notifications.
|
||||||
|
// Format args: siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay.
|
||||||
|
const quotaAlertEmailTemplate = `<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<style>
|
||||||
|
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
|
||||||
|
.container { max-width: 600px; margin: 0 auto; background-color: #fff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
|
||||||
|
.header { background: linear-gradient(135deg, #ef4444 0%%, #dc2626 100%%); color: white; padding: 30px; text-align: center; }
|
||||||
|
.header h1 { margin: 0; font-size: 24px; }
|
||||||
|
.content { padding: 40px 30px; }
|
||||||
|
.metric { display: flex; justify-content: space-between; padding: 12px 0; border-bottom: 1px solid #eee; }
|
||||||
|
.metric-label { color: #666; }
|
||||||
|
.metric-value { font-weight: bold; color: #333; }
|
||||||
|
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; text-align: center; }
|
||||||
|
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<div class="header"><h1>%s</h1></div>
|
||||||
|
<div class="content">
|
||||||
|
<p style="font-size: 18px; color: #333; text-align: center;">账号限额告警 / Account Quota Alert</p>
|
||||||
|
<div class="metric"><span class="metric-label">账号 ID / Account ID</span><span class="metric-value">#%d</span></div>
|
||||||
|
<div class="metric"><span class="metric-label">账号 / Account</span><span class="metric-value">%s</span></div>
|
||||||
|
<div class="metric"><span class="metric-label">平台 / Platform</span><span class="metric-value">%s</span></div>
|
||||||
|
<div class="metric"><span class="metric-label">维度 / Dimension</span><span class="metric-value">%s</span></div>
|
||||||
|
<div class="metric"><span class="metric-label">已使用 / Used</span><span class="metric-value">$%.2f</span></div>
|
||||||
|
<div class="metric"><span class="metric-label">限额 / Limit</span><span class="metric-value">%s</span></div>
|
||||||
|
<div class="metric"><span class="metric-label">剩余额度 / Remaining</span><span class="metric-value">$%.2f</span></div>
|
||||||
|
<div class="metric"><span class="metric-label">提醒阈值 / Alert Threshold</span><span class="metric-value">%s</span></div>
|
||||||
|
<div class="info">
|
||||||
|
<p>账号剩余额度已低于提醒阈值,请及时关注。</p>
|
||||||
|
<p>Account remaining quota has fallen below the alert threshold.</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="footer"><p>此邮件由系统自动发送,请勿回复。</p></div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
|
||||||
|
// buildBalanceLowEmailBody builds HTML email for balance low notification.
|
||||||
|
func (s *BalanceNotifyService) buildBalanceLowEmailBody(userName string, balance, threshold float64, siteName, rechargeURL string) string {
|
||||||
|
rechargeBlock := ""
|
||||||
|
if rechargeURL != "" {
|
||||||
|
rechargeBlock = fmt.Sprintf(`<a href="%s" class="recharge-btn">立即充值 / Top Up Now</a>`, html.EscapeString(rechargeURL))
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(balanceLowEmailTemplate, siteName, userName, userName, balance, threshold, threshold, rechargeBlock)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildQuotaAlertEmailBody builds HTML email for account quota alert.
|
||||||
|
func (s *BalanceNotifyService) buildQuotaAlertEmailBody(accountID int64, accountName, platform, dimLabel string, used, limit, remaining float64, thresholdDisplay, siteName string) string {
|
||||||
|
limitStr := fmt.Sprintf("$%.2f", limit)
|
||||||
|
if limit <= 0 {
|
||||||
|
limitStr = "无限制 / Unlimited"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(quotaAlertEmailTemplate, siteName, accountID, accountName, platform, dimLabel, used, limitStr, remaining, thresholdDisplay)
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user