diff --git a/README.md b/README.md index add1b4eb..5415ea61 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,12 @@ Register now via this link to recei + +RunAPI +Thanks to RunAPI for sponsoring this project! RunAPI is an efficient and stable API platform and OpenRouter alternative. With one API Key, you can access 150+ popular models including OpenAI, Claude, Gemini, DeepSeek, and Grok, with pricing as low as 10% of the original rate. It is highly stable and seamlessly compatible with tools such as Claude Code and OpenClaw. + + + ## Ecosystem diff --git a/README_CN.md b/README_CN.md index 67340969..ca7b3218 100644 --- a/README_CN.md +++ b/README_CN.md @@ -112,6 +112,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 感谢 PPToken.org 赞助本项目! PPToken.org 主打 GPT 系列模型 API 中转服务,支持 Codex、Claude Code、OpenAI 兼容客户端及 Gemini CLI 等工具接入。充值 1:1,1 元=1 美元额度;GPT 模型最低 0.16 倍倍率,综合成本约为官方价格的 0.22 折,最快首字 Token 约 1 秒,适合开发者低成本、高响应速度接入 GPT 模型能力。技术支持: 7×24 小时真人响应(不是机器人),群内@技术,10 分钟内有回复 。赞助商福利:前 200 名用户通过 [专属注册链接] 注册,输入优惠码 `SUB2API`,可领取 Codex / Claude Code 免费试用额度,无门槛、不绑卡。 + + +RunAPI +感谢 RunAPI 赞助本项目! RunAPI 是高效稳定的API OpenRouter平替平台,一个 API Key 即可访问 OpenAI、Claude、Gemini、DeepSeek、Grok 等 150+ 主流模型,低至 1 折,极其稳定,可以无缝兼容 Claude Code、OpenClaw 等工具。 + + + ## 生态项目 diff --git a/README_JA.md b/README_JA.md index 13d710df..45adfd65 100644 --- a/README_JA.md +++ b/README_JA.md @@ -113,6 +113,12 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを + +RunAPI +RunAPI のご支援に感謝します!RunAPI は効率的で安定した API プラットフォームで、OpenRouter の代替として利用できます。1つの API キーで OpenAI、Claude、Gemini、DeepSeek、Grok など 150以上の主要モデルにアクセスでき、価格は最低 10% から。非常に安定しており、Claude Code や OpenClaw などのツールとシームレスに互換します。 + + + ## エコシステム diff --git a/assets/partners/logos/runapi.png b/assets/partners/logos/runapi.png new file mode 100644 index 00000000..7f522975 Binary files /dev/null and b/assets/partners/logos/runapi.png differ diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index aab9b571..d0c4abc6 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.129 +0.1.130 diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 46edcb69..784f309f 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -24,8 +24,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/web" "github.com/gin-gonic/gin" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" ) //go:embed VERSION @@ -116,11 +114,16 @@ func runSetupServer() { log.Printf("Setup wizard available at http://%s", addr) log.Println("Complete the setup wizard to configure Sub2API") + protocols := new(http.Protocols) + protocols.SetHTTP1(true) + protocols.SetUnencryptedHTTP2(true) + server := &http.Server{ Addr: addr, - Handler: h2c.NewHandler(r, &http2.Server{}), + Handler: r, ReadHeaderTimeout: 30 * time.Second, IdleTimeout: 120 * time.Second, + Protocols: protocols, } if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 3e72a49d..76fca0df 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -113,23 +113,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) privacyClientFactory := providePrivacyClientFactory() - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory) + usageBillingRepository := repository.NewUsageBillingRepository(client, db) + gatewayCache := repository.NewGatewayCache(redisClient) + schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) + schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) - adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) - sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) - rpmCache := repository.NewRPMCache(redisClient) - groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache) - groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService) - claudeOAuthClient := repository.NewClaudeOAuthClient() - oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) - openAIOAuthClient := repository.NewOpenAIOAuthClient() - openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) - geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) - geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() - driveClient := repository.NewGeminiDriveClient() - geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig) - antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) + pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig) + pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) + if err != nil { + return nil, err + } + billingService := service.NewBillingService(configConfig, pricingService) geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) tempUnschedCache := repository.NewTempUnschedCache(redisClient) timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient) @@ -138,6 +133,30 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache) rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator) httpUpstream := repository.NewHTTPUpstream(configConfig) + deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) + openAIOAuthClient := repository.NewOpenAIOAuthClient() + openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient) + oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache) + openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI) + channelRepository := repository.NewChannelRepository(db) + channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService) + modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) + notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService) + balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory, openAIGatewayService) + adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) + sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) + rpmCache := repository.NewRPMCache(redisClient) + groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache) + groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService) + claudeOAuthClient := repository.NewClaudeOAuthClient() + oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) + geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) + geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() + driveClient := repository.NewGeminiDriveClient() + geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig) + antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream) antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) usageCache := service.NewUsageCache() @@ -146,12 +165,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient) tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) - oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache) geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI) claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI) - gatewayCache := repository.NewGatewayCache(redisClient) - schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) - schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache) internal500CounterCache := repository.NewInternal500CounterCache(redisClient) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) @@ -176,25 +191,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService) promoHandler := admin.NewPromoHandler(promoService) opsRepository := repository.NewOpsRepository(db) - usageBillingRepository := repository.NewUsageBillingRepository(client, db) - pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig) - pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient) - if err != nil { - return nil, err - } - billingService := service.NewBillingService(configConfig, pricingService) identityService := service.NewIdentityService(identityCache) - deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) digestSessionStore := service.NewDigestSessionStore() - channelRepository := repository.NewChannelRepository(db) - channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService) - modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) - notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService) - balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService) rpmTokenBucketService := service.NewRPMTokenBucketService() gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService) - openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) @@ -271,9 +271,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService, settingRepository, opsService) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI, openAIGatewayService) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) - subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository, notificationEmailService) + subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository, settingRepository, notificationEmailService) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository) diff --git a/backend/go.mod b/backend/go.mod index 74fb4a6e..d8a9c437 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -15,6 +15,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 github.com/coder/websocket v1.8.14 github.com/dgraph-io/ristretto v0.2.0 + github.com/docker/docker v28.5.2+incompatible github.com/gin-gonic/gin v1.9.1 github.com/golang-jwt/jwt/v5 v5.2.2 github.com/google/uuid v1.6.0 @@ -41,11 +42,11 @@ require ( github.com/wechatpay-apiv3/wechatpay-go v0.2.21 github.com/zeromicro/go-zero v1.9.4 go.uber.org/zap v1.24.0 - golang.org/x/crypto v0.50.0 + golang.org/x/crypto v0.51.0 golang.org/x/image v0.39.0 - golang.org/x/net v0.53.0 + golang.org/x/net v0.55.0 golang.org/x/sync v0.20.0 - golang.org/x/term v0.42.0 + golang.org/x/term v0.43.0 google.golang.org/protobuf v1.36.10 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 @@ -87,7 +88,6 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect - github.com/docker/docker v28.5.2+incompatible // indirect github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect @@ -107,7 +107,6 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect - github.com/google/subcommands v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect @@ -175,10 +174,10 @@ require ( go.uber.org/multierr v1.9.0 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect - golang.org/x/mod v0.34.0 // indirect - golang.org/x/sys v0.43.0 // indirect - golang.org/x/text v0.36.0 // indirect - golang.org/x/tools v0.43.0 // indirect + golang.org/x/mod v0.35.0 // indirect + golang.org/x/sys v0.45.0 // indirect + golang.org/x/text v0.37.0 // indirect + golang.org/x/tools v0.44.0 // indirect google.golang.org/grpc v1.75.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect modernc.org/libc v1.67.6 // indirect diff --git a/backend/go.sum b/backend/go.sum index 3e8f0f04..04a9d449 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -108,8 +108,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= -github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM= -github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM= github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94= @@ -166,8 +164,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= -github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= -github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= @@ -222,8 +218,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= -github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -257,8 +251,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= -github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -288,8 +280,6 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= @@ -322,8 +312,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= -github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= @@ -415,16 +403,16 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= -golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= +golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/image v0.39.0 h1:skVYidAEVKgn8lZ602XO75asgXBgLj9G/FE3RbuPFww= golang.org/x/image v0.39.0/go.mod h1:sIbmppfU+xFLPIG0FoVUTvyBMmgng1/XAMhQ2ft0hpA= -golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= -golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= -golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= -golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= +golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM= +golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU= +golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8= +golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -436,16 +424,16 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= -golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= -golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= -golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= -golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= +golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= +golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= -golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s= -golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0= +golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c= +golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 3ee806d3..ee088d63 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -9,6 +9,7 @@ import ( "net/url" "os" "strings" + "sync/atomic" "time" "github.com/spf13/viper" @@ -581,11 +582,35 @@ type CORSConfig struct { } type SecurityConfig struct { - URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"` - ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"` - CSP CSPConfig `mapstructure:"csp"` - ProxyFallback ProxyFallbackConfig `mapstructure:"proxy_fallback"` - ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"` + URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"` + ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"` + CSP CSPConfig `mapstructure:"csp"` + ProxyFallback ProxyFallbackConfig `mapstructure:"proxy_fallback"` + ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"` + TrustForwardedIPForAPIKeyACL bool `mapstructure:"trust_forwarded_ip_for_api_key_acl"` + trustForwardedIPForAPIKeyACLLive *atomic.Bool `mapstructure:"-"` +} + +func (c *Config) TrustForwardedIPForAPIKeyACL() bool { + if c == nil { + return false + } + live := c.Security.trustForwardedIPForAPIKeyACLLive + if live == nil { + return c.Security.TrustForwardedIPForAPIKeyACL + } + return live.Load() +} + +func (c *Config) SetTrustForwardedIPForAPIKeyACL(enabled bool) { + if c == nil { + return + } + c.Security.TrustForwardedIPForAPIKeyACL = enabled + if c.Security.trustForwardedIPForAPIKeyACLLive == nil { + c.Security.trustForwardedIPForAPIKeyACLLive = &atomic.Bool{} + } + c.Security.trustForwardedIPForAPIKeyACLLive.Store(enabled) } type URLAllowlistConfig struct { @@ -1083,7 +1108,8 @@ type GatewaySchedulingConfig struct { FallbackSelectionMode string `mapstructure:"fallback_selection_mode"` // 负载计算 - LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` + LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` + LoadBatchCacheTTLMS int `mapstructure:"load_batch_cache_ttl_ms"` // 快照桶读取时的 MGET 分块大小 SnapshotMGetChunkSize int `mapstructure:"snapshot_mget_chunk_size"` // 快照重建时的缓存写入分块大小 @@ -1466,6 +1492,7 @@ func load(allowMissingJWTSecret bool) (*Config, error) { cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove) cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy) + cfg.SetTrustForwardedIPForAPIKeyACL(cfg.Security.TrustForwardedIPForAPIKeyACL) cfg.Log.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level)) cfg.Log.Format = strings.ToLower(strings.TrimSpace(cfg.Log.Format)) cfg.Log.ServiceName = strings.TrimSpace(cfg.Log.ServiceName) @@ -1610,6 +1637,7 @@ func setDefaults() { viper.SetDefault("security.csp.enabled", true) viper.SetDefault("security.csp.policy", DefaultCSPPolicy) viper.SetDefault("security.proxy_probe.insecure_skip_verify", false) + viper.SetDefault("security.trust_forwarded_ip_for_api_key_acl", false) // Security - disable direct fallback on proxy error viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false) @@ -1905,6 +1933,7 @@ func setDefaults() { viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used") viper.SetDefault("gateway.scheduling.load_batch_enabled", true) + viper.SetDefault("gateway.scheduling.load_batch_cache_ttl_ms", 200) viper.SetDefault("gateway.scheduling.snapshot_mget_chunk_size", 128) viper.SetDefault("gateway.scheduling.snapshot_write_chunk_size", 256) viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) @@ -2766,6 +2795,9 @@ func (c *Config) Validate() error { if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 { return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive") } + if c.Gateway.Scheduling.LoadBatchCacheTTLMS < 0 { + return fmt.Errorf("gateway.scheduling.load_batch_cache_ttl_ms must be non-negative") + } if c.Gateway.Scheduling.SnapshotMGetChunkSize <= 0 { return fmt.Errorf("gateway.scheduling.snapshot_mget_chunk_size must be positive") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index a47de2f8..99fec46c 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -73,6 +73,9 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) { if !cfg.Gateway.Scheduling.LoadBatchEnabled { t.Fatalf("LoadBatchEnabled = false, want true") } + if cfg.Gateway.Scheduling.LoadBatchCacheTTLMS != 200 { + t.Fatalf("LoadBatchCacheTTLMS = %d, want 200", cfg.Gateway.Scheduling.LoadBatchCacheTTLMS) + } if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second { t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval) } @@ -1415,6 +1418,11 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 }, wantErr: "gateway.scheduling.sticky_session_max_waiting", }, + { + name: "gateway scheduling load batch cache ttl", + mutate: func(c *Config) { c.Gateway.Scheduling.LoadBatchCacheTTLMS = -1 }, + wantErr: "gateway.scheduling.load_batch_cache_ttl_ms", + }, { name: "gateway scheduling outbox poll", mutate: func(c *Config) { c.Gateway.Scheduling.OutboxPollIntervalSeconds = 0 }, diff --git a/backend/internal/handler/admin/content_moderation_handler.go b/backend/internal/handler/admin/content_moderation_handler.go index 6f0f2aab..defcd29d 100644 --- a/backend/internal/handler/admin/content_moderation_handler.go +++ b/backend/internal/handler/admin/content_moderation_handler.go @@ -20,34 +20,35 @@ func NewContentModerationHandler(svc *service.ContentModerationService) *Content } type contentModerationConfigRequest struct { - Enabled *bool `json:"enabled"` - Mode *string `json:"mode"` - BaseURL *string `json:"base_url"` - Model *string `json:"model"` - APIKey *string `json:"api_key"` - APIKeys *[]string `json:"api_keys"` - APIKeysMode string `json:"api_keys_mode"` - DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"` - ClearAPIKey bool `json:"clear_api_key"` - TimeoutMS *int `json:"timeout_ms"` - SampleRate *int `json:"sample_rate"` - AllGroups *bool `json:"all_groups"` - GroupIDs *[]int64 `json:"group_ids"` - RecordNonHits *bool `json:"record_non_hits"` - WorkerCount *int `json:"worker_count"` - QueueSize *int `json:"queue_size"` - BlockStatus *int `json:"block_status"` - BlockMessage *string `json:"block_message"` - EmailOnHit *bool `json:"email_on_hit"` - AutoBanEnabled *bool `json:"auto_ban_enabled"` - BanThreshold *int `json:"ban_threshold"` - ViolationWindowHours *int `json:"violation_window_hours"` - RetryCount *int `json:"retry_count"` - HitRetentionDays *int `json:"hit_retention_days"` - NonHitRetentionDays *int `json:"non_hit_retention_days"` - PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"` - BlockedKeywords *[]string `json:"blocked_keywords"` - KeywordBlockingMode *string `json:"keyword_blocking_mode"` + Enabled *bool `json:"enabled"` + Mode *string `json:"mode"` + BaseURL *string `json:"base_url"` + Model *string `json:"model"` + APIKey *string `json:"api_key"` + APIKeys *[]string `json:"api_keys"` + APIKeysMode string `json:"api_keys_mode"` + DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"` + ClearAPIKey bool `json:"clear_api_key"` + TimeoutMS *int `json:"timeout_ms"` + SampleRate *int `json:"sample_rate"` + AllGroups *bool `json:"all_groups"` + GroupIDs *[]int64 `json:"group_ids"` + RecordNonHits *bool `json:"record_non_hits"` + WorkerCount *int `json:"worker_count"` + QueueSize *int `json:"queue_size"` + BlockStatus *int `json:"block_status"` + BlockMessage *string `json:"block_message"` + EmailOnHit *bool `json:"email_on_hit"` + AutoBanEnabled *bool `json:"auto_ban_enabled"` + BanThreshold *int `json:"ban_threshold"` + ViolationWindowHours *int `json:"violation_window_hours"` + RetryCount *int `json:"retry_count"` + HitRetentionDays *int `json:"hit_retention_days"` + NonHitRetentionDays *int `json:"non_hit_retention_days"` + PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"` + BlockedKeywords *[]string `json:"blocked_keywords"` + KeywordBlockingMode *string `json:"keyword_blocking_mode"` + ModelFilter *service.ContentModerationModelFilter `json:"model_filter"` } type contentModerationAPIKeyTestRequest struct { @@ -107,6 +108,7 @@ func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) { PreHashCheckEnabled: req.PreHashCheckEnabled, BlockedKeywords: req.BlockedKeywords, KeywordBlockingMode: req.KeywordBlockingMode, + ModelFilter: req.ModelFilter, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index 7b4300b1..5dca01d5 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -307,6 +307,51 @@ func (h *RedeemHandler) BatchDelete(c *gin.Context) { }) } +// BatchUpdate handles batch updating redeem codes +// POST /api/v1/admin/redeem-codes/batch-update +func (h *RedeemHandler) BatchUpdate(c *gin.Context) { + if h.redeemService == nil { + response.InternalError(c, "redeem service not configured") + return + } + + var req dto.BatchUpdateRedeemCodesRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + result, err := h.redeemService.BatchUpdate(c.Request.Context(), &service.RedeemCodeBatchUpdateInput{ + IDs: req.IDs, + Fields: redeemBatchUpdateFieldsFromDTO(req.Fields), + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{ + "updated": result.Updated, + "message": "Redeem codes updated successfully", + }) +} + +func redeemBatchUpdateFieldsFromDTO(in dto.BatchUpdateRedeemCodeFields) service.RedeemCodeBatchUpdateFields { + out := service.RedeemCodeBatchUpdateFields{ + Status: in.Status, + Notes: in.Notes, + Type: in.Type, + Value: in.Value, + } + if in.ExpiresAt.Set { + out.ExpiresAt = service.NullableTimeUpdate{Set: true, Value: in.ExpiresAt.Value} + } + if in.GroupID.Set { + out.GroupID = service.NullableInt64Update{Set: true, Value: in.GroupID.Value} + } + return out +} + // Expire handles expiring a redeem code // POST /api/v1/admin/redeem-codes/:id/expire func (h *RedeemHandler) Expire(c *gin.Context) { diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 14f5dce0..36a60c23 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -142,6 +142,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { TurnstileEnabled: settings.TurnstileEnabled, TurnstileSiteKey: settings.TurnstileSiteKey, TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured, + APIKeyACLTrustForwardedIP: settings.APIKeyACLTrustForwardedIP, LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled, LinuxDoConnectClientID: settings.LinuxDoConnectClientID, LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured, @@ -264,6 +265,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled, BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL, + SubscriptionExpiryNotifyEnabled: settings.SubscriptionExpiryNotifyEnabled, AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled, AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails), PaymentEnabled: paymentCfg.Enabled, @@ -399,6 +401,9 @@ type UpdateSettingsRequest struct { TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileSecretKey string `json:"turnstile_secret_key"` + // API Key IP 访问控制设置 + APIKeyACLTrustForwardedIP *bool `json:"api_key_acl_trust_forwarded_ip"` + // LinuxDo Connect OAuth 登录 LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"` LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"` @@ -582,12 +587,13 @@ type UpdateSettingsRequest struct { // OpenAI account scheduling OpenAIAdvancedSchedulerEnabled *bool `json:"openai_advanced_scheduler_enabled"` - // Balance low notification - BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"` - BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"` - BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"` - AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"` - AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"` + // 余额不足提醒 + BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"` + BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"` + BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"` + SubscriptionExpiryNotifyEnabled *bool `json:"subscription_expiry_notify_enabled"` + AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"` + AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"` // Payment configuration (integrated into settings, full replace) PaymentEnabled *bool `json:"payment_enabled"` @@ -1432,28 +1438,34 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } settings := &service.SystemSettings{ - RegistrationEnabled: req.RegistrationEnabled, - EmailVerifyEnabled: req.EmailVerifyEnabled, - RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist, - PromoCodeEnabled: req.PromoCodeEnabled, - PasswordResetEnabled: req.PasswordResetEnabled, - FrontendURL: req.FrontendURL, - InvitationCodeEnabled: req.InvitationCodeEnabled, - TotpEnabled: req.TotpEnabled, - LoginAgreementEnabled: req.LoginAgreementEnabled, - LoginAgreementMode: loginAgreementMode, - LoginAgreementUpdatedAt: loginAgreementUpdatedAt, - LoginAgreementDocuments: loginAgreementDocuments, - SMTPHost: req.SMTPHost, - SMTPPort: req.SMTPPort, - SMTPUsername: req.SMTPUsername, - SMTPPassword: req.SMTPPassword, - SMTPFrom: req.SMTPFrom, - SMTPFromName: req.SMTPFromName, - SMTPUseTLS: req.SMTPUseTLS, - TurnstileEnabled: req.TurnstileEnabled, - TurnstileSiteKey: req.TurnstileSiteKey, - TurnstileSecretKey: req.TurnstileSecretKey, + RegistrationEnabled: req.RegistrationEnabled, + EmailVerifyEnabled: req.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist, + PromoCodeEnabled: req.PromoCodeEnabled, + PasswordResetEnabled: req.PasswordResetEnabled, + FrontendURL: req.FrontendURL, + InvitationCodeEnabled: req.InvitationCodeEnabled, + TotpEnabled: req.TotpEnabled, + LoginAgreementEnabled: req.LoginAgreementEnabled, + LoginAgreementMode: loginAgreementMode, + LoginAgreementUpdatedAt: loginAgreementUpdatedAt, + LoginAgreementDocuments: loginAgreementDocuments, + SMTPHost: req.SMTPHost, + SMTPPort: req.SMTPPort, + SMTPUsername: req.SMTPUsername, + SMTPPassword: req.SMTPPassword, + SMTPFrom: req.SMTPFrom, + SMTPFromName: req.SMTPFromName, + SMTPUseTLS: req.SMTPUseTLS, + TurnstileEnabled: req.TurnstileEnabled, + TurnstileSiteKey: req.TurnstileSiteKey, + TurnstileSecretKey: req.TurnstileSecretKey, + APIKeyACLTrustForwardedIP: func() bool { + if req.APIKeyACLTrustForwardedIP != nil { + return *req.APIKeyACLTrustForwardedIP + } + return previousSettings.APIKeyACLTrustForwardedIP + }(), LinuxDoConnectEnabled: req.LinuxDoConnectEnabled, LinuxDoConnectClientID: req.LinuxDoConnectClientID, LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, @@ -1669,6 +1681,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.BalanceLowNotifyRechargeURL }(), + SubscriptionExpiryNotifyEnabled: func() bool { + if req.SubscriptionExpiryNotifyEnabled != nil { + return *req.SubscriptionExpiryNotifyEnabled + } + return previousSettings.SubscriptionExpiryNotifyEnabled + }(), AccountQuotaNotifyEnabled: func() bool { if req.AccountQuotaNotifyEnabled != nil { return *req.AccountQuotaNotifyEnabled @@ -1869,6 +1887,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { TurnstileEnabled: updatedSettings.TurnstileEnabled, TurnstileSiteKey: updatedSettings.TurnstileSiteKey, TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured, + APIKeyACLTrustForwardedIP: updatedSettings.APIKeyACLTrustForwardedIP, LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled, LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID, LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured, @@ -1989,6 +2008,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled, BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold, BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL, + SubscriptionExpiryNotifyEnabled: updatedSettings.SubscriptionExpiryNotifyEnabled, AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled, AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails), PaymentEnabled: updatedPaymentCfg.Enabled, @@ -2145,6 +2165,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if req.TurnstileSecretKey != "" { changed = append(changed, "turnstile_secret_key") } + if before.APIKeyACLTrustForwardedIP != after.APIKeyACLTrustForwardedIP { + changed = append(changed, "api_key_acl_trust_forwarded_ip") + } if before.LinuxDoConnectEnabled != after.LinuxDoConnectEnabled { changed = append(changed, "linuxdo_connect_enabled") } @@ -2454,7 +2477,7 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.OpenAIAdvancedSchedulerEnabled != after.OpenAIAdvancedSchedulerEnabled { changed = append(changed, "openai_advanced_scheduler_enabled") } - // Balance & quota notification + // 余额、订阅到期与账号限额通知 if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled { changed = append(changed, "balance_low_notify_enabled") } @@ -2464,6 +2487,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.BalanceLowNotifyRechargeURL != after.BalanceLowNotifyRechargeURL { changed = append(changed, "balance_low_notify_recharge_url") } + if before.SubscriptionExpiryNotifyEnabled != after.SubscriptionExpiryNotifyEnabled { + changed = append(changed, "subscription_expiry_notify_enabled") + } if before.AccountQuotaNotifyEnabled != after.AccountQuotaNotifyEnabled { changed = append(changed, "account_quota_notify_enabled") } @@ -3472,6 +3498,8 @@ func emailTemplateEventOptionsToDTO(events []service.NotificationEmailEventInfo) Value: event.Event, Label: event.Label, Description: event.Description, + Category: event.Category, + Optional: event.Optional, }) } return items diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index 584e5751..4081b9e4 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -2492,6 +2492,10 @@ func (r *oauthPendingFlowRedeemCodeRepo) Update(ctx context.Context, code *servi return err } +func (r *oauthPendingFlowRedeemCodeRepo) BatchUpdate(context.Context, []int64, service.RedeemCodeBatchUpdateFields) (int64, error) { + panic("unexpected BatchUpdate call") +} + func (r *oauthPendingFlowRedeemCodeRepo) Delete(context.Context, int64) error { panic("unexpected Delete call") } diff --git a/backend/internal/handler/auth_oauth_test_helpers_test.go b/backend/internal/handler/auth_oauth_test_helpers_test.go index 47bad942..5022ffe8 100644 --- a/backend/internal/handler/auth_oauth_test_helpers_test.go +++ b/backend/internal/handler/auth_oauth_test_helpers_test.go @@ -2,6 +2,7 @@ package handler import ( "net/http" + "net/http/httptest" "net/url" "testing" @@ -32,6 +33,13 @@ func findCookie(cookies []*http.Cookie, name string) *http.Cookie { return nil } +func requireCookieCleared(t *testing.T, recorder *httptest.ResponseRecorder, name string) { + t.Helper() + cookie := findCookie(recorder.Result().Cookies(), name) + require.NotNil(t, cookie) + require.Equal(t, -1, cookie.MaxAge) +} + func decodeCookieValueForTest(t *testing.T, value string) string { t.Helper() decoded, err := decodeCookieValue(value) @@ -40,6 +48,13 @@ func decodeCookieValueForTest(t *testing.T, value string) string { } func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) { + t.Helper() + values := parseOAuthRedirectFragment(t, location) + require.Equal(t, errorCode, values.Get("error")) + require.Equal(t, errorMessage, values.Get("error_message")) +} + +func parseOAuthRedirectFragment(t *testing.T, location string) url.Values { t.Helper() require.NotEmpty(t, location) @@ -52,6 +67,5 @@ func assertOAuthRedirectError(t *testing.T, location string, errorCode string, e } values, err := url.ParseQuery(rawValues) require.NoError(t, err) - require.Equal(t, errorCode, values.Get("error")) - require.Equal(t, errorMessage, values.Get("error_message")) + return values } diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index c7c517c8..76eb9498 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -454,6 +454,24 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) { } } + // 快捷路径:当上游返回已验证邮箱、部署不要求额外确认且本地没有同邮箱账号时, + // 直接信任上游身份完成注册/登录,避免展示 choice 页。 + if compatEmailUser == nil && + strings.TrimSpace(compatEmail) != "" && + emailVerified != nil && *emailVerified { + if handled := h.tryOIDCVerifiedEmailFastPath( + c, + frontendCallback, + redirectTo, + identityRef, + compatEmail, + username, + upstreamClaims, + ); handled { + return + } + } + if h.isForceEmailOnThirdPartySignup(c.Request.Context()) { if err := h.createOIDCOAuthChoicePendingSession( c, @@ -1190,3 +1208,70 @@ func oidcClearCookie(c *gin.Context, name string, secure bool) { SameSite: http.SameSiteLaxMode, }) } + +// tryOIDCVerifiedEmailFastPath 在 OIDC 上游已返回已验证邮箱时尝试跳过 choice/pending 页。 +// 返回 true 表示已经写出重定向响应;返回 false 表示调用方应继续回退到常规 choice 流程。 +func (h *AuthHandler) tryOIDCVerifiedEmailFastPath( + c *gin.Context, + frontendCallback string, + redirectTo string, + identity service.PendingAuthIdentityKey, + compatEmail string, + username string, + upstreamClaims map[string]any, +) bool { + if h == nil || h.authService == nil || h.settingSvc == nil { + return false + } + ctx := c.Request.Context() + if h.isForceEmailOnThirdPartySignup(ctx) { + return false + } + if h.settingSvc.IsInvitationCodeEnabled(ctx) { + return false + } + if err := h.ensureBackendModeAllowsNewUserLogin(ctx); err != nil { + log.Printf("[OIDC OAuth] verified-email fast path blocked by backend mode: reason=%s", infraerrors.Reason(err)) + clearOAuthPendingSessionCookie(c, isRequestHTTPS(c)) + clearOAuthPendingBrowserCookie(c, isRequestHTTPS(c)) + redirectOAuthError(c, frontendCallback, "login_blocked", infraerrors.Reason(err), infraerrors.Message(err)) + return true + } + + verifiedEmail := strings.TrimSpace(strings.ToLower(compatEmail)) + upstreamMetadata := make(map[string]any, len(upstreamClaims)+1) + for k, v := range upstreamClaims { + upstreamMetadata[k] = v + } + if syntheticEmail := pendingSessionStringValue(upstreamClaims, "email"); syntheticEmail != "" && !strings.EqualFold(syntheticEmail, verifiedEmail) { + upstreamMetadata["synthetic_email"] = syntheticEmail + } + upstreamMetadata["email"] = verifiedEmail + input := service.EmailOAuthIdentityInput{ + ProviderType: strings.TrimSpace(identity.ProviderType), + ProviderKey: strings.TrimSpace(identity.ProviderKey), + ProviderSubject: strings.TrimSpace(identity.ProviderSubject), + Email: verifiedEmail, + EmailVerified: true, + Username: strings.TrimSpace(username), + DisplayName: pendingSessionStringValue(upstreamClaims, "suggested_display_name"), + AvatarURL: pendingSessionStringValue(upstreamClaims, "suggested_avatar_url"), + UpstreamMetadata: upstreamMetadata, + } + tokenPair, _, err := h.authService.LoginOrRegisterVerifiedEmailOAuthWithInvitation(ctx, input, "", "") + if err != nil { + log.Printf("[OIDC OAuth] verified-email fast path skipped: reason=%s", infraerrors.Reason(err)) + return false + } + + fragment := url.Values{} + fragment.Set("access_token", tokenPair.AccessToken) + fragment.Set("refresh_token", tokenPair.RefreshToken) + fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn)) + fragment.Set("token_type", "Bearer") + fragment.Set("redirect", redirectTo) + clearOAuthPendingSessionCookie(c, isRequestHTTPS(c)) + clearOAuthPendingBrowserCookie(c, isRequestHTTPS(c)) + redirectWithFragment(c, frontendCallback, fragment) + return true +} diff --git a/backend/internal/handler/auth_oidc_oauth_test.go b/backend/internal/handler/auth_oidc_oauth_test.go index 3216d51e..08bb459a 100644 --- a/backend/internal/handler/auth_oidc_oauth_test.go +++ b/backend/internal/handler/auth_oidc_oauth_test.go @@ -10,6 +10,7 @@ import ( "math/big" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -926,6 +927,232 @@ func TestCompleteOIDCOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUser require.Nil(t, storedSession.ConsumedAt) } +func TestTryOIDCVerifiedEmailFastPathCreatesUserAndIdentity(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, false) + t.Cleanup(func() { _ = client.Close() }) + + ctx := context.Background() + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil) + + identity := service.PendingAuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example.com", + ProviderSubject: "fast-path-subject", + } + completed := handler.tryOIDCVerifiedEmailFastPath( + c, + "/auth/oidc/callback", + "/dashboard", + identity, + "fastpath@example.com", + "fastpath_user", + map[string]any{ + "suggested_display_name": "Fast Path", + "suggested_avatar_url": "", + }, + ) + require.True(t, completed) + require.Equal(t, http.StatusFound, recorder.Code) + + location := recorder.Header().Get("Location") + require.Contains(t, location, "/auth/oidc/callback") + require.Contains(t, location, "access_token=") + require.Contains(t, location, "refresh_token=") + require.Contains(t, location, "token_type=Bearer") + + user, err := client.User.Query().Where(dbuser.EmailEQ("fastpath@example.com")).Only(ctx) + require.NoError(t, err) + require.Equal(t, "fastpath_user", user.Username) + require.Equal(t, "oidc", user.SignupSource) + + identityRecord, err := client.AuthIdentity.Query().Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ("https://issuer.example.com"), + authidentity.ProviderSubjectEQ("fast-path-subject"), + authidentity.UserIDEQ(user.ID), + ).Only(ctx) + require.NoError(t, err) + require.Equal(t, "fastpath@example.com", identityRecord.Metadata["email"]) + require.Equal(t, true, identityRecord.Metadata["email_verified"]) + + pendingCount, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, pendingCount) +} + +func TestOIDCOAuthCallbackVerifiedEmailFastPathIssuesTokenWithoutPendingSession(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-fast-callback-subject", + PreferredUsername: "oidc_fast_callback", + DisplayName: "OIDC Fast Callback", + AvatarURL: "https://cdn.example/oidc-fast.png", + Email: "oidc-fast-callback@example.com", + EmailVerified: true, + }) + defer cleanup() + + handler, client := newOIDCOAuthHandlerAndClientWithSettings(t, false, cfg, nil) + t.Cleanup(func() { _ = client.Close() }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-fast-callback", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-fast-callback")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-fast-callback")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-fast-callback-subject")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-fast-callback")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + location := recorder.Header().Get("Location") + require.Contains(t, location, "/auth/oidc/callback#") + require.Contains(t, location, "access_token=") + require.Contains(t, location, "refresh_token=") + require.Contains(t, location, "token_type=Bearer") + fragmentValues := parseOAuthRedirectFragment(t, location) + require.Equal(t, "/dashboard", fragmentValues.Get("redirect")) + requireCookieCleared(t, recorder, oauthPendingSessionCookieName) + requireCookieCleared(t, recorder, oauthPendingBrowserCookieName) + + ctx := context.Background() + user, err := client.User.Query().Where(dbuser.EmailEQ("oidc-fast-callback@example.com")).Only(ctx) + require.NoError(t, err) + require.Equal(t, "oidc_fast_callback", user.Username) + require.Equal(t, "oidc", user.SignupSource) + + identity, err := client.AuthIdentity.Query().Where( + authidentity.ProviderTypeEQ("oidc"), + authidentity.ProviderKeyEQ(cfg.IssuerURL), + authidentity.ProviderSubjectEQ("oidc-fast-callback-subject"), + authidentity.UserIDEQ(user.ID), + ).Only(ctx) + require.NoError(t, err) + require.Equal(t, "oidc-fast-callback@example.com", identity.Metadata["email"]) + require.Equal(t, true, identity.Metadata["email_verified"]) + require.Equal(t, "OIDC Fast Callback", identity.Metadata["suggested_display_name"]) + require.NotEqual(t, identity.Metadata["email"], identity.Metadata["synthetic_email"]) + + pendingCount, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, pendingCount) +} + +func TestOIDCOAuthCallbackVerifiedEmailFastPathBackendModeBlocksBeforeUserCreation(t *testing.T) { + cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{ + Subject: "oidc-fast-backend-mode-subject", + PreferredUsername: "oidc_backend_mode", + DisplayName: "OIDC Backend Mode", + Email: "oidc-backend-mode@example.com", + EmailVerified: true, + }) + defer cleanup() + + handler, client := newOIDCOAuthHandlerAndClientWithSettings(t, false, cfg, map[string]string{ + service.SettingKeyBackendModeEnabled: "true", + }) + t.Cleanup(func() { _ = client.Close() }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-backend-mode", nil) + req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-backend-mode")) + req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard")) + req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-backend-mode")) + req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-fast-backend-mode-subject")) + req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin)) + req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-backend-mode")) + c.Request = req + + handler.OIDCOAuthCallback(c) + + require.Equal(t, http.StatusFound, recorder.Code) + assertOAuthRedirectError(t, recorder.Header().Get("Location"), "login_blocked", "BACKEND_MODE_ADMIN_ONLY") + requireCookieCleared(t, recorder, oauthPendingSessionCookieName) + requireCookieCleared(t, recorder, oauthPendingBrowserCookieName) + + ctx := context.Background() + userCount, err := client.User.Query().Where(dbuser.EmailEQ("oidc-backend-mode@example.com")).Count(ctx) + require.NoError(t, err) + require.Zero(t, userCount) + identityCount, err := client.AuthIdentity.Query(). + Where(authidentity.ProviderSubjectEQ("oidc-fast-backend-mode-subject")). + Count(ctx) + require.NoError(t, err) + require.Zero(t, identityCount) + pendingCount, err := client.PendingAuthSession.Query().Count(ctx) + require.NoError(t, err) + require.Zero(t, pendingCount) +} + +func TestTryOIDCVerifiedEmailFastPathSkippedWhenInvitationCodeRequired(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandler(t, true) + t.Cleanup(func() { _ = client.Close() }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil) + + identity := service.PendingAuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example.com", + ProviderSubject: "fast-path-skipped-invitation", + } + completed := handler.tryOIDCVerifiedEmailFastPath( + c, + "/auth/oidc/callback", + "/dashboard", + identity, + "invite-only@example.com", + "invite_only_user", + map[string]any{}, + ) + require.False(t, completed) + require.NotEqual(t, http.StatusFound, recorder.Code) + + userCount, err := client.User.Query().Where(dbuser.EmailEQ("invite-only@example.com")).Count(context.Background()) + require.NoError(t, err) + require.Zero(t, userCount) +} + +func TestTryOIDCVerifiedEmailFastPathSkippedWhenForceEmailEnabled(t *testing.T) { + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + settingValues: map[string]string{ + service.SettingKeyForceEmailOnThirdPartySignup: "true", + }, + }) + t.Cleanup(func() { _ = client.Close() }) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil) + + identity := service.PendingAuthIdentityKey{ + ProviderType: "oidc", + ProviderKey: "https://issuer.example.com", + ProviderSubject: "fast-path-skipped-force-email", + } + completed := handler.tryOIDCVerifiedEmailFastPath( + c, + "/auth/oidc/callback", + "/dashboard", + identity, + "force-email@example.com", + "force_email_user", + map[string]any{}, + ) + require.False(t, completed) + + userCount, err := client.User.Query().Where(dbuser.EmailEQ("force-email@example.com")).Count(context.Background()) + require.NoError(t, err) + require.Zero(t, userCount) +} + type oidcProviderFixture struct { Subject string PreferredUsername string @@ -957,6 +1184,49 @@ func newOIDCOAuthHandlerAndClient(t *testing.T, invitationEnabled bool, oauthCfg return handler, client } +func newOIDCOAuthHandlerAndClientWithSettings( + t *testing.T, + invitationEnabled bool, + oauthCfg config.OIDCConnectConfig, + settingValues map[string]string, +) (*AuthHandler, *dbent.Client) { + t.Helper() + + values := map[string]string{ + service.SettingKeyOIDCConnectEnabled: "true", + service.SettingKeyOIDCConnectProviderName: strings.TrimSpace(firstNonEmpty(oauthCfg.ProviderName, "OIDC")), + service.SettingKeyOIDCConnectClientID: oauthCfg.ClientID, + service.SettingKeyOIDCConnectClientSecret: oauthCfg.ClientSecret, + service.SettingKeyOIDCConnectIssuerURL: oauthCfg.IssuerURL, + service.SettingKeyOIDCConnectAuthorizeURL: oauthCfg.AuthorizeURL, + service.SettingKeyOIDCConnectTokenURL: oauthCfg.TokenURL, + service.SettingKeyOIDCConnectUserInfoURL: oauthCfg.UserInfoURL, + service.SettingKeyOIDCConnectJWKSURL: oauthCfg.JWKSURL, + service.SettingKeyOIDCConnectScopes: oauthCfg.Scopes, + service.SettingKeyOIDCConnectRedirectURL: oauthCfg.RedirectURL, + service.SettingKeyOIDCConnectFrontendRedirectURL: oauthCfg.FrontendRedirectURL, + service.SettingKeyOIDCConnectTokenAuthMethod: oauthCfg.TokenAuthMethod, + service.SettingKeyOIDCConnectUsePKCE: boolSettingValue(oauthCfg.UsePKCE), + service.SettingKeyOIDCConnectValidateIDToken: boolSettingValue(oauthCfg.ValidateIDToken), + service.SettingKeyOIDCConnectAllowedSigningAlgs: oauthCfg.AllowedSigningAlgs, + service.SettingKeyOIDCConnectClockSkewSeconds: "120", + service.SettingKeyOIDCConnectRequireEmailVerified: boolSettingValue(oauthCfg.RequireEmailVerified), + } + for key, value := range settingValues { + values[key] = value + } + + handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{ + invitationEnabled: invitationEnabled, + settingValues: values, + }) + if handler.cfg == nil { + handler.cfg = &config.Config{} + } + handler.cfg.OIDC = oauthCfg + return handler, client +} + func newOIDCTestProvider(t *testing.T, fixture oidcProviderFixture) (config.OIDCConnectConfig, func()) { t.Helper() diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index bdad5572..fac60573 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -50,6 +50,7 @@ type SystemSettings struct { TurnstileEnabled bool `json:"turnstile_enabled"` TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"` + APIKeyACLTrustForwardedIP bool `json:"api_key_acl_trust_forwarded_ip"` LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"` LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"` @@ -222,12 +223,13 @@ type SystemSettings struct { // Force Alipay mobile clients to use QR code payment instead of mobile redirect PaymentAlipayForceQRCode bool `json:"payment_alipay_force_qrcode"` - // Balance low notification - BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` - BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` - BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"` - AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` - AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"` + // 余额、订阅到期与账号限额通知 + BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"` + BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"` + BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"` + SubscriptionExpiryNotifyEnabled bool `json:"subscription_expiry_notify_enabled"` + AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"` + AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"` // Channel Monitor feature switch ChannelMonitorEnabled bool `json:"channel_monitor_enabled"` @@ -378,11 +380,13 @@ type OpenAIFastPolicySettings struct { Rules []OpenAIFastPolicyRule `json:"rules"` } -// EmailTemplateEventOption describes an editable notification email event. +// EmailTemplateEventOption 描述可编辑的通知邮件事件。 type EmailTemplateEventOption struct { Value string `json:"value"` Label string `json:"label,omitempty"` Description string `json:"description,omitempty"` + Category string `json:"category,omitempty"` + Optional bool `json:"optional,omitempty"` } // EmailTemplateSummary is shown in the admin email template list. diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index cc360f78..31828375 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -1,6 +1,8 @@ package dto import ( + "bytes" + "encoding/json" "time" "github.com/Wei-Shaw/sub2api/internal/domain" @@ -359,6 +361,59 @@ type AdminRedeemCode struct { Notes string `json:"notes"` } +type NullableTimeField struct { + Set bool + Value *time.Time +} + +func (f *NullableTimeField) UnmarshalJSON(data []byte) error { + f.Set = true + if bytes.Equal(data, []byte("null")) { + f.Value = nil + return nil + } + var value time.Time + if err := json.Unmarshal(data, &value); err != nil { + return err + } + f.Value = &value + return nil +} + +type NullableInt64Field struct { + Set bool + Value *int64 +} + +func (f *NullableInt64Field) UnmarshalJSON(data []byte) error { + f.Set = true + if bytes.Equal(data, []byte("null")) { + f.Value = nil + return nil + } + var value int64 + if err := json.Unmarshal(data, &value); err != nil { + return err + } + f.Value = &value + return nil +} + +type BatchUpdateRedeemCodeFields struct { + Status *string `json:"status,omitempty"` + ExpiresAt NullableTimeField `json:"expires_at,omitempty"` + Notes *string `json:"notes,omitempty"` + GroupID NullableInt64Field `json:"group_id,omitempty"` + + Type *string `json:"type,omitempty"` + Value *float64 `json:"value,omitempty"` +} + +type BatchUpdateRedeemCodesRequest struct { + IDs []int64 `json:"ids" binding:"required,min=1"` + Fields BatchUpdateRedeemCodeFields `json:"fields" binding:"required"` +} + // UsageLog 是普通用户接口使用的 usage log DTO(不包含管理员字段)。 type UsageLog struct { ID int64 `json:"id"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 0902f6e5..38dad596 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -785,8 +785,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if channelMapping.Mapped { parsedReq.Model = channelMapping.MappedModel parsedReq.Body = h.gatewayService.ReplaceModelInBody(parsedReq.Body, channelMapping.MappedModel) - body = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) } + // Bedrock CC 兼容:渠道模型映射后,清理 Anthropic API 专有字段、注入 Bedrock 必需字段 + parsedReq.Body = h.gatewayService.ApplyBedrockCCCompat(c.Request.Context(), parsedReq.Body, parsedReq.Model, account, apiKey.GroupID) + body = parsedReq.Body // 转发请求 - 根据账号平台分流 c.Set("parsed_request", parsedReq) diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index f7269214..4d523dba 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -179,12 +179,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) } writerSizeBeforeForward := c.Writer.Size() - result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "") + result, err := func() (*service.OpenAIForwardResult, error) { + defer func() { + if accountReleaseFunc != nil { + accountReleaseFunc() + } + }() + return h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "") + }() forwardDurationMs := time.Since(forwardStart).Milliseconds() - if accountReleaseFunc != nil { - accountReleaseFunc() - } upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) responseLatencyMs := forwardDurationMs if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { @@ -236,6 +240,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { return } switchCount++ + if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } reqLog.Warn("openai_chat_completions.upstream_failover_switching", zap.Int64("account_id", account.ID), zap.Int("upstream_status", failoverErr.StatusCode), diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index e7ba699d..9c5560f5 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -333,11 +333,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) } writerSizeBeforeForward := c.Writer.Size() - result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody) + result, err := func() (*service.OpenAIForwardResult, error) { + defer func() { + if accountReleaseFunc != nil { + accountReleaseFunc() + } + }() + return h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody) + }() forwardDurationMs := time.Since(forwardStart).Milliseconds() - if accountReleaseFunc != nil { - accountReleaseFunc() - } upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) responseLatencyMs := forwardDurationMs if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { @@ -389,6 +393,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } switchCount++ + if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } reqLog.Warn("openai.upstream_failover_switching", zap.Int64("account_id", account.ID), zap.Int("upstream_status", failoverErr.StatusCode), @@ -722,12 +730,16 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { if channelMappingMsg.Mapped { forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMappingMsg.MappedModel) } - result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel) + result, err := func() (*service.OpenAIForwardResult, error) { + defer func() { + if accountReleaseFunc != nil { + accountReleaseFunc() + } + }() + return h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel) + }() forwardDurationMs := time.Since(forwardStart).Milliseconds() - if accountReleaseFunc != nil { - accountReleaseFunc() - } upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) responseLatencyMs := forwardDurationMs if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { @@ -775,6 +787,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { return } switchCount++ + if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) { + h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted) + return + } reqLog.Warn("openai_messages.upstream_failover_switching", zap.Int64("account_id", account.ID), zap.Int("upstream_status", failoverErr.StatusCode), diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go index 1a81a59e..e6c37272 100644 --- a/backend/internal/handler/openai_images.go +++ b/backend/internal/handler/openai_images.go @@ -195,11 +195,15 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() - result, err := h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel) + result, err := func() (*service.OpenAIForwardResult, error) { + defer func() { + if accountReleaseFunc != nil { + accountReleaseFunc() + } + }() + return h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel) + }() forwardDurationMs := time.Since(forwardStart).Milliseconds() - if accountReleaseFunc != nil { - accountReleaseFunc() - } upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) responseLatencyMs := forwardDurationMs if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { @@ -217,6 +221,18 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { zap.Error(err), ) } else { + var imageUpstreamErr *service.OpenAIImagesUpstreamError + if errors.As(err, &imageUpstreamErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) + reqLog.Warn("openai.images.upstream_user_error", + zap.Int64("account_id", account.ID), + zap.Int("status_code", imageUpstreamErr.StatusCode), + zap.String("error_type", imageUpstreamErr.ErrorType), + zap.String("error_code", imageUpstreamErr.Code), + zap.Error(err), + ) + return + } var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) @@ -246,6 +262,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { return } switchCount++ + if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) { + h.handleFailoverExhausted(c, failoverErr, streamStarted) + return + } reqLog.Warn("openai.images.upstream_failover_switching", zap.Int64("account_id", account.ID), zap.Int("upstream_status", failoverErr.StatusCode), diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index 166f67d7..fcf88f6c 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -41,13 +41,18 @@ const ( opsErrInsufficientQuota = "insufficient_quota" // 上游错误码常量 — 错误分类 (normalizeOpsErrorType / classifyOpsPhase / classifyOpsIsBusinessLimited) - opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE" - opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED" - opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND" - opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID" - opsCodeUserInactive = "USER_INACTIVE" - opsCodeInvalidAPIKey = "INVALID_API_KEY" - opsCodeAPIKeyRequired = "API_KEY_REQUIRED" + opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE" + opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED" + opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND" + opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID" + opsCodeUserInactive = "USER_INACTIVE" + opsCodeInvalidAPIKey = "INVALID_API_KEY" + opsCodeAPIKeyRequired = "API_KEY_REQUIRED" + opsCodeAPIKeyExpired = "API_KEY_EXPIRED" + opsCodeAPIKeyDisabled = "API_KEY_DISABLED" + opsCodeUserNotFound = "USER_NOT_FOUND" + opsCodeAPIKeyQuotaExhausted = "API_KEY_QUOTA_EXHAUSTED" + opsCodeAPIKeyQueryDeprecated = "api_key_in_query_deprecated" ) const ( @@ -1091,8 +1096,7 @@ func classifyOpsPhase(errType, message, code string) string { if isOpsClientAuthError(code, msg) { return "auth" } - switch strings.TrimSpace(code) { - case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid: + if isOpsLocalBusinessLimitError(code, msg) { return "request" } @@ -1151,8 +1155,10 @@ func classifyOpsErrorLog(c *gin.Context, errType, message, code string, status i if routingCapacityLimited { phase = "routing" } - localClientAuthError := !upstreamError && phase == "auth" && isOpsClientAuthError(code, strings.ToLower(message)) - isBusinessLimited = routingCapacityLimited || clientBusinessLimited || classifyOpsIsBusinessLimited(errType, phase, code, status, message, localClientAuthError) + msg := strings.ToLower(message) + localClientAuthError := !upstreamError && phase == "auth" && isOpsClientAuthError(code, msg) + localBusinessLimited := !upstreamError && classifyOpsIsBusinessLimited(errType, phase, code, status, message, localClientAuthError) + isBusinessLimited = routingCapacityLimited || (clientBusinessLimited && !upstreamError) || localBusinessLimited errorOwner = classifyOpsErrorOwner(phase, message) errorSource = classifyOpsErrorSource(phase, message) return phase, isBusinessLimited, errorOwner, errorSource @@ -1162,8 +1168,7 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa if len(localClientAuthError) > 0 && localClientAuthError[0] { return true } - switch strings.TrimSpace(code) { - case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive: + if isOpsLocalBusinessLimitError(code, strings.ToLower(message)) { return true } if phase == "billing" || phase == "concurrency" { @@ -1180,10 +1185,45 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa func isOpsClientAuthError(code string, msg string) bool { switch strings.TrimSpace(code) { - case opsCodeInvalidAPIKey, opsCodeAPIKeyRequired: + case opsCodeInvalidAPIKey, + opsCodeAPIKeyRequired, + opsCodeAPIKeyExpired, + opsCodeAPIKeyDisabled, + opsCodeUserNotFound, + opsCodeUserInactive: return true } - return strings.Contains(msg, "invalid api key") || strings.Contains(msg, "api key is required") + return strings.Contains(msg, "invalid api key") || + strings.Contains(msg, "api key is required") || + strings.Contains(msg, "api key is disabled") || + strings.Contains(msg, "user associated with api key not found") || + strings.Contains(msg, "user account is not active") +} + +func isOpsLocalBusinessLimitError(code string, msg string) bool { + switch strings.TrimSpace(code) { + case opsCodeInsufficientBalance, + opsCodeUsageLimitExceeded, + opsCodeSubscriptionNotFound, + opsCodeSubscriptionInvalid, + opsCodeAPIKeyQuotaExhausted, + opsCodeAPIKeyQueryDeprecated: + return true + } + return strings.Contains(msg, "api key in query parameter is deprecated") || + strings.Contains(msg, "query parameter api_key is deprecated") || + strings.Contains(msg, "no active subscription found for this group") || + strings.Contains(msg, opsErrInsufficientBalance) || + strings.Contains(msg, "insufficient account balance") || + strings.Contains(msg, "api key group platform is not gemini") || + strings.Contains(msg, "api key 额度已用完") || + strings.Contains(msg, "api key 5小时限额已用完") || + strings.Contains(msg, "api key 日限额已用完") || + strings.Contains(msg, "api key 7天限额已用完") || + strings.Contains(msg, "daily usage limit exceeded") || + strings.Contains(msg, "weekly usage limit exceeded") || + strings.Contains(msg, "monthly usage limit exceeded") || + strings.Contains(msg, "requests-per-minute limit exceeded") } func hasOpsUpstreamErrorContext(c *gin.Context) bool { diff --git a/backend/internal/handler/ops_error_logger_test.go b/backend/internal/handler/ops_error_logger_test.go index 99a9af2f..81f12b0c 100644 --- a/backend/internal/handler/ops_error_logger_test.go +++ b/backend/internal/handler/ops_error_logger_test.go @@ -260,6 +260,34 @@ func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) { code: "API_KEY_REQUIRED", status: http.StatusUnauthorized, }, + { + name: "expired local API key", + errType: "api_error", + message: "API key 已过期", + code: "API_KEY_EXPIRED", + status: http.StatusForbidden, + }, + { + name: "disabled local API key", + errType: "api_error", + message: "API key is disabled", + code: "API_KEY_DISABLED", + status: http.StatusUnauthorized, + }, + { + name: "local API key user missing", + errType: "api_error", + message: "User associated with API key not found", + code: "USER_NOT_FOUND", + status: http.StatusUnauthorized, + }, + { + name: "inactive local API key user", + errType: "api_error", + message: "User account is not active", + code: "USER_INACTIVE", + status: http.StatusUnauthorized, + }, { name: "google invalid API key", errType: "api_error", @@ -274,6 +302,27 @@ func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) { code: "401", status: http.StatusUnauthorized, }, + { + name: "google disabled API key", + errType: "api_error", + message: "API key is disabled", + code: "401", + status: http.StatusUnauthorized, + }, + { + name: "google local API key user missing", + errType: "api_error", + message: "User associated with API key not found", + code: "401", + status: http.StatusUnauthorized, + }, + { + name: "google inactive local API key user", + errType: "api_error", + message: "User account is not active", + code: "401", + status: http.StatusUnauthorized, + }, } for _, tt := range tests { @@ -294,6 +343,126 @@ func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) { } } +func TestClassifyOpsLocalBusinessLimitErrorsExcludedFromSLA(t *testing.T) { + tests := []struct { + name string + errType string + message string + code string + status int + wantErrType string + wantPhase string + }{ + { + name: "standard API key quota exhausted", + errType: "api_error", + message: "API key 额度已用完", + code: "API_KEY_QUOTA_EXHAUSTED", + status: http.StatusTooManyRequests, + wantErrType: "api_error", + wantPhase: "request", + }, + { + name: "standard query API key deprecated", + errType: "api_error", + message: "API key in query parameter is deprecated. Please use Authorization header instead.", + code: "api_key_in_query_deprecated", + status: http.StatusBadRequest, + wantErrType: "api_error", + wantPhase: "request", + }, + { + name: "google query API key deprecated", + errType: "api_error", + message: "Query parameter api_key is deprecated. Use Authorization header or key instead.", + code: "400", + status: http.StatusBadRequest, + wantErrType: "api_error", + wantPhase: "request", + }, + { + name: "google no active subscription", + errType: "api_error", + message: "No active subscription found for this group", + code: "403", + status: http.StatusForbidden, + wantErrType: "api_error", + wantPhase: "request", + }, + { + name: "google insufficient account balance", + errType: "api_error", + message: "Insufficient account balance", + code: "403", + status: http.StatusForbidden, + wantErrType: "api_error", + wantPhase: "request", + }, + { + name: "gateway billing cache insufficient balance", + errType: "billing_error", + message: "insufficient balance", + code: "", + status: http.StatusForbidden, + wantErrType: "billing_error", + wantPhase: "request", + }, + { + name: "gemini group platform mismatch", + errType: "api_error", + message: "API key group platform is not gemini", + code: "400", + status: http.StatusBadRequest, + wantErrType: "api_error", + wantPhase: "request", + }, + { + name: "gateway API key 5h rate limit", + errType: "api_error", + message: "api key 5小时限额已用完", + code: "rate_limit_exceeded", + status: http.StatusTooManyRequests, + wantErrType: "api_error", + wantPhase: "request", + }, + { + name: "gateway group RPM limit", + errType: "api_error", + message: "group requests-per-minute limit exceeded", + code: "rate_limit_exceeded", + status: http.StatusTooManyRequests, + wantErrType: "api_error", + wantPhase: "request", + }, + { + name: "google subscription daily limit", + errType: "api_error", + message: "daily usage limit exceeded", + code: "429", + status: http.StatusTooManyRequests, + wantErrType: "api_error", + wantPhase: "request", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + errType := normalizeOpsErrorType(tt.errType, tt.code) + phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, tt.message, tt.code, tt.status) + + require.Equal(t, tt.wantErrType, errType) + require.Equal(t, tt.wantPhase, phase) + require.True(t, isBusinessLimited) + require.Equal(t, "client", errorOwner) + require.Equal(t, "client_request", errorSource) + }) + } +} + func TestClassifyOpsIPRestrictionAccessDeniedExcludedFromSLA(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() @@ -372,23 +541,71 @@ func TestClassifyOpsUnmarkedNoAvailableTextStillCountsForSLA(t *testing.T) { } func TestClassifyOpsUpstreamAuthTextStillCountsForSLA(t *testing.T) { - gin.SetMode(gin.TestMode) - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - service.SetOpsUpstreamError(c, http.StatusUnauthorized, "Invalid API key", "") + tests := []struct { + name string + message string + code string + status int + }{ + { + name: "invalid API key", + message: "Invalid API key", + code: "401", + status: http.StatusUnauthorized, + }, + { + name: "disabled API key", + message: "API key is disabled", + code: "API_KEY_DISABLED", + status: http.StatusUnauthorized, + }, + { + name: "gemini group platform mismatch", + message: "API key group platform is not gemini", + code: "400", + status: http.StatusBadRequest, + }, + { + name: "provider balance error", + message: "Insufficient account balance", + code: "INSUFFICIENT_BALANCE", + status: http.StatusForbidden, + }, + { + name: "provider subscription error", + message: "No active subscription found for this group", + code: "SUBSCRIPTION_NOT_FOUND", + status: http.StatusForbidden, + }, + { + name: "provider quota error", + message: "api key 额度已用完", + code: "API_KEY_QUOTA_EXHAUSTED", + status: http.StatusTooManyRequests, + }, + } - phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog( - c, - "api_error", - "Invalid API key", - "401", - http.StatusUnauthorized, - ) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + service.SetOpsUpstreamError(c, tt.status, tt.message, "") - require.Equal(t, "upstream", phase) - require.False(t, isBusinessLimited) - require.Equal(t, "provider", errorOwner) - require.Equal(t, "upstream_http", errorSource) + phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog( + c, + "api_error", + tt.message, + tt.code, + tt.status, + ) + + require.Equal(t, "upstream", phase) + require.False(t, isBusinessLimited) + require.Equal(t, "provider", errorOwner) + require.Equal(t, "upstream_http", errorSource) + }) + } } func TestClassifyOpsUpstreamNoAvailableTextStillCountsForSLA(t *testing.T) { diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_bridge.go b/backend/internal/pkg/apicompat/chatcompletions_responses_bridge.go index 8fb82ef4..09b680c7 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_responses_bridge.go +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_bridge.go @@ -89,7 +89,7 @@ func responsesInputToChatMessages(instructions string, inputRaw json.RawMessage) return nil, fmt.Errorf("parse responses input item: %w", err) } - role := rawString(item["role"]) + role := chatCompletionsBridgeRole(rawString(item["role"])) itemType := rawString(item["type"]) switch itemType { case "function_call": @@ -130,9 +130,6 @@ func responsesInputToChatMessages(instructions string, inputRaw json.RawMessage) continue } - if role == "" { - role = "user" - } content := item["content"] if len(bytesTrimSpace(content)) == 0 { if text := rawString(item["text"]); text != "" { @@ -152,6 +149,17 @@ func responsesInputToChatMessages(instructions string, inputRaw json.RawMessage) return messages, nil } +func chatCompletionsBridgeRole(role string) string { + trimmed := strings.TrimSpace(role) + if trimmed == "" { + return "user" + } + if strings.EqualFold(trimmed, "developer") { + return "system" + } + return role +} + func responsesContentToChatContent(raw json.RawMessage, role string) (json.RawMessage, error) { raw = bytesTrimSpace(raw) if len(raw) == 0 || string(raw) == "null" { diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_bridge_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_bridge_test.go new file mode 100644 index 00000000..3e55e23a --- /dev/null +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_bridge_test.go @@ -0,0 +1,82 @@ +package apicompat + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestResponsesInputToChatMessages_DeveloperRoleMapsToSystem(t *testing.T) { + messages, err := responsesInputToChatMessages("", json.RawMessage(`[{"role":"developer","content":"follow project instructions"}]`)) + require.NoError(t, err) + require.Len(t, messages, 1) + + assert.Equal(t, "system", messages[0].Role) + assert.JSONEq(t, `"follow project instructions"`, string(messages[0].Content)) +} + +func TestResponsesInputToChatMessages_KeepsChatCompletionRoles(t *testing.T) { + input := json.RawMessage(`[ + {"role":"system","content":"system message"}, + {"role":"user","content":"user message"}, + {"role":"assistant","content":"assistant message"}, + {"role":"tool","content":"tool message"} + ]`) + + messages, err := responsesInputToChatMessages("", input) + require.NoError(t, err) + require.Len(t, messages, 4) + + assert.Equal(t, []string{"system", "user", "assistant", "tool"}, chatMessageRoles(messages)) +} + +func TestResponsesInputToChatMessages_EmptyRoleFallsBackToUser(t *testing.T) { + messages, err := responsesInputToChatMessages("", json.RawMessage(`[{"role":"","content":"hello"}]`)) + require.NoError(t, err) + require.Len(t, messages, 1) + + assert.Equal(t, "user", messages[0].Role) +} + +func TestResponsesInputToChatMessages_DeveloperRoleTrimAndCaseInsensitive(t *testing.T) { + input := json.RawMessage(`[ + {"role":" Developer ","content":"one"}, + {"role":"\tDEVELOPER\n","content":"two"} + ]`) + + messages, err := responsesInputToChatMessages("", input) + require.NoError(t, err) + require.Len(t, messages, 2) + + assert.Equal(t, []string{"system", "system"}, chatMessageRoles(messages)) +} + +func TestResponsesToChatCompletionsRequest_InstructionsAndInputDeveloperRole(t *testing.T) { + req := &ResponsesRequest{ + Model: "gpt-4o", + Instructions: "Use concise answers.", + Input: json.RawMessage(`[ + {"role":"developer","content":[{"type":"input_text","text":"Prefer JSON."}]}, + {"role":"user","content":"Hello"} + ]`), + } + + out, err := ResponsesToChatCompletionsRequest(req) + require.NoError(t, err) + require.Len(t, out.Messages, 3) + + assert.Equal(t, []string{"system", "system", "user"}, chatMessageRoles(out.Messages)) + assert.JSONEq(t, `"Use concise answers."`, string(out.Messages[0].Content)) + assert.JSONEq(t, `"Prefer JSON."`, string(out.Messages[1].Content)) + assert.JSONEq(t, `"Hello"`, string(out.Messages[2].Content)) +} + +func chatMessageRoles(messages []ChatMessage) []string { + roles := make([]string, 0, len(messages)) + for _, message := range messages { + roles = append(roles, message.Role) + } + return roles +} diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 38e4a6a4..e62c4e52 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -325,6 +325,10 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account if account == nil { return nil } + schedulable := account.Schedulable + if account.Status == service.StatusError { + schedulable = false + } builder := r.client.Account.UpdateOneID(account.ID). SetName(account.Name). @@ -337,7 +341,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account SetPriority(account.Priority). SetStatus(account.Status). SetErrorMessage(account.ErrorMessage). - SetSchedulable(account.Schedulable). + SetSchedulable(schedulable). SetAutoPauseOnExpired(account.AutoPauseOnExpired) if account.RateMultiplier != nil { @@ -458,6 +462,7 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error { return err } } + r.deleteSchedulerAccountSnapshot(ctx, id) if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil { logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err) } @@ -724,6 +729,7 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str Where(dbaccount.IDEQ(id)). SetStatus(service.StatusError). SetErrorMessage(errorMsg). + SetSchedulable(false). Save(ctx) if err != nil { return err @@ -757,6 +763,15 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac } } +func (r *accountRepository) deleteSchedulerAccountSnapshot(ctx context.Context, accountID int64) { + if r == nil || r.schedulerCache == nil || accountID <= 0 { + return + } + if err := r.schedulerCache.DeleteAccount(ctx, accountID); err != nil { + logger.LegacyPrintf("repository.account", "[Scheduler] delete account snapshot failed: id=%d err=%v", accountID, err) + } +} + func (r *accountRepository) syncSchedulerAccountSnapshots(ctx context.Context, accountIDs []int64) { if r == nil || r.schedulerCache == nil || len(accountIDs) == 0 { return diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index d1cea9eb..9e15047c 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -23,6 +23,7 @@ type AccountRepoSuite struct { type schedulerCacheRecorder struct { setAccounts []*service.Account + deleteIDs []int64 accounts map[int64]*service.Account } @@ -53,6 +54,10 @@ func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *servic } func (s *schedulerCacheRecorder) DeleteAccount(ctx context.Context, accountID int64) error { + s.deleteIDs = append(s.deleteIDs, accountID) + if s.accounts != nil { + delete(s.accounts, accountID) + } return nil } @@ -185,6 +190,27 @@ func (s *AccountRepoSuite) TestDelete() { s.Require().Error(err, "expected error after delete") } +func (s *AccountRepoSuite) TestDelete_RemovesSchedulerAccountSnapshot() { + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete-cache"}) + cacheRecorder := &schedulerCacheRecorder{ + accounts: map[int64]*service.Account{ + account.ID: { + ID: account.ID, + Name: account.Name, + Status: service.StatusActive, + Schedulable: true, + }, + }, + } + s.repo.schedulerCache = cacheRecorder + + err := s.repo.Delete(s.ctx, account.ID) + s.Require().NoError(err, "Delete") + + s.Require().Equal([]int64{account.ID}, cacheRecorder.deleteIDs) + s.Require().NotContains(cacheRecorder.accounts, account.ID) +} + func (s *AccountRepoSuite) TestDelete_WithGroupBindings() { group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-del"}) account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-del"}) @@ -729,7 +755,7 @@ func (s *AccountRepoSuite) TestUpdateLastUsed() { // --- SetError --- func (s *AccountRepoSuite) TestSetError() { - account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-err", Status: service.StatusActive}) + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-err", Status: service.StatusActive, Schedulable: true}) s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong")) @@ -737,6 +763,22 @@ func (s *AccountRepoSuite) TestSetError() { s.Require().NoError(err) s.Require().Equal(service.StatusError, got.Status) s.Require().Equal("something went wrong", got.ErrorMessage) + s.Require().False(got.Schedulable) +} + +func (s *AccountRepoSuite) TestUpdateErrorStatusUnschedulesAccount() { + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-update-err", Status: service.StatusActive, Schedulable: true}) + account.Status = service.StatusError + account.ErrorMessage = "token revoked" + account.Schedulable = true + + s.Require().NoError(s.repo.Update(s.ctx, account)) + + got, err := s.repo.GetByID(s.ctx, account.ID) + s.Require().NoError(err) + s.Require().Equal(service.StatusError, got.Status) + s.Require().Equal("token revoked", got.ErrorMessage) + s.Require().False(got.Schedulable) } func (s *AccountRepoSuite) TestClearError_SyncSchedulerSnapshotOnRecovery() { diff --git a/backend/internal/repository/redeem_code_repo.go b/backend/internal/repository/redeem_code_repo.go index 47c38d3e..2bdb34b4 100644 --- a/backend/internal/repository/redeem_code_repo.go +++ b/backend/internal/repository/redeem_code_repo.go @@ -236,6 +236,91 @@ func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemC return nil } +func (r *redeemCodeRepository) BatchUpdate(ctx context.Context, ids []int64, fields service.RedeemCodeBatchUpdateFields) (int64, error) { + uniqueIDs := make([]int64, 0, len(ids)) + seen := make(map[int64]struct{}, len(ids)) + for _, id := range ids { + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + uniqueIDs = append(uniqueIDs, id) + } + if len(uniqueIDs) == 0 { + return 0, nil + } + + if tx := dbent.TxFromContext(ctx); tx != nil { + return r.batchUpdate(ctx, tx.Client(), uniqueIDs, fields) + } + + tx, err := r.client.Tx(ctx) + if err != nil { + return 0, err + } + txCtx := dbent.NewTxContext(ctx, tx) + defer func() { _ = tx.Rollback() }() + + updated, err := r.batchUpdate(txCtx, tx.Client(), uniqueIDs, fields) + if err != nil { + return 0, err + } + if err := tx.Commit(); err != nil { + return 0, err + } + return updated, nil +} + +func (r *redeemCodeRepository) batchUpdate(ctx context.Context, client *dbent.Client, ids []int64, fields service.RedeemCodeBatchUpdateFields) (int64, error) { + existing, err := client.RedeemCode.Query(). + Where(redeemcode.IDIn(ids...)). + All(ctx) + if err != nil { + return 0, err + } + if len(existing) != len(ids) { + return 0, service.ErrRedeemCodeNotFound + } + if fields.TouchesUsedSensitiveFields() { + for _, code := range existing { + if code.Status == service.StatusUsed { + return 0, service.ErrRedeemCodeUsed + } + } + } + + up := client.RedeemCode.Update().Where(redeemcode.IDIn(ids...)) + if fields.Status != nil { + up.SetStatus(*fields.Status) + } + if fields.Notes != nil { + up.SetNotes(*fields.Notes) + } + if fields.ExpiresAt.Set { + if fields.ExpiresAt.Value != nil { + up.SetExpiresAt(*fields.ExpiresAt.Value) + } else { + up.ClearExpiresAt() + } + } + if fields.GroupID.Set { + if fields.GroupID.Value != nil { + up.SetGroupID(*fields.GroupID.Value) + } else { + up.ClearGroupID() + } + } + + affected, err := up.Save(ctx) + if err != nil { + return 0, err + } + if affected != len(ids) { + return 0, service.ErrRedeemCodeNotFound + } + return int64(affected), nil +} + func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error { now := time.Now() client := clientFromContext(ctx, r.client) diff --git a/backend/internal/repository/redeem_code_repo_integration_test.go b/backend/internal/repository/redeem_code_repo_integration_test.go index 24e5910e..99dc609e 100644 --- a/backend/internal/repository/redeem_code_repo_integration_test.go +++ b/backend/internal/repository/redeem_code_repo_integration_test.go @@ -4,6 +4,7 @@ package repository import ( "context" + "errors" "testing" "time" @@ -21,8 +22,8 @@ type RedeemCodeRepoSuite struct { } func (s *RedeemCodeRepoSuite) SetupTest() { - s.ctx = context.Background() tx := testEntTx(s.T()) + s.ctx = dbent.NewTxContext(context.Background(), tx) s.client = tx.Client() s.repo = NewRedeemCodeRepository(s.client).(*redeemCodeRepository) } @@ -237,6 +238,123 @@ func (s *RedeemCodeRepoSuite) TestUpdate() { s.Require().Equal(float64(50), got.Value) } +func (s *RedeemCodeRepoSuite) TestBatchUpdate_PartialFieldsAndClear() { + group := s.createGroup(uniqueTestValue(s.T(), "batch-update-group")) + groupID := group.ID + expiresAt := time.Now().UTC().Add(2 * time.Hour) + status := service.StatusDisabled + notes := "batch note" + + codeA := &service.RedeemCode{ + Code: "BATCH-UP-A", + Type: service.RedeemTypeBalance, + Value: 10, + Status: service.StatusUnused, + Notes: "old", + } + codeB := &service.RedeemCode{ + Code: "BATCH-UP-B", + Type: service.RedeemTypeBalance, + Value: 20, + Status: service.StatusUnused, + Notes: "old", + } + untouched := &service.RedeemCode{ + Code: "BATCH-UP-C", + Type: service.RedeemTypeBalance, + Value: 30, + Status: service.StatusUnused, + Notes: "keep", + } + s.Require().NoError(s.repo.Create(s.ctx, codeA)) + s.Require().NoError(s.repo.Create(s.ctx, codeB)) + s.Require().NoError(s.repo.Create(s.ctx, untouched)) + + updated, err := s.repo.BatchUpdate(s.ctx, []int64{codeA.ID, codeB.ID}, service.RedeemCodeBatchUpdateFields{ + Status: &status, + ExpiresAt: service.NullableTimeUpdate{Set: true, Value: &expiresAt}, + Notes: ¬es, + GroupID: service.NullableInt64Update{Set: true, Value: &groupID}, + }) + s.Require().NoError(err) + s.Require().Equal(int64(2), updated) + + gotA, err := s.repo.GetByID(s.ctx, codeA.ID) + s.Require().NoError(err) + s.Require().Equal("BATCH-UP-A", gotA.Code) + s.Require().Equal(service.RedeemTypeBalance, gotA.Type) + s.Require().Equal(float64(10), gotA.Value) + s.Require().Equal(service.StatusDisabled, gotA.Status) + s.Require().Equal(notes, gotA.Notes) + s.Require().NotNil(gotA.ExpiresAt) + s.Require().WithinDuration(expiresAt, *gotA.ExpiresAt, time.Second) + s.Require().NotNil(gotA.GroupID) + s.Require().Equal(groupID, *gotA.GroupID) + + gotB, err := s.repo.GetByID(s.ctx, codeB.ID) + s.Require().NoError(err) + s.Require().Equal(service.StatusDisabled, gotB.Status) + s.Require().Equal(notes, gotB.Notes) + + gotUntouched, err := s.repo.GetByID(s.ctx, untouched.ID) + s.Require().NoError(err) + s.Require().Equal(service.StatusUnused, gotUntouched.Status) + s.Require().Equal("keep", gotUntouched.Notes) + s.Require().Nil(gotUntouched.ExpiresAt) + s.Require().Nil(gotUntouched.GroupID) + + updated, err = s.repo.BatchUpdate(s.ctx, []int64{codeA.ID}, service.RedeemCodeBatchUpdateFields{ + ExpiresAt: service.NullableTimeUpdate{Set: true}, + GroupID: service.NullableInt64Update{Set: true}, + }) + s.Require().NoError(err) + s.Require().Equal(int64(1), updated) + + gotA, err = s.repo.GetByID(s.ctx, codeA.ID) + s.Require().NoError(err) + s.Require().Nil(gotA.ExpiresAt) + s.Require().Nil(gotA.GroupID) +} + +func (s *RedeemCodeRepoSuite) TestBatchUpdate_InvalidIDRollsBack() { + code := &service.RedeemCode{ + Code: "BATCH-UP-ROLLBACK", + Type: service.RedeemTypeBalance, + Value: 10, + Status: service.StatusUnused, + Notes: "keep", + } + s.Require().NoError(s.repo.Create(s.ctx, code)) + notes := "changed" + + _, err := s.repo.BatchUpdate(s.ctx, []int64{code.ID, 999999}, service.RedeemCodeBatchUpdateFields{Notes: ¬es}) + s.Require().Error(err) + s.Require().True(errors.Is(err, service.ErrRedeemCodeNotFound)) + + got, getErr := s.repo.GetByID(s.ctx, code.ID) + s.Require().NoError(getErr) + s.Require().Equal("keep", got.Notes) +} + +func (s *RedeemCodeRepoSuite) TestBatchUpdate_UsedCodeRejectsSensitiveFields() { + code := &service.RedeemCode{ + Code: "BATCH-UP-USED", + Type: service.RedeemTypeBalance, + Value: 10, + Status: service.StatusUsed, + } + s.Require().NoError(s.repo.Create(s.ctx, code)) + status := service.StatusDisabled + + _, err := s.repo.BatchUpdate(s.ctx, []int64{code.ID}, service.RedeemCodeBatchUpdateFields{Status: &status}) + s.Require().Error(err) + s.Require().True(errors.Is(err, service.ErrRedeemCodeUsed)) + + got, getErr := s.repo.GetByID(s.ctx, code.ID) + s.Require().NoError(getErr) + s.Require().Equal(service.StatusUsed, got.Status) +} + // --- Use --- func (s *RedeemCodeRepoSuite) TestUse() { diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 65757b62..8dd9e6fb 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -757,6 +757,7 @@ func TestAPIContracts(t *testing.T) { "site_logo": "", "site_subtitle": "Subtitle", "api_base_url": "https://api.example.com", + "api_key_acl_trust_forwarded_ip": false, "contact_info": "support", "doc_url": "https://docs.example.com", "auth_source_default_email_balance": 0, @@ -862,6 +863,7 @@ func TestAPIContracts(t *testing.T) { "payment_alipay_force_qrcode": false, "balance_low_notify_enabled": false, "account_quota_notify_enabled": false, + "subscription_expiry_notify_enabled": true, "balance_low_notify_threshold": 0, "balance_low_notify_recharge_url": "", "account_quota_notify_emails": [], @@ -1014,6 +1016,7 @@ func TestAPIContracts(t *testing.T) { "site_logo": "", "site_subtitle": "Subscription to API Conversion Platform", "api_base_url": "", + "api_key_acl_trust_forwarded_ip": false, "contact_info": "", "doc_url": "", "home_content": "", @@ -1086,6 +1089,7 @@ func TestAPIContracts(t *testing.T) { "payment_alipay_force_qrcode": false, "balance_low_notify_enabled": false, "account_quota_notify_enabled": false, + "subscription_expiry_notify_enabled": true, "balance_low_notify_threshold": 0, "balance_low_notify_recharge_url": "", "account_quota_notify_emails": [], @@ -1254,7 +1258,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) @@ -1852,6 +1856,10 @@ func (stubRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode) return errors.New("not implemented") } +func (stubRedeemCodeRepo) BatchUpdate(ctx context.Context, ids []int64, fields service.RedeemCodeBatchUpdateFields) (int64, error) { + return int64(len(ids)), nil +} + func (stubRedeemCodeRepo) Delete(ctx context.Context, id int64) error { return errors.New("not implemented") } diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index fad4d8e0..e4a18dee 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -18,7 +18,6 @@ import ( "github.com/google/wire" "github.com/redis/go-redis/v9" "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" ) // ProviderSet 提供服务器层的依赖 @@ -102,6 +101,16 @@ func ProvideRouter( // ProvideHTTPServer 提供 HTTP 服务器 func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server { httpHandler := http.Handler(router) + server := &http.Server{ + Addr: cfg.Server.Address(), + Handler: httpHandler, + // ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击 + ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second, + // IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源 + IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second, + // 注意:不设置 WriteTimeout,因为流式响应可能持续十几分钟 + // 不设置 ReadTimeout,因为大请求体可能需要较长时间读取 + } globalMaxSize := cfg.Server.MaxRequestBodySize if globalMaxSize <= 0 { @@ -115,32 +124,31 @@ func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server { // 根据配置决定是否启用 H2C if cfg.Server.H2C.Enabled { h2cConfig := cfg.Server.H2C - httpHandler = h2c.NewHandler(router, &http2.Server{ + if err := http2.ConfigureServer(server, &http2.Server{ MaxConcurrentStreams: h2cConfig.MaxConcurrentStreams, IdleTimeout: time.Duration(h2cConfig.IdleTimeout) * time.Second, MaxReadFrameSize: uint32(h2cConfig.MaxReadFrameSize), MaxUploadBufferPerConnection: int32(h2cConfig.MaxUploadBufferPerConnection), MaxUploadBufferPerStream: int32(h2cConfig.MaxUploadBufferPerStream), - }) - log.Printf("HTTP/2 Cleartext (h2c) enabled: max_concurrent_streams=%d, idle_timeout=%ds, max_read_frame_size=%d, max_upload_buffer_per_connection=%d, max_upload_buffer_per_stream=%d", - h2cConfig.MaxConcurrentStreams, - h2cConfig.IdleTimeout, - h2cConfig.MaxReadFrameSize, - h2cConfig.MaxUploadBufferPerConnection, - h2cConfig.MaxUploadBufferPerStream, - ) + }); err != nil { + log.Printf("Failed to configure HTTP/2 Cleartext (h2c): %v", err) + } else { + protocols := new(http.Protocols) + protocols.SetHTTP1(true) + protocols.SetUnencryptedHTTP2(true) + server.Protocols = protocols + log.Printf("HTTP/2 Cleartext (h2c) enabled: max_concurrent_streams=%d, idle_timeout=%ds, max_read_frame_size=%d, max_upload_buffer_per_connection=%d, max_upload_buffer_per_stream=%d", + h2cConfig.MaxConcurrentStreams, + h2cConfig.IdleTimeout, + h2cConfig.MaxReadFrameSize, + h2cConfig.MaxUploadBufferPerConnection, + h2cConfig.MaxUploadBufferPerStream, + ) + } } - return &http.Server{ - Addr: cfg.Server.Address(), - Handler: httpHandler, - // ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击 - ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second, - // IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源 - IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second, - // 注意:不设置 WriteTimeout,因为流式响应可能持续十几分钟 - // 不设置 ReadTimeout,因为大请求体可能需要较长时间读取 - } + server.Handler = httpHandler + return server } func derefInt64(p *int64) int64 { diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index ee439cda..7b9a1ee0 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -90,6 +90,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti // 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制 if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 { clientIP := ip.GetTrustedClientIP(c) + if cfg.TrustForwardedIPForAPIKeyACL() { + clientIP = ip.GetClientIP(c) + } allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist) if !allowed { service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonIPRestriction) diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index a00f70c7..57e69f10 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -398,7 +398,7 @@ func TestAPIKeyAuthRejectsUnavailableGroup(t *testing.T) { } } -func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) { +func TestAPIKeyAuthIPRestrictionDoesNotTrustForwardedClientIPByDefault(t *testing.T) { gin.SetMode(gin.TestMode) user := &service.User{ @@ -460,6 +460,57 @@ func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) require.Equal(t, service.OpsClientBusinessLimitedReasonIPRestriction, businessLimitedReason) } +func TestAPIKeyAuthIPRestrictionCanTrustForwardedClientIPForReverseProxy(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 100, + UserID: user.ID, + Key: "test-key", + Status: service.StatusActive, + User: user, + IPWhitelist: []string{"1.2.3.4"}, + } + + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + } + + cfg := &config.Config{RunMode: config.RunModeSimple} + cfg.SetTrustForwardedIPForAPIKeyACL(true) + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := gin.New() + require.NoError(t, router.SetTrustedProxies(nil)) + router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) + router.GET("/t", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.RemoteAddr = "9.9.9.9:12345" + req.Header.Set("x-api-key", apiKey.Key) + req.Header.Set("X-Forwarded-For", "1.2.3.4") + req.Header.Set("X-Real-IP", "1.2.3.4") + req.Header.Set("CF-Connecting-IP", "1.2.3.4") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) +} + func TestAPIKeyAuthTouchesLastUsedOnSuccess(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/server/middleware/logger.go b/backend/internal/server/middleware/logger.go index b14a3a21..6cec1c1a 100644 --- a/backend/internal/server/middleware/logger.go +++ b/backend/internal/server/middleware/logger.go @@ -4,6 +4,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/gin-gonic/gin" "go.uber.org/zap" @@ -31,7 +32,7 @@ func Logger() gin.HandlerFunc { method := c.Request.Method statusCode := c.Writer.Status() - clientIP := c.ClientIP() + clientIP := ip.GetClientIP(c) protocol := c.Request.Proto accountID, hasAccountID := c.Request.Context().Value(ctxkey.AccountID).(int64) platform, _ := c.Request.Context().Value(ctxkey.Platform).(string) diff --git a/backend/internal/server/middleware/request_access_logger_test.go b/backend/internal/server/middleware/request_access_logger_test.go index fec3ed22..bf3666f2 100644 --- a/backend/internal/server/middleware/request_access_logger_test.go +++ b/backend/internal/server/middleware/request_access_logger_test.go @@ -180,6 +180,37 @@ func TestLogger_AccessLogIncludesCoreFields(t *testing.T) { } } +func TestLogger_AccessLogUsesForwardedClientIP(t *testing.T) { + gin.SetMode(gin.TestMode) + sink := initMiddlewareTestLogger(t) + + r := gin.New() + r.Use(Logger()) + r.GET("/api/test", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/test", nil) + req.RemoteAddr = "104.23.251.120:443" + req.Header.Set("CF-Connecting-IP", "203.0.113.42") + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d", w.Code) + } + + for _, event := range sink.list() { + if event == nil || event.Message != "http request completed" { + continue + } + if got := event.Fields["client_ip"]; got != "203.0.113.42" { + t.Fatalf("client_ip=%q, want real forwarded ip", got) + } + return + } + t.Fatalf("access log event not found") +} + func TestLogger_HealthPathSkipped(t *testing.T) { gin.SetMode(gin.TestMode) sink := initMiddlewareTestLogger(t) diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 67164cb9..738ae2c4 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -396,6 +396,7 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) { codes.POST("/generate", h.Admin.Redeem.Generate) codes.DELETE("/:id", h.Admin.Redeem.Delete) codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete) + codes.POST("/batch-update", h.Admin.Redeem.BatchUpdate) codes.POST("/:id/expire", h.Admin.Redeem.Expire) } } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 1ac04ede..397004ac 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -511,7 +511,6 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co // testOpenAIAccountConnection tests an OpenAI account's connection func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string, mode string) error { ctx := c.Request.Context() - _ = prompt mode = normalizeAccountTestMode(mode) // Default to openai.DefaultTestModel for OpenAI testing @@ -572,14 +571,8 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error())) } - // 账号已被探测为不支持 Responses(如 DeepSeek/Kimi 等)时,丢出明确提示。 - // 账号本身可用(网关会走 CC 直转),仅测试入口需要补齐 CC SSE 处理逻辑。 - // TODO:实现 CC 格式的账号测试路径(需专门的 CC SSE handler)。 if !openai_compat.ShouldUseResponsesAPI(account.Extra) { - return s.sendErrorAndEnd(c, - "账号已被探测为不支持 OpenAI Responses API(如 DeepSeek/Kimi 等三方兼容上游),"+ - "账号本身可正常使用,但当前测试接口仅支持 Responses API 路径。请直接通过实际 API 调用验证。", - ) + return s.testOpenAIChatCompletionsConnection(c, account, testModelID, prompt, normalizedBaseURL, authToken) } apiURL = buildOpenAIResponsesURL(normalizedBaseURL) } else { @@ -654,6 +647,65 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account return s.processOpenAIStream(c, resp.Body) } +// testOpenAIChatCompletionsConnection tests an OpenAI-compatible APIKey account +// through the raw /v1/chat/completions endpoint. +func (s *AccountTestService) testOpenAIChatCompletionsConnection( + c *gin.Context, + account *Account, + testModelID string, + prompt string, + normalizedBaseURL string, + authToken string, +) error { + ctx := c.Request.Context() + apiURL := buildOpenAIChatCompletionsURL(normalizedBaseURL) + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + payload := createOpenAIChatCompletionsTestPayload(testModelID, prompt) + payloadBytes, _ := json.Marshal(payload) + + s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID}) + s.sendEvent(c, TestEvent{Type: "status", Text: "正在通过 /v1/chat/completions 测试连接"}) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(payloadBytes)) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create Chat Completions request") + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Authorization", "Bearer "+authToken) + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Chat Completions API (/v1/chat/completions) request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode == http.StatusTooManyRequests { + s.reconcileOpenAI429State(ctx, account, resp.Header, body) + } + if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil { + errMsg := fmt.Sprintf("Chat Completions authentication failed (401): %s", string(body)) + _ = s.accountRepo.SetError(ctx, account.ID, errMsg) + } + return s.sendErrorAndEnd(c, fmt.Sprintf("Chat Completions API (/v1/chat/completions) returned %d: %s", resp.StatusCode, string(body))) + } + + return s.processOpenAIChatCompletionsStream(c, resp.Body) +} + // testOpenAICompactConnection probes /responses/compact and persists the // resulting capability state on the account. func (s *AccountTestService) testOpenAICompactConnection(c *gin.Context, account *Account, testModelID string) error { @@ -1256,6 +1308,24 @@ func createOpenAITestPayload(modelID string, isOAuth bool, prompt string) map[st return payload } +func createOpenAIChatCompletionsTestPayload(modelID string, prompt string) map[string]any { + testPrompt := strings.TrimSpace(prompt) + if testPrompt == "" { + testPrompt = "hi" + } + + return map[string]any{ + "model": modelID, + "messages": []map[string]any{ + { + "role": "user", + "content": testPrompt, + }, + }, + "stream": true, + } +} + // processClaudeStream processes the SSE stream from Claude API func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader) error { reader := bufio.NewReader(body) @@ -1310,6 +1380,82 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader) } } +// processOpenAIChatCompletionsStream processes SSE chunks from the +// OpenAI-compatible Chat Completions API. +func (s *AccountTestService) processOpenAIChatCompletionsStream(c *gin.Context, body io.Reader) error { + reader := bufio.NewReader(body) + seenJSON := false + seenFinish := false + + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + if seenFinish { + s.sendEvent(c, TestEvent{Type: "status", Text: "已通过 /v1/chat/completions 验证"}) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + if seenJSON { + return s.sendErrorAndEnd(c, "Chat Completions stream from /v1/chat/completions ended before [DONE]") + } + return s.sendErrorAndEnd(c, "Invalid Chat Completions response from /v1/chat/completions: expected SSE JSON data") + } + return s.sendErrorAndEnd(c, fmt.Sprintf("Chat Completions stream read error from /v1/chat/completions: %s", err.Error())) + } + + line = strings.TrimSpace(line) + if line == "" || !sseDataPrefix.MatchString(line) { + continue + } + + jsonStr := sseDataPrefix.ReplaceAllString(line, "") + if jsonStr == "[DONE]" { + s.sendEvent(c, TestEvent{Type: "status", Text: "已通过 /v1/chat/completions 验证"}) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + + var data map[string]any + if err := json.Unmarshal([]byte(jsonStr), &data); err != nil { + return s.sendErrorAndEnd(c, "Invalid Chat Completions response from /v1/chat/completions: expected JSON data") + } + seenJSON = true + + if errData, ok := data["error"].(map[string]any); ok { + errorMsg := "Chat Completions API (/v1/chat/completions) returned an error" + if msg, ok := errData["message"].(string); ok && msg != "" { + errorMsg = msg + } + return s.sendErrorAndEnd(c, fmt.Sprintf("Chat Completions API (/v1/chat/completions) error: %s", errorMsg)) + } + + choices, ok := data["choices"].([]any) + if !ok { + continue + } + for _, choiceValue := range choices { + choice, ok := choiceValue.(map[string]any) + if !ok { + continue + } + if delta, ok := choice["delta"].(map[string]any); ok { + if text, ok := delta["content"].(string); ok && text != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: text}) + } + } + if message, ok := choice["message"].(map[string]any); ok { + if text, ok := message["content"].(string); ok && text != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: text}) + } + } + if finishReason, ok := choice["finish_reason"].(string); ok && finishReason != "" { + seenFinish = true + } + } + } +} + // processOpenAIStream processes the SSE stream from OpenAI Responses API func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error { reader := bufio.NewReader(body) diff --git a/backend/internal/service/account_test_service_openai_test.go b/backend/internal/service/account_test_service_openai_test.go index c6e30ed6..9844957a 100644 --- a/backend/internal/service/account_test_service_openai_test.go +++ b/backend/internal/service/account_test_service_openai_test.go @@ -12,10 +12,12 @@ import ( "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" - - "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" + "github.com/tidwall/gjson" ) // --- shared test helpers --- @@ -333,3 +335,145 @@ func TestAccountTestService_OpenAI401SetsPermanentErrorOnly(t *testing.T) { require.Zero(t, repo.clearedErrorID) require.Nil(t, account.RateLimitResetAt) } + +func TestAccountTestService_OpenAIAPIKeyResponsesUnsupportedUsesChatCompletionsPath(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, recorder := newTestContext() + + upstreamBody := strings.Join([]string{ + `data: {"id":"chatcmpl_test","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"pong"},"finish_reason":null}]}`, + "", + `data: {"id":"chatcmpl_test","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + svc := &AccountTestService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 91, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://compat-upstream.example/v1", + }, + Extra: map[string]any{openai_compat.ExtraKeyResponsesSupported: false}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "hello", "") + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "https://compat-upstream.example/v1/chat/completions", upstream.lastReq.URL.String()) + require.Equal(t, "Bearer sk-test", upstream.lastReq.Header.Get("Authorization")) + require.Equal(t, "text/event-stream", upstream.lastReq.Header.Get("Accept")) + require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String()) + require.True(t, gjson.GetBytes(upstream.lastBody, "stream").Bool()) + require.Equal(t, "hello", gjson.GetBytes(upstream.lastBody, "messages.0.content").String()) + require.False(t, gjson.GetBytes(upstream.lastBody, "input").Exists()) + body := recorder.Body.String() + require.Contains(t, body, "pong") + require.Contains(t, body, "已通过 /v1/chat/completions 验证") + require.Contains(t, body, `"success":true`) + require.NotContains(t, body, "当前测试接口仅支持 Responses API 路径") +} + +func TestAccountTestService_OpenAIChatCompletionsPathReturns4xx(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, recorder := newTestContext() + + upstream := &httpUpstreamRecorder{resp: newJSONResponse(http.StatusBadRequest, `{"error":{"message":"bad request"}}`)} + svc := &AccountTestService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 92, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://compat-upstream.example", + }, + Extra: map[string]any{openai_compat.ExtraKeyResponsesSupported: false}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "") + require.Error(t, err) + require.Equal(t, "https://compat-upstream.example/v1/chat/completions", upstream.lastReq.URL.String()) + require.Contains(t, err.Error(), "Chat Completions API (/v1/chat/completions) returned 400") + require.Contains(t, recorder.Body.String(), "/v1/chat/completions") + require.NotContains(t, recorder.Body.String(), `"success":true`) +} + +func TestAccountTestService_OpenAIChatCompletionsPathTimeout(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, recorder := newTestContext() + + upstream := &httpUpstreamRecorder{err: context.DeadlineExceeded} + svc := &AccountTestService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 93, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://compat-upstream.example", + }, + Extra: map[string]any{openai_compat.ExtraKeyResponsesSupported: false}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "") + require.Error(t, err) + require.Equal(t, "https://compat-upstream.example/v1/chat/completions", upstream.lastReq.URL.String()) + require.Contains(t, err.Error(), "Chat Completions API (/v1/chat/completions) request failed") + require.Contains(t, err.Error(), context.DeadlineExceeded.Error()) + require.Contains(t, recorder.Body.String(), "/v1/chat/completions") + require.NotContains(t, recorder.Body.String(), `"success":true`) +} + +func TestAccountTestService_OpenAIChatCompletionsPathRejectsNonJSONStream(t *testing.T) { + gin.SetMode(gin.TestMode) + ctx, recorder := newTestContext() + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader("data: not-json\n\n")), + }} + svc := &AccountTestService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 94, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://compat-upstream.example", + }, + Extra: map[string]any{openai_compat.ExtraKeyResponsesSupported: false}, + } + + err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "") + require.Error(t, err) + require.Equal(t, "https://compat-upstream.example/v1/chat/completions", upstream.lastReq.URL.String()) + require.Contains(t, err.Error(), "Invalid Chat Completions response from /v1/chat/completions") + require.Contains(t, recorder.Body.String(), "/v1/chat/completions") + require.NotContains(t, recorder.Body.String(), `"success":true`) +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 9b5c1afc..cdc5217e 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -531,6 +531,7 @@ type adminServiceImpl struct { defaultSubAssigner DefaultSubscriptionAssigner userSubRepo UserSubscriptionRepository privacyClientFactory PrivacyClientFactory + runtimeBlocker AccountRuntimeBlocker } type userGroupRateBatchReader interface { @@ -556,6 +557,7 @@ func NewAdminService( defaultSubAssigner DefaultSubscriptionAssigner, userSubRepo UserSubscriptionRepository, privacyClientFactory PrivacyClientFactory, + runtimeBlocker AccountRuntimeBlocker, ) AdminService { return &adminServiceImpl{ userRepo: userRepo, @@ -575,6 +577,7 @@ func NewAdminService( defaultSubAssigner: defaultSubAssigner, userSubRepo: userSubRepo, privacyClientFactory: privacyClientFactory, + runtimeBlocker: runtimeBlocker, } } @@ -2791,6 +2794,9 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac if err := s.accountRepo.ClearTempUnschedulable(ctx, id); err != nil { return nil, err } + if s.runtimeBlocker != nil { + s.runtimeBlocker.ClearAccountSchedulingBlock(id) + } return s.accountRepo.GetByID(ctx, id) } diff --git a/backend/internal/service/admin_service_clear_error_test.go b/backend/internal/service/admin_service_clear_error_test.go index 141466dc..41418c4f 100644 --- a/backend/internal/service/admin_service_clear_error_test.go +++ b/backend/internal/service/admin_service_clear_error_test.go @@ -70,7 +70,8 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes TempUnschedulableReason: "missing refresh token", }, } - svc := &adminServiceImpl{accountRepo: repo} + blocker := &runtimeBlockRecorder{} + svc := &adminServiceImpl{accountRepo: repo, runtimeBlocker: blocker} updated, err := svc.ClearAccountError(context.Background(), 31) require.NoError(t, err) @@ -83,4 +84,5 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes require.Nil(t, updated.RateLimitResetAt) require.Nil(t, updated.TempUnschedulableUntil) require.Empty(t, updated.TempUnschedulableReason) + require.Equal(t, []int64{31}, blocker.clearedIDs) } diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index c1da92d1..2f764d67 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -325,6 +325,12 @@ func (s *proxyRepoStub) ListAccountSummariesByProxyID(ctx context.Context, proxy type redeemRepoStub struct { deleteErrByID map[int64]error deletedIDs []int64 + + batchUpdateIDs []int64 + batchUpdateFields RedeemCodeBatchUpdateFields + batchUpdateResult int64 + batchUpdateErr error + batchUpdateCalled bool } func (s *redeemRepoStub) Create(ctx context.Context, code *RedeemCode) error { @@ -347,6 +353,19 @@ func (s *redeemRepoStub) Update(ctx context.Context, code *RedeemCode) error { panic("unexpected Update call") } +func (s *redeemRepoStub) BatchUpdate(ctx context.Context, ids []int64, fields RedeemCodeBatchUpdateFields) (int64, error) { + s.batchUpdateCalled = true + s.batchUpdateIDs = append([]int64(nil), ids...) + s.batchUpdateFields = fields + if s.batchUpdateErr != nil { + return 0, s.batchUpdateErr + } + if s.batchUpdateResult != 0 { + return s.batchUpdateResult, nil + } + return int64(len(ids)), nil +} + func (s *redeemRepoStub) Delete(ctx context.Context, id int64) error { s.deletedIDs = append(s.deletedIDs, id) if s.deleteErrByID != nil { diff --git a/backend/internal/service/auth_email_oauth_auto.go b/backend/internal/service/auth_email_oauth_auto.go index 56fd4004..4db845c2 100644 --- a/backend/internal/service/auth_email_oauth_auto.go +++ b/backend/internal/service/auth_email_oauth_auto.go @@ -49,7 +49,7 @@ func (s *AuthService) loginOrRegisterVerifiedEmailOAuth( } providerType := normalizeOAuthSignupSource(input.ProviderType) - if providerType != "github" && providerType != "google" { + if providerType != "github" && providerType != "google" && providerType != "oidc" { return nil, nil, infraerrors.BadRequest("OAUTH_PROVIDER_INVALID", "oauth provider is invalid") } providerKey := strings.TrimSpace(input.ProviderKey) diff --git a/backend/internal/service/auth_oauth_email_flow_test.go b/backend/internal/service/auth_oauth_email_flow_test.go index cd76c6b7..3c02587b 100644 --- a/backend/internal/service/auth_oauth_email_flow_test.go +++ b/backend/internal/service/auth_oauth_email_flow_test.go @@ -59,6 +59,10 @@ func (s *redeemCodeRepoStub) Update(_ context.Context, code *RedeemCode) error { return nil } +func (s *redeemCodeRepoStub) BatchUpdate(context.Context, []int64, RedeemCodeBatchUpdateFields) (int64, error) { + panic("unexpected BatchUpdate call") +} + func (s *redeemCodeRepoStub) Delete(context.Context, int64) error { panic("unexpected Delete call") } diff --git a/backend/internal/service/bedrock_request.go b/backend/internal/service/bedrock_request.go index 2160c13c..8a1fb317 100644 --- a/backend/internal/service/bedrock_request.go +++ b/backend/internal/service/bedrock_request.go @@ -9,12 +9,16 @@ import ( "strings" "github.com/Wei-Shaw/sub2api/internal/domain" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) const defaultBedrockRegion = "us-east-1" +// featureKeyBedrockCCCompat is the key used in Channel.FeaturesConfig for Bedrock CC compatibility. +const featureKeyBedrockCCCompat = "bedrock_cc_compat" + var bedrockCrossRegionPrefixes = []string{"us.", "eu.", "apac.", "jp.", "au.", "us-gov.", "global."} // BedrockCrossRegionPrefix 根据 AWS Region 返回 Bedrock 跨区域推理的模型 ID 前缀 @@ -179,13 +183,16 @@ func BuildBedrockURL(region, modelID string, stream bool) string { // 3. 移除 Bedrock 不支持的字段(model, stream, output_format, output_config) // 4. 移除工具定义中的 custom 字段(Claude Code 会发送 custom: {defer_loading: true}) // 5. 清理 cache_control 中 Bedrock 不支持的字段(scope, ttl) +// 6. 修复 thinking 字段兼容性(Opus 4.7 仅支持 adaptive,enabled 需要 budget_tokens) +// 7. 清理 tool_use.id / tool_use_id 中 Bedrock 不接受的字符 func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ([]byte, error) { betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID) - return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens) + return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens, false) } // PrepareBedrockRequestBodyWithTokens prepares a Bedrock request using pre-resolved beta tokens. -func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string) ([]byte, error) { +// ccCompat 启用 CC 兼容模式时额外处理 thinking 类型转换和 tool_use.id 清理。 +func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string, ccCompat bool) ([]byte, error) { var err error // 注入 anthropic_version(Bedrock 要求) @@ -203,6 +210,7 @@ func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens if err != nil { return nil, fmt.Errorf("inject anthropic_beta: %w", err) } + logger.LegacyPrintf("service.gateway", "[Bedrock] Injected beta tokens: %v (model=%s ccCompat=%v)", betaTokens, modelID, ccCompat) } // 移除 model 字段(Bedrock 通过 URL 指定模型) @@ -235,6 +243,12 @@ func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens // 清理 cache_control 中 Bedrock 不支持的字段 body = sanitizeBedrockCacheControl(body, modelID) + // CC 兼容模式:修复 CC 发送的 Bedrock 不兼容字段 + if ccCompat { + body = sanitizeBedrockThinking(body, modelID) + body = sanitizeBedrockToolUseIDs(body) + } + return body, nil } @@ -444,17 +458,17 @@ func parseAnthropicBetaHeader(header string) []string { } // bedrockSupportedBetaTokens 是 Bedrock Invoke 支持的 beta 头白名单 -// 参考: litellm/litellm/llms/bedrock/common_utils.py (anthropic_beta_headers_config.json) +// 参考: AWS Bedrock 官方文档 + litellm anthropic_beta_headers_config.json // 更新策略: 当 AWS Bedrock 新增支持的 beta token 时需同步更新此白名单 var bedrockSupportedBetaTokens = map[string]bool{ - "computer-use-2025-01-24": true, - "computer-use-2025-11-24": true, - "context-1m-2025-08-07": true, - "context-management-2025-06-27": true, - "compact-2026-01-12": true, - "interleaved-thinking-2025-05-14": true, - "tool-search-tool-2025-10-19": true, - "tool-examples-2025-10-29": true, + "computer-use-2025-01-24": true, + "computer-use-2025-11-24": true, + "context-1m-2025-08-07": true, + // "context-management-2025-06-27": false, // 无官方文档支持 + "compact-2026-01-12": true, // 官方支持,仅 InvokeModel API(Opus 4.6+) + // "interleaved-thinking-2025-05-14": false, // 无官方文档支持 + "tool-search-tool-2025-10-19": true, + "tool-examples-2025-10-29": true, } // bedrockBetaTokenTransforms 定义 Bedrock Invoke 特有的 beta 头转换规则 @@ -482,11 +496,8 @@ func autoInjectBedrockBetaTokens(tokens []string, body []byte, modelID string) [ } } - // 检测 thinking / interleaved thinking - // 请求体中有 "thinking" 字段 → 需要 interleaved-thinking beta - if gjson.GetBytes(body, "thinking").Exists() { - inject("interleaved-thinking-2025-05-14") - } + // 注意:thinking 字段不再自动注入 interleaved-thinking-2025-05-14 + // 因为该 beta token 未在 AWS Bedrock 官方文档中确认支持 // 检测 computer_use 工具 // tools 中有 type="computer_20xxxxxx" 的工具 → 需要 computer-use beta @@ -605,3 +616,156 @@ func filterBedrockBetaTokens(tokens []string) []string { return result } + +// bedrockToolUseIDRe 匹配 Bedrock 允许的 tool_use ID 字符(字母、数字、下划线、连字符) +var bedrockToolUseIDRe = regexp.MustCompile(`[^a-zA-Z0-9_-]`) + +// isBedrockOpus47OrNewer 判断 Bedrock 模型 ID 是否为 Claude Opus 4.7 或更新版本 +// Opus 4.7 仅支持 thinking.type: "adaptive",不支持 "enabled" +func isBedrockOpus47OrNewer(modelID string) bool { + lower := strings.ToLower(modelID) + if !strings.Contains(lower, "opus") { + return false + } + matches := claudeVersionRe.FindStringSubmatch(lower) + if matches == nil { + return false + } + major, _ := strconv.Atoi(matches[1]) + minor, _ := strconv.Atoi(matches[2]) + return major > 4 || (major == 4 && minor >= 7) +} + +const defaultThinkingBudgetTokens = 10000 + +// sanitizeBedrockThinking 修复 thinking 字段的 Bedrock 兼容性问题: +// - Opus 4.7+: 仅支持 "adaptive",将 "enabled" 转换为 "adaptive" 并移除 budget_tokens +// - 其他模型: "enabled" 必须带 budget_tokens,缺失时补充默认值 +func sanitizeBedrockThinking(body []byte, modelID string) []byte { + thinking := gjson.GetBytes(body, "thinking") + if !thinking.Exists() || !thinking.IsObject() { + return body + } + + thinkingType := thinking.Get("type").String() + if thinkingType == "" { + return body + } + + if isBedrockOpus47OrNewer(modelID) { + if thinkingType == "enabled" { + body, _ = sjson.SetBytes(body, "thinking.type", "adaptive") + body, _ = sjson.DeleteBytes(body, "thinking.budget_tokens") + } + return body + } + + if thinkingType == "enabled" && !thinking.Get("budget_tokens").Exists() { + body, _ = sjson.SetBytes(body, "thinking.budget_tokens", defaultThinkingBudgetTokens) + } + + return body +} + +// sanitizeBedrockToolUseIDs 清理 messages 中 tool_use.id 和 tool_result.tool_use_id +// 的非法字符。Bedrock 要求 ID 匹配 '^[a-zA-Z0-9_-]+$'。 +func sanitizeBedrockToolUseIDs(body []byte) []byte { + messages := gjson.GetBytes(body, "messages") + if !messages.Exists() || !messages.IsArray() { + return body + } + for mi, msg := range messages.Array() { + content := msg.Get("content") + if !content.Exists() || !content.IsArray() { + continue + } + for ci, block := range content.Array() { + switch block.Get("type").String() { + case "tool_use": + body = sanitizeIDField(body, block.Get("id").String(), fmt.Sprintf("messages.%d.content.%d.id", mi, ci)) + case "tool_result": + body = sanitizeIDField(body, block.Get("tool_use_id").String(), fmt.Sprintf("messages.%d.content.%d.tool_use_id", mi, ci)) + } + } + } + return body +} + +func sanitizeIDField(body []byte, id, path string) []byte { + if id == "" { + return body + } + sanitized := bedrockToolUseIDRe.ReplaceAllString(id, "_") + if sanitized != id { + body, _ = sjson.SetBytes(body, path, sanitized) + } + return body +} + +const defaultCCMaxTokens = 81920 + +// sanitizeBedrockCCFields 处理 Claude Code 发送的 Bedrock 不兼容字段: +// - 移除 service_tier(Anthropic API 专有,Bedrock 不支持) +// - 移除 interface_geo(Anthropic API 专有,Bedrock 不支持) +// - 移除 context_management(Anthropic API 专有,Bedrock 不支持,CC v2.1.87+ 默认携带) +// - 注入 max_tokens 默认值 81920(CC 可能省略,Bedrock 要求必须提供) +// - 注入 anthropic_version(CC 通过 HTTP 头发送,Bedrock 需要放在请求体中) +func sanitizeBedrockCCFields(body []byte) []byte { + if gjson.GetBytes(body, "service_tier").Exists() { + body, _ = sjson.DeleteBytes(body, "service_tier") + } + if gjson.GetBytes(body, "interface_geo").Exists() { + body, _ = sjson.DeleteBytes(body, "interface_geo") + } + if gjson.GetBytes(body, "context_management").Exists() { + body, _ = sjson.DeleteBytes(body, "context_management") + } + if !gjson.GetBytes(body, "max_tokens").Exists() { + body, _ = sjson.SetBytes(body, "max_tokens", defaultCCMaxTokens) + } + if !gjson.GetBytes(body, "anthropic_version").Exists() { + body, _ = sjson.SetBytes(body, "anthropic_version", "bedrock-2023-05-31") + } + return body +} + +// sanitizeBedrockCCBetaTokens 清理请求体中的 anthropic_beta 字段,只保留 Bedrock 支持的 beta token +// CC 可能在请求体中注入了 Bedrock 不支持的 beta token(如 prompt-caching 等),导致 ValidationException +func sanitizeBedrockCCBetaTokens(body []byte, modelID string) []byte { + betaField := gjson.GetBytes(body, "anthropic_beta") + if !betaField.Exists() { + return body + } + + var tokens []string + if betaField.IsArray() { + for _, t := range betaField.Array() { + if t.Type == gjson.String { + tokens = append(tokens, t.String()) + } + } + } + + originalTokens := append([]string(nil), tokens...) // 保存原始 tokens 用于日志 + + // 复用现有的 Bedrock beta token 过滤逻辑(自动注入 + 白名单过滤 + 转换) + // 即使 tokens 为空,也要执行自动注入(根据 body 内容补充必要的 beta token) + tokens = autoInjectBedrockBetaTokens(tokens, body, modelID) + tokens = filterBedrockBetaTokens(tokens) + + if len(tokens) == 0 { + // 所有 token 都被过滤掉,删除 anthropic_beta 字段 + body, _ = sjson.DeleteBytes(body, "anthropic_beta") + logger.LegacyPrintf("service.gateway", "[Bedrock CC Compat] Removed all beta tokens: original=%v", originalTokens) + } else { + // 更新为过滤后的 token 列表 + body, _ = sjson.SetBytes(body, "anthropic_beta", tokens) + if len(originalTokens) > 0 { + logger.LegacyPrintf("service.gateway", "[Bedrock CC Compat] Filtered beta tokens: original=%v final=%v", originalTokens, tokens) + } else { + logger.LegacyPrintf("service.gateway", "[Bedrock CC Compat] Auto-injected beta tokens: %v", tokens) + } + } + + return body +} diff --git a/backend/internal/service/bedrock_request_test.go b/backend/internal/service/bedrock_request_test.go index 361cafb4..98942ba4 100644 --- a/backend/internal/service/bedrock_request_test.go +++ b/backend/internal/service/bedrock_request_test.go @@ -216,7 +216,7 @@ func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) { ] }` - betaHeader := "interleaved-thinking-2025-05-14, context-1m-2025-08-07, compact-2026-01-12" + betaHeader := "context-1m-2025-08-07, compact-2026-01-12" result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", betaHeader) require.NoError(t, err) @@ -228,10 +228,9 @@ func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) { // anthropic_beta 应包含所有 beta tokens betaArr := gjson.GetBytes(result, "anthropic_beta").Array() - require.Len(t, betaArr, 3) - assert.Equal(t, "interleaved-thinking-2025-05-14", betaArr[0].String()) - assert.Equal(t, "context-1m-2025-08-07", betaArr[1].String()) - assert.Equal(t, "compact-2026-01-12", betaArr[2].String()) + require.Len(t, betaArr, 2) + assert.Equal(t, "context-1m-2025-08-07", betaArr[0].String()) + assert.Equal(t, "compact-2026-01-12", betaArr[1].String()) // output_format 应被移除,schema 内联到最后一条 user message assert.False(t, gjson.GetBytes(result, "output_format").Exists()) @@ -264,29 +263,29 @@ func TestPrepareBedrockRequestBody_BetaHeader(t *testing.T) { }) t.Run("single beta token", func(t *testing.T) { - result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14") + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07") require.NoError(t, err) arr := gjson.GetBytes(result, "anthropic_beta").Array() require.Len(t, arr, 1) - assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String()) + assert.Equal(t, "context-1m-2025-08-07", arr[0].String()) }) t.Run("multiple beta tokens with spaces", func(t *testing.T) { - result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14 , context-1m-2025-08-07 ") + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07 , compact-2026-01-12 ") require.NoError(t, err) arr := gjson.GetBytes(result, "anthropic_beta").Array() require.Len(t, arr, 2) - assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String()) - assert.Equal(t, "context-1m-2025-08-07", arr[1].String()) + assert.Equal(t, "context-1m-2025-08-07", arr[0].String()) + assert.Equal(t, "compact-2026-01-12", arr[1].String()) }) t.Run("json array beta header", func(t *testing.T) { - result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["interleaved-thinking-2025-05-14","context-1m-2025-08-07"]`) + result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["context-1m-2025-08-07","compact-2026-01-12"]`) require.NoError(t, err) arr := gjson.GetBytes(result, "anthropic_beta").Array() require.Len(t, arr, 2) - assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String()) - assert.Equal(t, "context-1m-2025-08-07", arr[1].String()) + assert.Equal(t, "context-1m-2025-08-07", arr[0].String()) + assert.Equal(t, "compact-2026-01-12", arr[1].String()) }) } @@ -301,15 +300,15 @@ func TestParseAnthropicBetaHeader(t *testing.T) { func TestFilterBedrockBetaTokens(t *testing.T) { t.Run("supported tokens pass through", func(t *testing.T) { - tokens := []string{"interleaved-thinking-2025-05-14", "context-1m-2025-08-07", "compact-2026-01-12"} + tokens := []string{"context-1m-2025-08-07", "compact-2026-01-12", "computer-use-2025-11-24"} result := filterBedrockBetaTokens(tokens) assert.Equal(t, tokens, result) }) t.Run("unsupported tokens are filtered out", func(t *testing.T) { - tokens := []string{"interleaved-thinking-2025-05-14", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"} + tokens := []string{"context-1m-2025-08-07", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"} result := filterBedrockBetaTokens(tokens) - assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result) + assert.Equal(t, []string{"context-1m-2025-08-07"}, result) }) t.Run("advanced-tool-use transforms to tool-search-tool", func(t *testing.T) { @@ -361,11 +360,11 @@ func TestPrepareBedrockRequestBody_BetaFiltering(t *testing.T) { t.Run("unsupported beta tokens are filtered", func(t *testing.T) { result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", - "interleaved-thinking-2025-05-14, output-128k-2025-02-19, files-api-2025-04-14") + "compact-2026-01-12, output-128k-2025-02-19, files-api-2025-04-14") require.NoError(t, err) arr := gjson.GetBytes(result, "anthropic_beta").Array() require.Len(t, arr, 1) - assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String()) + assert.Equal(t, "compact-2026-01-12", arr[0].String()) }) t.Run("advanced-tool-use transformed in full pipeline", func(t *testing.T) { @@ -498,22 +497,17 @@ func TestResolveBedrockModelID(t *testing.T) { } func TestAutoInjectBedrockBetaTokens(t *testing.T) { - t.Run("inject interleaved-thinking when thinking present", func(t *testing.T) { + t.Run("no auto-inject for thinking (interleaved-thinking not supported)", func(t *testing.T) { body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`) result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1") - assert.Contains(t, result, "interleaved-thinking-2025-05-14") + // interleaved-thinking-2025-05-14 已从白名单移除,不应自动注入 + assert.Empty(t, result) }) t.Run("no duplicate when already present", func(t *testing.T) { body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`) - result := autoInjectBedrockBetaTokens([]string{"interleaved-thinking-2025-05-14"}, body, "us.anthropic.claude-opus-4-6-v1") - count := 0 - for _, t := range result { - if t == "interleaved-thinking-2025-05-14" { - count++ - } - } - assert.Equal(t, 1, count) + result := autoInjectBedrockBetaTokens([]string{"context-1m-2025-08-07"}, body, "us.anthropic.claude-opus-4-6-v1") + assert.Equal(t, []string{"context-1m-2025-08-07"}, result) }) t.Run("inject computer-use when computer tool present", func(t *testing.T) { @@ -574,7 +568,8 @@ func TestAutoInjectBedrockBetaTokens(t *testing.T) { result := autoInjectBedrockBetaTokens(existing, body, "us.anthropic.claude-opus-4-6-v1") assert.Contains(t, result, "context-1m-2025-08-07") assert.Contains(t, result, "compact-2026-01-12") - assert.Contains(t, result, "interleaved-thinking-2025-05-14") + // interleaved-thinking 不再自动注入 + assert.NotContains(t, result, "interleaved-thinking-2025-05-14") }) } @@ -588,27 +583,21 @@ func TestResolveBedrockBetaTokens(t *testing.T) { t.Run("unsupported client beta tokens are filtered out", func(t *testing.T) { body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`) - result := ResolveBedrockBetaTokens("interleaved-thinking-2025-05-14,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1") - assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result) + result := ResolveBedrockBetaTokens("context-1m-2025-08-07,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1") + assert.Equal(t, []string{"context-1m-2025-08-07"}, result) }) } func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) { - t.Run("thinking in body auto-injects beta without header", func(t *testing.T) { + t.Run("thinking in body does not auto-inject beta (not supported)", func(t *testing.T) { input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}` result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "") require.NoError(t, err) - arr := gjson.GetBytes(result, "anthropic_beta").Array() - found := false - for _, v := range arr { - if v.String() == "interleaved-thinking-2025-05-14" { - found = true - } - } - assert.True(t, found, "interleaved-thinking should be auto-injected") + // interleaved-thinking 已从白名单移除,不应自动注入 + assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists()) }) - t.Run("header tokens merged with auto-injected tokens", func(t *testing.T) { + t.Run("header tokens preserved without auto-injection", func(t *testing.T) { input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}` result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07") require.NoError(t, err) @@ -618,7 +607,8 @@ func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) { names[i] = v.String() } assert.Contains(t, names, "context-1m-2025-08-07") - assert.Contains(t, names, "interleaved-thinking-2025-05-14") + // interleaved-thinking 不再自动注入 + assert.NotContains(t, names, "interleaved-thinking-2025-05-14") }) } @@ -657,3 +647,363 @@ func TestAdjustBedrockModelRegionPrefix(t *testing.T) { }) } } + +func TestIsBedrockOpus47OrNewer(t *testing.T) { + tests := []struct { + modelID string + expect bool + }{ + {"us.anthropic.claude-opus-4-7-v1", true}, + {"us.anthropic.claude-opus-4-6-v1", false}, + {"us.anthropic.claude-opus-4-5-20251101-v1:0", false}, + {"us.anthropic.claude-opus-5-0-v1", true}, + // Sonnet 4.7 is not Opus → false + {"us.anthropic.claude-sonnet-4-7-v1", false}, + {"us.anthropic.claude-sonnet-4-6", false}, + // Haiku is not Opus + {"us.anthropic.claude-haiku-4-5-20251001-v1:0", false}, + // Non-Claude models + {"amazon.nova-pro-v1", false}, + } + for _, tt := range tests { + t.Run(tt.modelID, func(t *testing.T) { + assert.Equal(t, tt.expect, isBedrockOpus47OrNewer(tt.modelID)) + }) + } +} + +func TestSanitizeBedrockThinking(t *testing.T) { + t.Run("opus 4.7 converts enabled to adaptive", func(t *testing.T) { + input := `{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[]}` + result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1") + assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String()) + assert.False(t, gjson.GetBytes(result, "thinking.budget_tokens").Exists()) + }) + + t.Run("opus 4.7 keeps adaptive unchanged", func(t *testing.T) { + input := `{"thinking":{"type":"adaptive"},"messages":[]}` + result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1") + assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String()) + }) + + t.Run("opus 4.6 enabled without budget_tokens gets default", func(t *testing.T) { + input := `{"thinking":{"type":"enabled"},"messages":[]}` + result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-6-v1") + assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String()) + assert.Equal(t, int64(defaultThinkingBudgetTokens), gjson.GetBytes(result, "thinking.budget_tokens").Int()) + }) + + t.Run("opus 4.6 enabled with budget_tokens unchanged", func(t *testing.T) { + input := `{"thinking":{"type":"enabled","budget_tokens":20000},"messages":[]}` + result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-6-v1") + assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String()) + assert.Equal(t, int64(20000), gjson.GetBytes(result, "thinking.budget_tokens").Int()) + }) + + t.Run("no thinking field unchanged", func(t *testing.T) { + input := `{"messages":[]}` + result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1") + assert.JSONEq(t, input, string(result)) + }) + + t.Run("sonnet 4.6 enabled without budget_tokens gets default", func(t *testing.T) { + input := `{"thinking":{"type":"enabled"},"messages":[]}` + result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-sonnet-4-6") + assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String()) + assert.Equal(t, int64(defaultThinkingBudgetTokens), gjson.GetBytes(result, "thinking.budget_tokens").Int()) + }) +} + +func TestSanitizeBedrockToolUseIDs(t *testing.T) { + t.Run("clean IDs unchanged", func(t *testing.T) { + input := `{"messages":[{"role":"assistant","content":[{"type":"tool_use","id":"toolu_01AbCdEf","name":"bash","input":{}}]}]}` + result := sanitizeBedrockToolUseIDs([]byte(input)) + assert.Equal(t, "toolu_01AbCdEf", gjson.GetBytes(result, "messages.0.content.0.id").String()) + }) + + t.Run("dots in tool_use ID replaced with underscores", func(t *testing.T) { + input := `{"messages":[{"role":"assistant","content":[{"type":"tool_use","id":"toolu.01.Ab","name":"bash","input":{}}]}]}` + result := sanitizeBedrockToolUseIDs([]byte(input)) + assert.Equal(t, "toolu_01_Ab", gjson.GetBytes(result, "messages.0.content.0.id").String()) + }) + + t.Run("special chars in tool_use ID sanitized", func(t *testing.T) { + input := `{"messages":[{"role":"assistant","content":[{"type":"tool_use","id":"toolu:01@Ab#Cd","name":"bash","input":{}}]}]}` + result := sanitizeBedrockToolUseIDs([]byte(input)) + id := gjson.GetBytes(result, "messages.0.content.0.id").String() + assert.Regexp(t, `^[a-zA-Z0-9_-]+$`, id) + }) + + t.Run("tool_result tool_use_id sanitized", func(t *testing.T) { + input := `{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu.01.Ab","content":"ok"}]}]}` + result := sanitizeBedrockToolUseIDs([]byte(input)) + assert.Equal(t, "toolu_01_Ab", gjson.GetBytes(result, "messages.0.content.0.tool_use_id").String()) + }) + + t.Run("mixed clean and dirty IDs", func(t *testing.T) { + input := `{"messages":[ + {"role":"assistant","content":[{"type":"tool_use","id":"clean_id-123","name":"a","input":{}}]}, + {"role":"user","content":[{"type":"tool_result","tool_use_id":"dirty.id@456","content":"ok"}]}, + {"role":"assistant","content":[{"type":"tool_use","id":"also.dirty","name":"b","input":{}}]} + ]}` + result := sanitizeBedrockToolUseIDs([]byte(input)) + assert.Equal(t, "clean_id-123", gjson.GetBytes(result, "messages.0.content.0.id").String()) + assert.Equal(t, "dirty_id_456", gjson.GetBytes(result, "messages.1.content.0.tool_use_id").String()) + assert.Equal(t, "also_dirty", gjson.GetBytes(result, "messages.2.content.0.id").String()) + }) + + t.Run("no messages unchanged", func(t *testing.T) { + input := `{"system":[{"type":"text","text":"hi"}]}` + result := sanitizeBedrockToolUseIDs([]byte(input)) + assert.JSONEq(t, input, string(result)) + }) + + t.Run("string content skipped", func(t *testing.T) { + input := `{"messages":[{"role":"user","content":"plain text"}]}` + result := sanitizeBedrockToolUseIDs([]byte(input)) + assert.JSONEq(t, input, string(result)) + }) + + t.Run("empty ID skipped", func(t *testing.T) { + input := `{"messages":[{"role":"assistant","content":[{"type":"tool_use","id":"","name":"a","input":{}}]}]}` + result := sanitizeBedrockToolUseIDs([]byte(input)) + assert.Equal(t, "", gjson.GetBytes(result, "messages.0.content.0.id").String()) + }) +} + +func TestSanitizeBedrockThinking_EdgeCases(t *testing.T) { + t.Run("opus 4.7 enabled without budget_tokens converts to adaptive", func(t *testing.T) { + input := `{"thinking":{"type":"enabled"},"messages":[]}` + result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1") + assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String()) + assert.False(t, gjson.GetBytes(result, "thinking.budget_tokens").Exists()) + }) + + t.Run("thinking type disabled unchanged", func(t *testing.T) { + input := `{"thinking":{"type":"disabled"},"messages":[]}` + result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1") + assert.Equal(t, "disabled", gjson.GetBytes(result, "thinking.type").String()) + }) + + t.Run("thinking type empty string unchanged", func(t *testing.T) { + input := `{"thinking":{"type":""},"messages":[]}` + result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1") + assert.JSONEq(t, input, string(result)) + }) + + t.Run("thinking is not an object unchanged", func(t *testing.T) { + input := `{"thinking":true,"messages":[]}` + result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1") + assert.JSONEq(t, input, string(result)) + }) + + t.Run("opus 4.7 adaptive with budget_tokens preserved", func(t *testing.T) { + input := `{"thinking":{"type":"adaptive","budget_tokens":5000},"messages":[]}` + result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1") + assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String()) + assert.Equal(t, int64(5000), gjson.GetBytes(result, "thinking.budget_tokens").Int()) + }) + + // Forward() passes parsed.Model (standard names like "claude-opus-4-7") + t.Run("standard model name opus 4.7 converts enabled to adaptive", func(t *testing.T) { + input := `{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[]}` + result := sanitizeBedrockThinking([]byte(input), "claude-opus-4-7") + assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String()) + assert.False(t, gjson.GetBytes(result, "thinking.budget_tokens").Exists()) + }) + + t.Run("standard model name opus 4.6 keeps enabled", func(t *testing.T) { + input := `{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[]}` + result := sanitizeBedrockThinking([]byte(input), "claude-opus-4-6") + assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String()) + assert.Equal(t, int64(10000), gjson.GetBytes(result, "thinking.budget_tokens").Int()) + }) +} + +func TestIsBedrockOpus47OrNewer_EdgeCases(t *testing.T) { + tests := []struct { + modelID string + expect bool + }{ + {"anthropic.claude-opus-4-7-v1", true}, + {"us.anthropic.claude-opus-4-7-20270101-v1:0", true}, + {"", false}, + // Forward() passes parsed.Model (standard names), not Bedrock IDs + {"claude-opus-4-7", true}, + {"claude-opus-4-6", false}, + {"claude-sonnet-4-7", false}, + } + for _, tt := range tests { + t.Run(tt.modelID, func(t *testing.T) { + assert.Equal(t, tt.expect, isBedrockOpus47OrNewer(tt.modelID)) + }) + } +} + +func TestPrepareBedrockRequestBodyWithTokens_CCCompat(t *testing.T) { + input := `{ + "model":"claude-opus-4-6", + "stream":true, + "max_tokens":16384, + "thinking":{"type":"enabled"}, + "messages":[ + {"role":"assistant","content":[{"type":"tool_use","id":"toolu.01.Ab","name":"bash","input":{}}]}, + {"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu.01.Ab","content":"ok"}]} + ] + }` + + t.Run("ccCompat=false skips thinking and toolUseID sanitization", func(t *testing.T) { + result, err := PrepareBedrockRequestBodyWithTokens([]byte(input), "us.anthropic.claude-opus-4-6-v1", nil, false) + require.NoError(t, err) + assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String()) + assert.False(t, gjson.GetBytes(result, "thinking.budget_tokens").Exists()) + assert.Equal(t, "toolu.01.Ab", gjson.GetBytes(result, "messages.0.content.0.id").String()) + }) + + t.Run("ccCompat=true applies thinking fix and toolUseID sanitization (opus 4.6)", func(t *testing.T) { + result, err := PrepareBedrockRequestBodyWithTokens([]byte(input), "us.anthropic.claude-opus-4-6-v1", nil, true) + require.NoError(t, err) + assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String()) + assert.Equal(t, int64(defaultThinkingBudgetTokens), gjson.GetBytes(result, "thinking.budget_tokens").Int()) + assert.Equal(t, "toolu_01_Ab", gjson.GetBytes(result, "messages.0.content.0.id").String()) + assert.Equal(t, "toolu_01_Ab", gjson.GetBytes(result, "messages.1.content.0.tool_use_id").String()) + }) + + t.Run("ccCompat=true converts thinking to adaptive for opus 4.7", func(t *testing.T) { + result, err := PrepareBedrockRequestBodyWithTokens([]byte(input), "us.anthropic.claude-opus-4-7-v1", nil, true) + require.NoError(t, err) + assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String()) + assert.False(t, gjson.GetBytes(result, "thinking.budget_tokens").Exists()) + assert.Equal(t, "toolu_01_Ab", gjson.GetBytes(result, "messages.0.content.0.id").String()) + }) +} + +func TestSanitizeBedrockCCFields(t *testing.T) { + t.Run("removes service_tier and interface_geo", func(t *testing.T) { + body := []byte(`{"model":"claude-opus-4-6","service_tier":"standard","interface_geo":"us","messages":[]}`) + result := sanitizeBedrockCCFields(body) + assert.False(t, gjson.GetBytes(result, "service_tier").Exists()) + assert.False(t, gjson.GetBytes(result, "interface_geo").Exists()) + assert.True(t, gjson.GetBytes(result, "messages").Exists()) + }) + + t.Run("removes context_management", func(t *testing.T) { + body := []byte(`{"model":"claude-opus-4-6","context_management":{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]},"messages":[]}`) + result := sanitizeBedrockCCFields(body) + assert.False(t, gjson.GetBytes(result, "context_management").Exists()) + assert.True(t, gjson.GetBytes(result, "messages").Exists()) + }) + + t.Run("injects max_tokens when missing", func(t *testing.T) { + body := []byte(`{"model":"claude-opus-4-6","messages":[]}`) + result := sanitizeBedrockCCFields(body) + assert.Equal(t, int64(defaultCCMaxTokens), gjson.GetBytes(result, "max_tokens").Int()) + }) + + t.Run("preserves existing max_tokens", func(t *testing.T) { + body := []byte(`{"model":"claude-opus-4-6","max_tokens":4096,"messages":[]}`) + result := sanitizeBedrockCCFields(body) + assert.Equal(t, int64(4096), gjson.GetBytes(result, "max_tokens").Int()) + }) + + t.Run("injects anthropic_version when missing", func(t *testing.T) { + body := []byte(`{"model":"claude-opus-4-6","messages":[]}`) + result := sanitizeBedrockCCFields(body) + assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String()) + }) + + t.Run("preserves existing anthropic_version", func(t *testing.T) { + body := []byte(`{"model":"claude-opus-4-6","anthropic_version":"bedrock-2023-05-31","messages":[]}`) + result := sanitizeBedrockCCFields(body) + assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String()) + }) + + t.Run("no-op when fields already clean", func(t *testing.T) { + body := []byte(`{"model":"claude-opus-4-6","max_tokens":81920,"anthropic_version":"bedrock-2023-05-31","messages":[]}`) + result := sanitizeBedrockCCFields(body) + assert.Equal(t, int64(defaultCCMaxTokens), gjson.GetBytes(result, "max_tokens").Int()) + assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String()) + assert.False(t, gjson.GetBytes(result, "service_tier").Exists()) + assert.False(t, gjson.GetBytes(result, "interface_geo").Exists()) + assert.False(t, gjson.GetBytes(result, "context_management").Exists()) + }) + + t.Run("full CC request sanitization", func(t *testing.T) { + body := []byte(`{ + "model":"claude-opus-4-6", + "service_tier":"standard", + "interface_geo":"us", + "context_management":{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}, + "messages":[{"role":"user","content":"hello"}], + "thinking":{"type":"enabled"} + }`) + result := sanitizeBedrockCCFields(body) + assert.False(t, gjson.GetBytes(result, "service_tier").Exists()) + assert.False(t, gjson.GetBytes(result, "interface_geo").Exists()) + assert.False(t, gjson.GetBytes(result, "context_management").Exists()) + assert.Equal(t, int64(defaultCCMaxTokens), gjson.GetBytes(result, "max_tokens").Int()) + assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String()) + assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String()) + }) +} + +func TestSanitizeBedrockCCBetaTokens(t *testing.T) { + t.Run("filters unsupported beta tokens", func(t *testing.T) { + input := `{"anthropic_beta":["prompt-caching-2024-07-31","context-1m-2025-08-07","unsupported-feature"],"messages":[]}` + result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6") + beta := gjson.GetBytes(result, "anthropic_beta") + assert.True(t, beta.Exists()) + assert.True(t, beta.IsArray()) + tokens := beta.Array() + assert.Equal(t, 1, len(tokens)) + assert.Equal(t, "context-1m-2025-08-07", tokens[0].String()) + }) + + t.Run("removes anthropic_beta if all tokens filtered", func(t *testing.T) { + input := `{"anthropic_beta":["prompt-caching-2024-07-31","unsupported-feature"],"messages":[]}` + result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6") + assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists()) + }) + + t.Run("thinking alone does not auto-inject beta tokens", func(t *testing.T) { + input := `{"anthropic_beta":[],"thinking":{"type":"enabled"},"messages":[]}` + result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6") + assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists()) + }) + + t.Run("auto-injects computer-use beta token", func(t *testing.T) { + input := `{"anthropic_beta":[],"tools":[{"type":"computer_20250124","name":"computer"}],"messages":[]}` + result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6") + beta := gjson.GetBytes(result, "anthropic_beta") + assert.True(t, beta.Exists()) + tokens := beta.Array() + assert.Equal(t, 1, len(tokens)) + assert.Equal(t, "computer-use-2025-11-24", tokens[0].String()) + }) + + t.Run("transforms advanced-tool-use to tool-search-tool", func(t *testing.T) { + input := `{"anthropic_beta":["advanced-tool-use-2025-11-20"],"messages":[]}` + result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6") + beta := gjson.GetBytes(result, "anthropic_beta") + tokens := beta.Array() + assert.Equal(t, 2, len(tokens)) // tool-search-tool + tool-examples (auto-associated) + assert.Contains(t, []string{tokens[0].String(), tokens[1].String()}, "tool-search-tool-2025-10-19") + assert.Contains(t, []string{tokens[0].String(), tokens[1].String()}, "tool-examples-2025-10-29") + }) + + t.Run("no-op when anthropic_beta not present", func(t *testing.T) { + input := `{"messages":[]}` + result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6") + assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists()) + }) + + t.Run("preserves supported beta tokens", func(t *testing.T) { + input := `{"anthropic_beta":["computer-use-2025-11-24","context-1m-2025-08-07"],"messages":[]}` + result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6") + beta := gjson.GetBytes(result, "anthropic_beta") + tokens := beta.Array() + assert.Equal(t, 2, len(tokens)) + assert.Contains(t, []string{tokens[0].String(), tokens[1].String()}, "computer-use-2025-11-24") + assert.Contains(t, []string{tokens[0].String(), tokens[1].String()}, "context-1m-2025-08-07") + }) +} diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index 760f688d..88ed2df7 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -248,6 +248,17 @@ func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool { return ok && enabled } +// IsBedrockCCCompatEnabled 返回该渠道是否启用了 Bedrock CC 兼容模式。 +// 一旦启用,该渠道下所有请求都会应用 CC 兼容转换,不区分账号 platform。 +func (c *Channel) IsBedrockCCCompatEnabled(platform string) bool { + if c == nil || c.FeaturesConfig == nil { + return false + } + // 直接检查 bedrock_cc_compat 开关,不再检查 platform 子字段 + enabled, ok := c.FeaturesConfig[featureKeyBedrockCCCompat].(bool) + return ok && enabled +} + // deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution. func deepCopyFeaturesConfig(src map[string]any) map[string]any { dst := make(map[string]any, len(src)) diff --git a/backend/internal/service/channel_bedrock_cc_test.go b/backend/internal/service/channel_bedrock_cc_test.go new file mode 100644 index 00000000..1d0476f6 --- /dev/null +++ b/backend/internal/service/channel_bedrock_cc_test.go @@ -0,0 +1,73 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestChannel_IsBedrockCCCompatEnabled_Enabled(t *testing.T) { + c := &Channel{ + FeaturesConfig: map[string]any{ + featureKeyBedrockCCCompat: true, + }, + } + require.True(t, c.IsBedrockCCCompatEnabled("bedrock")) +} + +func TestChannel_IsBedrockCCCompatEnabled_AppliesToAllPlatforms(t *testing.T) { + c := &Channel{ + FeaturesConfig: map[string]any{ + featureKeyBedrockCCCompat: true, + }, + } + require.True(t, c.IsBedrockCCCompatEnabled("anthropic")) + require.True(t, c.IsBedrockCCCompatEnabled("openai")) + require.True(t, c.IsBedrockCCCompatEnabled("")) +} + +func TestChannel_IsBedrockCCCompatEnabled_Disabled(t *testing.T) { + c := &Channel{ + FeaturesConfig: map[string]any{ + featureKeyBedrockCCCompat: false, + }, + } + require.False(t, c.IsBedrockCCCompatEnabled("bedrock")) +} + +func TestChannel_IsBedrockCCCompatEnabled_NilFeaturesConfig(t *testing.T) { + c := &Channel{FeaturesConfig: nil} + require.False(t, c.IsBedrockCCCompatEnabled("bedrock")) +} + +func TestChannel_IsBedrockCCCompatEnabled_NilChannel(t *testing.T) { + var c *Channel + require.False(t, c.IsBedrockCCCompatEnabled("bedrock")) +} + +func TestChannel_IsBedrockCCCompatEnabled_WrongType(t *testing.T) { + c := &Channel{ + FeaturesConfig: map[string]any{ + featureKeyBedrockCCCompat: "yes", + }, + } + require.False(t, c.IsBedrockCCCompatEnabled("bedrock")) +} + +func TestChannel_IsBedrockCCCompatEnabled_OldMapFormat(t *testing.T) { + c := &Channel{ + FeaturesConfig: map[string]any{ + featureKeyBedrockCCCompat: map[string]any{"bedrock": true}, + }, + } + require.False(t, c.IsBedrockCCCompatEnabled("bedrock")) +} + +func TestChannel_IsBedrockCCCompatEnabled_MissingKey(t *testing.T) { + c := &Channel{ + FeaturesConfig: map[string]any{ + "other_feature": true, + }, + } + require.False(t, c.IsBedrockCCCompatEnabled("bedrock")) +} diff --git a/backend/internal/service/channel_monitor_checker.go b/backend/internal/service/channel_monitor_checker.go index 25737e45..7fb829a3 100644 --- a/backend/internal/service/channel_monitor_checker.go +++ b/backend/internal/service/channel_monitor_checker.go @@ -281,9 +281,54 @@ func callProvider(ctx context.Context, provider, endpoint, apiKey, model, prompt if err != nil { return "", "", status, err } + if provider == MonitorProviderOpenAI && apiMode == MonitorAPIModeResponses { + return extractOpenAIResponsesText(respBytes), string(respBytes), status, nil + } return gjson.GetBytes(respBytes, adapter.textPath).String(), string(respBytes), status, nil } +// extractOpenAIResponsesText 聚合 Responses API 的最终 assistant 文本。 +// Responses 的 output 数组顺序由模型决定:reasoning / tool-call item 可能排在 message 前面, +// 因此不能假设文本永远在 output.0.content.0.text。 +func extractOpenAIResponsesText(respBytes []byte) string { + if text := gjson.GetBytes(respBytes, "output_text").String(); strings.TrimSpace(text) != "" { + return text + } + + var texts []string + outputs := gjson.GetBytes(respBytes, "output") + if outputs.IsArray() { + outputs.ForEach(func(_, output gjson.Result) bool { + outputType := output.Get("type").String() + if outputType != "" && outputType != "message" { + return true + } + + content := output.Get("content") + if !content.IsArray() { + return true + } + + content.ForEach(func(_, block gjson.Result) bool { + blockType := block.Get("type").String() + if blockType != "" && blockType != "output_text" { + return true + } + if text := block.Get("text").String(); strings.TrimSpace(text) != "" { + texts = append(texts, text) + } + return true + }) + return true + }) + } + + if len(texts) > 0 { + return strings.Join(texts, "") + } + return gjson.GetBytes(respBytes, providerOpenAIResponsesAdapter.textPath).String() +} + // mergeHeaders 把用户自定义 headers 合并到 adapter 默认 headers 上。 // 用户值覆盖默认;命中黑名单(hop-by-hop / 由 http.Client 自管的)的 key 静默丢弃。 func mergeHeaders(base map[string]string, opts *CheckOptions) map[string]string { diff --git a/backend/internal/service/channel_monitor_checker_body_test.go b/backend/internal/service/channel_monitor_checker_body_test.go index 620cf565..bba3d7df 100644 --- a/backend/internal/service/channel_monitor_checker_body_test.go +++ b/backend/internal/service/channel_monitor_checker_body_test.go @@ -60,10 +60,11 @@ func setupFakeAnthropic(t *testing.T, handler *captureHandler) string { } type openAICaptureHandler struct { - lastBody map[string]any - lastHeaders http.Header - lastPath string - status int + lastBody map[string]any + lastHeaders http.Header + lastPath string + status int + responsesLeadingReasoning bool } func (h *openAICaptureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -82,10 +83,23 @@ func (h *openAICaptureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) answer := answerFromOpenAIRequest(parsed) if h.lastPath == providerOpenAIResponsesPath { + output := []map[string]any{} + if h.responsesLeadingReasoning { + output = append(output, map[string]any{ + "type": "reasoning", + "summary": []any{}, + }) + } + output = append(output, map[string]any{ + "type": "message", + "status": "completed", + "role": "assistant", + "content": []map[string]any{ + {"type": "output_text", "text": answer}, + }, + }) _ = json.NewEncoder(w).Encode(map[string]any{ - "output": []map[string]any{{ - "content": []map[string]any{{"type": "output_text", "text": answer}}, - }}, + "output": output, }) return } @@ -212,6 +226,22 @@ func TestRunCheckForModel_OpenAIResponses_DefaultRequest(t *testing.T) { } } +func TestRunCheckForModel_OpenAIResponses_SkipsLeadingReasoningItem(t *testing.T) { + h := &openAICaptureHandler{responsesLeadingReasoning: true} + endpoint := setupFakeOpenAI(t, h) + + res := runCheckForModel(context.Background(), MonitorProviderOpenAI, endpoint, "sk-openai", "gpt-5.5", &CheckOptions{ + APIMode: MonitorAPIModeResponses, + }) + + if res.Status != MonitorStatusOperational { + t.Fatalf("responses request should find text after leading reasoning item, got status=%s message=%q", res.Status, res.Message) + } + if h.lastPath != providerOpenAIResponsesPath { + t.Fatalf("expected responses path %q, got %q", providerOpenAIResponsesPath, h.lastPath) + } +} + func TestRunCheckForModel_OpenAIResponsesReplaceMissingInstructionsFailsLocally(t *testing.T) { h := &openAICaptureHandler{} endpoint := setupFakeOpenAI(t, h) diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index 386d5ed0..712fc1a7 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -3,13 +3,17 @@ package service import ( "context" "crypto/rand" + "crypto/sha256" "encoding/binary" + "encoding/hex" "os" "strconv" + "sync" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "golang.org/x/sync/singleflight" ) // ConcurrencyCache 定义并发控制的缓存接口 @@ -79,18 +83,50 @@ func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error } const ( - // Default extra wait slots beyond concurrency limit + // 默认等待队列额外槽位 defaultExtraWaitSlots = 20 + + defaultAccountLoadBatchCacheTTL = 200 * time.Millisecond + accountLoadBatchFetchTimeout = 3 * time.Second + maxAccountLoadBatchCacheEntries = 256 ) -// ConcurrencyService manages concurrent request limiting for accounts and users +// ConcurrencyService 管理账号和用户的并发限制。 type ConcurrencyService struct { cache ConcurrencyCache + + accountLoadCacheTTL atomic.Int64 + accountLoadCacheMu sync.RWMutex + accountLoadCache map[string]cachedAccountLoadBatch + accountLoadGroup singleflight.Group } -// NewConcurrencyService creates a new ConcurrencyService +type cachedAccountLoadBatch struct { + loadMap map[int64]*AccountLoadInfo + expiresAt time.Time +} + +// NewConcurrencyService 创建并发控制服务。 func NewConcurrencyService(cache ConcurrencyCache) *ConcurrencyService { - return &ConcurrencyService{cache: cache} + svc := &ConcurrencyService{ + cache: cache, + accountLoadCache: make(map[string]cachedAccountLoadBatch), + } + svc.SetAccountLoadBatchCacheTTL(defaultAccountLoadBatchCacheTTL) + return svc +} + +// SetAccountLoadBatchCacheTTL 设置账号负载批量读取的极短 TTL 缓存;非正数表示禁用缓存。 +func (s *ConcurrencyService) SetAccountLoadBatchCacheTTL(ttl time.Duration) { + if s == nil { + return + } + s.accountLoadCacheTTL.Store(int64(ttl)) + if ttl <= 0 { + s.accountLoadCacheMu.Lock() + s.accountLoadCache = make(map[string]cachedAccountLoadBatch) + s.accountLoadCacheMu.Unlock() + } } // AcquireResult represents the result of acquiring a concurrency slot @@ -284,12 +320,140 @@ func CalculateMaxWait(userConcurrency int) int { return userConcurrency + defaultExtraWaitSlots } -// GetAccountsLoadBatch returns load info for multiple accounts. +// GetAccountsLoadBatch 批量获取账号负载信息。 func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + return s.getAccountsLoadBatch(ctx, accounts, true) +} + +// GetAccountsLoadBatchFresh 绕过极短 TTL 缓存,用于抢槽失败后的实时刷新兜底。 +func (s *ConcurrencyService) GetAccountsLoadBatchFresh(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + return s.getAccountsLoadBatch(ctx, accounts, false) +} + +func (s *ConcurrencyService) getAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency, allowCache bool) (map[int64]*AccountLoadInfo, error) { + if len(accounts) == 0 { + return map[int64]*AccountLoadInfo{}, nil + } if s.cache == nil { return map[int64]*AccountLoadInfo{}, nil } - return s.cache.GetAccountsLoadBatch(ctx, accounts) + + ttl := time.Duration(s.accountLoadCacheTTL.Load()) + if !allowCache || ttl <= 0 { + return s.fetchAccountsLoadBatch(ctx, accounts) + } + + key := accountLoadBatchCacheKey(accounts) + if cached, ok := s.getCachedAccountLoadBatch(key, time.Now()); ok { + return cached, nil + } + + value, err, _ := s.accountLoadGroup.Do(key, func() (any, error) { + now := time.Now() + if cached, ok := s.getCachedAccountLoadBatch(key, now); ok { + return cached, nil + } + loadMap, fetchErr := s.fetchAccountsLoadBatch(ctx, accounts) + if fetchErr != nil { + return nil, fetchErr + } + cached := cloneAccountLoadMap(loadMap) + s.storeCachedAccountLoadBatch(key, cached, now.Add(ttl)) + return cached, nil + }) + if err != nil { + return nil, err + } + loadMap, _ := value.(map[int64]*AccountLoadInfo) + if loadMap == nil { + return map[int64]*AccountLoadInfo{}, nil + } + return loadMap, nil +} + +func (s *ConcurrencyService) fetchAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + if s.cache == nil { + return map[int64]*AccountLoadInfo{}, nil + } + baseCtx := context.Background() + if ctx != nil { + baseCtx = context.WithoutCancel(ctx) + } + redisCtx, cancel := context.WithTimeout(baseCtx, accountLoadBatchFetchTimeout) + defer cancel() + return s.cache.GetAccountsLoadBatch(redisCtx, accounts) +} + +func (s *ConcurrencyService) getCachedAccountLoadBatch(key string, now time.Time) (map[int64]*AccountLoadInfo, bool) { + s.accountLoadCacheMu.RLock() + cached, ok := s.accountLoadCache[key] + s.accountLoadCacheMu.RUnlock() + if !ok { + return nil, false + } + if !now.Before(cached.expiresAt) { + s.accountLoadCacheMu.Lock() + if current, exists := s.accountLoadCache[key]; exists && !now.Before(current.expiresAt) { + delete(s.accountLoadCache, key) + } + s.accountLoadCacheMu.Unlock() + return nil, false + } + return cached.loadMap, true +} + +func (s *ConcurrencyService) storeCachedAccountLoadBatch(key string, loadMap map[int64]*AccountLoadInfo, expiresAt time.Time) { + s.accountLoadCacheMu.Lock() + if s.accountLoadCache == nil { + s.accountLoadCache = make(map[string]cachedAccountLoadBatch) + } + if len(s.accountLoadCache) >= maxAccountLoadBatchCacheEntries { + now := time.Now() + for cacheKey, cached := range s.accountLoadCache { + if !now.Before(cached.expiresAt) { + delete(s.accountLoadCache, cacheKey) + } + } + for len(s.accountLoadCache) >= maxAccountLoadBatchCacheEntries { + for cacheKey := range s.accountLoadCache { + delete(s.accountLoadCache, cacheKey) + break + } + } + } + s.accountLoadCache[key] = cachedAccountLoadBatch{ + loadMap: loadMap, + expiresAt: expiresAt, + } + s.accountLoadCacheMu.Unlock() +} + +func accountLoadBatchCacheKey(accounts []AccountWithConcurrency) string { + hash := sha256.New() + var buf [16]byte + for _, account := range accounts { + binary.LittleEndian.PutUint64(buf[:8], uint64(account.ID)) + binary.LittleEndian.PutUint64(buf[8:], uint64(int64(account.MaxConcurrency))) + _, _ = hash.Write(buf[:]) + } + sum := hash.Sum(nil) + return strconv.Itoa(len(accounts)) + ":" + hex.EncodeToString(sum) +} + +func cloneAccountLoadMap(loadMap map[int64]*AccountLoadInfo) map[int64]*AccountLoadInfo { + if len(loadMap) == 0 { + return map[int64]*AccountLoadInfo{} + } + clone := make(map[int64]*AccountLoadInfo, len(loadMap)) + for accountID, loadInfo := range loadMap { + if loadInfo == nil { + clone[accountID] = nil + continue + } + copied := *loadInfo + clone[accountID] = &copied + } + return clone } // GetUsersLoadBatch returns load info for multiple users. diff --git a/backend/internal/service/concurrency_service_test.go b/backend/internal/service/concurrency_service_test.go index 078ba0dc..7d5f501d 100644 --- a/backend/internal/service/concurrency_service_test.go +++ b/backend/internal/service/concurrency_service_test.go @@ -7,7 +7,9 @@ import ( "errors" "strconv" "strings" + "sync/atomic" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -32,6 +34,7 @@ type stubConcurrencyCacheForTest struct { // 记录调用 releasedAccountIDs []int64 releasedRequestIDs []string + loadBatchCalls atomic.Int64 } var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil) @@ -82,6 +85,7 @@ func (c *stubConcurrencyCacheForTest) DecrementWaitCount(_ context.Context, _ in return nil } func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + c.loadBatchCalls.Add(1) return c.loadBatch, c.loadBatchErr } func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) { @@ -237,6 +241,47 @@ func TestGetAccountsLoadBatch_NilCache(t *testing.T) { require.Empty(t, result) } +func TestGetAccountsLoadBatch_UsesShortTTLCache(t *testing.T) { + cache := &stubConcurrencyCacheForTest{ + loadBatch: map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, CurrentConcurrency: 1, LoadRate: 20}, + }, + } + svc := NewConcurrencyService(cache) + svc.SetAccountLoadBatchCacheTTL(time.Second) + + accounts := []AccountWithConcurrency{{ID: 1, MaxConcurrency: 5}} + first, err := svc.GetAccountsLoadBatch(context.Background(), accounts) + require.NoError(t, err) + require.Equal(t, 1, first[int64(1)].CurrentConcurrency) + + cache.loadBatch[1] = &AccountLoadInfo{AccountID: 1, CurrentConcurrency: 4, LoadRate: 80} + second, err := svc.GetAccountsLoadBatch(context.Background(), accounts) + require.NoError(t, err) + require.Equal(t, 1, second[int64(1)].CurrentConcurrency) + require.Equal(t, int64(1), cache.loadBatchCalls.Load()) +} + +func TestGetAccountsLoadBatchFresh_BypassesShortTTLCache(t *testing.T) { + cache := &stubConcurrencyCacheForTest{ + loadBatch: map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, CurrentConcurrency: 1, LoadRate: 20}, + }, + } + svc := NewConcurrencyService(cache) + svc.SetAccountLoadBatchCacheTTL(time.Second) + + accounts := []AccountWithConcurrency{{ID: 1, MaxConcurrency: 5}} + _, err := svc.GetAccountsLoadBatch(context.Background(), accounts) + require.NoError(t, err) + + cache.loadBatch[1] = &AccountLoadInfo{AccountID: 1, CurrentConcurrency: 4, LoadRate: 80} + fresh, err := svc.GetAccountsLoadBatchFresh(context.Background(), accounts) + require.NoError(t, err) + require.Equal(t, 4, fresh[int64(1)].CurrentConcurrency) + require.Equal(t, int64(2), cache.loadBatchCalls.Load()) +} + func TestIncrementWaitCount_Success(t *testing.T) { cache := &stubConcurrencyCacheForTest{waitAllowed: true} svc := NewConcurrencyService(cache) diff --git a/backend/internal/service/content_moderation.go b/backend/internal/service/content_moderation.go index 2d066298..b5a889e1 100644 --- a/backend/internal/service/content_moderation.go +++ b/backend/internal/service/content_moderation.go @@ -44,6 +44,10 @@ const ( ContentModerationKeywordModeKeywordAndAPI = "keyword_and_api" ContentModerationKeywordModeAPIOnly = "api_only" + ContentModerationModelFilterAll = "all" + ContentModerationModelFilterInclude = "include" + ContentModerationModelFilterExclude = "exclude" + ContentModerationProtocolAnthropicMessages = "anthropic_messages" ContentModerationProtocolOpenAIResponses = "openai_responses" ContentModerationProtocolOpenAIChat = "openai_chat_completions" @@ -80,6 +84,8 @@ const ( maxContentModerationTestImageDataURLBytes = 12 * 1024 * 1024 maxContentModerationBlockedKeywords = 10000 maxContentModerationBlockedKeywordRunes = 200 + maxContentModerationModelFilterModels = 1000 + maxContentModerationModelFilterRunes = 200 contentModerationCleanupInterval = 24 * time.Hour contentModerationCleanupTimeout = 30 * time.Minute @@ -127,32 +133,33 @@ func ContentModerationCategories() []string { } type ContentModerationConfig struct { - Enabled bool `json:"enabled"` - Mode string `json:"mode"` - BaseURL string `json:"base_url"` - Model string `json:"model"` - APIKey string `json:"api_key,omitempty"` - APIKeys []string `json:"api_keys,omitempty"` - TimeoutMS int `json:"timeout_ms"` - SampleRate int `json:"sample_rate"` - AllGroups bool `json:"all_groups"` - GroupIDs []int64 `json:"group_ids"` - RecordNonHits bool `json:"record_non_hits"` - Thresholds map[string]float64 `json:"thresholds"` - WorkerCount int `json:"worker_count"` - QueueSize int `json:"queue_size"` - BlockStatus int `json:"block_status"` - BlockMessage string `json:"block_message"` - EmailOnHit bool `json:"email_on_hit"` - AutoBanEnabled bool `json:"auto_ban_enabled"` - BanThreshold int `json:"ban_threshold"` - ViolationWindowHours int `json:"violation_window_hours"` - RetryCount int `json:"retry_count"` - HitRetentionDays int `json:"hit_retention_days"` - NonHitRetentionDays int `json:"non_hit_retention_days"` - PreHashCheckEnabled bool `json:"pre_hash_check_enabled"` - BlockedKeywords []string `json:"blocked_keywords"` - KeywordBlockingMode string `json:"keyword_blocking_mode"` + Enabled bool `json:"enabled"` + Mode string `json:"mode"` + BaseURL string `json:"base_url"` + Model string `json:"model"` + APIKey string `json:"api_key,omitempty"` + APIKeys []string `json:"api_keys,omitempty"` + TimeoutMS int `json:"timeout_ms"` + SampleRate int `json:"sample_rate"` + AllGroups bool `json:"all_groups"` + GroupIDs []int64 `json:"group_ids"` + RecordNonHits bool `json:"record_non_hits"` + Thresholds map[string]float64 `json:"thresholds"` + WorkerCount int `json:"worker_count"` + QueueSize int `json:"queue_size"` + BlockStatus int `json:"block_status"` + BlockMessage string `json:"block_message"` + EmailOnHit bool `json:"email_on_hit"` + AutoBanEnabled bool `json:"auto_ban_enabled"` + BanThreshold int `json:"ban_threshold"` + ViolationWindowHours int `json:"violation_window_hours"` + RetryCount int `json:"retry_count"` + HitRetentionDays int `json:"hit_retention_days"` + NonHitRetentionDays int `json:"non_hit_retention_days"` + PreHashCheckEnabled bool `json:"pre_hash_check_enabled"` + BlockedKeywords []string `json:"blocked_keywords"` + KeywordBlockingMode string `json:"keyword_blocking_mode"` + ModelFilter ContentModerationModelFilter `json:"model_filter"` } type ContentModerationConfigView struct { @@ -184,6 +191,7 @@ type ContentModerationConfigView struct { PreHashCheckEnabled bool `json:"pre_hash_check_enabled"` BlockedKeywords []string `json:"blocked_keywords"` KeywordBlockingMode string `json:"keyword_blocking_mode"` + ModelFilter ContentModerationModelFilter `json:"model_filter"` } type ContentModerationAPIKeyStatus struct { @@ -227,34 +235,40 @@ type ContentModerationTestAuditResult struct { } type UpdateContentModerationConfigInput struct { - Enabled *bool `json:"enabled"` - Mode *string `json:"mode"` - BaseURL *string `json:"base_url"` - Model *string `json:"model"` - APIKey *string `json:"api_key"` - APIKeys *[]string `json:"api_keys"` - APIKeysMode string `json:"api_keys_mode"` - DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"` - ClearAPIKey bool `json:"clear_api_key"` - TimeoutMS *int `json:"timeout_ms"` - SampleRate *int `json:"sample_rate"` - AllGroups *bool `json:"all_groups"` - GroupIDs *[]int64 `json:"group_ids"` - RecordNonHits *bool `json:"record_non_hits"` - WorkerCount *int `json:"worker_count"` - QueueSize *int `json:"queue_size"` - BlockStatus *int `json:"block_status"` - BlockMessage *string `json:"block_message"` - EmailOnHit *bool `json:"email_on_hit"` - AutoBanEnabled *bool `json:"auto_ban_enabled"` - BanThreshold *int `json:"ban_threshold"` - ViolationWindowHours *int `json:"violation_window_hours"` - RetryCount *int `json:"retry_count"` - HitRetentionDays *int `json:"hit_retention_days"` - NonHitRetentionDays *int `json:"non_hit_retention_days"` - PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"` - BlockedKeywords *[]string `json:"blocked_keywords"` - KeywordBlockingMode *string `json:"keyword_blocking_mode"` + Enabled *bool `json:"enabled"` + Mode *string `json:"mode"` + BaseURL *string `json:"base_url"` + Model *string `json:"model"` + APIKey *string `json:"api_key"` + APIKeys *[]string `json:"api_keys"` + APIKeysMode string `json:"api_keys_mode"` + DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"` + ClearAPIKey bool `json:"clear_api_key"` + TimeoutMS *int `json:"timeout_ms"` + SampleRate *int `json:"sample_rate"` + AllGroups *bool `json:"all_groups"` + GroupIDs *[]int64 `json:"group_ids"` + RecordNonHits *bool `json:"record_non_hits"` + WorkerCount *int `json:"worker_count"` + QueueSize *int `json:"queue_size"` + BlockStatus *int `json:"block_status"` + BlockMessage *string `json:"block_message"` + EmailOnHit *bool `json:"email_on_hit"` + AutoBanEnabled *bool `json:"auto_ban_enabled"` + BanThreshold *int `json:"ban_threshold"` + ViolationWindowHours *int `json:"violation_window_hours"` + RetryCount *int `json:"retry_count"` + HitRetentionDays *int `json:"hit_retention_days"` + NonHitRetentionDays *int `json:"non_hit_retention_days"` + PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"` + BlockedKeywords *[]string `json:"blocked_keywords"` + KeywordBlockingMode *string `json:"keyword_blocking_mode"` + ModelFilter *ContentModerationModelFilter `json:"model_filter"` +} + +type ContentModerationModelFilter struct { + Type string `json:"type"` + Models []string `json:"models"` } type ContentModerationCheckInput struct { @@ -581,6 +595,9 @@ func (s *ContentModerationService) UpdateConfig(ctx context.Context, input Updat if input.KeywordBlockingMode != nil { cfg.KeywordBlockingMode = strings.TrimSpace(*input.KeywordBlockingMode) } + if input.ModelFilter != nil { + cfg.ModelFilter = *input.ModelFilter + } if input.AllGroups != nil { cfg.AllGroups = *input.AllGroups } @@ -719,7 +736,8 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer "error", err) return allow, nil } - inScope := cfg.includesGroup(input.GroupID) + inGroupScope := cfg.includesGroup(input.GroupID) + inModelScope := cfg.includesModel(input.Model) slog.Info("content_moderation.config_loaded", "user_id", input.UserID, "api_key_id", input.APIKeyID, @@ -733,7 +751,10 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer "mode", cfg.Mode, "all_groups", cfg.AllGroups, "configured_group_ids", cfg.GroupIDs, - "in_scope", inScope, + "in_group_scope", inGroupScope, + "model_filter_type", cfg.ModelFilter.Type, + "configured_models", cfg.ModelFilter.Models, + "in_model_scope", inModelScope, "sample_rate", cfg.SampleRate, "api_key_count", len(cfg.apiKeys()), "pre_hash_check_enabled", cfg.PreHashCheckEnabled, @@ -756,7 +777,7 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer "protocol", input.Protocol) return allow, nil } - if !inScope { + if !inGroupScope { slog.Info("content_moderation.skip_group_out_of_scope", "user_id", input.UserID, "api_key_id", input.APIKeyID, @@ -768,6 +789,19 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer "configured_group_ids", cfg.GroupIDs) return allow, nil } + if !inModelScope { + slog.Info("content_moderation.skip_model_out_of_scope", + "user_id", input.UserID, + "api_key_id", input.APIKeyID, + "group_id", contentModerationLogGroupID(input.GroupID), + "group_name", input.GroupName, + "endpoint", input.Endpoint, + "protocol", input.Protocol, + "model", input.Model, + "model_filter_type", cfg.ModelFilter.Type, + "configured_models", cfg.ModelFilter.Models) + return allow, nil + } content := ExtractContentModerationInput(input.Protocol, input.Body) if content.IsEmpty() { slog.Info("content_moderation.skip_empty_input", @@ -1025,6 +1059,9 @@ func (s *ContentModerationService) worker(id int) { if !cfg.includesGroup(task.input.GroupID) { return } + if !cfg.includesModel(task.input.Model) { + return + } s.asyncActive.Add(1) defer s.asyncActive.Add(-1) queueDelay := int(time.Since(task.enqueuedAt).Milliseconds()) @@ -1270,6 +1307,9 @@ func (s *ContentModerationService) validateConfig(ctx context.Context, cfg *Cont if cfg.BlockStatus < 400 || cfg.BlockStatus > 599 { return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_BLOCK_STATUS", "拦截 HTTP 状态码必须在 400-599 之间") } + if cfg.ModelFilter.Type != ContentModerationModelFilterAll && len(cfg.ModelFilter.Models) == 0 { + return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_MODEL_FILTER", "指定或排除模型时至少需要配置 1 个模型") + } if !cfg.AllGroups && len(cfg.GroupIDs) > 0 && s.groupRepo != nil { for _, groupID := range cfg.GroupIDs { if _, err := s.groupRepo.GetByIDLite(ctx, groupID); err != nil { @@ -1590,6 +1630,10 @@ func defaultContentModerationConfig() *ContentModerationConfig { PreHashCheckEnabled: false, BlockedKeywords: []string{}, KeywordBlockingMode: ContentModerationKeywordModeKeywordAndAPI, + ModelFilter: ContentModerationModelFilter{ + Type: ContentModerationModelFilterAll, + Models: []string{}, + }, } } @@ -1670,6 +1714,7 @@ func (cfg *ContentModerationConfig) normalize() { cfg.Thresholds = mergeContentModerationThresholds(ContentModerationDefaultThresholds(), cfg.Thresholds) cfg.BlockedKeywords = normalizeBlockedKeywords(cfg.BlockedKeywords) cfg.KeywordBlockingMode = normalizeKeywordBlockingMode(cfg.KeywordBlockingMode) + cfg.ModelFilter = normalizeContentModerationModelFilter(cfg.ModelFilter) } func (cfg *ContentModerationConfig) includesGroup(groupID *int64) bool { @@ -1687,6 +1732,21 @@ func (cfg *ContentModerationConfig) includesGroup(groupID *int64) bool { return false } +func (cfg *ContentModerationConfig) includesModel(model string) bool { + if cfg == nil { + return true + } + filter := normalizeContentModerationModelFilter(cfg.ModelFilter) + switch filter.Type { + case ContentModerationModelFilterInclude: + return contentModerationModelListContains(filter.Models, model) + case ContentModerationModelFilterExclude: + return !contentModerationModelListContains(filter.Models, model) + default: + return true + } +} + func contentModerationLogGroupID(groupID *int64) int64 { if groupID == nil { return 0 @@ -1848,6 +1908,7 @@ func (s *ContentModerationService) configView(cfg *ContentModerationConfig) *Con PreHashCheckEnabled: cfg.PreHashCheckEnabled, BlockedKeywords: append([]string(nil), cfg.BlockedKeywords...), KeywordBlockingMode: cfg.KeywordBlockingMode, + ModelFilter: cloneContentModerationModelFilter(cfg.ModelFilter), } } @@ -2125,6 +2186,73 @@ func normalizeKeywordBlockingMode(mode string) string { } } +func normalizeContentModerationModelFilter(filter ContentModerationModelFilter) ContentModerationModelFilter { + out := ContentModerationModelFilter{ + Type: normalizeContentModerationModelFilterType(filter.Type), + Models: normalizeContentModerationModelNames(filter.Models), + } + if out.Type == ContentModerationModelFilterAll { + out.Models = []string{} + } + return out +} + +func cloneContentModerationModelFilter(filter ContentModerationModelFilter) ContentModerationModelFilter { + normalized := normalizeContentModerationModelFilter(filter) + normalized.Models = append([]string(nil), normalized.Models...) + return normalized +} + +func normalizeContentModerationModelFilterType(filterType string) string { + switch strings.ToLower(strings.TrimSpace(filterType)) { + case ContentModerationModelFilterInclude: + return ContentModerationModelFilterInclude + case ContentModerationModelFilterExclude: + return ContentModerationModelFilterExclude + case ContentModerationModelFilterAll: + return ContentModerationModelFilterAll + default: + return ContentModerationModelFilterAll + } +} + +func normalizeContentModerationModelNames(models []string) []string { + if len(models) == 0 { + return []string{} + } + out := make([]string, 0, len(models)) + seen := make(map[string]struct{}, len(models)) + for _, raw := range models { + model := trimRunes(strings.TrimSpace(raw), maxContentModerationModelFilterRunes) + if model == "" { + continue + } + key := strings.ToLower(model) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out = append(out, model) + if len(out) >= maxContentModerationModelFilterModels { + break + } + } + return out +} + +func contentModerationModelListContains(models []string, model string) bool { + model = strings.ToLower(strings.TrimSpace(model)) + if model == "" { + return false + } + for _, candidate := range models { + if strings.ToLower(strings.TrimSpace(candidate)) == model { + return true + } + } + return false +} + func matchBlockedKeyword(text string, keywords []string) (string, bool) { if text == "" || len(keywords) == 0 { return "", false diff --git a/backend/internal/service/content_moderation_input.go b/backend/internal/service/content_moderation_input.go index 67df397d..73d5b697 100644 --- a/backend/internal/service/content_moderation_input.go +++ b/backend/internal/service/content_moderation_input.go @@ -48,44 +48,44 @@ func collectLastRoleMessage(messages gjson.Result, role string, parts *[]string, if !messages.IsArray() { return } - var lastParts []string - var lastImages []string - messages.ForEach(func(_, msg gjson.Result) bool { - if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == role { - var candidate []string - var candidateImages []string - collectContentValue(msg.Get("content"), &candidate, &candidateImages) - if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 { - lastParts = candidate - lastImages = candidateImages - } - } - return true - }) - *parts = append(*parts, lastParts...) - *images = append(*images, lastImages...) + array := messages.Array() + if len(array) == 0 { + return + } + last := array[len(array)-1] + if strings.ToLower(strings.TrimSpace(last.Get("role").String())) != role { + return + } + var candidate []string + var candidateImages []string + collectContentValue(last.Get("content"), &candidate, &candidateImages) + if normalizeContentModerationText(strings.Join(candidate, "\n")) == "" && len(candidateImages) == 0 { + return + } + *parts = append(*parts, candidate...) + *images = append(*images, candidateImages...) } func collectLastAnthropicUserMessage(messages gjson.Result, parts *[]string, images *[]string) { if !messages.IsArray() { return } - var lastParts []string - var lastImages []string - messages.ForEach(func(_, msg gjson.Result) bool { - if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == "user" { - var candidate []string - var candidateImages []string - collectAnthropicUserContentValue(msg.Get("content"), &candidate, &candidateImages) - if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 { - lastParts = candidate - lastImages = candidateImages - } - } - return true - }) - *parts = append(*parts, lastParts...) - *images = append(*images, lastImages...) + array := messages.Array() + if len(array) == 0 { + return + } + last := array[len(array)-1] + if strings.ToLower(strings.TrimSpace(last.Get("role").String())) != "user" { + return + } + var candidate []string + var candidateImages []string + collectAnthropicUserContentValue(last.Get("content"), &candidate, &candidateImages) + if normalizeContentModerationText(strings.Join(candidate, "\n")) == "" && len(candidateImages) == 0 { + return + } + *parts = append(*parts, candidate...) + *images = append(*images, candidateImages...) } func collectAnthropicUserContentValue(value gjson.Result, parts *[]string, images *[]string) { @@ -128,18 +128,17 @@ func collectLastResponsesInput(input gjson.Result, parts *[]string, images *[]st case input.Type == gjson.String: addModerationText(parts, input.String()) case input.IsArray(): - var last gjson.Result - input.ForEach(func(_, item gjson.Result) bool { - if isResponsesUserTextItem(item) { - last = item - } - return true - }) - if last.Exists() { - collectContentValue(last.Get("content"), parts, images) - if last.Get("type").String() == "input_text" || last.Get("text").Exists() { - collectContentValue(last, parts, images) - } + array := input.Array() + if len(array) == 0 { + return + } + last := array[len(array)-1] + if !isResponsesUserTextItem(last) { + return + } + collectContentValue(last.Get("content"), parts, images) + if last.Get("type").String() == "input_text" || last.Get("text").Exists() { + collectContentValue(last, parts, images) } case input.IsObject(): if isResponsesUserTextItem(input) { @@ -176,29 +175,29 @@ func collectLastGeminiContent(contents gjson.Result, parts *[]string, images *[] if !contents.IsArray() { return } - var lastParts []string - var lastImages []string - contents.ForEach(func(_, content gjson.Result) bool { - role := strings.ToLower(strings.TrimSpace(content.Get("role").String())) - if role == "" || role == "user" { - var candidate []string - var candidateImages []string - if arr := content.Get("parts"); arr.IsArray() { - arr.ForEach(func(_, part gjson.Result) bool { - addModerationText(&candidate, part.Get("text").String()) - addGeminiModerationImage(&candidateImages, part) - return true - }) - } - if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 { - lastParts = candidate - lastImages = candidateImages - } - } - return true - }) - *parts = append(*parts, lastParts...) - *images = append(*images, lastImages...) + array := contents.Array() + if len(array) == 0 { + return + } + last := array[len(array)-1] + role := strings.ToLower(strings.TrimSpace(last.Get("role").String())) + if role != "" && role != "user" { + return + } + var candidate []string + var candidateImages []string + if arr := last.Get("parts"); arr.IsArray() { + arr.ForEach(func(_, part gjson.Result) bool { + addModerationText(&candidate, part.Get("text").String()) + addGeminiModerationImage(&candidateImages, part) + return true + }) + } + if normalizeContentModerationText(strings.Join(candidate, "\n")) == "" && len(candidateImages) == 0 { + return + } + *parts = append(*parts, candidate...) + *images = append(*images, candidateImages...) } func collectContentValue(value gjson.Result, parts *[]string, images *[]string) { diff --git a/backend/internal/service/content_moderation_input_test.go b/backend/internal/service/content_moderation_input_test.go new file mode 100644 index 00000000..d51dc21b --- /dev/null +++ b/backend/internal/service/content_moderation_input_test.go @@ -0,0 +1,179 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// 当数组末尾不是用户消息时(典型场景:Agent 工具循环结束于 tool/assistant), +// 应直接跳过审计——不再回溯查找历史中的某条用户消息。 + +func TestExtractContentModerationInput_AnthropicAgentToolLoopSkipsAudit(t *testing.T) { + body := []byte(`{ + "messages": [ + {"role":"user","content":"调用一下天气工具"}, + {"role":"assistant","content":[{"type":"tool_use","id":"tool_1","name":"weather","input":{}}]}, + {"role":"user","content":[{"type":"tool_result","tool_use_id":"tool_1","content":"晴 25 度"}]} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body) + + require.Empty(t, input.Text) + require.Empty(t, input.Images) +} + +func TestExtractContentModerationInput_AnthropicFirstTurnExtractsUser(t *testing.T) { + body := []byte(`{ + "messages": [ + {"role":"user","content":"Q1"} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body) + + require.Equal(t, "Q1", input.Text) +} + +func TestExtractContentModerationInput_AnthropicMultiTurnExtractsLatestUser(t *testing.T) { + body := []byte(`{ + "messages": [ + {"role":"user","content":"Q1"}, + {"role":"assistant","content":"A1"}, + {"role":"user","content":"Q2"} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body) + + require.Equal(t, "Q2", input.Text) +} + +func TestExtractContentModerationInput_AnthropicStreamResendExtractsResend(t *testing.T) { + body := []byte(`{ + "messages": [ + {"role":"user","content":"原问题"}, + {"role":"assistant","content":"部分回答……"}, + {"role":"user","content":"重发"} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body) + + require.Equal(t, "重发", input.Text) +} + +func TestExtractContentModerationInput_OpenAIChatAgentToolLoopSkipsAudit(t *testing.T) { + body := []byte(`{ + "messages": [ + {"role":"system","content":"sys"}, + {"role":"user","content":"列出我的订单"}, + {"role":"assistant","content":null,"tool_calls":[{"id":"call_1","type":"function","function":{"name":"orders","arguments":"{}"}}]}, + {"role":"tool","tool_call_id":"call_1","content":"[]"} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolOpenAIChat, body) + + require.Empty(t, input.Text) + require.Empty(t, input.Images) +} + +func TestExtractContentModerationInput_OpenAIChatMultiTurnExtractsLatestUser(t *testing.T) { + body := []byte(`{ + "messages": [ + {"role":"user","content":"Q1"}, + {"role":"assistant","content":"A1"}, + {"role":"user","content":"Q2"} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolOpenAIChat, body) + + require.Equal(t, "Q2", input.Text) +} + +func TestExtractContentModerationInput_GeminiAgentToolLoopSkipsAudit(t *testing.T) { + body := []byte(`{ + "contents": [ + {"role":"user","parts":[{"text":"查询天气"}]}, + {"role":"model","parts":[{"functionCall":{"name":"weather","args":{}}}]}, + {"role":"user","parts":[{"functionResponse":{"name":"weather","response":{"temp":25}}}]} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolGemini, body) + + require.Empty(t, input.Text) + require.Empty(t, input.Images) +} + +func TestExtractContentModerationInput_GeminiFirstTurnExtractsUser(t *testing.T) { + body := []byte(`{ + "contents": [ + {"role":"user","parts":[{"text":"你好"}]} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolGemini, body) + + require.Equal(t, "你好", input.Text) +} + +func TestExtractContentModerationInput_GeminiMultiTurnExtractsLatestUser(t *testing.T) { + body := []byte(`{ + "contents": [ + {"role":"user","parts":[{"text":"Q1"}]}, + {"role":"model","parts":[{"text":"A1"}]}, + {"role":"user","parts":[{"text":"Q2"}]} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolGemini, body) + + require.Equal(t, "Q2", input.Text) +} + +func TestExtractContentModerationInput_ResponsesAgentToolLoopSkipsAudit(t *testing.T) { + body := []byte(`{ + "input":[ + {"type":"message","role":"user","content":[{"type":"input_text","text":"运行测试"}]}, + {"type":"function_call","call_id":"call_1","name":"run_tests","arguments":"{}"}, + {"type":"function_call_output","call_id":"call_1","output":"all passed"} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolOpenAIResponses, body) + + require.Empty(t, input.Text) + require.Empty(t, input.Images) +} + +func TestExtractContentModerationInput_ResponsesLastUserMessageExtracted(t *testing.T) { + body := []byte(`{ + "input":[ + {"type":"message","role":"user","content":[{"type":"input_text","text":"first"}]}, + {"type":"message","role":"assistant","content":[{"type":"output_text","text":"answer"}]}, + {"type":"message","role":"user","content":[{"type":"input_text","text":"latest"}]} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolOpenAIResponses, body) + + require.Equal(t, "latest", input.Text) +} + +func TestExtractContentModerationInput_ResponsesLastIsAssistantSkipped(t *testing.T) { + body := []byte(`{ + "input":[ + {"type":"message","role":"user","content":[{"type":"input_text","text":"q1"}]}, + {"type":"message","role":"assistant","content":[{"type":"output_text","text":"a1"}]} + ] + }`) + + input := ExtractContentModerationInput(ContentModerationProtocolOpenAIResponses, body) + + require.Empty(t, input.Text) + require.Empty(t, input.Images) +} diff --git a/backend/internal/service/content_moderation_test.go b/backend/internal/service/content_moderation_test.go index 30578ca5..60a99318 100644 --- a/backend/internal/service/content_moderation_test.go +++ b/backend/internal/service/content_moderation_test.go @@ -530,6 +530,147 @@ func TestNormalizeKeywordBlockingMode_UnknownFallsBackToDefault(t *testing.T) { require.Equal(t, ContentModerationKeywordModeAPIOnly, normalizeKeywordBlockingMode("api_only")) } +func TestContentModerationCheck_ModelFilterAllAuditsEveryModel(t *testing.T) { + cfg := defaultContentModerationModelFilterTestConfig() + cfg.ModelFilter = ContentModerationModelFilter{Type: ContentModerationModelFilterAll} + svc, repo := newContentModerationModelFilterTestService(t, cfg) + + for _, model := range []string{"gpt-5.5", "gpt-5.4"} { + decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + Model: model, + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`), + }) + require.NoError(t, err) + require.True(t, decision.Blocked) + require.Equal(t, ContentModerationActionKeywordBlock, decision.Action) + } + require.Len(t, repo.logs, 2) +} + +func TestContentModerationCheck_ModelFilterIncludeOnlyAuditsListedModels(t *testing.T) { + cfg := defaultContentModerationModelFilterTestConfig() + cfg.ModelFilter = ContentModerationModelFilter{Type: ContentModerationModelFilterInclude, Models: []string{"gpt-5.5"}} + svc, repo := newContentModerationModelFilterTestService(t, cfg) + + decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + Model: "gpt-5.5", + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`), + }) + require.NoError(t, err) + require.True(t, decision.Blocked) + require.Equal(t, ContentModerationActionKeywordBlock, decision.Action) + + decision, err = svc.Check(context.Background(), ContentModerationCheckInput{ + Model: "gpt-5.4", + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`), + }) + require.NoError(t, err) + require.True(t, decision.Allowed) + require.False(t, decision.Blocked) + require.Equal(t, ContentModerationActionAllow, decision.Action) + require.Len(t, repo.logs, 1) + require.Equal(t, "gpt-5.5", repo.logs[0].Model) +} + +func TestContentModerationCheck_ModelFilterExcludeSkipsListedModels(t *testing.T) { + cfg := defaultContentModerationModelFilterTestConfig() + cfg.ModelFilter = ContentModerationModelFilter{Type: ContentModerationModelFilterExclude, Models: []string{"gpt-5.4"}} + svc, repo := newContentModerationModelFilterTestService(t, cfg) + + decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + Model: "gpt-5.5", + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`), + }) + require.NoError(t, err) + require.True(t, decision.Blocked) + require.Equal(t, ContentModerationActionKeywordBlock, decision.Action) + + decision, err = svc.Check(context.Background(), ContentModerationCheckInput{ + Model: "gpt-5.4", + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`), + }) + require.NoError(t, err) + require.True(t, decision.Allowed) + require.False(t, decision.Blocked) + require.Equal(t, ContentModerationActionAllow, decision.Action) + require.Len(t, repo.logs, 1) + require.Equal(t, "gpt-5.5", repo.logs[0].Model) +} + +func TestContentModerationLoadConfig_LegacyConfigDefaultsModelFilterToAll(t *testing.T) { + raw := `{"enabled":true,"mode":"pre_block","base_url":"https://api.openai.com","model":"omni-moderation-latest","blocked_keywords":["secret-token"]}` + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyContentModerationConfig: raw, + }}, + nil, + nil, + nil, + nil, + nil, + nil, + ) + + cfg, err := svc.loadConfig(context.Background()) + + require.NoError(t, err) + require.Equal(t, ContentModerationModelFilterAll, cfg.ModelFilter.Type) + require.Empty(t, cfg.ModelFilter.Models) + require.True(t, cfg.includesModel("gpt-5.5")) + require.True(t, cfg.includesModel("gpt-5.4")) +} + +func TestContentModerationCheck_ModelFilterUsesRequestedModelNotBodyModel(t *testing.T) { + cfg := defaultContentModerationModelFilterTestConfig() + cfg.ModelFilter = ContentModerationModelFilter{Type: ContentModerationModelFilterInclude, Models: []string{"gpt-5.5"}} + svc, repo := newContentModerationModelFilterTestService(t, cfg) + + decision, err := svc.Check(context.Background(), ContentModerationCheckInput{ + Model: "gpt-5.5", + Protocol: ContentModerationProtocolOpenAIChat, + Body: []byte(`{"model":"mapped-upstream-model","messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`), + }) + + require.NoError(t, err) + require.True(t, decision.Blocked) + require.Equal(t, ContentModerationActionKeywordBlock, decision.Action) + require.Len(t, repo.logs, 1) + require.Equal(t, "gpt-5.5", repo.logs[0].Model) +} + +func defaultContentModerationModelFilterTestConfig() *ContentModerationConfig { + cfg := defaultContentModerationConfig() + cfg.Enabled = true + cfg.Mode = ContentModerationModePreBlock + cfg.BlockedKeywords = []string{"secret-token"} + return cfg +} + +func newContentModerationModelFilterTestService(t *testing.T, cfg *ContentModerationConfig) (*ContentModerationService, *contentModerationTestRepo) { + t.Helper() + rawCfg, err := json.Marshal(cfg) + require.NoError(t, err) + repo := &contentModerationTestRepo{} + svc := NewContentModerationService( + &contentModerationTestSettingRepo{values: map[string]string{ + SettingKeyRiskControlEnabled: "true", + SettingKeyContentModerationConfig: string(rawCfg), + }}, + repo, + &contentModerationTestHashCache{}, + nil, + nil, + nil, + nil, + ) + return svc, repo +} + func TestContentModerationUpdateConfig_AppendsAndDeletesAPIKeys(t *testing.T) { cfg := defaultContentModerationConfig() cfg.APIKeys = []string{"sk-old-a", "sk-old-b"} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 8d591086..2b22e94d 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -132,6 +132,9 @@ const ( SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key + // API Key IP 访问控制设置 + SettingKeyAPIKeyACLTrustForwardedIP = "api_key_acl_trust_forwarded_ip" // API Key IP 白/黑名单是否信任转发 IP + // TOTP 双因素认证设置 SettingKeyTotpEnabled = "totp_enabled" // 是否启用 TOTP 2FA 功能 @@ -406,12 +409,15 @@ const ( // 用于避免 Cloudflare 对浏览器型 UA 的质询拦截。 SettingKeyOpenAICodexUserAgent = "openai_codex_user_agent" - // Balance Low Notification + // 余额不足提醒 SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关 SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD) SettingKeyBalanceLowNotifyRechargeURL = "balance_low_notify_recharge_url" // 充值页面 URL - // Account Quota Notification + // 订阅到期提醒 + SettingKeySubscriptionExpiryNotifyEnabled = "subscription_expiry_notify_enabled" // 订阅到期提醒全局开关,默认开启 + + // 账号限额通知 SettingKeyAccountQuotaNotifyEnabled = "account_quota_notify_enabled" // 全局开关 SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表(JSON 数组) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 6dd81dcd..0e6ce24d 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -5739,6 +5739,31 @@ func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, } } +// ApplyBedrockCCCompat 应用 Bedrock CC 兼容转换(渠道级模型映射后调用) +// 清理 Anthropic API 专有字段、注入 Bedrock 必需字段、修复 thinking/tool_use ID +func (s *GatewayService) ApplyBedrockCCCompat(ctx context.Context, body []byte, model string, account *Account, groupID *int64) []byte { + if !s.isBedrockCCCompatEnabled(ctx, account, groupID) { + return body + } + body = sanitizeBedrockCCFields(body) + body = sanitizeBedrockThinking(body, model) + body = sanitizeBedrockToolUseIDs(body) + body = sanitizeBedrockCCBetaTokens(body, model) + return body +} + +// isBedrockCCCompatEnabled 检查渠道是否启用了 Bedrock CC 兼容模式 +func (s *GatewayService) isBedrockCCCompatEnabled(ctx context.Context, account *Account, groupID *int64) bool { + if groupID == nil || s.channelService == nil { + return false + } + ch, err := s.channelService.GetChannelForGroup(ctx, *groupID) + if err != nil || ch == nil { + return false + } + return ch.IsBedrockCCCompatEnabled(account.Platform) +} + // forwardBedrock 转发请求到 AWS Bedrock func (s *GatewayService) forwardBedrock( ctx context.Context, @@ -5771,7 +5796,7 @@ func (s *GatewayService) forwardBedrock( return nil, err } - bedrockBody, err := PrepareBedrockRequestBodyWithTokens(body, mappedModel, betaTokens) + bedrockBody, err := PrepareBedrockRequestBodyWithTokens(body, mappedModel, betaTokens, false) if err != nil { return nil, fmt.Errorf("prepare bedrock request body: %w", err) } diff --git a/backend/internal/service/gateway_service_bedrock_beta_test.go b/backend/internal/service/gateway_service_bedrock_beta_test.go index 8920ee08..fa2feda1 100644 --- a/backend/internal/service/gateway_service_bedrock_beta_test.go +++ b/backend/internal/service/gateway_service_bedrock_beta_test.go @@ -126,17 +126,17 @@ func TestResolveBedrockBetaTokensForRequest_FiltersAfterBedrockTransform(t *test } } -// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking 验证: -// 管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token, -// 但请求体包含 thinking 字段 → 自动注入后应被 block。 -func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking(t *testing.T) { +// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedComputerUse 验证: +// 管理员 block 了 computer-use,客户端不在 header 中带该 token, +// 但请求体包含 computer_use 工具 → 自动注入后应被 block。 +func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedComputerUse(t *testing.T) { settings := &BetaPolicySettings{ Rules: []BetaPolicyRule{ { - BetaToken: "interleaved-thinking-2025-05-14", + BetaToken: "computer-use-2025-11-24", Action: BetaPolicyActionBlock, Scope: BetaPolicyScopeAll, - ErrorMessage: "thinking is blocked", + ErrorMessage: "computer use is blocked", }, }, } @@ -155,18 +155,18 @@ func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking(t *te } account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock} - // header 中不带 beta token,但 body 中有 thinking 字段 + // header 中不带 beta token,但 body 中有 computer_use 工具 _, err = svc.resolveBedrockBetaTokensForRequest( context.Background(), account, "", // 空 header - []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`), + []byte(`{"tools":[{"type":"computer_20250124","name":"computer"}],"messages":[{"role":"user","content":"hi"}]}`), "us.anthropic.claude-opus-4-6-v1", ) if err == nil { - t.Fatal("expected body-injected interleaved-thinking to be blocked") + t.Fatal("expected body-injected computer-use to be blocked") } - if err.Error() != "thinking is blocked" { + if err.Error() != "computer use is blocked" { t.Fatalf("unexpected error: %v", err) } } @@ -222,10 +222,10 @@ func TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches(t *test settings := &BetaPolicySettings{ Rules: []BetaPolicyRule{ { - BetaToken: "computer-use-2025-11-24", + BetaToken: "context-1m-2025-08-07", Action: BetaPolicyActionBlock, Scope: BetaPolicyScopeAll, - ErrorMessage: "computer use is blocked", + ErrorMessage: "context is blocked", }, }, } @@ -244,12 +244,12 @@ func TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches(t *test } account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock} - // body 中有 thinking(会注入 interleaved-thinking),但 block 规则只针对 computer-use + // body 中有 computer_use 工具(会注入 computer-use token),但 block 规则只针对 context-1m tokens, err := svc.resolveBedrockBetaTokensForRequest( context.Background(), account, "", - []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`), + []byte(`{"tools":[{"type":"computer_20250124","name":"computer"}],"messages":[{"role":"user","content":"hi"}]}`), "us.anthropic.claude-opus-4-6-v1", ) if err != nil { @@ -257,11 +257,11 @@ func TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches(t *test } found := false for _, token := range tokens { - if token == "interleaved-thinking-2025-05-14" { + if token == "computer-use-2025-11-24" { found = true } } if !found { - t.Fatal("expected interleaved-thinking token to be present") + t.Fatal("expected computer-use token to be present") } } diff --git a/backend/internal/service/openai_account_runtime_block_fastpath.go b/backend/internal/service/openai_account_runtime_block_fastpath.go new file mode 100644 index 00000000..41a309fd --- /dev/null +++ b/backend/internal/service/openai_account_runtime_block_fastpath.go @@ -0,0 +1,169 @@ +package service + +import ( + "context" + "net/http" + "time" +) + +const ( + openAIAccountStateUpdateTimeout = 5 * time.Second + openAIOAuth429FallbackCooldown = 5 * time.Second + openAIStopSchedulingBridgeCooldown = 2 * time.Minute + openAIOAuth429StormWindow = 10 * time.Second + openAIOAuth429StormThreshold = 20 + openAIOAuth429StormMaxAccountSwitches = 1 +) + +func openAIAccountStateContext(ctx context.Context) (context.Context, context.CancelFunc) { + base := context.Background() + if ctx != nil { + base = context.WithoutCancel(ctx) + } + return context.WithTimeout(base, openAIAccountStateUpdateTimeout) +} + +func isOpenAIOAuthAccount(account *Account) bool { + return account != nil && account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth +} + +func isOpenAIAccount(account *Account) bool { + return account != nil && account.Platform == PlatformOpenAI +} + +func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) bool { + stateCtx, cancel := openAIAccountStateContext(ctx) + defer cancel() + + if statusCode == http.StatusTooManyRequests { + s.markOpenAIOAuth429RateLimited(stateCtx, account, headers, responseBody) + } + if s == nil || account == nil || s.rateLimitService == nil { + return false + } + shouldDisable := s.rateLimitService.HandleUpstreamError(stateCtx, account, statusCode, headers, responseBody) + if shouldDisable { + s.BlockAccountScheduling(account, time.Time{}, "upstream_disable") + } + return shouldDisable +} + +func (s *OpenAIGatewayService) markOpenAIOAuth429RateLimited(ctx context.Context, account *Account, headers http.Header, responseBody []byte) { + if s == nil || !isOpenAIOAuthAccount(account) { + return + } + s.recordOpenAIOAuth429() + + cooldownUntil := time.Now().Add(openAIOAuth429FallbackCooldown) + if s.rateLimitService != nil { + if resetAt := s.rateLimitService.calculateOpenAI429ResetTime(headers); resetAt != nil && resetAt.After(time.Now()) { + cooldownUntil = *resetAt + } else if resetUnix := parseOpenAIRateLimitResetTime(responseBody); resetUnix != nil { + if resetAt := time.Unix(*resetUnix, 0); resetAt.After(time.Now()) { + cooldownUntil = resetAt + } + } else if cooldown, ok := s.rateLimitService.get429FallbackCooldown(ctx, account); ok && cooldown > 0 { + cooldownUntil = time.Now().Add(cooldown) + } + } + s.BlockAccountScheduling(account, cooldownUntil, "429") +} + +func (s *OpenAIGatewayService) BlockAccountScheduling(account *Account, until time.Time, reason string) { + if s == nil || !isOpenAIAccount(account) { + return + } + now := time.Now() + blockUntil := until + if blockUntil.IsZero() || !blockUntil.After(now) { + blockUntil = now.Add(openAIStopSchedulingBridgeCooldown) + } + + for { + current, loaded := s.openaiAccountRuntimeBlockUntil.Load(account.ID) + if !loaded { + actual, stored := s.openaiAccountRuntimeBlockUntil.LoadOrStore(account.ID, blockUntil) + if !stored { + return + } + current = actual + } + + currentUntil, ok := current.(time.Time) + if !ok || currentUntil.IsZero() { + if s.openaiAccountRuntimeBlockUntil.CompareAndSwap(account.ID, current, blockUntil) { + return + } + continue + } + if currentUntil.After(blockUntil) { + return + } + if s.openaiAccountRuntimeBlockUntil.CompareAndSwap(account.ID, current, blockUntil) { + return + } + } +} + +func (s *OpenAIGatewayService) ClearAccountSchedulingBlock(accountID int64) { + if s == nil || accountID <= 0 { + return + } + s.openaiAccountRuntimeBlockUntil.Delete(accountID) +} + +func (s *OpenAIGatewayService) isOpenAIAccountRuntimeBlocked(account *Account) bool { + if s == nil || !isOpenAIAccount(account) { + return false + } + value, ok := s.openaiAccountRuntimeBlockUntil.Load(account.ID) + if !ok { + return false + } + cooldownUntil, ok := value.(time.Time) + if !ok || cooldownUntil.IsZero() { + s.openaiAccountRuntimeBlockUntil.Delete(account.ID) + return false + } + if time.Now().Before(cooldownUntil) { + return true + } + s.openaiAccountRuntimeBlockUntil.Delete(account.ID) + return false +} + +func (s *OpenAIGatewayService) recordOpenAIOAuth429() { + if s == nil { + return + } + now := time.Now() + windowStart := s.openaiOAuth429WindowStartUnixNano.Load() + if windowStart == 0 || now.Sub(time.Unix(0, windowStart)) >= openAIOAuth429StormWindow { + if s.openaiOAuth429WindowStartUnixNano.CompareAndSwap(windowStart, now.UnixNano()) { + s.openaiOAuth429WindowCount.Store(1) + return + } + } + s.openaiOAuth429WindowCount.Add(1) +} + +func (s *OpenAIGatewayService) isOpenAIOAuth429Storm() bool { + if s == nil { + return false + } + windowStart := s.openaiOAuth429WindowStartUnixNano.Load() + if windowStart == 0 || time.Since(time.Unix(0, windowStart)) >= openAIOAuth429StormWindow { + return false + } + return s.openaiOAuth429WindowCount.Load() >= openAIOAuth429StormThreshold +} + +func (s *OpenAIGatewayService) ShouldStopOpenAIOAuth429Failover(account *Account, statusCode int, failedSwitches int) bool { + if statusCode != http.StatusTooManyRequests || failedSwitches < openAIOAuth429StormMaxAccountSwitches { + return false + } + if !isOpenAIOAuthAccount(account) { + return false + } + return s.isOpenAIOAuth429Storm() +} diff --git a/backend/internal/service/openai_account_runtime_block_fastpath_test.go b/backend/internal/service/openai_account_runtime_block_fastpath_test.go new file mode 100644 index 00000000..95336e81 --- /dev/null +++ b/backend/internal/service/openai_account_runtime_block_fastpath_test.go @@ -0,0 +1,101 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAI429FastPath_MarksOAuthAccountCoolingDown(t *testing.T) { + svc := &OpenAIGatewayService{} + account := &Account{ID: 42, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + apiKeyAccount := &Account{ID: 43, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + shouldDisable := svc.handleOpenAIAccountUpstreamError(context.Background(), account, http.StatusTooManyRequests, http.Header{}, nil) + apiKeyShouldDisable := svc.handleOpenAIAccountUpstreamError(context.Background(), apiKeyAccount, http.StatusTooManyRequests, http.Header{}, nil) + + require.False(t, shouldDisable) + require.False(t, apiKeyShouldDisable) + require.True(t, svc.isOpenAIAccountRuntimeBlocked(account)) + require.False(t, svc.isOpenAIAccountRuntimeBlocked(apiKeyAccount)) +} + +func TestOpenAIRuntimeBlock_AppliesToOpenAIAPIKeyWhenRateLimitServiceStopsScheduling(t *testing.T) { + svc := &OpenAIGatewayService{} + account := &Account{ID: 44, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + svc.BlockAccountScheduling(account, time.Time{}, "custom_error_code") + + require.True(t, svc.isOpenAIAccountRuntimeBlocked(account)) +} + +func TestOpenAIRuntimeBlock_DoesNotApplyToOtherPlatforms(t *testing.T) { + svc := &OpenAIGatewayService{} + account := &Account{ID: 45, Platform: PlatformGemini, Type: AccountTypeOAuth} + + svc.BlockAccountScheduling(account, time.Time{}, "custom_error_code") + + require.False(t, svc.isOpenAIAccountRuntimeBlocked(account)) +} + +func TestOpenAIRuntimeBlocker_IgnoresNonOpenAIFromRateLimitService(t *testing.T) { + gateway := &OpenAIGatewayService{} + repo := &rateLimitAccountRepoStub{} + rateLimitService := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + rateLimitService.SetAccountRuntimeBlocker(gateway) + account := &Account{ID: 45, Platform: PlatformGemini, Type: AccountTypeOAuth} + + shouldDisable := rateLimitService.HandleUpstreamError(context.Background(), account, http.StatusForbidden, http.Header{}, []byte("forbidden")) + + require.True(t, shouldDisable) + require.False(t, gateway.isOpenAIAccountRuntimeBlocked(account)) +} + +func TestOpenAIRuntimeBlock_DoesNotShortenExistingBlock(t *testing.T) { + svc := &OpenAIGatewayService{} + account := &Account{ID: 46, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + longUntil := time.Now().Add(10 * time.Minute) + + svc.BlockAccountScheduling(account, longUntil, "oauth_401") + svc.BlockAccountScheduling(account, time.Time{}, "upstream_disable") + + value, ok := svc.openaiAccountRuntimeBlockUntil.Load(account.ID) + require.True(t, ok) + actualUntil, ok := value.(time.Time) + require.True(t, ok) + require.WithinDuration(t, longUntil, actualUntil, time.Second) +} + +func TestOpenAIRuntimeBlock_ClearAccountSchedulingBlock(t *testing.T) { + svc := &OpenAIGatewayService{} + account := &Account{ID: 47, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + + svc.BlockAccountScheduling(account, time.Now().Add(time.Minute), "429") + require.True(t, svc.isOpenAIAccountRuntimeBlocked(account)) + + svc.ClearAccountSchedulingBlock(account.ID) + require.False(t, svc.isOpenAIAccountRuntimeBlocked(account)) +} + +func TestShouldStopOpenAIOAuth429Failover_OnlyDuringStorm(t *testing.T) { + svc := &OpenAIGatewayService{} + account := &Account{ID: 42, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + apiKeyAccount := &Account{ID: 43, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 1)) + + for i := 0; i < openAIOAuth429StormThreshold; i++ { + svc.recordOpenAIOAuth429() + } + + require.True(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 1)) + require.False(t, svc.ShouldStopOpenAIOAuth429Failover(apiKeyAccount, http.StatusTooManyRequests, 1)) + require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusInternalServerError, 1)) + require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 0)) +} diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index c63151ae..121cf714 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -92,6 +92,16 @@ type openAIAccountSchedulerMetrics struct { loadSkewMilliTotal atomic.Int64 } +type openAIAccountLoadPlan struct { + allCandidates []openAIAccountCandidateScore + candidates []openAIAccountCandidateScore + staleSnapshotCompactRetry []openAIAccountCandidateScore + selectionOrder []openAIAccountCandidateScore + candidateCount int + topK int + loadSkew float64 +} + func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) { if m == nil { return @@ -360,7 +370,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( } result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) - if acquireErr == nil && result.Acquired { + if acquireErr == nil && result != nil && result.Acquired { _ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL()) return &AccountSelectionResult{ Account: account, @@ -586,6 +596,234 @@ func buildOpenAIWeightedSelectionOrder( return order } +func (s *defaultOpenAIAccountScheduler) buildOpenAIAccountLoadPlan( + req OpenAIAccountScheduleRequest, + filtered []*Account, + loadMap map[int64]*AccountLoadInfo, +) openAIAccountLoadPlan { + allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered)) + for _, account := range filtered { + loadInfo := loadMap[account.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: account.ID} + } + errorRate, ttft, hasTTFT := 0.0, 0.0, false + if s.stats != nil { + errorRate, ttft, hasTTFT = s.stats.snapshot(account.ID) + } + allCandidates = append(allCandidates, openAIAccountCandidateScore{ + account: account, + loadInfo: loadInfo, + errorRate: errorRate, + ttft: ttft, + hasTTFT: hasTTFT, + }) + } + + candidates := allCandidates + staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates)) + if req.RequireCompact { + candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates)) + for _, candidate := range allCandidates { + if openAICompactSupportTier(candidate.account) == 0 { + staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate) + continue + } + candidates = append(candidates, candidate) + } + } + + plan := openAIAccountLoadPlan{ + allCandidates: allCandidates, + candidates: candidates, + staleSnapshotCompactRetry: staleSnapshotCompactRetry, + candidateCount: len(candidates), + } + if len(candidates) == 0 { + plan.selectionOrder = s.buildOpenAISelectionOrder(req, plan) + return plan + } + + minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority + maxWaiting := 1 + loadRateSum := 0.0 + loadRateSumSquares := 0.0 + minTTFT, maxTTFT := 0.0, 0.0 + hasTTFTSample := false + for _, candidate := range candidates { + if candidate.account.Priority < minPriority { + minPriority = candidate.account.Priority + } + if candidate.account.Priority > maxPriority { + maxPriority = candidate.account.Priority + } + if candidate.loadInfo.WaitingCount > maxWaiting { + maxWaiting = candidate.loadInfo.WaitingCount + } + if candidate.hasTTFT && candidate.ttft > 0 { + if !hasTTFTSample { + minTTFT, maxTTFT = candidate.ttft, candidate.ttft + hasTTFTSample = true + } else { + if candidate.ttft < minTTFT { + minTTFT = candidate.ttft + } + if candidate.ttft > maxTTFT { + maxTTFT = candidate.ttft + } + } + } + loadRate := float64(candidate.loadInfo.LoadRate) + loadRateSum += loadRate + loadRateSumSquares += loadRate * loadRate + } + plan.loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates)) + + weights := s.service.openAIWSSchedulerWeights() + for i := range candidates { + item := &candidates[i] + priorityFactor := 1.0 + if maxPriority > minPriority { + priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority) + } + loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0) + queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting)) + errorFactor := 1 - clamp01(item.errorRate) + ttftFactor := 0.5 + if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT { + ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT)) + } + + quotaFactor := item.account.GetQuotaRemainingFraction() + + item.score = weights.Priority*priorityFactor + + weights.Load*loadFactor + + weights.Queue*queueFactor + + weights.ErrorRate*errorFactor + + weights.TTFT*ttftFactor + + weights.Quota*quotaFactor + } + plan.candidates = candidates + + plan.topK = s.service.openAIWSLBTopK() + if plan.topK > len(candidates) { + plan.topK = len(candidates) + } + if plan.topK <= 0 { + plan.topK = 1 + } + + plan.selectionOrder = s.buildOpenAISelectionOrder(req, plan) + return plan +} + +func (s *defaultOpenAIAccountScheduler) buildOpenAISelectionOrder( + req OpenAIAccountScheduleRequest, + plan openAIAccountLoadPlan, +) []openAIAccountCandidateScore { + buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore { + if len(pool) == 0 || plan.topK <= 0 { + return nil + } + groupTopK := plan.topK + if groupTopK > len(pool) { + groupTopK = len(pool) + } + ranked := selectTopKOpenAICandidates(pool, groupTopK) + return buildOpenAIWeightedSelectionOrder(ranked, req) + } + + if req.RequireCompact { + supported := make([]openAIAccountCandidateScore, 0, len(plan.candidates)) + unknown := make([]openAIAccountCandidateScore, 0, len(plan.candidates)) + for _, candidate := range plan.candidates { + switch openAICompactSupportTier(candidate.account) { + case 2: + supported = append(supported, candidate) + case 1: + unknown = append(unknown, candidate) + } + } + selectionOrder := make([]openAIAccountCandidateScore, 0, len(plan.allCandidates)) + selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...) + selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...) + if len(plan.staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil { + selectionOrder = append(selectionOrder, sortOpenAICompactRetryCandidates(plan.staleSnapshotCompactRetry)...) + } + return selectionOrder + } + + return buildSelectionOrder(plan.candidates) +} + +func sortOpenAICompactRetryCandidates(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore { + if len(pool) == 0 { + return nil + } + ordered := append([]openAIAccountCandidateScore(nil), pool...) + sort.SliceStable(ordered, func(i, j int) bool { + a, b := ordered[i], ordered[j] + if a.account.Priority != b.account.Priority { + return a.account.Priority < b.account.Priority + } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return a.loadInfo.LoadRate < b.loadInfo.LoadRate + } + if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount { + return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount + } + switch { + case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: + return true + case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: + return false + case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: + return false + default: + return a.account.LastUsedAt.Before(*b.account.LastUsedAt) + } + }) + return ordered +} + +func (s *defaultOpenAIAccountScheduler) tryAcquireOpenAISelectionOrder( + ctx context.Context, + req OpenAIAccountScheduleRequest, + selectionOrder []openAIAccountCandidateScore, +) (*AccountSelectionResult, bool, error) { + compactBlocked := false + for i := 0; i < len(selectionOrder); i++ { + candidate := selectionOrder[i] + fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { + continue + } + fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false) + if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { + continue + } + if req.RequireCompact && openAICompactSupportTier(fresh) == 0 { + compactBlocked = true + continue + } + result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) + if acquireErr != nil { + return nil, compactBlocked, acquireErr + } + if result != nil && result.Acquired { + if req.SessionHash != "" { + _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID) + } + return &AccountSelectionResult{ + Account: fresh, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, compactBlocked, nil + } + } + return nil, compactBlocked, nil +} + func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( ctx context.Context, req OpenAIAccountScheduleRequest, @@ -616,8 +854,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( if !account.IsSchedulable() || !account.IsOpenAI() { continue } + if s.service.isOpenAIAccountRuntimeBlocked(account) { + continue + } // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 if schedGroup != nil && schedGroup.RequirePrivacySet && !account.IsPrivacySet() { + s.service.BlockAccountScheduling(account, time.Time{}, "privacy_not_set") _ = s.service.accountRepo.SetError(ctx, account.ID, fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) continue @@ -645,214 +887,47 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( } } - allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered)) - for _, account := range filtered { - loadInfo := loadMap[account.ID] - if loadInfo == nil { - loadInfo = &AccountLoadInfo{AccountID: account.ID} - } - errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID) - allCandidates = append(allCandidates, openAIAccountCandidateScore{ - account: account, - loadInfo: loadInfo, - errorRate: errorRate, - ttft: ttft, - hasTTFT: hasTTFT, - }) + plan := s.buildOpenAIAccountLoadPlan(req, filtered, loadMap) + candidateCount := plan.candidateCount + topK := plan.topK + loadSkew := plan.loadSkew + selectionOrder := plan.selectionOrder + if req.RequireCompact && len(plan.candidates) == 0 && len(plan.staleSnapshotCompactRetry) == 0 { + return nil, 0, 0, 0, ErrNoAvailableCompactAccounts } - // Compact 模式下把明确不支持 compact 的账号拆出,仅在 schedulerSnapshot 启用 - // 时作为最后兜底(snapshot 可能已陈旧)。 - candidates := allCandidates - staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates)) - if req.RequireCompact { - candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates)) - for _, candidate := range allCandidates { - if openAICompactSupportTier(candidate.account) == 0 { - staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate) - continue - } - candidates = append(candidates, candidate) - } - if len(candidates) == 0 && len(staleSnapshotCompactRetry) == 0 { - return nil, 0, 0, 0, ErrNoAvailableCompactAccounts - } - } - - candidateCount := len(candidates) - loadSkew := 0.0 - if len(candidates) > 0 { - minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority - maxWaiting := 1 - loadRateSum := 0.0 - loadRateSumSquares := 0.0 - minTTFT, maxTTFT := 0.0, 0.0 - hasTTFTSample := false - for _, candidate := range candidates { - if candidate.account.Priority < minPriority { - minPriority = candidate.account.Priority - } - if candidate.account.Priority > maxPriority { - maxPriority = candidate.account.Priority - } - if candidate.loadInfo.WaitingCount > maxWaiting { - maxWaiting = candidate.loadInfo.WaitingCount - } - if candidate.hasTTFT && candidate.ttft > 0 { - if !hasTTFTSample { - minTTFT, maxTTFT = candidate.ttft, candidate.ttft - hasTTFTSample = true - } else { - if candidate.ttft < minTTFT { - minTTFT = candidate.ttft - } - if candidate.ttft > maxTTFT { - maxTTFT = candidate.ttft - } - } - } - loadRate := float64(candidate.loadInfo.LoadRate) - loadRateSum += loadRate - loadRateSumSquares += loadRate * loadRate - } - loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates)) - - weights := s.service.openAIWSSchedulerWeights() - for i := range candidates { - item := &candidates[i] - priorityFactor := 1.0 - if maxPriority > minPriority { - priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority) - } - loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0) - queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting)) - errorFactor := 1 - clamp01(item.errorRate) - ttftFactor := 0.5 - if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT { - ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT)) - } - quotaFactor := item.account.GetQuotaRemainingFraction() - - item.score = weights.Priority*priorityFactor + - weights.Load*loadFactor + - weights.Queue*queueFactor + - weights.ErrorRate*errorFactor + - weights.TTFT*ttftFactor + - weights.Quota*quotaFactor - } - - if s.service.openAIWSP2CEnabled() { - return s.selectByPowerOfTwo(ctx, req, candidates, loadSkew) - } - } - - topK := 0 - if len(candidates) > 0 { - topK = s.service.openAIWSLBTopK() - if topK > len(candidates) { - topK = len(candidates) - } - if topK <= 0 { - topK = 1 - } - } - - buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore { - if len(pool) == 0 || topK <= 0 { - return nil - } - groupTopK := topK - if groupTopK > len(pool) { - groupTopK = len(pool) - } - ranked := selectTopKOpenAICandidates(pool, groupTopK) - return buildOpenAIWeightedSelectionOrder(ranked, req) - } - sortCompactRetryCandidates := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore { - if len(pool) == 0 { - return nil - } - ordered := append([]openAIAccountCandidateScore(nil), pool...) - sort.SliceStable(ordered, func(i, j int) bool { - a, b := ordered[i], ordered[j] - if a.account.Priority != b.account.Priority { - return a.account.Priority < b.account.Priority - } - if a.loadInfo.LoadRate != b.loadInfo.LoadRate { - return a.loadInfo.LoadRate < b.loadInfo.LoadRate - } - if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount { - return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount - } - switch { - case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: - return true - case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: - return false - case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: - return false - default: - return a.account.LastUsedAt.Before(*b.account.LastUsedAt) - } - }) - return ordered - } - - selectionOrder := make([]openAIAccountCandidateScore, 0, len(allCandidates)) - if req.RequireCompact { - supported := make([]openAIAccountCandidateScore, 0, len(candidates)) - unknown := make([]openAIAccountCandidateScore, 0, len(candidates)) - for _, candidate := range candidates { - switch openAICompactSupportTier(candidate.account) { - case 2: - supported = append(supported, candidate) - case 1: - unknown = append(unknown, candidate) - } - } - if len(supported) == 0 && len(unknown) == 0 && s.service.schedulerSnapshot == nil { - return nil, candidateCount, topK, loadSkew, ErrNoAvailableCompactAccounts - } - selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...) - selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...) - if len(staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil { - selectionOrder = append(selectionOrder, sortCompactRetryCandidates(staleSnapshotCompactRetry)...) - } - } else { - selectionOrder = buildSelectionOrder(candidates) + if s.service.openAIWSP2CEnabled() && len(plan.candidates) > 0 { + return s.selectByPowerOfTwo(ctx, req, plan.candidates, loadSkew) } if len(selectionOrder) == 0 { - return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(allCandidates) > 0) + return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(plan.allCandidates) > 0) } - compactBlocked := false - for i := 0; i < len(selectionOrder); i++ { - candidate := selectionOrder[i] - fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { - continue - } - fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false) - if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) { - continue - } - if req.RequireCompact && openAICompactSupportTier(fresh) == 0 { - compactBlocked = true - continue - } - result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) - if acquireErr != nil { - return nil, candidateCount, topK, loadSkew, acquireErr - } - if result != nil && result.Acquired { - if req.SessionHash != "" { - _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID) + result, compactBlocked, acquireErr := s.tryAcquireOpenAISelectionOrder(ctx, req, selectionOrder) + if acquireErr != nil { + return nil, candidateCount, topK, loadSkew, acquireErr + } + if result != nil { + return result, candidateCount, topK, loadSkew, nil + } + + if s.service.concurrencyService != nil { + if freshLoadMap, loadErr := s.service.concurrencyService.GetAccountsLoadBatchFresh(ctx, loadReq); loadErr == nil { + freshPlan := s.buildOpenAIAccountLoadPlan(req, filtered, freshLoadMap) + if len(freshPlan.selectionOrder) > 0 { + freshResult, freshCompactBlocked, freshAcquireErr := s.tryAcquireOpenAISelectionOrder(ctx, req, freshPlan.selectionOrder) + if freshAcquireErr != nil { + return nil, candidateCount, topK, loadSkew, freshAcquireErr + } + if freshResult != nil { + return freshResult, freshPlan.candidateCount, freshPlan.topK, freshPlan.loadSkew, nil + } + compactBlocked = compactBlocked || freshCompactBlocked + selectionOrder = freshPlan.selectionOrder + candidateCount = freshPlan.candidateCount + topK = freshPlan.topK + loadSkew = freshPlan.loadSkew } - return &AccountSelectionResult{ - Account: fresh, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, candidateCount, topK, loadSkew, nil } } @@ -899,6 +974,9 @@ func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.C if account == nil { return false } + if s != nil && s.service != nil && s.service.isOpenAIAccountRuntimeBlocked(account) { + return false + } if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { return false } diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index f8b23a28..27eb211e 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -276,9 +276,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( Message: upstreamMsg, Detail: upstreamDetail, }) - if s.rateLimitService != nil { - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - } + s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, diff --git a/backend/internal/service/openai_gateway_chat_completions_raw.go b/backend/internal/service/openai_gateway_chat_completions_raw.go index c585290e..ad6d3e8d 100644 --- a/backend/internal/service/openai_gateway_chat_completions_raw.go +++ b/backend/internal/service/openai_gateway_chat_completions_raw.go @@ -206,9 +206,7 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions( Message: upstreamMsg, Detail: upstreamDetail, }) - if s.rateLimitService != nil { - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - } + s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go index 6d74f7dd..336a7d79 100644 --- a/backend/internal/service/openai_gateway_messages.go +++ b/backend/internal/service/openai_gateway_messages.go @@ -337,9 +337,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic( Message: upstreamMsg, Detail: upstreamDetail, }) - if s.rateLimitService != nil { - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - } + s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, diff --git a/backend/internal/service/openai_gateway_responses_chat_fallback.go b/backend/internal/service/openai_gateway_responses_chat_fallback.go index 1d28a9c2..c3ebc35c 100644 --- a/backend/internal/service/openai_gateway_responses_chat_fallback.go +++ b/backend/internal/service/openai_gateway_responses_chat_fallback.go @@ -187,9 +187,7 @@ func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions( Message: upstreamMsg, Detail: upstreamDetail, }) - if s.rateLimitService != nil { - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - } + s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) return nil, &UpstreamFailoverError{ StatusCode: resp.StatusCode, ResponseBody: respBody, diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f13c4748..f312f50d 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -22,6 +22,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat" @@ -354,6 +355,9 @@ type OpenAIGatewayService struct { openaiAccountStats *openAIAccountRuntimeStats openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time + openaiAccountRuntimeBlockUntil sync.Map // key: int64(accountID), value: time.Time + openaiOAuth429WindowStartUnixNano atomic.Int64 + openaiOAuth429WindowCount atomic.Int64 openaiWSRetryMetrics openAIWSRetryMetrics responseHeaderFilter *responseheaders.CompiledHeaderFilter codexSnapshotThrottle *accountWriteThrottle @@ -417,6 +421,12 @@ func NewOpenAIGatewayService( responseHeaderFilter: compileResponseHeaderFilter(cfg), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), } + if rateLimitService != nil { + rateLimitService.SetAccountRuntimeBlocker(svc) + } + if openAITokenProvider != nil { + openAITokenProvider.SetAccountRuntimeBlocker(svc) + } svc.logOpenAIWSModeBootstrap() return svc } @@ -962,7 +972,7 @@ func appendCodexCLIOnlyRejectedRequestFields(fields []zap.Field, c *gin.Context, zap.String("request_path", strings.TrimSpace(req.URL.Path)), zap.String("request_query", strings.TrimSpace(req.URL.RawQuery)), zap.String("request_host", strings.TrimSpace(req.Host)), - zap.String("request_client_ip", strings.TrimSpace(c.ClientIP())), + zap.String("request_client_ip", strings.TrimSpace(ip.GetClientIP(c))), zap.String("request_remote_addr", strings.TrimSpace(req.RemoteAddr)), zap.String("request_user_agent", strings.TrimSpace(req.Header.Get("User-Agent"))), zap.String("request_content_type", strings.TrimSpace(req.Header.Get("Content-Type"))), @@ -1381,13 +1391,18 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked) } + hydrated, err := s.hydrateSelectedAccount(ctx, selected) + if err != nil { + return nil, err + } + // 4. 设置粘性会话绑定 // Set sticky session binding if sessionHash != "" { _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL) } - return s.hydrateSelectedAccount(ctx, selected) + return hydrated, nil } // tryStickySessionHit 尝试从粘性会话获取账号。 @@ -1430,6 +1445,10 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID if !isOpenAIAccountEligibleForRequest(account, requestedModel, false) { return nil } + if s.isOpenAIAccountRuntimeBlocked(account) { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + return nil + } account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) if account == nil { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) @@ -1575,8 +1594,8 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex return nil, err } result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) - if err == nil && result.Acquired { - return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) + if err == nil && result != nil && result.Acquired { + return s.newAcquiredSelectionResult(ctx, account, result.ReleaseFunc) } if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) @@ -1627,13 +1646,19 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) if account == nil { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + } else if s.isOpenAIAccountRuntimeBlocked(account) { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) { _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } else { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) - if err == nil && result.Acquired { + if err == nil && result != nil && result.Acquired { + selection, selectErr := s.newAcquiredSelectionResult(ctx, account, result.ReleaseFunc) + if selectErr != nil { + return nil, selectErr + } _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) - return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) + return selection, nil } waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) @@ -1665,6 +1690,9 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex if !acc.IsSchedulable() { continue } + if s.isOpenAIAccountRuntimeBlocked(acc) { + continue + } if requestedModel != "" && !acc.IsModelSupported(requestedModel) { continue } @@ -1687,6 +1715,92 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex }) } + tryAcquireFromLoadMap := func(loadMap map[int64]*AccountLoadInfo) (*AccountSelectionResult, bool, error) { + var available []accountWithLoad + for _, acc := range candidates { + loadInfo := loadMap[acc.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: acc.ID} + } + if loadInfo.LoadRate < 100 { + available = append(available, accountWithLoad{ + account: acc, + loadInfo: loadInfo, + }) + } + } + + if len(available) == 0 { + return nil, false, nil + } + + sort.SliceStable(available, func(i, j int) bool { + a, b := available[i], available[j] + if a.account.Priority != b.account.Priority { + return a.account.Priority < b.account.Priority + } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return a.loadInfo.LoadRate < b.loadInfo.LoadRate + } + switch { + case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: + return true + case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: + return false + case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: + return false + default: + return a.account.LastUsedAt.Before(*b.account.LastUsedAt) + } + }) + shuffleWithinSortGroups(available) + + selectionOrder := make([]accountWithLoad, 0, len(available)) + if requireCompact { + appendTier := func(out []accountWithLoad, tier int) []accountWithLoad { + for _, item := range available { + if openAICompactSupportTier(item.account) == tier { + out = append(out, item) + } + } + return out + } + selectionOrder = appendTier(selectionOrder, 2) + selectionOrder = appendTier(selectionOrder, 1) + // tier 0 候选作为兜底追加:DB recheck 时若发现 cache tier 0 实际 + // 已升级为 1/2(探测刚跑完,cache 尚未刷新),仍可正常命中。 + selectionOrder = appendTier(selectionOrder, 0) + } else { + selectionOrder = append(selectionOrder, available...) + } + + for _, item := range selectionOrder { + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false) + if fresh == nil { + continue + } + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) + if fresh == nil { + continue + } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) { + continue + } + result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) + if err == nil && result != nil && result.Acquired { + selection, selectErr := s.newAcquiredSelectionResult(ctx, fresh, result.ReleaseFunc) + if selectErr != nil { + return nil, true, selectErr + } + if sessionHash != "" { + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) + } + return selection, true, nil + } + } + return nil, true, nil + } + loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) if err != nil { ordered := append([]*Account(nil), candidates...) @@ -1707,87 +1821,28 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex continue } result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) - if err == nil && result.Acquired { + if err == nil && result != nil && result.Acquired { + selection, selectErr := s.newAcquiredSelectionResult(ctx, fresh, result.ReleaseFunc) + if selectErr != nil { + return nil, selectErr + } if sessionHash != "" { _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) } - return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil) + return selection, nil } } } else { - var available []accountWithLoad - for _, acc := range candidates { - loadInfo := loadMap[acc.ID] - if loadInfo == nil { - loadInfo = &AccountLoadInfo{AccountID: acc.ID} - } - if loadInfo.LoadRate < 100 { - available = append(available, accountWithLoad{ - account: acc, - loadInfo: loadInfo, - }) - } - } - - if len(available) > 0 { - sort.SliceStable(available, func(i, j int) bool { - a, b := available[i], available[j] - if a.account.Priority != b.account.Priority { - return a.account.Priority < b.account.Priority - } - if a.loadInfo.LoadRate != b.loadInfo.LoadRate { - return a.loadInfo.LoadRate < b.loadInfo.LoadRate - } - switch { - case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: - return true - case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: - return false - case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: - return false - default: - return a.account.LastUsedAt.Before(*b.account.LastUsedAt) - } - }) - shuffleWithinSortGroups(available) - - selectionOrder := make([]accountWithLoad, 0, len(available)) - if requireCompact { - appendTier := func(out []accountWithLoad, tier int) []accountWithLoad { - for _, item := range available { - if openAICompactSupportTier(item.account) == tier { - out = append(out, item) - } - } - return out - } - selectionOrder = appendTier(selectionOrder, 2) - selectionOrder = appendTier(selectionOrder, 1) - // tier 0 候选作为兜底追加:DB recheck 时若发现 cache tier 0 实际 - // 已升级为 1/2(探测刚跑完,cache 尚未刷新),仍可正常命中。 - selectionOrder = appendTier(selectionOrder, 0) - } else { - selectionOrder = append(selectionOrder, available...) - } - - for _, item := range selectionOrder { - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false) - if fresh == nil { - continue - } - fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) - if fresh == nil { - continue - } - if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) { - continue - } - result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) - if err == nil && result.Acquired { - if sessionHash != "" { - _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) - } - return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil) + if selection, attempted, selectErr := tryAcquireFromLoadMap(loadMap); selectErr != nil { + return nil, selectErr + } else if selection != nil { + return selection, nil + } else if attempted { + if freshLoadMap, loadErr := s.concurrencyService.GetAccountsLoadBatchFresh(ctx, accountLoads); loadErr == nil { + if selection, _, selectErr := tryAcquireFromLoadMap(freshLoadMap); selectErr != nil { + return nil, selectErr + } else if selection != nil { + return selection, nil } } } @@ -1868,6 +1923,9 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context. if !isOpenAIAccountEligibleForRequest(fresh, requestedModel, requireCompact) { return nil } + if s.isOpenAIAccountRuntimeBlocked(fresh) { + return nil + } return fresh } @@ -1889,6 +1947,9 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co if !isOpenAIAccountEligibleForRequest(latest, requestedModel, requireCompact) { return nil } + if s.isOpenAIAccountRuntimeBlocked(latest) { + return nil + } return latest } @@ -1935,6 +1996,14 @@ func (s *OpenAIGatewayService) newSelectionResult(ctx context.Context, account * }, nil } +func (s *OpenAIGatewayService) newAcquiredSelectionResult(ctx context.Context, account *Account, release func()) (*AccountSelectionResult, error) { + selection, err := s.newSelectionResult(ctx, account, true, release, nil) + if err != nil && release != nil { + release() + } + return selection, err +} + func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { if s.cfg != nil { return s.cfg.Gateway.Scheduling @@ -1996,7 +2065,7 @@ func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode i func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) } // Forward forwards request to OpenAI API @@ -3278,9 +3347,7 @@ func (s *OpenAIGatewayService) handleFailoverErrorResponsePassthrough( } setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) - if s.rateLimitService != nil { - _ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) - } + _ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, @@ -3321,12 +3388,9 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough( } setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) - if s.rateLimitService != nil { - // Passthrough mode preserves the raw upstream error response, but runtime - // account state still needs to be updated so sticky routing can stop - // reusing a freshly rate-limited account. - _ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) - } + // 透传模式保留原始上游错误响应,但运行态账号状态仍需更新, + // 避免粘性路由继续复用刚被限流的账号。 + _ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, @@ -4075,10 +4139,7 @@ func (s *OpenAIGatewayService) handleErrorResponse( } // Handle upstream error (mark account status) - shouldDisable := false - if s.rateLimitService != nil { - shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) - } + shouldDisable := s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) kind := "http_error" if shouldDisable { kind = "failover" @@ -4210,12 +4271,9 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse( } // Track rate limits and decide whether to trigger secondary failover. - shouldDisable := false - if s.rateLimitService != nil { - shouldDisable = s.rateLimitService.HandleUpstreamError( - c.Request.Context(), account, resp.StatusCode, resp.Header, body, - ) - } + shouldDisable := s.handleOpenAIAccountUpstreamError( + c.Request.Context(), account, resp.StatusCode, resp.Header, body, + ) kind := "http_error" if shouldDisable { kind = "failover" diff --git a/backend/internal/service/openai_gateway_service_codex_cli_only_test.go b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go index 951860cd..17a874ea 100644 --- a/backend/internal/service/openai_gateway_service_codex_cli_only_test.go +++ b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go @@ -131,8 +131,10 @@ func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) { rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil)) + c.Request.RemoteAddr = "172.18.0.1:54321" c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown") c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("X-Real-IP", "203.0.113.42") c.Request.Header.Set("OpenAI-Beta", "assistants=v2") body := []byte(`{"model":"gpt-5.2","stream":false,"prompt_cache_key":"pc-123","access_token":"secret-token","input":[{"type":"text","text":"hello"}]}`) @@ -146,6 +148,8 @@ func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) { require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")) require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.2")) require.True(t, logSink.ContainsFieldValue("request_query", "trace=1")) + require.True(t, logSink.ContainsFieldValue("request_client_ip", "203.0.113.42")) + require.True(t, logSink.ContainsFieldValue("request_remote_addr", "172.18.0.1:54321")) require.True(t, logSink.ContainsFieldValue("request_prompt_cache_key_sha256", hashSensitiveValueForLog("pc-123"))) require.True(t, logSink.ContainsFieldValue("request_headers", "openai-beta")) require.True(t, logSink.ContainsField("request_body_size")) diff --git a/backend/internal/service/openai_images_responses.go b/backend/internal/service/openai_images_responses.go index 56272c26..b39fa609 100644 --- a/backend/internal/service/openai_images_responses.go +++ b/backend/internal/service/openai_images_responses.go @@ -29,6 +29,69 @@ type openAIResponsesImageResult struct { Model string } +type OpenAIImagesUpstreamError struct { + StatusCode int + ErrorType string + Code string + Message string + Param string + UpstreamRequestID string +} + +func (e *OpenAIImagesUpstreamError) Error() string { + if e == nil { + return "" + } + code := strings.TrimSpace(e.Code) + if code == "" { + code = strings.TrimSpace(e.ErrorType) + } + message := strings.TrimSpace(e.Message) + if code != "" && message != "" { + return fmt.Sprintf("openai images upstream error: %s: %s", code, message) + } + if message != "" { + return "openai images upstream error: " + message + } + if code != "" { + return "openai images upstream error: " + code + } + return "openai images upstream error" +} + +func (e *OpenAIImagesUpstreamError) clientStatusCode() int { + if e == nil { + return http.StatusBadGateway + } + if e.StatusCode > 0 { + return e.StatusCode + } + return http.StatusBadGateway +} + +func (e *OpenAIImagesUpstreamError) clientErrorType() string { + if e == nil { + return "upstream_error" + } + if trimmed := strings.TrimSpace(e.ErrorType); trimmed != "" { + return trimmed + } + return "upstream_error" +} + +func (e *OpenAIImagesUpstreamError) clientMessage() string { + if e == nil { + return "Upstream request failed" + } + if trimmed := strings.TrimSpace(e.Message); trimmed != "" { + return trimmed + } + if trimmed := strings.TrimSpace(e.Code); trimmed != "" { + return trimmed + } + return "Upstream request failed" +} + func openAIResponsesImageResultKey(itemID string, result openAIResponsesImageResult) string { if strings.TrimSpace(result.Result) != "" { return strings.TrimSpace(result.OutputFormat) + "|" + strings.TrimSpace(result.Result) @@ -465,6 +528,57 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe return nil, createdAt, usageRaw, openAIResponsesImageResult{}, foundFinal, nil } +func extractOpenAIImagesUpstreamError(body []byte) *OpenAIImagesUpstreamError { + var upstreamErr *OpenAIImagesUpstreamError + forEachOpenAISSEDataPayload(string(body), func(payload []byte) { + if upstreamErr != nil || !gjson.ValidBytes(payload) { + return + } + upstreamErr = openAIImagesUpstreamErrorFromSSEPayload(payload) + }) + return upstreamErr +} + +func openAIImagesUpstreamErrorFromSSEPayload(payload []byte) *OpenAIImagesUpstreamError { + if !gjson.ValidBytes(payload) { + return nil + } + switch gjson.GetBytes(payload, "type").String() { + case "error": + return openAIImagesUpstreamErrorFromGJSON(gjson.GetBytes(payload, "error"), "") + case "response.failed": + response := gjson.GetBytes(payload, "response") + return openAIImagesUpstreamErrorFromGJSON(response.Get("error"), response.Get("id").String()) + default: + return nil + } +} + +func openAIImagesUpstreamErrorFromGJSON(errorObj gjson.Result, upstreamRequestID string) *OpenAIImagesUpstreamError { + if !errorObj.Exists() { + return nil + } + code := strings.TrimSpace(errorObj.Get("code").String()) + errType := strings.TrimSpace(errorObj.Get("type").String()) + message := strings.TrimSpace(errorObj.Get("message").String()) + param := strings.TrimSpace(errorObj.Get("param").String()) + statusCode := http.StatusBadGateway + if strings.EqualFold(code, "moderation_blocked") || strings.EqualFold(errType, "image_generation_user_error") { + statusCode = http.StatusBadRequest + } + if message == "" { + message = "Upstream request failed" + } + return &OpenAIImagesUpstreamError{ + StatusCode: statusCode, + ErrorType: errType, + Code: code, + Message: sanitizeUpstreamErrorMessage(message), + Param: param, + UpstreamRequestID: strings.TrimSpace(upstreamRequestID), + } +} + func buildOpenAIImagesAPIResponse( results []openAIResponsesImageResult, createdAt int64, @@ -531,6 +645,41 @@ func buildOpenAIImagesStreamErrorBody(message string) []byte { return body } +func buildOpenAIImagesStreamErrorBodyFromUpstream(err *OpenAIImagesUpstreamError) []byte { + if err == nil { + return buildOpenAIImagesStreamErrorBody("") + } + body := buildOpenAIImagesStreamErrorBody(err.clientMessage()) + body, _ = sjson.SetBytes(body, "error.type", err.clientErrorType()) + if code := strings.TrimSpace(err.Code); code != "" { + body, _ = sjson.SetBytes(body, "error.code", code) + } + if param := strings.TrimSpace(err.Param); param != "" { + body, _ = sjson.SetBytes(body, "error.param", param) + } + return body +} + +func writeOpenAIImagesUpstreamErrorResponse(c *gin.Context, err *OpenAIImagesUpstreamError) bool { + if c == nil || c.Writer == nil || c.Writer.Written() || err == nil { + return false + } + errorObj := gin.H{ + "type": err.clientErrorType(), + "message": err.clientMessage(), + } + if code := strings.TrimSpace(err.Code); code != "" { + errorObj["code"] = code + } + if param := strings.TrimSpace(err.Param); param != "" { + errorObj["param"] = param + } + c.JSON(err.clientStatusCode(), gin.H{ + "error": errorObj, + }) + return true +} + func (s *OpenAIGatewayService) writeOpenAIImagesStreamEvent(c *gin.Context, flusher http.Flusher, eventName string, payload []byte) error { if strings.TrimSpace(eventName) != "" { if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil { @@ -588,6 +737,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse( return OpenAIUsage{}, 0, nil, err } if len(results) == 0 { + if upstreamErr := extractOpenAIImagesUpstreamError(body); upstreamErr != nil { + setOpsUpstreamError(c, upstreamErr.clientStatusCode(), upstreamErr.clientMessage(), "") + writeOpenAIImagesUpstreamErrorResponse(c, upstreamErr) + return OpenAIUsage{}, 0, nil, upstreamErr + } return OpenAIUsage{}, 0, nil, fmt.Errorf("upstream did not return image output") } if strings.TrimSpace(firstMeta.Model) == "" { @@ -742,6 +896,16 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( imageCount = len(emitted) imageOutputSizes = openAIResponsesImageResultSizes(finalResults) processDataDone = true + case "error", "response.failed": + if upstreamErr := openAIImagesUpstreamErrorFromSSEPayload(dataBytes); upstreamErr != nil { + if !clientDisconnected { + s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBodyFromUpstream(upstreamErr)) + } + setOpsUpstreamError(c, upstreamErr.clientStatusCode(), upstreamErr.clientMessage(), "") + processDataErr = upstreamErr + processDataDone = true + return + } } } diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index d47c52ca..52903a1b 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -553,6 +553,59 @@ func TestOpenAIGatewayServiceForwardImages_OAuthPassesNAndReturnsAllImages(t *te require.Equal(t, "draw a cat 3", gjson.Get(rec.Body.String(), "data.2.revised_prompt").String()) } +func TestOpenAIGatewayServiceForwardImages_OAuthNonStreamModerationBlockedReturnsClientError(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw blocked image","response_format":"b64_json"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + c.Set("api_key", &APIKey{ID: 42}) + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + svc.httpUpstream = &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_blocked"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000020}}\n\n" + + "data: {\"type\":\"error\",\"error\":{\"type\":\"image_generation_user_error\",\"code\":\"moderation_blocked\",\"message\":\"Your request was rejected by the safety system. safety_violations=[sexual].\"}}\n\n" + + "data: {\"type\":\"response.failed\",\"response\":{\"id\":\"resp_blocked\",\"status\":\"failed\",\"error\":{\"type\":\"image_generation_user_error\",\"code\":\"moderation_blocked\",\"message\":\"Your request was rejected by the safety system. safety_violations=[sexual].\"}}}\n\n", + )), + }, + } + + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.Nil(t, result) + var upstreamErr *OpenAIImagesUpstreamError + require.ErrorAs(t, err, &upstreamErr) + require.Equal(t, http.StatusBadRequest, upstreamErr.StatusCode) + require.Equal(t, "moderation_blocked", upstreamErr.Code) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Equal(t, "image_generation_user_error", gjson.Get(rec.Body.String(), "error.type").String()) + require.Equal(t, "moderation_blocked", gjson.Get(rec.Body.String(), "error.code").String()) + require.Contains(t, gjson.Get(rec.Body.String(), "error.message").String(), "safety system") +} + func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) { gin.SetMode(gin.TestMode) body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","response_format":"b64_json"}`) diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index 5b55d200..94791189 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -80,6 +80,7 @@ type OpenAITokenProvider struct { accountRepo AccountRepository tokenCache OpenAITokenCache openAIOAuthService *OpenAIOAuthService + runtimeBlocker AccountRuntimeBlocker metrics *openAITokenRuntimeMetricsStore refreshAPI *OAuthRefreshAPI executor OAuthRefreshExecutor @@ -111,6 +112,10 @@ func (p *OpenAITokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) { p.refreshPolicy = policy } +func (p *OpenAITokenProvider) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) { + p.runtimeBlocker = blocker +} + func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics { if p == nil { return OpenAITokenRuntimeMetrics{} @@ -275,6 +280,9 @@ func (p *OpenAITokenProvider) disableAccountMissingRefreshToken(account *Account if p == nil || p.accountRepo == nil || account == nil { return } + if p.runtimeBlocker != nil { + p.runtimeBlocker.BlockAccountScheduling(account, time.Time{}, "missing_refresh_token") + } bgCtx := context.Background() if err := p.accountRepo.SetError(bgCtx, account.ID, reason); err != nil { slog.Warn("openai_token_provider.set_error_failed", diff --git a/backend/internal/service/openai_token_provider_test.go b/backend/internal/service/openai_token_provider_test.go index df2f0f3e..fb506523 100644 --- a/backend/internal/service/openai_token_provider_test.go +++ b/backend/internal/service/openai_token_provider_test.go @@ -952,6 +952,8 @@ func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T) cache.getErr = errors.New("simulated cache miss") provider := NewOpenAITokenProvider(repo, cache, nil) + blocker := &runtimeBlockRecorder{} + provider.SetAccountRuntimeBlocker(blocker) token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) @@ -960,4 +962,7 @@ func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T) require.Equal(t, 1, repo.setErrorCalls, "account should be disabled via SetError exactly once") require.Contains(t, repo.lastErrorMsg, "refresh_token is missing") + require.Len(t, blocker.accounts, 1) + require.Equal(t, account.ID, blocker.accounts[0].ID) + require.Equal(t, "missing_refresh_token", blocker.reasons[0]) } diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 920a2239..700dbedf 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -4091,7 +4091,7 @@ func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Contex if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) { return } - s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody) + s.handleOpenAIAccountUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody) } func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) { diff --git a/backend/internal/service/payment_order_lifecycle_test.go b/backend/internal/service/payment_order_lifecycle_test.go index 1964cdf6..658a1806 100644 --- a/backend/internal/service/payment_order_lifecycle_test.go +++ b/backend/internal/service/payment_order_lifecycle_test.go @@ -115,6 +115,10 @@ func (r *paymentOrderLifecycleRedeemRepo) Update(context.Context, *RedeemCode) e panic("unexpected call") } +func (r *paymentOrderLifecycleRedeemRepo) BatchUpdate(context.Context, []int64, RedeemCodeBatchUpdateFields) (int64, error) { + panic("unexpected call") +} + func (r *paymentOrderLifecycleRedeemRepo) Delete(context.Context, int64) error { panic("unexpected call") } diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 892d9aca..c3b160e7 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -28,10 +28,16 @@ type RateLimitService struct { openAI403CounterCache OpenAI403CounterCache settingService *SettingService tokenCacheInvalidator TokenCacheInvalidator + runtimeBlocker AccountRuntimeBlocker usageCacheMu sync.RWMutex usageCache map[int64]*geminiUsageCacheEntry } +type AccountRuntimeBlocker interface { + BlockAccountScheduling(account *Account, until time.Time, reason string) + ClearAccountSchedulingBlock(accountID int64) +} + // SuccessfulTestRecoveryResult 表示测试成功后恢复了哪些运行时状态。 type SuccessfulTestRecoveryResult struct { ClearedError bool @@ -98,6 +104,24 @@ func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvali s.tokenCacheInvalidator = invalidator } +func (s *RateLimitService) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) { + s.runtimeBlocker = blocker +} + +func (s *RateLimitService) notifyAccountSchedulingBlocked(account *Account, until time.Time, reason string) { + if s == nil || s.runtimeBlocker == nil || account == nil { + return + } + s.runtimeBlocker.BlockAccountScheduling(account, until, reason) +} + +func (s *RateLimitService) notifyAccountSchedulingBlockCleared(accountID int64) { + if s == nil || s.runtimeBlocker == nil || accountID <= 0 { + return + } + s.runtimeBlocker.ClearAccountSchedulingBlock(accountID) +} + // ErrorPolicyResult 表示错误策略检查的结果 type ErrorPolicyResult int @@ -240,6 +264,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc cooldownMinutes = 10 } until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute) + s.notifyAccountSchedulingBlocked(account, until, "oauth_401") if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, msg); err != nil { slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err) } @@ -678,6 +703,7 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account) // handleAuthError 处理认证类错误(401/403),停止账号调度 func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) { + s.notifyAccountSchedulingBlocked(account, time.Time{}, "auth_error") if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err) return @@ -758,6 +784,7 @@ func (s *RateLimitService) handleOpenAI403(ctx context.Context, account *Account until := time.Now().Add(time.Duration(openAI403CooldownMinutesDefault) * time.Minute) reason := fmt.Sprintf("OpenAI 403 temporary cooldown (%d/%d): %s", count, openAI403DisableThreshold, msg) + s.notifyAccountSchedulingBlocked(account, until, "openai_403_temp") if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { slog.Warn("openai_403_set_temp_unschedulable_failed", "account_id", account.ID, "error", err) s.handleAuthError(ctx, account, msg) @@ -823,6 +850,7 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac // handleCustomErrorCode 处理自定义错误码,停止账号调度 func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) { msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg + s.notifyAccountSchedulingBlocked(account, time.Time{}, "custom_error_code") if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil { slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err) return @@ -838,6 +866,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head persistOpenAI429PlanType(ctx, s.accountRepo, account, responseBody) s.persistOpenAICodexSnapshot(ctx, account, headers) if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil { + s.notifyAccountSchedulingBlocked(account, *resetAt, "429") if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) return @@ -849,6 +878,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head // 2. Anthropic 平台:尝试解析 per-window 头(5h / 7d),选择实际触发的窗口 if result := calculateAnthropic429ResetTime(headers); result != nil { + s.notifyAccountSchedulingBlocked(account, result.resetAt, "429") if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) return @@ -878,6 +908,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head // 尝试解析 OpenAI 的 usage_limit_reached 错误 if resetAt := parseOpenAIRateLimitResetTime(responseBody); resetAt != nil { resetTime := time.Unix(*resetAt, 0) + s.notifyAccountSchedulingBlocked(account, resetTime, "429") if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) return @@ -889,6 +920,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head // 尝试解析 Gemini 格式(用于其他平台) if resetAt := ParseGeminiRateLimitResetTime(responseBody); resetAt != nil { resetTime := time.Unix(*resetAt, 0) + s.notifyAccountSchedulingBlocked(account, resetTime, "429") if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) return @@ -924,6 +956,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head resetAt := time.Unix(ts, 0) // 标记限流状态 + s.notifyAccountSchedulingBlocked(account, resetAt, "429") if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) return @@ -948,6 +981,7 @@ func (s *RateLimitService) apply429FallbackRateLimit(ctx context.Context, accoun resetAt := time.Now().Add(cooldown) slog.Warn("rate_limit_429_fallback_used", "account_id", account.ID, "platform", account.Platform, "reason", reason, "using_default", cooldown.String()) + s.notifyAccountSchedulingBlocked(account, resetAt, "429_fallback") if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err) } @@ -1291,6 +1325,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) { } until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute) + s.notifyAccountSchedulingBlocked(account, until, "529") if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil { slog.Warn("overload_set_failed", "account_id", account.ID, "error", err) return @@ -1420,6 +1455,7 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) } } s.ResetOpenAI403Counter(ctx, accountID) + s.notifyAccountSchedulingBlockCleared(accountID) return nil } @@ -1460,6 +1496,9 @@ func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID in } if result.ClearedError || result.ClearedRateLimit { s.ResetOpenAI403Counter(ctx, accountID) + if result.ClearedError && !result.ClearedRateLimit { + s.notifyAccountSchedulingBlockCleared(accountID) + } } return result, nil @@ -1484,6 +1523,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID if err := s.accountRepo.ClearModelRateLimits(ctx, accountID); err != nil { slog.Warn("clear_model_rate_limits_on_temp_unsched_reset_failed", "account_id", accountID, "error", err) } + s.notifyAccountSchedulingBlockCleared(accountID) return nil } @@ -1694,6 +1734,7 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account reason = strings.TrimSpace(state.ErrorMessage) } + s.notifyAccountSchedulingBlocked(account, until, "temp_unschedulable") if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err) return false @@ -1798,6 +1839,7 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context, reason = state.ErrorMessage } + s.notifyAccountSchedulingBlocked(account, until, "stream_timeout_temp_unschedulable") if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err) return false @@ -1824,6 +1866,7 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context, func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, account *Account, model string) bool { errorMsg := "Stream data interval timeout (repeated failures) for model: " + model + s.notifyAccountSchedulingBlocked(account, time.Time{}, "stream_timeout_error") if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err) return false diff --git a/backend/internal/service/ratelimit_service_403_test.go b/backend/internal/service/ratelimit_service_403_test.go index 2fd11b71..9d4b2714 100644 --- a/backend/internal/service/ratelimit_service_403_test.go +++ b/backend/internal/service/ratelimit_service_403_test.go @@ -6,16 +6,36 @@ import ( "context" "net/http" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" ) +type runtimeBlockRecorder struct { + accounts []*Account + until []time.Time + reasons []string + clearedIDs []int64 +} + +func (r *runtimeBlockRecorder) BlockAccountScheduling(account *Account, until time.Time, reason string) { + r.accounts = append(r.accounts, account) + r.until = append(r.until, until) + r.reasons = append(r.reasons, reason) +} + +func (r *runtimeBlockRecorder) ClearAccountSchedulingBlock(accountID int64) { + r.clearedIDs = append(r.clearedIDs, accountID) +} + func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable(t *testing.T) { repo := &rateLimitAccountRepoStub{} counter := &openAI403CounterCacheStub{counts: []int64{1}} + blocker := &runtimeBlockRecorder{} service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) service.SetOpenAI403CounterCache(counter) + service.SetAccountRuntimeBlocker(blocker) account := &Account{ ID: 301, Platform: PlatformOpenAI, @@ -35,6 +55,10 @@ func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable require.Equal(t, 1, repo.tempCalls) require.Contains(t, repo.lastTempReason, "temporary edge rejection") require.Contains(t, repo.lastTempReason, "(1/3)") + require.Len(t, blocker.accounts, 1) + require.Equal(t, account.ID, blocker.accounts[0].ID) + require.Equal(t, "openai_403_temp", blocker.reasons[0]) + require.True(t, blocker.until[0].After(time.Now())) } func TestRateLimitService_HandleUpstreamError_OpenAI403ThresholdDisables(t *testing.T) { diff --git a/backend/internal/service/ratelimit_service_clear_test.go b/backend/internal/service/ratelimit_service_clear_test.go index 1d7a02fc..e8d312f9 100644 --- a/backend/internal/service/ratelimit_service_clear_test.go +++ b/backend/internal/service/ratelimit_service_clear_test.go @@ -219,7 +219,9 @@ func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLi }, } cache := &tempUnschedCacheRecorder{} + blocker := &runtimeBlockRecorder{} svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + svc.SetAccountRuntimeBlocker(blocker) result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 42) require.NoError(t, err) @@ -234,6 +236,7 @@ func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLi require.Equal(t, 1, repo.clearModelRateLimitCalls) require.Equal(t, 1, repo.clearTempUnschedCalls) require.Equal(t, []int64{42}, cache.deletedIDs) + require.Equal(t, []int64{42}, blocker.clearedIDs) } func TestRateLimitService_RecoverAccountAfterSuccessfulTest_NoRecoverableStateIsNoop(t *testing.T) { diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index 73aa02b1..8db0d7da 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -54,6 +54,7 @@ type RedeemCodeRepository interface { GetByID(ctx context.Context, id int64) (*RedeemCode, error) GetByCode(ctx context.Context, code string) (*RedeemCode, error) Update(ctx context.Context, code *RedeemCode) error + BatchUpdate(ctx context.Context, ids []int64, fields RedeemCodeBatchUpdateFields) (int64, error) Delete(ctx context.Context, id int64) error Use(ctx context.Context, id, userID int64) error @@ -82,6 +83,54 @@ type RedeemCodeResponse struct { CreatedAt time.Time `json:"created_at"` } +type NullableTimeUpdate struct { + Set bool + Value *time.Time +} + +type NullableInt64Update struct { + Set bool + Value *int64 +} + +type RedeemCodeBatchUpdateFields struct { + Status *string + ExpiresAt NullableTimeUpdate + Notes *string + GroupID NullableInt64Update + + // Core fields are intentionally modeled only so service validation can + // reject payloads that try to mutate redemption value semantics in bulk. + Type *string + Value *float64 +} + +func (f RedeemCodeBatchUpdateFields) HasChanges() bool { + return f.Status != nil || + f.ExpiresAt.Set || + f.Notes != nil || + f.GroupID.Set || + f.Type != nil || + f.Value != nil +} + +func (f RedeemCodeBatchUpdateFields) HasCoreFieldChanges() bool { + return f.Type != nil || f.Value != nil +} + +func (f RedeemCodeBatchUpdateFields) TouchesUsedSensitiveFields() bool { + return f.Status != nil || f.ExpiresAt.Set || f.GroupID.Set +} + +type RedeemCodeBatchUpdateInput struct { + IDs []int64 + Fields RedeemCodeBatchUpdateFields +} + +type RedeemCodeBatchUpdateResult struct { + Updated int64 `json:"updated"` +} + // RedeemService 兑换码服务 type RedeemService struct { redeemRepo RedeemCodeRepository @@ -218,6 +267,61 @@ func (s *RedeemService) CreateCode(ctx context.Context, code *RedeemCode) error return nil } +func (s *RedeemService) BatchUpdate(ctx context.Context, input *RedeemCodeBatchUpdateInput) (*RedeemCodeBatchUpdateResult, error) { + if input == nil { + return nil, infraerrors.BadRequest("REDEEM_CODE_BATCH_UPDATE_INVALID", "batch update input is required") + } + if len(input.IDs) == 0 { + return nil, infraerrors.BadRequest("REDEEM_CODE_BATCH_UPDATE_IDS_REQUIRED", "ids are required") + } + if !input.Fields.HasChanges() { + return nil, infraerrors.BadRequest("REDEEM_CODE_BATCH_UPDATE_EMPTY", "at least one field must be selected") + } + if input.Fields.HasCoreFieldChanges() { + return nil, infraerrors.BadRequest("REDEEM_CODE_CORE_FIELDS_IMMUTABLE", "type and value cannot be batch updated") + } + + ids := make([]int64, 0, len(input.IDs)) + seen := make(map[int64]struct{}, len(input.IDs)) + for _, id := range input.IDs { + if id <= 0 { + return nil, infraerrors.BadRequest("REDEEM_CODE_BATCH_UPDATE_INVALID_ID", "ids must be positive") + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + ids = append(ids, id) + } + if len(ids) == 0 { + return nil, infraerrors.BadRequest("REDEEM_CODE_BATCH_UPDATE_IDS_REQUIRED", "ids are required") + } + + if input.Fields.Status != nil { + switch *input.Fields.Status { + case StatusUnused, StatusDisabled: + default: + return nil, infraerrors.BadRequest("REDEEM_CODE_STATUS_INVALID", "status must be unused or disabled") + } + } + if input.Fields.ExpiresAt.Set && input.Fields.ExpiresAt.Value != nil { + expiresAt := input.Fields.ExpiresAt.Value.UTC() + if !expiresAt.After(time.Now().UTC()) { + return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRES_AT_INVALID", "expires_at must be in the future") + } + input.Fields.ExpiresAt.Value = &expiresAt + } + if input.Fields.GroupID.Set && input.Fields.GroupID.Value != nil && *input.Fields.GroupID.Value <= 0 { + return nil, infraerrors.BadRequest("REDEEM_CODE_GROUP_ID_INVALID", "group_id must be positive") + } + + updated, err := s.redeemRepo.BatchUpdate(ctx, ids, input.Fields) + if err != nil { + return nil, err + } + return &RedeemCodeBatchUpdateResult{Updated: updated}, nil +} + // checkRedeemRateLimit 检查用户兑换错误次数是否超限 func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error { if s.cache == nil { diff --git a/backend/internal/service/redeem_service_batch_update_test.go b/backend/internal/service/redeem_service_batch_update_test.go new file mode 100644 index 00000000..e54019cc --- /dev/null +++ b/backend/internal/service/redeem_service_batch_update_test.go @@ -0,0 +1,75 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestRedeemService_BatchUpdate_PartialFields(t *testing.T) { + status := StatusDisabled + notes := "maintenance window" + expiresAt := time.Now().UTC().Add(24 * time.Hour) + repo := &redeemRepoStub{} + svc := &RedeemService{redeemRepo: repo} + + result, err := svc.BatchUpdate(context.Background(), &RedeemCodeBatchUpdateInput{ + IDs: []int64{1, 2, 2}, + Fields: RedeemCodeBatchUpdateFields{ + Status: &status, + ExpiresAt: NullableTimeUpdate{Set: true, Value: &expiresAt}, + Notes: ¬es, + }, + }) + + require.NoError(t, err) + require.Equal(t, int64(2), result.Updated) + require.True(t, repo.batchUpdateCalled) + require.Equal(t, []int64{1, 2}, repo.batchUpdateIDs) + require.Equal(t, &status, repo.batchUpdateFields.Status) + require.True(t, repo.batchUpdateFields.ExpiresAt.Set) + require.WithinDuration(t, expiresAt, *repo.batchUpdateFields.ExpiresAt.Value, time.Second) + require.Equal(t, ¬es, repo.batchUpdateFields.Notes) + require.False(t, repo.batchUpdateFields.GroupID.Set) + require.Nil(t, repo.batchUpdateFields.Type) + require.Nil(t, repo.batchUpdateFields.Value) +} + +func TestRedeemService_BatchUpdate_RejectsInvalidID(t *testing.T) { + repo := &redeemRepoStub{} + svc := &RedeemService{redeemRepo: repo} + notes := "bad id" + + result, err := svc.BatchUpdate(context.Background(), &RedeemCodeBatchUpdateInput{ + IDs: []int64{1, 0}, + Fields: RedeemCodeBatchUpdateFields{Notes: ¬es}, + }) + + require.Nil(t, result) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) + require.False(t, repo.batchUpdateCalled) +} + +func TestRedeemService_BatchUpdate_RejectsCoreFieldsForUsedCodes(t *testing.T) { + repo := &redeemRepoStub{} + svc := &RedeemService{redeemRepo: repo} + newValue := 100.0 + + result, err := svc.BatchUpdate(context.Background(), &RedeemCodeBatchUpdateInput{ + IDs: []int64{42}, + Fields: RedeemCodeBatchUpdateFields{ + Value: &newValue, + }, + }) + + require.Nil(t, result) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) + require.False(t, repo.batchUpdateCalled) +} diff --git a/backend/internal/service/registration_email_policy.go b/backend/internal/service/registration_email_policy.go index 875668c7..0c910d62 100644 --- a/backend/internal/service/registration_email_policy.go +++ b/backend/internal/service/registration_email_policy.go @@ -26,12 +26,17 @@ func IsRegistrationEmailSuffixAllowed(email string, whitelist []string) bool { if len(whitelist) == 0 { return true } - suffix := RegistrationEmailSuffix(email) - if suffix == "" { + _, domain, ok := splitEmailForPolicy(email) + if !ok { return false } + suffix := "@" + domain for _, allowed := range whitelist { - if suffix == allowed { + allowed = strings.ToLower(strings.TrimSpace(allowed)) + if strings.HasPrefix(allowed, "@") && suffix == allowed { + return true + } + if strings.HasPrefix(allowed, "*.") && registrationEmailDomainMatchesWildcard(domain, allowed) { return true } } @@ -98,6 +103,14 @@ func normalizeRegistrationEmailSuffix(raw string) (string, error) { return "", nil } + if strings.HasPrefix(value, "*.") { + domain := strings.TrimPrefix(value, "*.") + if !isValidRegistrationEmailDomain(domain) { + return "", fmt.Errorf("invalid email suffix: %q", raw) + } + return "*." + domain, nil + } + domain := value if strings.Contains(value, "@") { if !strings.HasPrefix(value, "@") || strings.Count(value, "@") != 1 { @@ -106,13 +119,27 @@ func normalizeRegistrationEmailSuffix(raw string) (string, error) { domain = strings.TrimPrefix(value, "@") } - if domain == "" || strings.Contains(domain, "@") || !registrationEmailDomainPattern.MatchString(domain) { + if !isValidRegistrationEmailDomain(domain) { return "", fmt.Errorf("invalid email suffix: %q", raw) } return "@" + domain, nil } +func isValidRegistrationEmailDomain(domain string) bool { + return domain != "" && + !strings.Contains(domain, "@") && + registrationEmailDomainPattern.MatchString(domain) +} + +func registrationEmailDomainMatchesWildcard(domain string, allowed string) bool { + base := strings.TrimPrefix(allowed, "*.") + if !isValidRegistrationEmailDomain(base) { + return false + } + return domain == base || strings.HasSuffix(domain, "."+base) +} + func splitEmailForPolicy(raw string) (local string, domain string, ok bool) { email := strings.ToLower(strings.TrimSpace(raw)) local, domain, found := strings.Cut(email, "@") diff --git a/backend/internal/service/registration_email_policy_test.go b/backend/internal/service/registration_email_policy_test.go index f0c46642..79f47be0 100644 --- a/backend/internal/service/registration_email_policy_test.go +++ b/backend/internal/service/registration_email_policy_test.go @@ -9,23 +9,36 @@ import ( ) func TestNormalizeRegistrationEmailSuffixWhitelist(t *testing.T) { - got, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"example.com", "@EXAMPLE.COM", " @foo.bar "}) + got, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"example.com", "@EXAMPLE.COM", " @foo.bar ", "*.EDU.CN"}) require.NoError(t, err) - require.Equal(t, []string{"@example.com", "@foo.bar"}, got) + require.Equal(t, []string{"@example.com", "@foo.bar", "*.edu.cn"}, got) } func TestNormalizeRegistrationEmailSuffixWhitelist_Invalid(t *testing.T) { - _, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"@invalid_domain"}) - require.Error(t, err) + for _, item := range []string{"@invalid_domain", "*.", "*", "*.@", "*.foo"} { + t.Run(item, func(t *testing.T) { + _, err := NormalizeRegistrationEmailSuffixWhitelist([]string{item}) + require.Error(t, err) + }) + } } func TestParseRegistrationEmailSuffixWhitelist(t *testing.T) { - got := ParseRegistrationEmailSuffixWhitelist(`["example.com","@foo.bar","@invalid_domain"]`) - require.Equal(t, []string{"@example.com", "@foo.bar"}, got) + got := ParseRegistrationEmailSuffixWhitelist(`["example.com","@foo.bar","*.EDU.CN","@invalid_domain","*.foo"]`) + require.Equal(t, []string{"@example.com", "@foo.bar", "*.edu.cn"}, got) } func TestIsRegistrationEmailSuffixAllowed(t *testing.T) { require.True(t, IsRegistrationEmailSuffixAllowed("user@example.com", []string{"@example.com"})) require.False(t, IsRegistrationEmailSuffixAllowed("user@sub.example.com", []string{"@example.com"})) + require.True(t, IsRegistrationEmailSuffixAllowed("user@qq.com", []string{"@qq.com"})) + require.False(t, IsRegistrationEmailSuffixAllowed("user@sub.qq.com", []string{"@qq.com"})) + require.True(t, IsRegistrationEmailSuffixAllowed("student@cs.edu.cn", []string{"*.edu.cn"})) + require.True(t, IsRegistrationEmailSuffixAllowed("student@edu.cn", []string{"*.edu.cn"})) + require.False(t, IsRegistrationEmailSuffixAllowed("student@foo.cn", []string{"*.edu.cn"})) + require.True(t, IsRegistrationEmailSuffixAllowed("user@a.com", []string{"@a.com", "*.b.cn"})) + require.True(t, IsRegistrationEmailSuffixAllowed("user@school.b.cn", []string{"@a.com", "*.b.cn"})) + require.True(t, IsRegistrationEmailSuffixAllowed("user@b.cn", []string{"@a.com", "*.b.cn"})) + require.False(t, IsRegistrationEmailSuffixAllowed("user@c.cn", []string{"@a.com", "*.b.cn"})) require.True(t, IsRegistrationEmailSuffixAllowed("user@any.com", []string{})) } diff --git a/backend/internal/service/scheduler_snapshot_hydration_test.go b/backend/internal/service/scheduler_snapshot_hydration_test.go index 0b32c2ad..778cab23 100644 --- a/backend/internal/service/scheduler_snapshot_hydration_test.go +++ b/backend/internal/service/scheduler_snapshot_hydration_test.go @@ -114,6 +114,31 @@ func TestOpenAISelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedul } } +func TestOpenAINewAcquiredSelectionResult_ReleasesSlotWhenHydrationFails(t *testing.T) { + cache := &snapshotHydrationCache{ + accounts: map[int64]*Account{}, + } + schedulerSnapshot := NewSchedulerSnapshotService(cache, nil, stubOpenAIAccountRepo{}, nil, nil) + svc := &OpenAIGatewayService{ + schedulerSnapshot: schedulerSnapshot, + } + releaseCalls := 0 + + selection, err := svc.newAcquiredSelectionResult(context.Background(), &Account{ID: 1001}, func() { + releaseCalls++ + }) + + if err == nil { + t.Fatalf("expected hydration error") + } + if selection != nil { + t.Fatalf("expected nil selection on hydration error") + } + if releaseCalls != 1 { + t.Fatalf("expected release to be called once, got %d", releaseCalls) + } +} + func TestGatewaySelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedulerSnapshot(t *testing.T) { cache := &snapshotHydrationCache{ snapshot: []*Account{ diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index bd99e341..5eef2c13 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -597,6 +597,23 @@ func (s *SettingService) SetProxyRepository(repo ProxyRepository) { s.proxyRepo = repo } +func (s *SettingService) LoadAPIKeyACLTrustForwardedIPSetting(ctx context.Context) error { + if s == nil || s.cfg == nil || s.settingRepo == nil { + return nil + } + value, err := s.settingRepo.GetValue(ctx, SettingKeyAPIKeyACLTrustForwardedIP) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + s.cfg.SetTrustForwardedIPForAPIKeyACL(s.cfg.Security.TrustForwardedIPForAPIKeyACL) + return nil + } + return fmt.Errorf("get api key acl forwarded ip setting: %w", err) + } + enabled := value == "true" + s.cfg.SetTrustForwardedIPForAPIKeyACL(enabled) + return nil +} + // GetAllSettings 获取所有系统设置 func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) { settings, err := s.settingRepo.GetAll(ctx) @@ -633,6 +650,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyLoginAgreementDocuments, SettingKeyTurnstileEnabled, SettingKeyTurnstileSiteKey, + SettingKeyAPIKeyACLTrustForwardedIP, SettingKeySiteName, SettingKeySiteLogo, SettingKeySiteSubtitle, @@ -1568,6 +1586,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting if settings.TurnstileSecretKey != "" { updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey } + updates[SettingKeyAPIKeyACLTrustForwardedIP] = strconv.FormatBool(settings.APIKeyACLTrustForwardedIP) // LinuxDo Connect OAuth 登录 updates[SettingKeyLinuxDoConnectEnabled] = strconv.FormatBool(settings.LinuxDoConnectEnabled) @@ -1777,10 +1796,11 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting updates[SettingPaymentVisibleMethodWxpayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodWxpayEnabled) updates[openAIAdvancedSchedulerSettingKey] = strconv.FormatBool(settings.OpenAIAdvancedSchedulerEnabled) - // Balance low notification + // 余额、订阅到期与账号限额通知 updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled) updates[SettingKeyBalanceLowNotifyThreshold] = strconv.FormatFloat(settings.BalanceLowNotifyThreshold, 'f', 8, 64) updates[SettingKeyBalanceLowNotifyRechargeURL] = settings.BalanceLowNotifyRechargeURL + updates[SettingKeySubscriptionExpiryNotifyEnabled] = strconv.FormatBool(settings.SubscriptionExpiryNotifyEnabled) updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled) updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails) @@ -1867,6 +1887,9 @@ func (s *SettingService) refreshCachedSettings(settings *SystemSettings) { enabled: settings.OpenAIAdvancedSchedulerEnabled, expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(), }) + if s.cfg != nil { + s.cfg.SetTrustForwardedIPForAPIKeyACL(settings.APIKeyACLTrustForwardedIP) + } if s.onUpdate != nil { s.onUpdate() // Invalidate cache after settings update } @@ -2463,6 +2486,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyLoginAgreementMode: defaultLoginAgreementMode, SettingKeyLoginAgreementUpdatedAt: defaultLoginAgreementDate, SettingKeyLoginAgreementDocuments: loginAgreementDocumentsJSON, + SettingKeyAPIKeyACLTrustForwardedIP: "false", SettingKeySiteName: "Sub2API", SettingKeySiteLogo: "", SettingKeyPurchaseSubscriptionEnabled: "false", @@ -2622,6 +2646,12 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin if loginAgreementUpdatedAt == "" { loginAgreementUpdatedAt = defaultLoginAgreementDate } + apiKeyACLTrustForwardedIP := false + if value, ok := settings[SettingKeyAPIKeyACLTrustForwardedIP]; ok { + apiKeyACLTrustForwardedIP = value == "true" + } else if s != nil && s.cfg != nil { + apiKeyACLTrustForwardedIP = s.cfg.Security.TrustForwardedIPForAPIKeyACL + } result := &SystemSettings{ RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", EmailVerifyEnabled: emailVerifyEnabled, @@ -2644,6 +2674,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "", + APIKeyACLTrustForwardedIP: apiKeyACLTrustForwardedIP, SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), SiteLogo: settings[SettingKeySiteLogo], SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), @@ -3131,14 +3162,15 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin result.PaymentVisibleMethodWxpayEnabled = settings[SettingPaymentVisibleMethodWxpayEnabled] == "true" result.OpenAIAdvancedSchedulerEnabled = settings[openAIAdvancedSchedulerSettingKey] == "true" - // Balance low notification + // 余额、订阅到期与账号限额通知 result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true" if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 { result.BalanceLowNotifyThreshold = v } result.BalanceLowNotifyRechargeURL = settings[SettingKeyBalanceLowNotifyRechargeURL] + result.SubscriptionExpiryNotifyEnabled = !isFalseSettingValue(settings[SettingKeySubscriptionExpiryNotifyEnabled]) - // Account quota notification + // 账号限额通知 result.AccountQuotaNotifyEnabled = settings[SettingKeyAccountQuotaNotifyEnabled] == "true" if raw := strings.TrimSpace(settings[SettingKeyAccountQuotaNotifyEmails]); raw != "" { result.AccountQuotaNotifyEmails = ParseNotifyEmails(raw) diff --git a/backend/internal/service/setting_service_public_test.go b/backend/internal/service/setting_service_public_test.go index 1ecd4e6f..2faa4d82 100644 --- a/backend/internal/service/setting_service_public_test.go +++ b/backend/internal/service/setting_service_public_test.go @@ -53,14 +53,14 @@ func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelis values: map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyEmailVerifyEnabled: "true", - SettingKeyRegistrationEmailSuffixWhitelist: `["@EXAMPLE.com"," @foo.bar ","@invalid_domain",""]`, + SettingKeyRegistrationEmailSuffixWhitelist: `["@EXAMPLE.com"," @foo.bar ","*.EDU.CN","@invalid_domain",""]`, }, } svc := NewSettingService(repo, &config.Config{}) settings, err := svc.GetPublicSettings(context.Background()) require.NoError(t, err) - require.Equal(t, []string{"@example.com", "@foo.bar"}, settings.RegistrationEmailSuffixWhitelist) + require.Equal(t, []string{"@example.com", "@foo.bar", "*.edu.cn"}, settings.RegistrationEmailSuffixWhitelist) } func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) { diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go index d6b6b6cd..379bf9bc 100644 --- a/backend/internal/service/setting_service_update_test.go +++ b/backend/internal/service/setting_service_update_test.go @@ -213,10 +213,10 @@ func TestSettingService_UpdateSettings_RegistrationEmailSuffixWhitelist_Normaliz svc := NewSettingService(repo, &config.Config{}) err := svc.UpdateSettings(context.Background(), &SystemSettings{ - RegistrationEmailSuffixWhitelist: []string{"example.com", "@EXAMPLE.com", " @foo.bar "}, + RegistrationEmailSuffixWhitelist: []string{"example.com", "@EXAMPLE.com", " @foo.bar ", "*.EDU.CN"}, }) require.NoError(t, err) - require.Equal(t, `["@example.com","@foo.bar"]`, repo.updates[SettingKeyRegistrationEmailSuffixWhitelist]) + require.Equal(t, `["@example.com","@foo.bar","*.edu.cn"]`, repo.updates[SettingKeyRegistrationEmailSuffixWhitelist]) } func TestSettingService_UpdateSettings_RegistrationEmailSuffixWhitelist_Invalid(t *testing.T) { @@ -290,6 +290,30 @@ func TestSettingService_UpdateSettings_AntigravityUserAgentVersion(t *testing.T) require.Equal(t, "1.23.2", repo.updates[SettingKeyAntigravityUserAgentVersion]) } +func TestSettingService_UpdateSettings_APIKeyACLTrustForwardedIPRefreshesConfig(t *testing.T) { + repo := &settingUpdateRepoStub{} + cfg := &config.Config{} + svc := NewSettingService(repo, cfg) + + err := svc.UpdateSettings(context.Background(), &SystemSettings{ + APIKeyACLTrustForwardedIP: true, + }) + require.NoError(t, err) + require.Equal(t, "true", repo.updates[SettingKeyAPIKeyACLTrustForwardedIP]) + require.True(t, cfg.Security.TrustForwardedIPForAPIKeyACL) + require.True(t, cfg.TrustForwardedIPForAPIKeyACL()) +} + +func TestSettingService_ParseSettings_APIKeyACLTrustForwardedIPFallsBackToConfigWhenMissing(t *testing.T) { + cfg := &config.Config{} + cfg.Security.TrustForwardedIPForAPIKeyACL = true + svc := NewSettingService(&settingUpdateRepoStub{}, cfg) + + got := svc.parseSettings(map[string]string{}) + + require.True(t, got.APIKeyACLTrustForwardedIP) +} + func TestSettingService_GetAntigravityUserAgentVersion_Precedence(t *testing.T) { t.Run("后台设置优先", func(t *testing.T) { svc := NewSettingService(&settingAntigravityUARepoStub{values: map[string]string{ diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 1e5e8b1c..c9bea224 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -38,6 +38,7 @@ type SystemSettings struct { TurnstileSiteKey string TurnstileSecretKey string TurnstileSecretKeyConfigured bool + APIKeyACLTrustForwardedIP bool // LinuxDo Connect OAuth 登录 LinuxDoConnectEnabled bool @@ -204,15 +205,18 @@ type SystemSettings struct { PaymentVisibleMethodAlipayEnabled bool PaymentVisibleMethodWxpayEnabled bool - // OpenAI account scheduling + // OpenAI 账号调度 OpenAIAdvancedSchedulerEnabled bool - // Balance low notification + // 余额不足提醒 BalanceLowNotifyEnabled bool BalanceLowNotifyThreshold float64 BalanceLowNotifyRechargeURL string - // Account quota notification + // 订阅到期提醒 + SubscriptionExpiryNotifyEnabled bool + + // 账号限额通知 AccountQuotaNotifyEnabled bool AccountQuotaNotifyEmails []NotifyEmailEntry } diff --git a/backend/internal/service/subscription_expiry_service.go b/backend/internal/service/subscription_expiry_service.go index 9b3a0309..a9ec9042 100644 --- a/backend/internal/service/subscription_expiry_service.go +++ b/backend/internal/service/subscription_expiry_service.go @@ -2,6 +2,7 @@ package service import ( "context" + "errors" "fmt" "log" "strconv" @@ -14,6 +15,7 @@ import ( // SubscriptionExpiryService periodically updates expired subscription status. type SubscriptionExpiryService struct { userSubRepo UserSubscriptionRepository + settingRepo SettingRepository notificationEmailService *NotificationEmailService interval time.Duration stopCh chan struct{} @@ -29,6 +31,10 @@ func NewSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, interv } } +func (s *SubscriptionExpiryService) SetSettingRepository(settingRepo SettingRepository) { + s.settingRepo = settingRepo +} + func (s *SubscriptionExpiryService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) { s.notificationEmailService = notificationEmailService } @@ -84,6 +90,9 @@ func (s *SubscriptionExpiryService) sendExpiryReminders(ctx context.Context) { if s == nil || s.userSubRepo == nil || s.notificationEmailService == nil { return } + if !s.expiryReminderEnabled(ctx) { + return + } for page := 1; ; page++ { subs, pag, err := s.userSubRepo.List(ctx, pagination.PaginationParams{Page: page, PageSize: 200}, nil, nil, SubscriptionStatusActive, "", "expires_at", "asc") if err != nil { @@ -99,6 +108,21 @@ func (s *SubscriptionExpiryService) sendExpiryReminders(ctx context.Context) { } } +func (s *SubscriptionExpiryService) expiryReminderEnabled(ctx context.Context) bool { + if s == nil || s.settingRepo == nil { + return true + } + value, err := s.settingRepo.GetValue(ctx, SettingKeySubscriptionExpiryNotifyEnabled) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + return true + } + log.Printf("[SubscriptionExpiry] Read expiry reminder switch failed: %v", err) + return false + } + return !isFalseSettingValue(value) +} + func (s *SubscriptionExpiryService) sendExpiryReminderIfDue(ctx context.Context, sub *UserSubscription) { if sub == nil || sub.User == nil || sub.Group == nil || sub.User.Email == "" { return diff --git a/backend/internal/service/subscription_expiry_service_test.go b/backend/internal/service/subscription_expiry_service_test.go new file mode 100644 index 00000000..00ae3eb3 --- /dev/null +++ b/backend/internal/service/subscription_expiry_service_test.go @@ -0,0 +1,164 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type subscriptionExpiryRepoStub struct { + listCalls int +} + +func (r *subscriptionExpiryRepoStub) Create(context.Context, *UserSubscription) error { + return nil +} + +func (r *subscriptionExpiryRepoStub) GetByID(context.Context, int64) (*UserSubscription, error) { + return nil, ErrSubscriptionNotFound +} + +func (r *subscriptionExpiryRepoStub) GetByUserIDAndGroupID(context.Context, int64, int64) (*UserSubscription, error) { + return nil, ErrSubscriptionNotFound +} + +func (r *subscriptionExpiryRepoStub) GetActiveByUserIDAndGroupID(context.Context, int64, int64) (*UserSubscription, error) { + return nil, ErrSubscriptionNotFound +} + +func (r *subscriptionExpiryRepoStub) Update(context.Context, *UserSubscription) error { + return nil +} + +func (r *subscriptionExpiryRepoStub) Delete(context.Context, int64) error { + return nil +} + +func (r *subscriptionExpiryRepoStub) ListByUserID(context.Context, int64) ([]UserSubscription, error) { + return nil, nil +} + +func (r *subscriptionExpiryRepoStub) ListActiveByUserID(context.Context, int64) ([]UserSubscription, error) { + return nil, nil +} + +func (r *subscriptionExpiryRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) { + return nil, nil, nil +} + +func (r *subscriptionExpiryRepoStub) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) { + r.listCalls++ + return nil, &pagination.PaginationResult{Page: 1, Pages: 1}, nil +} + +func (r *subscriptionExpiryRepoStub) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) { + return false, nil +} + +func (r *subscriptionExpiryRepoStub) ExtendExpiry(context.Context, int64, time.Time) error { + return nil +} + +func (r *subscriptionExpiryRepoStub) UpdateStatus(context.Context, int64, string) error { + return nil +} + +func (r *subscriptionExpiryRepoStub) UpdateNotes(context.Context, int64, string) error { + return nil +} + +func (r *subscriptionExpiryRepoStub) ActivateWindows(context.Context, int64, time.Time) error { + return nil +} + +func (r *subscriptionExpiryRepoStub) ResetDailyUsage(context.Context, int64, time.Time) error { + return nil +} + +func (r *subscriptionExpiryRepoStub) ResetWeeklyUsage(context.Context, int64, time.Time) error { + return nil +} + +func (r *subscriptionExpiryRepoStub) ResetMonthlyUsage(context.Context, int64, time.Time) error { + return nil +} + +func (r *subscriptionExpiryRepoStub) IncrementUsage(context.Context, int64, float64) error { + return nil +} + +func (r *subscriptionExpiryRepoStub) BatchUpdateExpiredStatus(context.Context) (int64, error) { + return 0, nil +} + +type subscriptionExpirySettingRepoStub struct { + values map[string]string + err error +} + +func (r *subscriptionExpirySettingRepoStub) Get(context.Context, string) (*Setting, error) { + return nil, ErrSettingNotFound +} + +func (r *subscriptionExpirySettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if r.err != nil { + return "", r.err + } + value, ok := r.values[key] + if !ok { + return "", ErrSettingNotFound + } + return value, nil +} + +func (r *subscriptionExpirySettingRepoStub) Set(context.Context, string, string) error { + return nil +} + +func (r *subscriptionExpirySettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) { + return nil, nil +} + +func (r *subscriptionExpirySettingRepoStub) SetMultiple(context.Context, map[string]string) error { + return nil +} + +func (r *subscriptionExpirySettingRepoStub) GetAll(context.Context) (map[string]string, error) { + return nil, nil +} + +func (r *subscriptionExpirySettingRepoStub) Delete(context.Context, string) error { + return nil +} + +func TestSubscriptionExpiryService_ExpiryReminderEnabledDefaultsToTrue(t *testing.T) { + svc := NewSubscriptionExpiryService(nil, time.Minute) + svc.SetSettingRepository(&subscriptionExpirySettingRepoStub{values: map[string]string{}}) + + require.True(t, svc.expiryReminderEnabled(context.Background())) +} + +func TestSubscriptionExpiryService_ExpiryReminderDisabledSkipsSubscriptionScan(t *testing.T) { + repo := &subscriptionExpiryRepoStub{} + settingRepo := &subscriptionExpirySettingRepoStub{ + values: map[string]string{SettingKeySubscriptionExpiryNotifyEnabled: "false"}, + } + svc := NewSubscriptionExpiryService(repo, time.Minute) + svc.SetSettingRepository(settingRepo) + svc.SetNotificationEmailService(NewNotificationEmailService(settingRepo, nil)) + + svc.sendExpiryReminders(context.Background()) + + require.Zero(t, repo.listCalls) +} + +func TestSubscriptionExpiryService_ExpiryReminderSettingReadErrorFailsClosed(t *testing.T) { + svc := NewSubscriptionExpiryService(nil, time.Minute) + svc.SetSettingRepository(&subscriptionExpirySettingRepoStub{err: errors.New("db down")}) + + require.False(t, svc.expiryReminderEnabled(context.Background())) +} diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 22f4aa29..c5f1991b 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -27,6 +27,7 @@ type TokenRefreshService struct { schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题 tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存 refreshAPI *OAuthRefreshAPI // 统一刷新 API + runtimeBlocker AccountRuntimeBlocker // OpenAI privacy: 刷新成功后检查并设置 training opt-out privacyClientFactory PrivacyClientFactory @@ -100,6 +101,24 @@ func (s *TokenRefreshService) SetRefreshPolicy(policy BackgroundRefreshPolicy) { s.refreshPolicy = policy } +func (s *TokenRefreshService) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) { + s.runtimeBlocker = blocker +} + +func (s *TokenRefreshService) notifyAccountSchedulingBlocked(account *Account, until time.Time, reason string) { + if s == nil || s.runtimeBlocker == nil || account == nil { + return + } + s.runtimeBlocker.BlockAccountScheduling(account, until, reason) +} + +func (s *TokenRefreshService) notifyAccountSchedulingBlockCleared(accountID int64) { + if s == nil || s.runtimeBlocker == nil || accountID <= 0 { + return + } + s.runtimeBlocker.ClearAccountSchedulingBlock(accountID) +} + // Start 启动后台刷新服务 func (s *TokenRefreshService) Start() { if !s.cfg.Enabled { @@ -284,6 +303,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc // 不可重试错误(invalid_grant/invalid_client 等)直接标记 error 状态并返回 if isNonRetryableRefreshError(err) { errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err) + s.notifyAccountSchedulingBlocked(account, time.Time{}, "token_refresh_non_retryable") if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil { slog.Error("token_refresh.set_error_status_failed", "account_id", account.ID, @@ -327,6 +347,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc // 设置临时不可调度 10 分钟(不标记 error,保持 status=active 让下个刷新周期能继续尝试) until := time.Now().Add(tokenRefreshTempUnschedDuration) reason := fmt.Sprintf("token refresh retry exhausted: %v", lastErr) + s.notifyAccountSchedulingBlocked(account, until, "token_refresh_retry_exhausted") if setErr := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); setErr != nil { slog.Warn("token_refresh.set_temp_unschedulable_failed", "account_id", account.ID, @@ -355,6 +376,7 @@ func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *A ) } else { slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID) + s.notifyAccountSchedulingBlockCleared(account.ID) } } // 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景) @@ -366,6 +388,7 @@ func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *A ) } else { slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID) + s.notifyAccountSchedulingBlockCleared(account.ID) } // 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态 if s.tempUnschedCache != nil { @@ -417,11 +440,12 @@ func isNonRetryableRefreshError(err error) bool { } msg := strings.ToLower(err.Error()) nonRetryable := []string{ - "invalid_grant", // refresh_token 已失效 - "invalid_client", // 客户端配置错误 - "unauthorized_client", // 客户端未授权 - "access_denied", // 访问被拒绝 - "missing_project_id", // 缺少 project_id + "invalid_grant", // refresh_token 已失效 + "refresh_token_reused", // OpenAI refresh_token 已被使用,必须重新授权 + "invalid_client", // 客户端配置错误 + "unauthorized_client", // 客户端未授权 + "access_denied", // 访问被拒绝 + "missing_project_id", // 缺少 project_id "no refresh token available", } for _, needle := range nonRetryable { diff --git a/backend/internal/service/token_refresh_service_test.go b/backend/internal/service/token_refresh_service_test.go index 2179a85e..d80b475e 100644 --- a/backend/internal/service/token_refresh_service_test.go +++ b/backend/internal/service/token_refresh_service_test.go @@ -532,6 +532,7 @@ func TestIsNonRetryableRefreshError(t *testing.T) { {name: "network_error", err: errors.New("network timeout"), expected: false}, {name: "invalid_grant", err: errors.New("invalid_grant"), expected: true}, {name: "invalid_client", err: errors.New("invalid_client"), expected: true}, + {name: "refresh_token_reused", err: errors.New(`OPENAI_OAUTH_TOKEN_REFRESH_FAILED: token refresh failed: status 401, body: {"error":{"code":"refresh_token_reused"}}`), expected: true}, {name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true}, {name: "access_denied", err: errors.New("access_denied"), expected: true}, {name: "no_refresh_token", err: errors.New("no refresh token available"), expected: true}, diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index b2309e31..94eb5d20 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -62,6 +62,7 @@ func ProvideTokenRefreshService( privacyClientFactory PrivacyClientFactory, proxyRepo ProxyRepository, refreshAPI *OAuthRefreshAPI, + runtimeBlocker AccountRuntimeBlocker, ) *TokenRefreshService { svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache) // 注入 OpenAI privacy opt-out 依赖 @@ -70,6 +71,7 @@ func ProvideTokenRefreshService( svc.SetRefreshAPI(refreshAPI) // 调用侧显式注入后台刷新策略,避免策略漂移 svc.SetRefreshPolicy(DefaultBackgroundRefreshPolicy()) + svc.SetAccountRuntimeBlocker(runtimeBlocker) svc.Start() return svc } @@ -154,8 +156,9 @@ func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpirySe } // ProvideSubscriptionExpiryService creates and starts SubscriptionExpiryService. -func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, notificationEmailService *NotificationEmailService) *SubscriptionExpiryService { +func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, settingRepo SettingRepository, notificationEmailService *NotificationEmailService) *SubscriptionExpiryService { svc := NewSubscriptionExpiryService(userSubRepo, time.Minute) + svc.SetSettingRepository(settingRepo) svc.SetNotificationEmailService(notificationEmailService) svc.Start() return svc @@ -185,6 +188,7 @@ func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountReposi logger.LegacyPrintf("service.concurrency", "Warning: startup cleanup stale process slots failed: %v", err) } if cfg != nil { + svc.SetAccountLoadBatchCacheTTL(time.Duration(cfg.Gateway.Scheduling.LoadBatchCacheTTLMS) * time.Millisecond) svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval) } return svc @@ -400,6 +404,9 @@ func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupReposit svc := NewSettingService(settingRepo, cfg) svc.SetDefaultSubscriptionGroupReader(groupRepo) svc.SetProxyRepository(proxyRepo) + if err := svc.LoadAPIKeyACLTrustForwardedIPSetting(context.Background()); err != nil { + logger.LegacyPrintf("service.setting", "Warning: load api key acl forwarded ip setting failed: %v", err) + } antigravity.SetUserAgentVersionResolver(svc.GetAntigravityUserAgentVersion) return svc } @@ -455,6 +462,7 @@ var ProviderSet = wire.NewSet( NewRPMTokenBucketService, NewGatewayService, NewOpenAIGatewayService, + wire.Bind(new(AccountRuntimeBlocker), new(*OpenAIGatewayService)), NewOAuthService, NewOpenAIOAuthService, NewGeminiOAuthService, diff --git a/backend/migrations/140_extend_user_provider_default_grants_check.sql b/backend/migrations/140_extend_user_provider_default_grants_check.sql new file mode 100644 index 00000000..c739e6e0 --- /dev/null +++ b/backend/migrations/140_extend_user_provider_default_grants_check.sql @@ -0,0 +1,12 @@ +-- 修复:user_provider_default_grants 表的 provider_type check 约束 +-- 与 auth_identities / auth_identity_channels / pending_auth_sessions 保持一致, +-- 否则启用了 auth_source_default_{github,google,dingtalk}_grant_on_first_bind +-- 之后,OAuth 首次绑定流程会因约束违反而失败。 +-- 参见 migrations 135、136 漏改本表。 + +ALTER TABLE user_provider_default_grants + DROP CONSTRAINT IF EXISTS user_provider_default_grants_provider_type_check; + +ALTER TABLE user_provider_default_grants + ADD CONSTRAINT user_provider_default_grants_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google', 'dingtalk')); diff --git a/backend/migrations/141_subscription_expiry_notify_enabled.sql b/backend/migrations/141_subscription_expiry_notify_enabled.sql new file mode 100644 index 00000000..37043806 --- /dev/null +++ b/backend/migrations/141_subscription_expiry_notify_enabled.sql @@ -0,0 +1,4 @@ +-- 订阅到期提醒邮件开关,默认保持历史行为:开启。 +INSERT INTO settings (key, value, updated_at) +VALUES ('subscription_expiry_notify_enabled', 'true', NOW()) +ON CONFLICT (key) DO NOTHING; diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 0d61b710..8e9b0e3b 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -405,6 +405,9 @@ gateway: # Enable batch load calculation for scheduling # 启用调度批量负载计算 load_batch_enabled: true + # Tiny in-process TTL for batch load reads in milliseconds (0 disables) + # 调度批量负载读取的进程内短缓存 TTL(毫秒,0 表示禁用) + load_batch_cache_ttl_ms: 200 # Slot cleanup interval (duration) # 并发槽位清理周期(时间段) slot_cleanup_interval: 30s diff --git a/frontend/package.json b/frontend/package.json index ec2b5942..7ca264d2 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -57,5 +57,10 @@ "vite-plugin-checker": "^0.9.1", "vitest": "^2.1.9", "vue-tsc": "^2.2.0" + }, + "pnpm": { + "overrides": { + "js-cookie": "3.0.7" + } } } diff --git a/frontend/pnpm-lock.yaml b/frontend/pnpm-lock.yaml index a401b57d..b7b0df0e 100644 --- a/frontend/pnpm-lock.yaml +++ b/frontend/pnpm-lock.yaml @@ -4,6 +4,9 @@ settings: autoInstallPeers: true excludeLinksFromLockfile: false +overrides: + js-cookie: 3.0.7 + importers: .: @@ -2882,9 +2885,9 @@ packages: engines: {node: '>=14'} hasBin: true - js-cookie@3.0.5: - resolution: {integrity: sha512-cEiJEAEoIbWfCZYKWhVwFuvPX1gETRYPw6LlaTKoxD3s2AkXzkCjnp6h0V77ozyqj0jakteJ4YqDJT830+lVGw==} - engines: {node: '>=14'} + js-cookie@3.0.7: + resolution: {integrity: sha512-z/wZZgDrkNV1eA0ULjM/F9/50Ya8fbzgKneSpoPsXSGd0KnpdtHfOZWK+GcwLk+EZbS4F9RBhU+K2RgzuDaItw==} + engines: {node: '>=20'} js-tokens@4.0.0: resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==} @@ -6365,7 +6368,7 @@ snapshots: '@types/js-cookie': 3.0.6 dayjs: 1.11.20 intersection-observer: 0.12.2 - js-cookie: 3.0.5 + js-cookie: 3.0.7 lodash: 4.18.1 react: 19.2.3 react-dom: 19.2.3(react@19.2.3) @@ -7656,10 +7659,10 @@ snapshots: config-chain: 1.1.13 editorconfig: 1.0.4 glob: 10.5.0 - js-cookie: 3.0.5 + js-cookie: 3.0.7 nopt: 7.2.1 - js-cookie@3.0.5: {} + js-cookie@3.0.7: {} js-tokens@4.0.0: {} diff --git a/frontend/src/api/admin/redeem.ts b/frontend/src/api/admin/redeem.ts index 398d68a4..777bb110 100644 --- a/frontend/src/api/admin/redeem.ts +++ b/frontend/src/api/admin/redeem.ts @@ -7,6 +7,7 @@ import { apiClient } from '../client' import type { RedeemCode, GenerateRedeemCodesRequest, + BatchUpdateRedeemCodeFields, RedeemCodeType, PaginatedResponse } from '@/types' @@ -23,7 +24,7 @@ export async function list( pageSize: number = 20, filters?: { type?: RedeemCodeType - status?: 'active' | 'used' | 'expired' | 'unused' + status?: 'active' | 'used' | 'expired' | 'unused' | 'disabled' search?: string sort_by?: string sort_order?: 'asc' | 'desc' @@ -118,6 +119,26 @@ export async function batchDelete(ids: number[]): Promise<{ return data } +/** + * Batch update selected redeem code fields + * @param ids - Array of redeem code IDs + * @param fields - Field collection to update + * @returns Updated count + */ +export async function batchUpdate( + ids: number[], + fields: BatchUpdateRedeemCodeFields +): Promise<{ + updated: number + message: string +}> { + const { data } = await apiClient.post<{ + updated: number + message: string + }>('/admin/redeem-codes/batch-update', { ids, fields }) + return data +} + /** * Expire redeem code * @param id - Redeem code ID @@ -158,7 +179,7 @@ export async function getStats(): Promise<{ */ export async function exportCodes(filters?: { type?: RedeemCodeType - status?: 'used' | 'expired' | 'unused' + status?: 'used' | 'expired' | 'unused' | 'disabled' search?: string sort_by?: string sort_order?: 'asc' | 'desc' @@ -176,6 +197,7 @@ export const redeemAPI = { generate, delete: deleteCode, batchDelete, + batchUpdate, expire, getStats, exportCodes diff --git a/frontend/src/api/admin/riskControl.ts b/frontend/src/api/admin/riskControl.ts index 4dad1f58..fbba96be 100644 --- a/frontend/src/api/admin/riskControl.ts +++ b/frontend/src/api/admin/riskControl.ts @@ -2,6 +2,12 @@ import { apiClient } from '../client' export type ModerationMode = 'off' | 'observe' | 'pre_block' export type KeywordBlockingMode = 'keyword_only' | 'keyword_and_api' | 'api_only' +export type ContentModerationModelFilterType = 'all' | 'include' | 'exclude' + +export interface ContentModerationModelFilter { + type: ContentModerationModelFilterType + models: string[] +} export interface ContentModerationConfig { enabled: boolean @@ -32,6 +38,7 @@ export interface ContentModerationConfig { pre_hash_check_enabled: boolean blocked_keywords: string[] keyword_blocking_mode: KeywordBlockingMode + model_filter: ContentModerationModelFilter } export type ContentModerationAPIKeyStatusValue = 'unknown' | 'ok' | 'error' | 'frozen' @@ -105,6 +112,7 @@ export interface UpdateContentModerationConfig { pre_hash_check_enabled?: boolean blocked_keywords?: string[] keyword_blocking_mode?: KeywordBlockingMode + model_filter?: ContentModerationModelFilter } export interface ContentModerationRuntimeStatus { diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 5632325e..7374d8d3 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -396,6 +396,7 @@ export interface SystemSettings { turnstile_enabled: boolean; turnstile_site_key: string; turnstile_secret_key_configured: boolean; + api_key_acl_trust_forwarded_ip: boolean; // LinuxDo Connect OAuth settings linuxdo_connect_enabled: boolean; @@ -537,10 +538,11 @@ export interface SystemSettings { payment_visible_method_wxpay_enabled?: boolean; openai_advanced_scheduler_enabled?: boolean; - // Balance & quota notification + // 余额、订阅到期与账号限额通知 balance_low_notify_enabled: boolean; balance_low_notify_threshold: number; balance_low_notify_recharge_url: string; + subscription_expiry_notify_enabled: boolean; account_quota_notify_enabled: boolean; account_quota_notify_emails: NotifyEmailEntry[]; @@ -638,6 +640,7 @@ export interface UpdateSettingsRequest { turnstile_enabled?: boolean; turnstile_site_key?: string; turnstile_secret_key?: string; + api_key_acl_trust_forwarded_ip?: boolean; linuxdo_connect_enabled?: boolean; linuxdo_connect_client_id?: string; linuxdo_connect_client_secret?: string; @@ -756,10 +759,11 @@ export interface UpdateSettingsRequest { payment_visible_method_alipay_enabled?: boolean; payment_visible_method_wxpay_enabled?: boolean; openai_advanced_scheduler_enabled?: boolean; - // Balance & quota notification + // 余额、订阅到期与账号限额通知 balance_low_notify_enabled?: boolean; balance_low_notify_threshold?: number; balance_low_notify_recharge_url?: string; + subscription_expiry_notify_enabled?: boolean; account_quota_notify_enabled?: boolean; account_quota_notify_emails?: NotifyEmailEntry[]; @@ -862,6 +866,8 @@ export interface EmailTemplateOption { value: string; label?: string; description?: string; + category?: string; + optional?: boolean; } export type EmailTemplateEventOption = string | EmailTemplateOption; diff --git a/frontend/src/components/account/AccountTestModal.vue b/frontend/src/components/account/AccountTestModal.vue index 8f81789d..932e5233 100644 --- a/frontend/src/components/account/AccountTestModal.vue +++ b/frontend/src/components/account/AccountTestModal.vue @@ -513,6 +513,12 @@ const handleEvent = (event: { } break + case 'status': + if (event.text) { + addLine(event.text, 'text-cyan-300') + } + break + case 'image': if (event.image_url) { generatedImages.value.push({ diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 24fc90b5..d60b5a04 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2550,7 +2550,10 @@
- +
+ + +
@@ -3268,6 +3271,7 @@ import Select from '@/components/common/Select.vue' import Icon from '@/components/icons/Icon.vue' import PlatformIcon from '@/components/common/PlatformIcon.vue' import ProxySelector from '@/components/common/ProxySelector.vue' +import ProxyAdBanner from '@/components/common/ProxyAdBanner.vue' import GroupSelector from '@/components/common/GroupSelector.vue' import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue' import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue' diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 9e664764..070887fe 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -1258,7 +1258,10 @@
- +
+ + +
@@ -2224,6 +2227,7 @@ import ConfirmDialog from '@/components/common/ConfirmDialog.vue' import Select from '@/components/common/Select.vue' import Icon from '@/components/icons/Icon.vue' import ProxySelector from '@/components/common/ProxySelector.vue' +import ProxyAdBanner from '@/components/common/ProxyAdBanner.vue' import GroupSelector from '@/components/common/GroupSelector.vue' import ModelWhitelistSelector from '@/components/account/ModelWhitelistSelector.vue' import QuotaLimitCard from '@/components/account/QuotaLimitCard.vue' diff --git a/frontend/src/components/account/__tests__/AccountTestModal.spec.ts b/frontend/src/components/account/__tests__/AccountTestModal.spec.ts index c82a3840..9670b521 100644 --- a/frontend/src/components/account/__tests__/AccountTestModal.spec.ts +++ b/frontend/src/components/account/__tests__/AccountTestModal.spec.ts @@ -147,4 +147,46 @@ describe('AccountTestModal', () => { mode: 'compact' }) }) + + it('renders Chat Completions path status from test SSE', async () => { + const encoder = new TextEncoder() + const chunks = [ + encoder.encode('data: {"type":"status","text":"已通过 /v1/chat/completions 验证"}\n\n'), + encoder.encode('data: {"type":"test_complete","success":true}\n\n') + ] + global.fetch = vi.fn().mockResolvedValue({ + ok: true, + body: { + getReader: () => ({ + read: vi.fn().mockImplementation(() => Promise.resolve( + chunks.length > 0 + ? { done: false, value: chunks.shift() } + : { done: true, value: undefined } + )) + }) + } + } as any) + + const wrapper = mount(AccountTestModal, { + props: { + show: true, + account: buildAccount() + }, + global: { + stubs: { + BaseDialog: BaseDialogStub, + Select: SelectStub, + TextArea: TextAreaStub, + Icon: true + } + } + }) + + await flushPromises() + ;(wrapper.vm as any).selectedModelId = 'gpt-5.4' + await (wrapper.vm as any).startTest() + await flushPromises() + + expect(wrapper.text()).toContain('已通过 /v1/chat/completions 验证') + }) }) diff --git a/frontend/src/components/charts/TokenUsageTrend.vue b/frontend/src/components/charts/TokenUsageTrend.vue index 4cd126b9..87198995 100644 --- a/frontend/src/components/charts/TokenUsageTrend.vue +++ b/frontend/src/components/charts/TokenUsageTrend.vue @@ -109,8 +109,8 @@ const chartData = computed(() => { { label: 'Cache Hit Rate', data: props.trendData.map((d) => { - const total = d.cache_read_tokens + d.cache_creation_tokens - return total > 0 ? (d.cache_read_tokens / total) * 100 : 0 + const totalPromptTokens = d.input_tokens + d.cache_read_tokens + d.cache_creation_tokens + return totalPromptTokens > 0 ? (d.cache_read_tokens / totalPromptTokens) * 100 : 0 }), borderColor: chartColors.value.cacheHitRate, backgroundColor: `${chartColors.value.cacheHitRate}20`, diff --git a/frontend/src/components/charts/__tests__/TokenUsageTrend.spec.ts b/frontend/src/components/charts/__tests__/TokenUsageTrend.spec.ts new file mode 100644 index 00000000..5fdac9c9 --- /dev/null +++ b/frontend/src/components/charts/__tests__/TokenUsageTrend.spec.ts @@ -0,0 +1,120 @@ +import { describe, expect, it, vi } from 'vitest' +import { mount } from '@vue/test-utils' + +import TokenUsageTrend from '../TokenUsageTrend.vue' + +const messages: Record = { + 'admin.dashboard.tokenUsageTrend': 'Token Usage Trend', + 'admin.dashboard.noDataAvailable': 'No data available', +} + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string) => messages[key] ?? key, + }), + } +}) + +vi.mock('vue-chartjs', () => ({ + Line: { + props: ['data', 'options'], + template: '
{{ JSON.stringify(data) }}
', + }, +})) + +describe('TokenUsageTrend', () => { + it('calculates cache hit rate against all prompt tokens', () => { + const wrapper = mount(TokenUsageTrend, { + props: { + trendData: [ + { + date: '2026-05-08', + requests: 1, + input_tokens: 500, + output_tokens: 100, + cache_creation_tokens: 0, + cache_read_tokens: 1500, + cost: 0.01, + actual_cost: 0.005, + }, + ], + }, + global: { + stubs: { + LoadingSpinner: true, + }, + }, + }) + + const chartData = JSON.parse(wrapper.find('.chart-data').text()) + const hitRateDataset = chartData.datasets.find( + (ds: any) => ds.label === 'Cache Hit Rate' + ) + // Hit rate = 1500 / (500 + 1500 + 0) * 100 = 75% + expect(hitRateDataset.data[0]).toBe(75) + }) + + it('returns 0 hit rate when all prompt tokens are zero', () => { + const wrapper = mount(TokenUsageTrend, { + props: { + trendData: [ + { + date: '2026-05-08', + requests: 0, + input_tokens: 0, + output_tokens: 0, + cache_creation_tokens: 0, + cache_read_tokens: 0, + cost: 0, + actual_cost: 0, + }, + ], + }, + global: { + stubs: { + LoadingSpinner: true, + }, + }, + }) + + const chartData = JSON.parse(wrapper.find('.chart-data').text()) + const hitRateDataset = chartData.datasets.find( + (ds: any) => ds.label === 'Cache Hit Rate' + ) + expect(hitRateDataset.data[0]).toBe(0) + }) + + it('includes cache_creation_tokens in denominator for Anthropic models', () => { + const wrapper = mount(TokenUsageTrend, { + props: { + trendData: [ + { + date: '2026-05-08', + requests: 1, + input_tokens: 200, + output_tokens: 50, + cache_creation_tokens: 300, + cache_read_tokens: 500, + cost: 0.02, + actual_cost: 0.01, + }, + ], + }, + global: { + stubs: { + LoadingSpinner: true, + }, + }, + }) + + const chartData = JSON.parse(wrapper.find('.chart-data').text()) + const hitRateDataset = chartData.datasets.find( + (ds: any) => ds.label === 'Cache Hit Rate' + ) + // Hit rate = 500 / (200 + 500 + 300) * 100 = 50% + expect(hitRateDataset.data[0]).toBe(50) + }) +}) \ No newline at end of file diff --git a/frontend/src/components/common/ProxyAdBanner.vue b/frontend/src/components/common/ProxyAdBanner.vue new file mode 100644 index 00000000..52e107fb --- /dev/null +++ b/frontend/src/components/common/ProxyAdBanner.vue @@ -0,0 +1,18 @@ + + + diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 661c9b15..7a8bb607 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -424,6 +424,7 @@ export default { emailSuffixNotAllowed: 'This email domain is not allowed for registration.', emailSuffixNotAllowedWithAllowed: 'This email domain is not allowed. Allowed domains: {suffixes}', + emailSuffixAllowedMore: 'and {count} more', loginSuccess: 'Login successful! Welcome back.', accountCreatedSuccess: 'Account created successfully! Welcome to {siteName}.', reloginRequired: 'Session expired. Please log in again.', @@ -2360,6 +2361,8 @@ export default { webSearchEmulationGlobalDisabled: 'Please enable the global switch first in Settings → Gateway → Web Search Emulation', codexImageGenerationBridge: 'Codex Image Generation Bridge', codexImageGenerationBridgeHint: 'When enabled, Codex /responses text requests in OpenAI groups may be automatically given the image_generation tool. Keep off unless the routed accounts support image generation.', + bedrockCCCompat: 'Bedrock CC Compatibility', + bedrockCCCompatHint: '⚠️ When enabled, requests to Bedrock accounts in this channel will be transformed for Claude Code compatibility (thinking type conversion, tool_use ID sanitization).', basicSettings: 'Basic Settings', addPlatform: 'Add Platform', noPlatforms: 'Click "Add Platform" to start configuring the channel', @@ -2532,6 +2535,20 @@ export default { selectedGroups: 'Selected Groups', searchGroups: 'Search group name or platform', noGroups: 'No groups available', + modelFilter: 'Model scope', + modelFilterHint: 'Moderate by the client-requested model name; channel model mappings do not change this match.', + modelFilterAll: 'All models', + modelFilterAllDesc: 'All model requests go through content moderation.', + modelFilterInclude: 'Only selected', + modelFilterIncludeDesc: 'Only listed models go through content moderation.', + modelFilterExclude: 'Exclude selected', + modelFilterExcludeDesc: 'Listed models skip content moderation; other models are moderated.', + modelFilterModels: 'Model list', + modelFilterModelCount: '{count} models configured', + modelFilterModelsRequired: 'This model scope requires at least 1 model', + modelFilterAllSummary: 'Applies to all models', + modelFilterIncludeSummary: 'Applies to {count} models', + modelFilterExcludeSummary: 'Excludes {count} models', emptyLogs: 'No audit records', workerStatus: 'Worker Runtime', workerStatusHint: 'Queue and worker pool status for asynchronous observation tasks.', @@ -4012,6 +4029,9 @@ export default { createProxy: 'Create Proxy', editProxy: 'Edit Proxy', deleteProxy: 'Delete Proxy', + ad: { + inline: 'Need proxy IP?' + }, dataImport: 'Import', dataExportSelected: 'Export Selected', dataImportTitle: 'Import Proxies', @@ -5307,9 +5327,9 @@ export default { emailVerificationHint: 'Require email verification for new registrations', emailSuffixWhitelist: 'Email Domain Whitelist', emailSuffixWhitelistHint: - "Only email addresses from the specified domains can register (for example, {'@'}qq.com, {'@'}gmail.com)", - emailSuffixWhitelistPlaceholder: 'example.com', - emailSuffixWhitelistInputHint: 'Leave empty for no restriction', + "Only email addresses from the specified domains can register (for example, {'@'}qq.com, {'@'}gmail.com, *.edu.cn)", + emailSuffixWhitelistPlaceholder: "{'@'}example.com, *.edu.cn", + emailSuffixWhitelistInputHint: 'Leave empty for no restriction. Use *.edu.cn to match edu.cn and its subdomains.', promoCode: 'Promo Code', promoCodeHint: 'Allow users to use promo codes during registration', invitationCode: 'Invitation Code Registration', @@ -5334,7 +5354,15 @@ export default { siteKeyHint: 'Get this from your Cloudflare Dashboard', cloudflareDashboard: 'Cloudflare Dashboard', secretKeyHint: 'Server-side verification key (keep this secret)', - secretKeyConfiguredHint: 'Secret key configured. Leave empty to keep the current value.' }, + secretKeyConfiguredHint: 'Secret key configured. Leave empty to keep the current value.' + }, + apiKeyAcl: { + title: 'API Key IP Access Control', + description: 'Choose which client IP is used by API Key allowlists and denylists', + trustForwardedIp: 'Trust forwarded client IP', + trustForwardedIpHint: + 'Disabled by default. Enable only when the origin is reachable only through Cloudflare or Nginx reverse proxy. When enabled, API Key IP allowlists and denylists use CF-Connecting-IP, X-Real-IP, or X-Forwarded-For, matching the request IP shown in usage records.' + }, linuxdo: { title: 'LinuxDo Connect Login', description: 'Configure LinuxDo Connect OAuth for Sub2API end-user login', @@ -5812,6 +5840,12 @@ export default { addEmail: 'Add Email', emailPlaceholder: 'Enter email address', }, + subscriptionExpiryNotify: { + title: 'Subscription Expiry Reminder', + description: 'Control whether users receive subscription expiry reminder emails.', + enabled: 'Enable Subscription Expiry Reminder', + enabledHint: 'When enabled, the system sends reminders 7, 3, and 1 day before expiry.' + }, smtp: { title: 'SMTP Settings', description: 'Configure email sending for verification codes', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index f50e218b..b23caa8a 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -423,6 +423,7 @@ export default { registrationFailed: '注册失败,请重试。', emailSuffixNotAllowed: '该邮箱域名不在允许注册范围内。', emailSuffixNotAllowedWithAllowed: '该邮箱域名不被允许。可用域名:{suffixes}', + emailSuffixAllowedMore: '等 {count} 项', loginSuccess: '登录成功!欢迎回来。', accountCreatedSuccess: '账户创建成功!欢迎使用 {siteName}。', reloginRequired: '会话已过期,请重新登录。', @@ -2437,6 +2438,8 @@ export default { webSearchEmulationGlobalDisabled: '请先在系统设置 → 网关 → Web Search 模拟中启用全局开关', codexImageGenerationBridge: 'Codex 图片生成桥接', codexImageGenerationBridgeHint: '开启后,OpenAI 分组的 Codex /responses 文本请求可能会被自动注入 image_generation 工具。仅在路由账号支持图片生成时开启。', + bedrockCCCompat: 'Bedrock CC 兼容', + bedrockCCCompatHint: '⚠️ 开启后,该渠道下 Bedrock 账号的请求将进行 Claude Code 兼容处理(thinking 类型转换、tool_use ID 清理)', basicSettings: '基础设置', addPlatform: '添加平台', noPlatforms: '点击"添加平台"开始配置渠道', @@ -2609,6 +2612,20 @@ export default { selectedGroups: '指定分组', searchGroups: '搜索分组名称或平台', noGroups: '暂无可用分组', + modelFilter: '模型范围', + modelFilterHint: '按客户端请求的模型名决定是否执行内容审计,模型映射后仍以请求模型判断。', + modelFilterAll: '所有模型', + modelFilterAllDesc: '所有模型请求都会进入内容审计。', + modelFilterInclude: '仅指定模型', + modelFilterIncludeDesc: '只有列表中的模型会执行内容审计。', + modelFilterExclude: '排除指定模型', + modelFilterExcludeDesc: '列表中的模型跳过内容审计,其余模型执行审计。', + modelFilterModels: '模型列表', + modelFilterModelCount: '已配置 {count} 个模型', + modelFilterModelsRequired: '当前模型范围至少需要配置 1 个模型', + modelFilterAllSummary: '全部模型生效', + modelFilterIncludeSummary: '仅 {count} 个模型生效', + modelFilterExcludeSummary: '排除 {count} 个模型', emptyLogs: '暂无审核记录', workerStatus: 'Worker 运行状态', workerStatusHint: '异步观察任务的队列和 worker 池状态。', @@ -4107,6 +4124,9 @@ export default { createProxy: '添加代理', editProxy: '编辑代理', deleteProxy: '删除代理', + ad: { + inline: '正在寻找合适的代理 IP?' + }, deleteConfirmMessage: "确定要删除代理 '{name}' 吗?", testProxy: '测试代理', dataImport: '导入', @@ -5470,9 +5490,9 @@ export default { emailVerificationHint: '新用户注册时需要验证邮箱', emailSuffixWhitelist: '邮箱域名白名单', emailSuffixWhitelistHint: - "仅允许使用指定域名的邮箱注册账号(例如 {'@'}qq.com, {'@'}gmail.com)", - emailSuffixWhitelistPlaceholder: 'example.com', - emailSuffixWhitelistInputHint: '留空则不限制', + "仅允许使用指定域名的邮箱注册账号(例如 {'@'}qq.com, {'@'}gmail.com, *.edu.cn)", + emailSuffixWhitelistPlaceholder: "{'@'}example.com, *.edu.cn", + emailSuffixWhitelistInputHint: '留空则不限制。使用 *.edu.cn 可匹配 edu.cn 及其子域名。', promoCode: '优惠码', promoCodeHint: '允许用户在注册时使用优惠码', invitationCode: '邀请码注册', @@ -5499,6 +5519,13 @@ export default { secretKeyHint: '服务端验证密钥(请保密)', secretKeyConfiguredHint: '密钥已配置,留空以保留当前值。' }, + apiKeyAcl: { + title: 'API Key IP 访问控制', + description: '控制 API Key 白名单和黑名单使用哪个客户端 IP 判断', + trustForwardedIp: '信任反代传递的客户端 IP', + trustForwardedIpHint: + '默认关闭。仅在源站只允许 Cloudflare 或 Nginx 反代访问时开启;开启后 API Key IP 白/黑名单会使用 CF-Connecting-IP、X-Real-IP 或 X-Forwarded-For,与使用记录中的请求 IP 保持一致。' + }, linuxdo: { title: 'LinuxDo Connect 登录', description: '配置 LinuxDo Connect OAuth,用于 Sub2API 用户登录', @@ -5972,6 +5999,12 @@ export default { addEmail: '添加邮箱', emailPlaceholder: '输入邮箱地址', }, + subscriptionExpiryNotify: { + title: '订阅到期提醒', + description: '控制是否向用户发送订阅即将到期的邮件提醒。', + enabled: '启用订阅到期提醒', + enabledHint: '开启后,系统会在订阅到期前 7 天、3 天、1 天各发送一次提醒。' + }, smtp: { title: 'SMTP 设置', description: '配置用于发送验证码的邮件服务', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index fbdee743..68162e53 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -1477,12 +1477,13 @@ export interface RedeemCode { code: string type: RedeemCodeType value: number - status: 'active' | 'used' | 'expired' | 'unused' + status: 'active' | 'used' | 'expired' | 'unused' | 'disabled' used_by: number | null used_at: string | null created_at: string expires_at?: string | null updated_at?: string + notes?: string group_id?: number | null // 订阅类型专用 validity_days?: number // 订阅类型专用 user?: User @@ -1499,6 +1500,18 @@ export interface GenerateRedeemCodesRequest { expires_in_days?: number } +export interface BatchUpdateRedeemCodeFields { + status?: 'unused' | 'disabled' + expires_at?: string | null + notes?: string + group_id?: number | null +} + +export interface BatchUpdateRedeemCodesRequest { + ids: number[] + fields: BatchUpdateRedeemCodeFields +} + export interface RedeemCodeRequest { code: string } diff --git a/frontend/src/utils/__tests__/registrationEmailPolicy.spec.ts b/frontend/src/utils/__tests__/registrationEmailPolicy.spec.ts index 021f0fc4..492a9a3d 100644 --- a/frontend/src/utils/__tests__/registrationEmailPolicy.spec.ts +++ b/frontend/src/utils/__tests__/registrationEmailPolicy.spec.ts @@ -1,5 +1,6 @@ import { describe, expect, it } from 'vitest' import { + formatRegistrationEmailSuffixWhitelistForMessage, isRegistrationEmailSuffixAllowed, isRegistrationEmailSuffixDomainValid, normalizeRegistrationEmailSuffixDomain, @@ -11,6 +12,7 @@ import { describe('registrationEmailPolicy utils', () => { it('normalizeRegistrationEmailSuffixDomain lowercases, strips @, and ignores invalid chars', () => { expect(normalizeRegistrationEmailSuffixDomain(' @Exa!mple.COM ')).toBe('example.com') + expect(normalizeRegistrationEmailSuffixDomain(' *.EDU!.CN ')).toBe('*.edu.cn') }) it('normalizeRegistrationEmailSuffixDomains deduplicates normalized domains', () => { @@ -22,14 +24,20 @@ describe('registrationEmailPolicy utils', () => { '-invalid.com', 'foo..bar.com', ' @foo.bar ', - '@foo.bar' + '@foo.bar', + '*.EDU.CN', + '*.edu.cn' ]) - ).toEqual(['example.com', 'foo.bar']) + ).toEqual(['example.com', 'foo.bar', '*.edu.cn']) }) it('parseRegistrationEmailSuffixWhitelistInput supports separators and deduplicates', () => { - const input = '\n @example.com,example.com,@foo.bar\t@FOO.bar ' - expect(parseRegistrationEmailSuffixWhitelistInput(input)).toEqual(['example.com', 'foo.bar']) + const input = '\n @example.com,example.com,@foo.bar\t@FOO.bar *.EDU.CN ' + expect(parseRegistrationEmailSuffixWhitelistInput(input)).toEqual([ + 'example.com', + 'foo.bar', + '*.edu.cn' + ]) }) it('parseRegistrationEmailSuffixWhitelistInput drops tokens containing invalid chars', () => { @@ -38,7 +46,7 @@ describe('registrationEmailPolicy utils', () => { }) it('parseRegistrationEmailSuffixWhitelistInput drops structurally invalid domains', () => { - const input = '@-bad.com, @foo..bar.com, @foo.bar, @xn--ok.com' + const input = '@-bad.com, @foo..bar.com, @foo.bar, @xn--ok.com, *., *, *.@, *.foo' expect(parseRegistrationEmailSuffixWhitelistInput(input)).toEqual(['foo.bar', 'xn--ok.com']) }) @@ -53,17 +61,22 @@ describe('registrationEmailPolicy utils', () => { 'foo.bar', '', '-invalid.com', - ' @foo.bar ' + ' @foo.bar ', + '*.EDU.CN' ]) - ).toEqual(['@example.com', '@foo.bar']) + ).toEqual(['@example.com', '@foo.bar', '*.edu.cn']) }) it('isRegistrationEmailSuffixDomainValid matches backend-compatible domain rules', () => { expect(isRegistrationEmailSuffixDomainValid('example.com')).toBe(true) expect(isRegistrationEmailSuffixDomainValid('foo-bar.example.com')).toBe(true) + expect(isRegistrationEmailSuffixDomainValid('*.edu.cn')).toBe(true) expect(isRegistrationEmailSuffixDomainValid('-bad.com')).toBe(false) expect(isRegistrationEmailSuffixDomainValid('foo..bar.com')).toBe(false) expect(isRegistrationEmailSuffixDomainValid('localhost')).toBe(false) + expect(isRegistrationEmailSuffixDomainValid('*.foo')).toBe(false) + expect(isRegistrationEmailSuffixDomainValid('*')).toBe(false) + expect(isRegistrationEmailSuffixDomainValid('*.@')).toBe(false) }) it('isRegistrationEmailSuffixAllowed allows any email when whitelist is empty', () => { @@ -73,5 +86,36 @@ describe('registrationEmailPolicy utils', () => { it('isRegistrationEmailSuffixAllowed applies exact suffix matching', () => { expect(isRegistrationEmailSuffixAllowed('user@example.com', ['@example.com'])).toBe(true) expect(isRegistrationEmailSuffixAllowed('user@sub.example.com', ['@example.com'])).toBe(false) + expect(isRegistrationEmailSuffixAllowed('user@qq.com', ['@qq.com'])).toBe(true) + expect(isRegistrationEmailSuffixAllowed('user@sub.qq.com', ['@qq.com'])).toBe(false) + }) + + it('isRegistrationEmailSuffixAllowed applies wildcard suffix matching', () => { + expect(isRegistrationEmailSuffixAllowed('student@cs.edu.cn', ['*.edu.cn'])).toBe(true) + expect(isRegistrationEmailSuffixAllowed('student@edu.cn', ['*.edu.cn'])).toBe(true) + expect(isRegistrationEmailSuffixAllowed('student@foo.cn', ['*.edu.cn'])).toBe(false) + }) + + it('isRegistrationEmailSuffixAllowed supports mixed exact and wildcard entries', () => { + const whitelist = ['@a.com', '*.b.cn'] + expect(isRegistrationEmailSuffixAllowed('user@a.com', whitelist)).toBe(true) + expect(isRegistrationEmailSuffixAllowed('user@school.b.cn', whitelist)).toBe(true) + expect(isRegistrationEmailSuffixAllowed('user@b.cn', whitelist)).toBe(true) + expect(isRegistrationEmailSuffixAllowed('user@c.cn', whitelist)).toBe(false) + }) + + it('formatRegistrationEmailSuffixWhitelistForMessage lists up to five entries', () => { + expect( + formatRegistrationEmailSuffixWhitelistForMessage( + ['@a.com', '@b.com', '@c.com', '@d.com', '@e.com'], + { separator: ', ', more: (count) => `and ${count} more` } + ) + ).toBe('@a.com, @b.com, @c.com, @d.com, @e.com') + expect( + formatRegistrationEmailSuffixWhitelistForMessage( + ['@a.com', '@b.com', '@c.com', '@d.com', '@e.com', '*.edu.cn', '@f.com'], + { separator: ', ', more: (count) => `and ${count} more` } + ) + ).toBe('@a.com, @b.com, @c.com, @d.com, @e.com, and 2 more') }) }) diff --git a/frontend/src/utils/registrationEmailPolicy.ts b/frontend/src/utils/registrationEmailPolicy.ts index 74d63fc4..bdb3dbc5 100644 --- a/frontend/src/utils/registrationEmailPolicy.ts +++ b/frontend/src/utils/registrationEmailPolicy.ts @@ -2,19 +2,21 @@ const EMAIL_SUFFIX_TOKEN_SPLIT_RE = /[\s,,]+/ const EMAIL_SUFFIX_INVALID_CHAR_RE = /[^a-z0-9.-]/g const EMAIL_SUFFIX_INVALID_CHAR_CHECK_RE = /[^a-z0-9.-]/ const EMAIL_SUFFIX_PREFIX_RE = /^@+/ +const EMAIL_SUFFIX_WILDCARD_PREFIX = '*.' +const EMAIL_SUFFIX_MESSAGE_VISIBLE_LIMIT = 5 const EMAIL_SUFFIX_DOMAIN_PATTERN = /^[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?(?:\.[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)+$/ // normalizeRegistrationEmailSuffixDomain converts raw input into a canonical domain token. -// It removes leading "@", lowercases input, and strips all invalid characters. +// Exact domains are returned without "@"; wildcard domains keep the "*." prefix. export function normalizeRegistrationEmailSuffixDomain(raw: string): string { let value = String(raw || '').trim().toLowerCase() if (!value) { return '' } + value = value.replace(EMAIL_SUFFIX_PREFIX_RE, '') - value = value.replace(EMAIL_SUFFIX_INVALID_CHAR_RE, '') - return value + return normalizeRegistrationEmailSuffixToken(value, false) } export function normalizeRegistrationEmailSuffixDomains( @@ -60,7 +62,7 @@ export function parseRegistrationEmailSuffixWhitelistInput(input: string): strin export function normalizeRegistrationEmailSuffixWhitelist( items: string[] | null | undefined ): string[] { - return normalizeRegistrationEmailSuffixDomains(items).map((domain) => `@${domain}`) + return normalizeRegistrationEmailSuffixDomains(items).map(toCanonicalRegistrationEmailSuffix) } function extractRegistrationEmailDomain(email: string): string { @@ -91,7 +93,32 @@ export function isRegistrationEmailSuffixAllowed( return false } const emailSuffix = `@${emailDomain}` - return normalizedWhitelist.includes(emailSuffix) + return normalizedWhitelist.some((allowed) => { + if (allowed.startsWith('@')) { + return allowed === emailSuffix + } + if (allowed.startsWith(EMAIL_SUFFIX_WILDCARD_PREFIX)) { + const base = allowed.slice(EMAIL_SUFFIX_WILDCARD_PREFIX.length) + return emailDomain === base || emailDomain.endsWith(`.${base}`) + } + return false + }) +} + +export function formatRegistrationEmailSuffixWhitelistForMessage( + whitelist: string[] | null | undefined, + options: { + separator: string + more: (count: number) => string + } +): string { + const normalizedWhitelist = normalizeRegistrationEmailSuffixWhitelist(whitelist) + const visible = normalizedWhitelist.slice(0, EMAIL_SUFFIX_MESSAGE_VISIBLE_LIMIT) + const hiddenCount = normalizedWhitelist.length - visible.length + if (hiddenCount > 0) { + visible.push(options.more(hiddenCount)) + } + return visible.join(options.separator) } // Pasted domains should be strict: any invalid character drops the whole token. @@ -101,15 +128,38 @@ function normalizeRegistrationEmailSuffixDomainStrict(raw: string): string { return '' } value = value.replace(EMAIL_SUFFIX_PREFIX_RE, '') - if (!value || EMAIL_SUFFIX_INVALID_CHAR_CHECK_RE.test(value)) { - return '' - } - return value + return normalizeRegistrationEmailSuffixToken(value, true) } export function isRegistrationEmailSuffixDomainValid(domain: string): boolean { if (!domain) { return false } - return EMAIL_SUFFIX_DOMAIN_PATTERN.test(domain) + if (domain.startsWith(EMAIL_SUFFIX_WILDCARD_PREFIX)) { + return EMAIL_SUFFIX_DOMAIN_PATTERN.test(domain.slice(EMAIL_SUFFIX_WILDCARD_PREFIX.length)) + } + return !domain.includes('*') && EMAIL_SUFFIX_DOMAIN_PATTERN.test(domain) +} + +function normalizeRegistrationEmailSuffixToken(value: string, strict: boolean): string { + if (value.startsWith(EMAIL_SUFFIX_WILDCARD_PREFIX)) { + const domain = value.slice(EMAIL_SUFFIX_WILDCARD_PREFIX.length) + if (strict && (!domain || EMAIL_SUFFIX_INVALID_CHAR_CHECK_RE.test(domain))) { + return '' + } + return `${EMAIL_SUFFIX_WILDCARD_PREFIX}${domain.replace(EMAIL_SUFFIX_INVALID_CHAR_RE, '')}` + } + + if (value === '*') { + return strict ? '' : value + } + + if (strict && EMAIL_SUFFIX_INVALID_CHAR_CHECK_RE.test(value)) { + return '' + } + return value.replace(/[*]/g, '').replace(EMAIL_SUFFIX_INVALID_CHAR_RE, '') +} + +function toCanonicalRegistrationEmailSuffix(domain: string): string { + return domain.startsWith(EMAIL_SUFFIX_WILDCARD_PREFIX) ? domain : `@${domain}` } diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index 93724aa6..4924ce55 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -354,6 +354,21 @@ + +
+
+
+ +

+ {{ t('admin.channels.form.bedrockCCCompatHint') }} +

+
+ +
+
+
@@ -669,6 +684,7 @@ interface PlatformSection { model_pricing: PricingFormEntry[] web_search_emulation: boolean codex_image_generation_bridge: boolean + bedrock_cc_compat: boolean account_stats_pricing_rules: FormPricingRule[] } @@ -765,6 +781,7 @@ function addPlatformSection(platform: GroupPlatform) { model_pricing: [], web_search_emulation: false, codex_image_generation_bridge: false, + bedrock_cc_compat: false, account_stats_pricing_rules: [], }) } @@ -1125,6 +1142,19 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[ delete featuresConfig.codex_image_generation_bridge } + const bedrockCCCompat: Record = {} + for (const section of form.platforms) { + if (!section.enabled) continue + if (section.platform === 'anthropic') { + bedrockCCCompat[section.platform] = !!section.bedrock_cc_compat + } + } + if (Object.keys(bedrockCCCompat).length > 0) { + featuresConfig.bedrock_cc_compat = bedrockCCCompat + } else { + delete featuresConfig.bedrock_cc_compat + } + return { group_ids, model_pricing, model_mapping, features_config: featuresConfig } } @@ -1175,6 +1205,8 @@ function apiToForm(channel: Channel): PlatformSection[] { const webSearchEnabled = wsEmulation?.[platform] === true const codexImageGenerationBridge = fc?.codex_image_generation_bridge as Record | undefined const codexImageGenerationBridgeEnabled = codexImageGenerationBridge?.[platform] === true + const bedrockCCCompat = fc?.bedrock_cc_compat as Record | undefined + const bedrockCCCompatEnabled = bedrockCCCompat?.[platform] === true sections.push({ platform, @@ -1185,6 +1217,7 @@ function apiToForm(channel: Channel): PlatformSection[] { model_pricing: pricing, web_search_emulation: webSearchEnabled, codex_image_generation_bridge: codexImageGenerationBridgeEnabled, + bedrock_cc_compat: bedrockCCCompatEnabled, account_stats_pricing_rules: [], }) } diff --git a/frontend/src/views/admin/ProxiesView.vue b/frontend/src/views/admin/ProxiesView.vue index 1e4df356..27f38307 100644 --- a/frontend/src/views/admin/ProxiesView.vue +++ b/frontend/src/views/admin/ProxiesView.vue @@ -357,45 +357,50 @@ @close="closeCreateModal" > -
- - + + {{ t('admin.proxies.standardAdd') }} + + +
+
@@ -887,6 +892,7 @@ import ConfirmDialog from '@/components/common/ConfirmDialog.vue' import EmptyState from '@/components/common/EmptyState.vue' import ImportDataModal from '@/components/admin/proxy/ImportDataModal.vue' import Select from '@/components/common/Select.vue' +import ProxyAdBanner from '@/components/common/ProxyAdBanner.vue' import Icon from '@/components/icons/Icon.vue' import PlatformTypeBadge from '@/components/common/PlatformTypeBadge.vue' import { useClipboard } from '@/composables/useClipboard' diff --git a/frontend/src/views/admin/RedeemView.vue b/frontend/src/views/admin/RedeemView.vue index b8e0e936..faae7439 100644 --- a/frontend/src/views/admin/RedeemView.vue +++ b/frontend/src/views/admin/RedeemView.vue @@ -39,6 +39,15 @@ + @@ -56,6 +65,28 @@ default-sort-order="desc" @sort="handleSort" > + + + +