chore: merge upstream Wei-Shaw/sub2api latest (v0.1.130+)

Upstream features: bedrock CC compat, email whitelist wildcard,
content moderation per-model toggle, redeem code batch update,
OIDC verified-email fast path, subscription expiry email,
cache hit rate fix, audit dedup, js-cookie security fix,
x/net vulnerability fix, OpenAI account cooldown optimization,
reverse proxy client IP fix, API key ACL trusted forwarded IP.

Local additions preserved: rpmTokenBucketService, quotaFactor
scoring, P2C scheduler selection.
This commit is contained in:
win 2026-05-24 15:54:54 +08:00
commit e938be5f3f
124 changed files with 6776 additions and 890 deletions

View File

@ -114,6 +114,12 @@ Register now via <a href="https://pateway.ai/?ch=1tsfr51">this link</a> to recei
</td>
</tr>
<tr>
<td width="180"><a href="https://runapi.co/register?aff=fu2E"><img src="assets/partners/logos/runapi.png" alt="RunAPI" width="150"></a></td>
<td>Thanks to RunAPI for sponsoring this project! <a href="https://runapi.co/register?aff=fu2E">RunAPI</a> is an efficient and stable API platform and OpenRouter alternative. With one API Key, you can access 150+ popular models including OpenAI, Claude, Gemini, DeepSeek, and Grok, with pricing as low as 10% of the original rate. It is highly stable and seamlessly compatible with tools such as Claude Code and OpenClaw.
</td>
</tr>
</table>
## Ecosystem

View File

@ -112,6 +112,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
<td>感谢 PPToken.org 赞助本项目! <a href="https://api.pptoken.org/register?promo=SUB2API">PPToken.org</a> 主打 GPT 系列模型 API 中转服务,支持 Codex、Claude Code、OpenAI 兼容客户端及 Gemini CLI 等工具接入。充值 1:11 元=1 美元额度GPT 模型最低 0.16 倍倍率,综合成本约为官方价格的 0.22 折,最快首字 Token 约 1 秒,适合开发者低成本、高响应速度接入 GPT 模型能力。技术支持: 7×24 小时真人响应(不是机器人),群内@技术10 分钟内有回复 。赞助商福利:前 200 名用户通过 <a href="https://api.pptoken.org/register?promo=SUB2API">[专属注册链接]</a> 注册,输入优惠码 `SUB2API`,可领取 Codex / Claude Code 免费试用额度,无门槛、不绑卡。
</td>
</tr>
<tr>
<td width="180"><a href="https://runapi.co/register?aff=fu2E"><img src="assets/partners/logos/runapi.png" alt="RunAPI" width="150"></a></td>
<td>感谢 RunAPI 赞助本项目! <a href="https://runapi.co/register?aff=fu2E">RunAPI</a> 是高效稳定的API OpenRouter平替平台一个 API Key 即可访问 OpenAI、Claude、Gemini、DeepSeek、Grok 等 150+ 主流模型,低至 1 折,极其稳定,可以无缝兼容 Claude Code、OpenClaw 等工具。
</td>
</tr>
</table>
## 生态项目

View File

@ -113,6 +113,12 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
</td>
</tr>
<tr>
<td width="180"><a href="https://runapi.co/register?aff=fu2E"><img src="assets/partners/logos/runapi.png" alt="RunAPI" width="150"></a></td>
<td>RunAPI のご支援に感謝します!<a href="https://runapi.co/register?aff=fu2E">RunAPI</a> は効率的で安定した API プラットフォームで、OpenRouter の代替として利用できます。1つの API キーで OpenAI、Claude、Gemini、DeepSeek、Grok など 150以上の主要モデルにアクセスでき、価格は最低 10% から。非常に安定しており、Claude Code や OpenClaw などのツールとシームレスに互換します。
</td>
</tr>
</table>
## エコシステム

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

View File

@ -1 +1 @@
0.1.129
0.1.130

View File

@ -24,8 +24,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/web"
"github.com/gin-gonic/gin"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
//go:embed VERSION
@ -116,11 +114,16 @@ func runSetupServer() {
log.Printf("Setup wizard available at http://%s", addr)
log.Println("Complete the setup wizard to configure Sub2API")
protocols := new(http.Protocols)
protocols.SetHTTP1(true)
protocols.SetUnencryptedHTTP2(true)
server := &http.Server{
Addr: addr,
Handler: h2c.NewHandler(r, &http2.Server{}),
Handler: r,
ReadHeaderTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
Protocols: protocols,
}
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {

View File

@ -113,23 +113,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
privacyClientFactory := providePrivacyClientFactory()
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
rpmCache := repository.NewRPMCache(redisClient)
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
driveClient := repository.NewGeminiDriveClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig)
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
return nil, err
}
billingService := service.NewBillingService(configConfig, pricingService)
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
@ -138,6 +133,30 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator)
httpUpstream := repository.NewHTTPUpstream(configConfig)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
openAIOAuthClient := repository.NewOpenAIOAuthClient()
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService)
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory, openAIGatewayService)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
rpmCache := repository.NewRPMCache(redisClient)
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
driveClient := repository.NewGeminiDriveClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig)
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
@ -146,12 +165,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
gatewayCache := repository.NewGatewayCache(redisClient)
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
@ -176,25 +191,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
return nil, err
}
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
digestSessionStore := service.NewDigestSessionStore()
channelRepository := repository.NewChannelRepository(db)
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService)
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService)
rpmTokenBucketService := service.NewRPMTokenBucketService()
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService)
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
@ -271,9 +271,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService, settingRepository, opsService)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI, openAIGatewayService)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository, notificationEmailService)
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository, settingRepository, notificationEmailService)
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository)

View File

@ -15,6 +15,7 @@ require (
github.com/cespare/xxhash/v2 v2.3.0
github.com/coder/websocket v1.8.14
github.com/dgraph-io/ristretto v0.2.0
github.com/docker/docker v28.5.2+incompatible
github.com/gin-gonic/gin v1.9.1
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/google/uuid v1.6.0
@ -41,11 +42,11 @@ require (
github.com/wechatpay-apiv3/wechatpay-go v0.2.21
github.com/zeromicro/go-zero v1.9.4
go.uber.org/zap v1.24.0
golang.org/x/crypto v0.50.0
golang.org/x/crypto v0.51.0
golang.org/x/image v0.39.0
golang.org/x/net v0.53.0
golang.org/x/net v0.55.0
golang.org/x/sync v0.20.0
golang.org/x/term v0.42.0
golang.org/x/term v0.43.0
google.golang.org/protobuf v1.36.10
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
@ -87,7 +88,6 @@ require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/docker v28.5.2+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
@ -107,7 +107,6 @@ require (
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/subcommands v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
@ -175,10 +174,10 @@ require (
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.34.0 // indirect
golang.org/x/sys v0.43.0 // indirect
golang.org/x/text v0.36.0 // indirect
golang.org/x/tools v0.43.0 // indirect
golang.org/x/mod v0.35.0 // indirect
golang.org/x/sys v0.45.0 // indirect
golang.org/x/text v0.37.0 // indirect
golang.org/x/tools v0.44.0 // indirect
google.golang.org/grpc v1.75.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
modernc.org/libc v1.67.6 // indirect

View File

@ -108,8 +108,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM=
github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
@ -166,8 +164,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
@ -222,8 +218,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@ -257,8 +251,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@ -288,8 +280,6 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
@ -322,8 +312,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
@ -415,16 +403,16 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI=
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/image v0.39.0 h1:skVYidAEVKgn8lZ602XO75asgXBgLj9G/FE3RbuPFww=
golang.org/x/image v0.39.0/go.mod h1:sIbmppfU+xFLPIG0FoVUTvyBMmgng1/XAMhQ2ft0hpA=
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM=
golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU=
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -436,16 +424,16 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c=
golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=

View File

@ -9,6 +9,7 @@ import (
"net/url"
"os"
"strings"
"sync/atomic"
"time"
"github.com/spf13/viper"
@ -581,11 +582,35 @@ type CORSConfig struct {
}
type SecurityConfig struct {
URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"`
ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"`
CSP CSPConfig `mapstructure:"csp"`
ProxyFallback ProxyFallbackConfig `mapstructure:"proxy_fallback"`
ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"`
URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"`
ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"`
CSP CSPConfig `mapstructure:"csp"`
ProxyFallback ProxyFallbackConfig `mapstructure:"proxy_fallback"`
ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"`
TrustForwardedIPForAPIKeyACL bool `mapstructure:"trust_forwarded_ip_for_api_key_acl"`
trustForwardedIPForAPIKeyACLLive *atomic.Bool `mapstructure:"-"`
}
func (c *Config) TrustForwardedIPForAPIKeyACL() bool {
if c == nil {
return false
}
live := c.Security.trustForwardedIPForAPIKeyACLLive
if live == nil {
return c.Security.TrustForwardedIPForAPIKeyACL
}
return live.Load()
}
func (c *Config) SetTrustForwardedIPForAPIKeyACL(enabled bool) {
if c == nil {
return
}
c.Security.TrustForwardedIPForAPIKeyACL = enabled
if c.Security.trustForwardedIPForAPIKeyACLLive == nil {
c.Security.trustForwardedIPForAPIKeyACLLive = &atomic.Bool{}
}
c.Security.trustForwardedIPForAPIKeyACLLive.Store(enabled)
}
type URLAllowlistConfig struct {
@ -1083,7 +1108,8 @@ type GatewaySchedulingConfig struct {
FallbackSelectionMode string `mapstructure:"fallback_selection_mode"`
// 负载计算
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
LoadBatchCacheTTLMS int `mapstructure:"load_batch_cache_ttl_ms"`
// 快照桶读取时的 MGET 分块大小
SnapshotMGetChunkSize int `mapstructure:"snapshot_mget_chunk_size"`
// 快照重建时的缓存写入分块大小
@ -1466,6 +1492,7 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
cfg.SetTrustForwardedIPForAPIKeyACL(cfg.Security.TrustForwardedIPForAPIKeyACL)
cfg.Log.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level))
cfg.Log.Format = strings.ToLower(strings.TrimSpace(cfg.Log.Format))
cfg.Log.ServiceName = strings.TrimSpace(cfg.Log.ServiceName)
@ -1610,6 +1637,7 @@ func setDefaults() {
viper.SetDefault("security.csp.enabled", true)
viper.SetDefault("security.csp.policy", DefaultCSPPolicy)
viper.SetDefault("security.proxy_probe.insecure_skip_verify", false)
viper.SetDefault("security.trust_forwarded_ip_for_api_key_acl", false)
// Security - disable direct fallback on proxy error
viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false)
@ -1905,6 +1933,7 @@ func setDefaults() {
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.load_batch_cache_ttl_ms", 200)
viper.SetDefault("gateway.scheduling.snapshot_mget_chunk_size", 128)
viper.SetDefault("gateway.scheduling.snapshot_write_chunk_size", 256)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
@ -2766,6 +2795,9 @@ func (c *Config) Validate() error {
if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 {
return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive")
}
if c.Gateway.Scheduling.LoadBatchCacheTTLMS < 0 {
return fmt.Errorf("gateway.scheduling.load_batch_cache_ttl_ms must be non-negative")
}
if c.Gateway.Scheduling.SnapshotMGetChunkSize <= 0 {
return fmt.Errorf("gateway.scheduling.snapshot_mget_chunk_size must be positive")
}

View File

@ -73,6 +73,9 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
if !cfg.Gateway.Scheduling.LoadBatchEnabled {
t.Fatalf("LoadBatchEnabled = false, want true")
}
if cfg.Gateway.Scheduling.LoadBatchCacheTTLMS != 200 {
t.Fatalf("LoadBatchCacheTTLMS = %d, want 200", cfg.Gateway.Scheduling.LoadBatchCacheTTLMS)
}
if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second {
t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval)
}
@ -1415,6 +1418,11 @@ func TestValidateConfigErrors(t *testing.T) {
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
wantErr: "gateway.scheduling.sticky_session_max_waiting",
},
{
name: "gateway scheduling load batch cache ttl",
mutate: func(c *Config) { c.Gateway.Scheduling.LoadBatchCacheTTLMS = -1 },
wantErr: "gateway.scheduling.load_batch_cache_ttl_ms",
},
{
name: "gateway scheduling outbox poll",
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxPollIntervalSeconds = 0 },

View File

@ -20,34 +20,35 @@ func NewContentModerationHandler(svc *service.ContentModerationService) *Content
}
type contentModerationConfigRequest struct {
Enabled *bool `json:"enabled"`
Mode *string `json:"mode"`
BaseURL *string `json:"base_url"`
Model *string `json:"model"`
APIKey *string `json:"api_key"`
APIKeys *[]string `json:"api_keys"`
APIKeysMode string `json:"api_keys_mode"`
DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"`
ClearAPIKey bool `json:"clear_api_key"`
TimeoutMS *int `json:"timeout_ms"`
SampleRate *int `json:"sample_rate"`
AllGroups *bool `json:"all_groups"`
GroupIDs *[]int64 `json:"group_ids"`
RecordNonHits *bool `json:"record_non_hits"`
WorkerCount *int `json:"worker_count"`
QueueSize *int `json:"queue_size"`
BlockStatus *int `json:"block_status"`
BlockMessage *string `json:"block_message"`
EmailOnHit *bool `json:"email_on_hit"`
AutoBanEnabled *bool `json:"auto_ban_enabled"`
BanThreshold *int `json:"ban_threshold"`
ViolationWindowHours *int `json:"violation_window_hours"`
RetryCount *int `json:"retry_count"`
HitRetentionDays *int `json:"hit_retention_days"`
NonHitRetentionDays *int `json:"non_hit_retention_days"`
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
BlockedKeywords *[]string `json:"blocked_keywords"`
KeywordBlockingMode *string `json:"keyword_blocking_mode"`
Enabled *bool `json:"enabled"`
Mode *string `json:"mode"`
BaseURL *string `json:"base_url"`
Model *string `json:"model"`
APIKey *string `json:"api_key"`
APIKeys *[]string `json:"api_keys"`
APIKeysMode string `json:"api_keys_mode"`
DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"`
ClearAPIKey bool `json:"clear_api_key"`
TimeoutMS *int `json:"timeout_ms"`
SampleRate *int `json:"sample_rate"`
AllGroups *bool `json:"all_groups"`
GroupIDs *[]int64 `json:"group_ids"`
RecordNonHits *bool `json:"record_non_hits"`
WorkerCount *int `json:"worker_count"`
QueueSize *int `json:"queue_size"`
BlockStatus *int `json:"block_status"`
BlockMessage *string `json:"block_message"`
EmailOnHit *bool `json:"email_on_hit"`
AutoBanEnabled *bool `json:"auto_ban_enabled"`
BanThreshold *int `json:"ban_threshold"`
ViolationWindowHours *int `json:"violation_window_hours"`
RetryCount *int `json:"retry_count"`
HitRetentionDays *int `json:"hit_retention_days"`
NonHitRetentionDays *int `json:"non_hit_retention_days"`
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
BlockedKeywords *[]string `json:"blocked_keywords"`
KeywordBlockingMode *string `json:"keyword_blocking_mode"`
ModelFilter *service.ContentModerationModelFilter `json:"model_filter"`
}
type contentModerationAPIKeyTestRequest struct {
@ -107,6 +108,7 @@ func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) {
PreHashCheckEnabled: req.PreHashCheckEnabled,
BlockedKeywords: req.BlockedKeywords,
KeywordBlockingMode: req.KeywordBlockingMode,
ModelFilter: req.ModelFilter,
})
if err != nil {
response.ErrorFrom(c, err)

View File

@ -307,6 +307,51 @@ func (h *RedeemHandler) BatchDelete(c *gin.Context) {
})
}
// BatchUpdate handles batch updating redeem codes
// POST /api/v1/admin/redeem-codes/batch-update
func (h *RedeemHandler) BatchUpdate(c *gin.Context) {
if h.redeemService == nil {
response.InternalError(c, "redeem service not configured")
return
}
var req dto.BatchUpdateRedeemCodesRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
result, err := h.redeemService.BatchUpdate(c.Request.Context(), &service.RedeemCodeBatchUpdateInput{
IDs: req.IDs,
Fields: redeemBatchUpdateFieldsFromDTO(req.Fields),
})
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, gin.H{
"updated": result.Updated,
"message": "Redeem codes updated successfully",
})
}
func redeemBatchUpdateFieldsFromDTO(in dto.BatchUpdateRedeemCodeFields) service.RedeemCodeBatchUpdateFields {
out := service.RedeemCodeBatchUpdateFields{
Status: in.Status,
Notes: in.Notes,
Type: in.Type,
Value: in.Value,
}
if in.ExpiresAt.Set {
out.ExpiresAt = service.NullableTimeUpdate{Set: true, Value: in.ExpiresAt.Value}
}
if in.GroupID.Set {
out.GroupID = service.NullableInt64Update{Set: true, Value: in.GroupID.Value}
}
return out
}
// Expire handles expiring a redeem code
// POST /api/v1/admin/redeem-codes/:id/expire
func (h *RedeemHandler) Expire(c *gin.Context) {

View File

@ -142,6 +142,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
TurnstileEnabled: settings.TurnstileEnabled,
TurnstileSiteKey: settings.TurnstileSiteKey,
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
APIKeyACLTrustForwardedIP: settings.APIKeyACLTrustForwardedIP,
LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled,
LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
@ -264,6 +265,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
SubscriptionExpiryNotifyEnabled: settings.SubscriptionExpiryNotifyEnabled,
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails),
PaymentEnabled: paymentCfg.Enabled,
@ -399,6 +401,9 @@ type UpdateSettingsRequest struct {
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKey string `json:"turnstile_secret_key"`
// API Key IP 访问控制设置
APIKeyACLTrustForwardedIP *bool `json:"api_key_acl_trust_forwarded_ip"`
// LinuxDo Connect OAuth 登录
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
@ -582,12 +587,13 @@ type UpdateSettingsRequest struct {
// OpenAI account scheduling
OpenAIAdvancedSchedulerEnabled *bool `json:"openai_advanced_scheduler_enabled"`
// Balance low notification
BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"`
AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"`
AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"`
// 余额不足提醒
BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"`
SubscriptionExpiryNotifyEnabled *bool `json:"subscription_expiry_notify_enabled"`
AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"`
AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"`
// Payment configuration (integrated into settings, full replace)
PaymentEnabled *bool `json:"payment_enabled"`
@ -1432,28 +1438,34 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
FrontendURL: req.FrontendURL,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled,
LoginAgreementEnabled: req.LoginAgreementEnabled,
LoginAgreementMode: loginAgreementMode,
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
LoginAgreementDocuments: loginAgreementDocuments,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
FrontendURL: req.FrontendURL,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled,
LoginAgreementEnabled: req.LoginAgreementEnabled,
LoginAgreementMode: loginAgreementMode,
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
LoginAgreementDocuments: loginAgreementDocuments,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
APIKeyACLTrustForwardedIP: func() bool {
if req.APIKeyACLTrustForwardedIP != nil {
return *req.APIKeyACLTrustForwardedIP
}
return previousSettings.APIKeyACLTrustForwardedIP
}(),
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
@ -1669,6 +1681,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.BalanceLowNotifyRechargeURL
}(),
SubscriptionExpiryNotifyEnabled: func() bool {
if req.SubscriptionExpiryNotifyEnabled != nil {
return *req.SubscriptionExpiryNotifyEnabled
}
return previousSettings.SubscriptionExpiryNotifyEnabled
}(),
AccountQuotaNotifyEnabled: func() bool {
if req.AccountQuotaNotifyEnabled != nil {
return *req.AccountQuotaNotifyEnabled
@ -1869,6 +1887,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
TurnstileEnabled: updatedSettings.TurnstileEnabled,
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
APIKeyACLTrustForwardedIP: updatedSettings.APIKeyACLTrustForwardedIP,
LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled,
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
@ -1989,6 +2008,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL,
SubscriptionExpiryNotifyEnabled: updatedSettings.SubscriptionExpiryNotifyEnabled,
AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled,
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails),
PaymentEnabled: updatedPaymentCfg.Enabled,
@ -2145,6 +2165,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if req.TurnstileSecretKey != "" {
changed = append(changed, "turnstile_secret_key")
}
if before.APIKeyACLTrustForwardedIP != after.APIKeyACLTrustForwardedIP {
changed = append(changed, "api_key_acl_trust_forwarded_ip")
}
if before.LinuxDoConnectEnabled != after.LinuxDoConnectEnabled {
changed = append(changed, "linuxdo_connect_enabled")
}
@ -2454,7 +2477,7 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.OpenAIAdvancedSchedulerEnabled != after.OpenAIAdvancedSchedulerEnabled {
changed = append(changed, "openai_advanced_scheduler_enabled")
}
// Balance & quota notification
// 余额、订阅到期与账号限额通知
if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled {
changed = append(changed, "balance_low_notify_enabled")
}
@ -2464,6 +2487,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.BalanceLowNotifyRechargeURL != after.BalanceLowNotifyRechargeURL {
changed = append(changed, "balance_low_notify_recharge_url")
}
if before.SubscriptionExpiryNotifyEnabled != after.SubscriptionExpiryNotifyEnabled {
changed = append(changed, "subscription_expiry_notify_enabled")
}
if before.AccountQuotaNotifyEnabled != after.AccountQuotaNotifyEnabled {
changed = append(changed, "account_quota_notify_enabled")
}
@ -3472,6 +3498,8 @@ func emailTemplateEventOptionsToDTO(events []service.NotificationEmailEventInfo)
Value: event.Event,
Label: event.Label,
Description: event.Description,
Category: event.Category,
Optional: event.Optional,
})
}
return items

View File

@ -2492,6 +2492,10 @@ func (r *oauthPendingFlowRedeemCodeRepo) Update(ctx context.Context, code *servi
return err
}
func (r *oauthPendingFlowRedeemCodeRepo) BatchUpdate(context.Context, []int64, service.RedeemCodeBatchUpdateFields) (int64, error) {
panic("unexpected BatchUpdate call")
}
func (r *oauthPendingFlowRedeemCodeRepo) Delete(context.Context, int64) error {
panic("unexpected Delete call")
}

View File

@ -2,6 +2,7 @@ package handler
import (
"net/http"
"net/http/httptest"
"net/url"
"testing"
@ -32,6 +33,13 @@ func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
return nil
}
func requireCookieCleared(t *testing.T, recorder *httptest.ResponseRecorder, name string) {
t.Helper()
cookie := findCookie(recorder.Result().Cookies(), name)
require.NotNil(t, cookie)
require.Equal(t, -1, cookie.MaxAge)
}
func decodeCookieValueForTest(t *testing.T, value string) string {
t.Helper()
decoded, err := decodeCookieValue(value)
@ -40,6 +48,13 @@ func decodeCookieValueForTest(t *testing.T, value string) string {
}
func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) {
t.Helper()
values := parseOAuthRedirectFragment(t, location)
require.Equal(t, errorCode, values.Get("error"))
require.Equal(t, errorMessage, values.Get("error_message"))
}
func parseOAuthRedirectFragment(t *testing.T, location string) url.Values {
t.Helper()
require.NotEmpty(t, location)
@ -52,6 +67,5 @@ func assertOAuthRedirectError(t *testing.T, location string, errorCode string, e
}
values, err := url.ParseQuery(rawValues)
require.NoError(t, err)
require.Equal(t, errorCode, values.Get("error"))
require.Equal(t, errorMessage, values.Get("error_message"))
return values
}

View File

@ -454,6 +454,24 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
}
}
// 快捷路径:当上游返回已验证邮箱、部署不要求额外确认且本地没有同邮箱账号时,
// 直接信任上游身份完成注册/登录,避免展示 choice 页。
if compatEmailUser == nil &&
strings.TrimSpace(compatEmail) != "" &&
emailVerified != nil && *emailVerified {
if handled := h.tryOIDCVerifiedEmailFastPath(
c,
frontendCallback,
redirectTo,
identityRef,
compatEmail,
username,
upstreamClaims,
); handled {
return
}
}
if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
if err := h.createOIDCOAuthChoicePendingSession(
c,
@ -1190,3 +1208,70 @@ func oidcClearCookie(c *gin.Context, name string, secure bool) {
SameSite: http.SameSiteLaxMode,
})
}
// tryOIDCVerifiedEmailFastPath 在 OIDC 上游已返回已验证邮箱时尝试跳过 choice/pending 页。
// 返回 true 表示已经写出重定向响应;返回 false 表示调用方应继续回退到常规 choice 流程。
func (h *AuthHandler) tryOIDCVerifiedEmailFastPath(
c *gin.Context,
frontendCallback string,
redirectTo string,
identity service.PendingAuthIdentityKey,
compatEmail string,
username string,
upstreamClaims map[string]any,
) bool {
if h == nil || h.authService == nil || h.settingSvc == nil {
return false
}
ctx := c.Request.Context()
if h.isForceEmailOnThirdPartySignup(ctx) {
return false
}
if h.settingSvc.IsInvitationCodeEnabled(ctx) {
return false
}
if err := h.ensureBackendModeAllowsNewUserLogin(ctx); err != nil {
log.Printf("[OIDC OAuth] verified-email fast path blocked by backend mode: reason=%s", infraerrors.Reason(err))
clearOAuthPendingSessionCookie(c, isRequestHTTPS(c))
clearOAuthPendingBrowserCookie(c, isRequestHTTPS(c))
redirectOAuthError(c, frontendCallback, "login_blocked", infraerrors.Reason(err), infraerrors.Message(err))
return true
}
verifiedEmail := strings.TrimSpace(strings.ToLower(compatEmail))
upstreamMetadata := make(map[string]any, len(upstreamClaims)+1)
for k, v := range upstreamClaims {
upstreamMetadata[k] = v
}
if syntheticEmail := pendingSessionStringValue(upstreamClaims, "email"); syntheticEmail != "" && !strings.EqualFold(syntheticEmail, verifiedEmail) {
upstreamMetadata["synthetic_email"] = syntheticEmail
}
upstreamMetadata["email"] = verifiedEmail
input := service.EmailOAuthIdentityInput{
ProviderType: strings.TrimSpace(identity.ProviderType),
ProviderKey: strings.TrimSpace(identity.ProviderKey),
ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
Email: verifiedEmail,
EmailVerified: true,
Username: strings.TrimSpace(username),
DisplayName: pendingSessionStringValue(upstreamClaims, "suggested_display_name"),
AvatarURL: pendingSessionStringValue(upstreamClaims, "suggested_avatar_url"),
UpstreamMetadata: upstreamMetadata,
}
tokenPair, _, err := h.authService.LoginOrRegisterVerifiedEmailOAuthWithInvitation(ctx, input, "", "")
if err != nil {
log.Printf("[OIDC OAuth] verified-email fast path skipped: reason=%s", infraerrors.Reason(err))
return false
}
fragment := url.Values{}
fragment.Set("access_token", tokenPair.AccessToken)
fragment.Set("refresh_token", tokenPair.RefreshToken)
fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
fragment.Set("token_type", "Bearer")
fragment.Set("redirect", redirectTo)
clearOAuthPendingSessionCookie(c, isRequestHTTPS(c))
clearOAuthPendingBrowserCookie(c, isRequestHTTPS(c))
redirectWithFragment(c, frontendCallback, fragment)
return true
}

View File

@ -10,6 +10,7 @@ import (
"math/big"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
@ -926,6 +927,232 @@ func TestCompleteOIDCOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUser
require.Nil(t, storedSession.ConsumedAt)
}
func TestTryOIDCVerifiedEmailFastPathCreatesUserAndIdentity(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, false)
t.Cleanup(func() { _ = client.Close() })
ctx := context.Background()
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil)
identity := service.PendingAuthIdentityKey{
ProviderType: "oidc",
ProviderKey: "https://issuer.example.com",
ProviderSubject: "fast-path-subject",
}
completed := handler.tryOIDCVerifiedEmailFastPath(
c,
"/auth/oidc/callback",
"/dashboard",
identity,
"fastpath@example.com",
"fastpath_user",
map[string]any{
"suggested_display_name": "Fast Path",
"suggested_avatar_url": "",
},
)
require.True(t, completed)
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
require.Contains(t, location, "/auth/oidc/callback")
require.Contains(t, location, "access_token=")
require.Contains(t, location, "refresh_token=")
require.Contains(t, location, "token_type=Bearer")
user, err := client.User.Query().Where(dbuser.EmailEQ("fastpath@example.com")).Only(ctx)
require.NoError(t, err)
require.Equal(t, "fastpath_user", user.Username)
require.Equal(t, "oidc", user.SignupSource)
identityRecord, err := client.AuthIdentity.Query().Where(
authidentity.ProviderTypeEQ("oidc"),
authidentity.ProviderKeyEQ("https://issuer.example.com"),
authidentity.ProviderSubjectEQ("fast-path-subject"),
authidentity.UserIDEQ(user.ID),
).Only(ctx)
require.NoError(t, err)
require.Equal(t, "fastpath@example.com", identityRecord.Metadata["email"])
require.Equal(t, true, identityRecord.Metadata["email_verified"])
pendingCount, err := client.PendingAuthSession.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, pendingCount)
}
func TestOIDCOAuthCallbackVerifiedEmailFastPathIssuesTokenWithoutPendingSession(t *testing.T) {
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
Subject: "oidc-fast-callback-subject",
PreferredUsername: "oidc_fast_callback",
DisplayName: "OIDC Fast Callback",
AvatarURL: "https://cdn.example/oidc-fast.png",
Email: "oidc-fast-callback@example.com",
EmailVerified: true,
})
defer cleanup()
handler, client := newOIDCOAuthHandlerAndClientWithSettings(t, false, cfg, nil)
t.Cleanup(func() { _ = client.Close() })
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-fast-callback", nil)
req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-fast-callback"))
req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-fast-callback"))
req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-fast-callback-subject"))
req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-fast-callback"))
c.Request = req
handler.OIDCOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
require.Contains(t, location, "/auth/oidc/callback#")
require.Contains(t, location, "access_token=")
require.Contains(t, location, "refresh_token=")
require.Contains(t, location, "token_type=Bearer")
fragmentValues := parseOAuthRedirectFragment(t, location)
require.Equal(t, "/dashboard", fragmentValues.Get("redirect"))
requireCookieCleared(t, recorder, oauthPendingSessionCookieName)
requireCookieCleared(t, recorder, oauthPendingBrowserCookieName)
ctx := context.Background()
user, err := client.User.Query().Where(dbuser.EmailEQ("oidc-fast-callback@example.com")).Only(ctx)
require.NoError(t, err)
require.Equal(t, "oidc_fast_callback", user.Username)
require.Equal(t, "oidc", user.SignupSource)
identity, err := client.AuthIdentity.Query().Where(
authidentity.ProviderTypeEQ("oidc"),
authidentity.ProviderKeyEQ(cfg.IssuerURL),
authidentity.ProviderSubjectEQ("oidc-fast-callback-subject"),
authidentity.UserIDEQ(user.ID),
).Only(ctx)
require.NoError(t, err)
require.Equal(t, "oidc-fast-callback@example.com", identity.Metadata["email"])
require.Equal(t, true, identity.Metadata["email_verified"])
require.Equal(t, "OIDC Fast Callback", identity.Metadata["suggested_display_name"])
require.NotEqual(t, identity.Metadata["email"], identity.Metadata["synthetic_email"])
pendingCount, err := client.PendingAuthSession.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, pendingCount)
}
func TestOIDCOAuthCallbackVerifiedEmailFastPathBackendModeBlocksBeforeUserCreation(t *testing.T) {
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
Subject: "oidc-fast-backend-mode-subject",
PreferredUsername: "oidc_backend_mode",
DisplayName: "OIDC Backend Mode",
Email: "oidc-backend-mode@example.com",
EmailVerified: true,
})
defer cleanup()
handler, client := newOIDCOAuthHandlerAndClientWithSettings(t, false, cfg, map[string]string{
service.SettingKeyBackendModeEnabled: "true",
})
t.Cleanup(func() { _ = client.Close() })
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-backend-mode", nil)
req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-backend-mode"))
req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-backend-mode"))
req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-fast-backend-mode-subject"))
req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-backend-mode"))
c.Request = req
handler.OIDCOAuthCallback(c)
require.Equal(t, http.StatusFound, recorder.Code)
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "login_blocked", "BACKEND_MODE_ADMIN_ONLY")
requireCookieCleared(t, recorder, oauthPendingSessionCookieName)
requireCookieCleared(t, recorder, oauthPendingBrowserCookieName)
ctx := context.Background()
userCount, err := client.User.Query().Where(dbuser.EmailEQ("oidc-backend-mode@example.com")).Count(ctx)
require.NoError(t, err)
require.Zero(t, userCount)
identityCount, err := client.AuthIdentity.Query().
Where(authidentity.ProviderSubjectEQ("oidc-fast-backend-mode-subject")).
Count(ctx)
require.NoError(t, err)
require.Zero(t, identityCount)
pendingCount, err := client.PendingAuthSession.Query().Count(ctx)
require.NoError(t, err)
require.Zero(t, pendingCount)
}
func TestTryOIDCVerifiedEmailFastPathSkippedWhenInvitationCodeRequired(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandler(t, true)
t.Cleanup(func() { _ = client.Close() })
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil)
identity := service.PendingAuthIdentityKey{
ProviderType: "oidc",
ProviderKey: "https://issuer.example.com",
ProviderSubject: "fast-path-skipped-invitation",
}
completed := handler.tryOIDCVerifiedEmailFastPath(
c,
"/auth/oidc/callback",
"/dashboard",
identity,
"invite-only@example.com",
"invite_only_user",
map[string]any{},
)
require.False(t, completed)
require.NotEqual(t, http.StatusFound, recorder.Code)
userCount, err := client.User.Query().Where(dbuser.EmailEQ("invite-only@example.com")).Count(context.Background())
require.NoError(t, err)
require.Zero(t, userCount)
}
func TestTryOIDCVerifiedEmailFastPathSkippedWhenForceEmailEnabled(t *testing.T) {
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
settingValues: map[string]string{
service.SettingKeyForceEmailOnThirdPartySignup: "true",
},
})
t.Cleanup(func() { _ = client.Close() })
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil)
identity := service.PendingAuthIdentityKey{
ProviderType: "oidc",
ProviderKey: "https://issuer.example.com",
ProviderSubject: "fast-path-skipped-force-email",
}
completed := handler.tryOIDCVerifiedEmailFastPath(
c,
"/auth/oidc/callback",
"/dashboard",
identity,
"force-email@example.com",
"force_email_user",
map[string]any{},
)
require.False(t, completed)
userCount, err := client.User.Query().Where(dbuser.EmailEQ("force-email@example.com")).Count(context.Background())
require.NoError(t, err)
require.Zero(t, userCount)
}
type oidcProviderFixture struct {
Subject string
PreferredUsername string
@ -957,6 +1184,49 @@ func newOIDCOAuthHandlerAndClient(t *testing.T, invitationEnabled bool, oauthCfg
return handler, client
}
func newOIDCOAuthHandlerAndClientWithSettings(
t *testing.T,
invitationEnabled bool,
oauthCfg config.OIDCConnectConfig,
settingValues map[string]string,
) (*AuthHandler, *dbent.Client) {
t.Helper()
values := map[string]string{
service.SettingKeyOIDCConnectEnabled: "true",
service.SettingKeyOIDCConnectProviderName: strings.TrimSpace(firstNonEmpty(oauthCfg.ProviderName, "OIDC")),
service.SettingKeyOIDCConnectClientID: oauthCfg.ClientID,
service.SettingKeyOIDCConnectClientSecret: oauthCfg.ClientSecret,
service.SettingKeyOIDCConnectIssuerURL: oauthCfg.IssuerURL,
service.SettingKeyOIDCConnectAuthorizeURL: oauthCfg.AuthorizeURL,
service.SettingKeyOIDCConnectTokenURL: oauthCfg.TokenURL,
service.SettingKeyOIDCConnectUserInfoURL: oauthCfg.UserInfoURL,
service.SettingKeyOIDCConnectJWKSURL: oauthCfg.JWKSURL,
service.SettingKeyOIDCConnectScopes: oauthCfg.Scopes,
service.SettingKeyOIDCConnectRedirectURL: oauthCfg.RedirectURL,
service.SettingKeyOIDCConnectFrontendRedirectURL: oauthCfg.FrontendRedirectURL,
service.SettingKeyOIDCConnectTokenAuthMethod: oauthCfg.TokenAuthMethod,
service.SettingKeyOIDCConnectUsePKCE: boolSettingValue(oauthCfg.UsePKCE),
service.SettingKeyOIDCConnectValidateIDToken: boolSettingValue(oauthCfg.ValidateIDToken),
service.SettingKeyOIDCConnectAllowedSigningAlgs: oauthCfg.AllowedSigningAlgs,
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
service.SettingKeyOIDCConnectRequireEmailVerified: boolSettingValue(oauthCfg.RequireEmailVerified),
}
for key, value := range settingValues {
values[key] = value
}
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
invitationEnabled: invitationEnabled,
settingValues: values,
})
if handler.cfg == nil {
handler.cfg = &config.Config{}
}
handler.cfg.OIDC = oauthCfg
return handler, client
}
func newOIDCTestProvider(t *testing.T, fixture oidcProviderFixture) (config.OIDCConnectConfig, func()) {
t.Helper()

View File

@ -50,6 +50,7 @@ type SystemSettings struct {
TurnstileEnabled bool `json:"turnstile_enabled"`
TurnstileSiteKey string `json:"turnstile_site_key"`
TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
APIKeyACLTrustForwardedIP bool `json:"api_key_acl_trust_forwarded_ip"`
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
@ -222,12 +223,13 @@ type SystemSettings struct {
// Force Alipay mobile clients to use QR code payment instead of mobile redirect
PaymentAlipayForceQRCode bool `json:"payment_alipay_force_qrcode"`
// Balance low notification
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"`
// 余额、订阅到期与账号限额通知
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
SubscriptionExpiryNotifyEnabled bool `json:"subscription_expiry_notify_enabled"`
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"`
// Channel Monitor feature switch
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
@ -378,11 +380,13 @@ type OpenAIFastPolicySettings struct {
Rules []OpenAIFastPolicyRule `json:"rules"`
}
// EmailTemplateEventOption describes an editable notification email event.
// EmailTemplateEventOption 描述可编辑的通知邮件事件。
type EmailTemplateEventOption struct {
Value string `json:"value"`
Label string `json:"label,omitempty"`
Description string `json:"description,omitempty"`
Category string `json:"category,omitempty"`
Optional bool `json:"optional,omitempty"`
}
// EmailTemplateSummary is shown in the admin email template list.

View File

@ -1,6 +1,8 @@
package dto
import (
"bytes"
"encoding/json"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
@ -359,6 +361,59 @@ type AdminRedeemCode struct {
Notes string `json:"notes"`
}
type NullableTimeField struct {
Set bool
Value *time.Time
}
func (f *NullableTimeField) UnmarshalJSON(data []byte) error {
f.Set = true
if bytes.Equal(data, []byte("null")) {
f.Value = nil
return nil
}
var value time.Time
if err := json.Unmarshal(data, &value); err != nil {
return err
}
f.Value = &value
return nil
}
type NullableInt64Field struct {
Set bool
Value *int64
}
func (f *NullableInt64Field) UnmarshalJSON(data []byte) error {
f.Set = true
if bytes.Equal(data, []byte("null")) {
f.Value = nil
return nil
}
var value int64
if err := json.Unmarshal(data, &value); err != nil {
return err
}
f.Value = &value
return nil
}
type BatchUpdateRedeemCodeFields struct {
Status *string `json:"status,omitempty"`
ExpiresAt NullableTimeField `json:"expires_at,omitempty"`
Notes *string `json:"notes,omitempty"`
GroupID NullableInt64Field `json:"group_id,omitempty"`
Type *string `json:"type,omitempty"`
Value *float64 `json:"value,omitempty"`
}
type BatchUpdateRedeemCodesRequest struct {
IDs []int64 `json:"ids" binding:"required,min=1"`
Fields BatchUpdateRedeemCodeFields `json:"fields" binding:"required"`
}
// UsageLog 是普通用户接口使用的 usage log DTO不包含管理员字段
type UsageLog struct {
ID int64 `json:"id"`

View File

@ -785,8 +785,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if channelMapping.Mapped {
parsedReq.Model = channelMapping.MappedModel
parsedReq.Body = h.gatewayService.ReplaceModelInBody(parsedReq.Body, channelMapping.MappedModel)
body = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
// Bedrock CC 兼容:渠道模型映射后,清理 Anthropic API 专有字段、注入 Bedrock 必需字段
parsedReq.Body = h.gatewayService.ApplyBedrockCCCompat(c.Request.Context(), parsedReq.Body, parsedReq.Model, account, apiKey.GroupID)
body = parsedReq.Body
// 转发请求 - 根据账号平台分流
c.Set("parsed_request", parsedReq)

View File

@ -179,12 +179,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
result, err := func() (*service.OpenAIForwardResult, error) {
defer func() {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
}()
return h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
}()
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
}
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
@ -236,6 +240,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
return
}
switchCount++
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),

View File

@ -333,11 +333,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
result, err := func() (*service.OpenAIForwardResult, error) {
defer func() {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
}()
return h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
}()
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
}
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
@ -389,6 +393,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
switchCount++
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
reqLog.Warn("openai.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),
@ -722,12 +730,16 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
if channelMappingMsg.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMappingMsg.MappedModel)
}
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
result, err := func() (*service.OpenAIForwardResult, error) {
defer func() {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
}()
return h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
}()
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
}
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
@ -775,6 +787,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
return
}
switchCount++
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
return
}
reqLog.Warn("openai_messages.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),

View File

@ -195,11 +195,15 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
result, err := h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel)
result, err := func() (*service.OpenAIForwardResult, error) {
defer func() {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
}()
return h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel)
}()
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
}
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
@ -217,6 +221,18 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
zap.Error(err),
)
} else {
var imageUpstreamErr *service.OpenAIImagesUpstreamError
if errors.As(err, &imageUpstreamErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
reqLog.Warn("openai.images.upstream_user_error",
zap.Int64("account_id", account.ID),
zap.Int("status_code", imageUpstreamErr.StatusCode),
zap.String("error_type", imageUpstreamErr.ErrorType),
zap.String("error_code", imageUpstreamErr.Code),
zap.Error(err),
)
return
}
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
@ -246,6 +262,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
return
}
switchCount++
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
h.handleFailoverExhausted(c, failoverErr, streamStarted)
return
}
reqLog.Warn("openai.images.upstream_failover_switching",
zap.Int64("account_id", account.ID),
zap.Int("upstream_status", failoverErr.StatusCode),

View File

@ -41,13 +41,18 @@ const (
opsErrInsufficientQuota = "insufficient_quota"
// 上游错误码常量 — 错误分类 (normalizeOpsErrorType / classifyOpsPhase / classifyOpsIsBusinessLimited)
opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE"
opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED"
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
opsCodeUserInactive = "USER_INACTIVE"
opsCodeInvalidAPIKey = "INVALID_API_KEY"
opsCodeAPIKeyRequired = "API_KEY_REQUIRED"
opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE"
opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED"
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
opsCodeUserInactive = "USER_INACTIVE"
opsCodeInvalidAPIKey = "INVALID_API_KEY"
opsCodeAPIKeyRequired = "API_KEY_REQUIRED"
opsCodeAPIKeyExpired = "API_KEY_EXPIRED"
opsCodeAPIKeyDisabled = "API_KEY_DISABLED"
opsCodeUserNotFound = "USER_NOT_FOUND"
opsCodeAPIKeyQuotaExhausted = "API_KEY_QUOTA_EXHAUSTED"
opsCodeAPIKeyQueryDeprecated = "api_key_in_query_deprecated"
)
const (
@ -1091,8 +1096,7 @@ func classifyOpsPhase(errType, message, code string) string {
if isOpsClientAuthError(code, msg) {
return "auth"
}
switch strings.TrimSpace(code) {
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
if isOpsLocalBusinessLimitError(code, msg) {
return "request"
}
@ -1151,8 +1155,10 @@ func classifyOpsErrorLog(c *gin.Context, errType, message, code string, status i
if routingCapacityLimited {
phase = "routing"
}
localClientAuthError := !upstreamError && phase == "auth" && isOpsClientAuthError(code, strings.ToLower(message))
isBusinessLimited = routingCapacityLimited || clientBusinessLimited || classifyOpsIsBusinessLimited(errType, phase, code, status, message, localClientAuthError)
msg := strings.ToLower(message)
localClientAuthError := !upstreamError && phase == "auth" && isOpsClientAuthError(code, msg)
localBusinessLimited := !upstreamError && classifyOpsIsBusinessLimited(errType, phase, code, status, message, localClientAuthError)
isBusinessLimited = routingCapacityLimited || (clientBusinessLimited && !upstreamError) || localBusinessLimited
errorOwner = classifyOpsErrorOwner(phase, message)
errorSource = classifyOpsErrorSource(phase, message)
return phase, isBusinessLimited, errorOwner, errorSource
@ -1162,8 +1168,7 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa
if len(localClientAuthError) > 0 && localClientAuthError[0] {
return true
}
switch strings.TrimSpace(code) {
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
if isOpsLocalBusinessLimitError(code, strings.ToLower(message)) {
return true
}
if phase == "billing" || phase == "concurrency" {
@ -1180,10 +1185,45 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa
func isOpsClientAuthError(code string, msg string) bool {
switch strings.TrimSpace(code) {
case opsCodeInvalidAPIKey, opsCodeAPIKeyRequired:
case opsCodeInvalidAPIKey,
opsCodeAPIKeyRequired,
opsCodeAPIKeyExpired,
opsCodeAPIKeyDisabled,
opsCodeUserNotFound,
opsCodeUserInactive:
return true
}
return strings.Contains(msg, "invalid api key") || strings.Contains(msg, "api key is required")
return strings.Contains(msg, "invalid api key") ||
strings.Contains(msg, "api key is required") ||
strings.Contains(msg, "api key is disabled") ||
strings.Contains(msg, "user associated with api key not found") ||
strings.Contains(msg, "user account is not active")
}
func isOpsLocalBusinessLimitError(code string, msg string) bool {
switch strings.TrimSpace(code) {
case opsCodeInsufficientBalance,
opsCodeUsageLimitExceeded,
opsCodeSubscriptionNotFound,
opsCodeSubscriptionInvalid,
opsCodeAPIKeyQuotaExhausted,
opsCodeAPIKeyQueryDeprecated:
return true
}
return strings.Contains(msg, "api key in query parameter is deprecated") ||
strings.Contains(msg, "query parameter api_key is deprecated") ||
strings.Contains(msg, "no active subscription found for this group") ||
strings.Contains(msg, opsErrInsufficientBalance) ||
strings.Contains(msg, "insufficient account balance") ||
strings.Contains(msg, "api key group platform is not gemini") ||
strings.Contains(msg, "api key 额度已用完") ||
strings.Contains(msg, "api key 5小时限额已用完") ||
strings.Contains(msg, "api key 日限额已用完") ||
strings.Contains(msg, "api key 7天限额已用完") ||
strings.Contains(msg, "daily usage limit exceeded") ||
strings.Contains(msg, "weekly usage limit exceeded") ||
strings.Contains(msg, "monthly usage limit exceeded") ||
strings.Contains(msg, "requests-per-minute limit exceeded")
}
func hasOpsUpstreamErrorContext(c *gin.Context) bool {

View File

@ -260,6 +260,34 @@ func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) {
code: "API_KEY_REQUIRED",
status: http.StatusUnauthorized,
},
{
name: "expired local API key",
errType: "api_error",
message: "API key 已过期",
code: "API_KEY_EXPIRED",
status: http.StatusForbidden,
},
{
name: "disabled local API key",
errType: "api_error",
message: "API key is disabled",
code: "API_KEY_DISABLED",
status: http.StatusUnauthorized,
},
{
name: "local API key user missing",
errType: "api_error",
message: "User associated with API key not found",
code: "USER_NOT_FOUND",
status: http.StatusUnauthorized,
},
{
name: "inactive local API key user",
errType: "api_error",
message: "User account is not active",
code: "USER_INACTIVE",
status: http.StatusUnauthorized,
},
{
name: "google invalid API key",
errType: "api_error",
@ -274,6 +302,27 @@ func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) {
code: "401",
status: http.StatusUnauthorized,
},
{
name: "google disabled API key",
errType: "api_error",
message: "API key is disabled",
code: "401",
status: http.StatusUnauthorized,
},
{
name: "google local API key user missing",
errType: "api_error",
message: "User associated with API key not found",
code: "401",
status: http.StatusUnauthorized,
},
{
name: "google inactive local API key user",
errType: "api_error",
message: "User account is not active",
code: "401",
status: http.StatusUnauthorized,
},
}
for _, tt := range tests {
@ -294,6 +343,126 @@ func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) {
}
}
func TestClassifyOpsLocalBusinessLimitErrorsExcludedFromSLA(t *testing.T) {
tests := []struct {
name string
errType string
message string
code string
status int
wantErrType string
wantPhase string
}{
{
name: "standard API key quota exhausted",
errType: "api_error",
message: "API key 额度已用完",
code: "API_KEY_QUOTA_EXHAUSTED",
status: http.StatusTooManyRequests,
wantErrType: "api_error",
wantPhase: "request",
},
{
name: "standard query API key deprecated",
errType: "api_error",
message: "API key in query parameter is deprecated. Please use Authorization header instead.",
code: "api_key_in_query_deprecated",
status: http.StatusBadRequest,
wantErrType: "api_error",
wantPhase: "request",
},
{
name: "google query API key deprecated",
errType: "api_error",
message: "Query parameter api_key is deprecated. Use Authorization header or key instead.",
code: "400",
status: http.StatusBadRequest,
wantErrType: "api_error",
wantPhase: "request",
},
{
name: "google no active subscription",
errType: "api_error",
message: "No active subscription found for this group",
code: "403",
status: http.StatusForbidden,
wantErrType: "api_error",
wantPhase: "request",
},
{
name: "google insufficient account balance",
errType: "api_error",
message: "Insufficient account balance",
code: "403",
status: http.StatusForbidden,
wantErrType: "api_error",
wantPhase: "request",
},
{
name: "gateway billing cache insufficient balance",
errType: "billing_error",
message: "insufficient balance",
code: "",
status: http.StatusForbidden,
wantErrType: "billing_error",
wantPhase: "request",
},
{
name: "gemini group platform mismatch",
errType: "api_error",
message: "API key group platform is not gemini",
code: "400",
status: http.StatusBadRequest,
wantErrType: "api_error",
wantPhase: "request",
},
{
name: "gateway API key 5h rate limit",
errType: "api_error",
message: "api key 5小时限额已用完",
code: "rate_limit_exceeded",
status: http.StatusTooManyRequests,
wantErrType: "api_error",
wantPhase: "request",
},
{
name: "gateway group RPM limit",
errType: "api_error",
message: "group requests-per-minute limit exceeded",
code: "rate_limit_exceeded",
status: http.StatusTooManyRequests,
wantErrType: "api_error",
wantPhase: "request",
},
{
name: "google subscription daily limit",
errType: "api_error",
message: "daily usage limit exceeded",
code: "429",
status: http.StatusTooManyRequests,
wantErrType: "api_error",
wantPhase: "request",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
errType := normalizeOpsErrorType(tt.errType, tt.code)
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, tt.message, tt.code, tt.status)
require.Equal(t, tt.wantErrType, errType)
require.Equal(t, tt.wantPhase, phase)
require.True(t, isBusinessLimited)
require.Equal(t, "client", errorOwner)
require.Equal(t, "client_request", errorSource)
})
}
}
func TestClassifyOpsIPRestrictionAccessDeniedExcludedFromSLA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
@ -372,23 +541,71 @@ func TestClassifyOpsUnmarkedNoAvailableTextStillCountsForSLA(t *testing.T) {
}
func TestClassifyOpsUpstreamAuthTextStillCountsForSLA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
service.SetOpsUpstreamError(c, http.StatusUnauthorized, "Invalid API key", "")
tests := []struct {
name string
message string
code string
status int
}{
{
name: "invalid API key",
message: "Invalid API key",
code: "401",
status: http.StatusUnauthorized,
},
{
name: "disabled API key",
message: "API key is disabled",
code: "API_KEY_DISABLED",
status: http.StatusUnauthorized,
},
{
name: "gemini group platform mismatch",
message: "API key group platform is not gemini",
code: "400",
status: http.StatusBadRequest,
},
{
name: "provider balance error",
message: "Insufficient account balance",
code: "INSUFFICIENT_BALANCE",
status: http.StatusForbidden,
},
{
name: "provider subscription error",
message: "No active subscription found for this group",
code: "SUBSCRIPTION_NOT_FOUND",
status: http.StatusForbidden,
},
{
name: "provider quota error",
message: "api key 额度已用完",
code: "API_KEY_QUOTA_EXHAUSTED",
status: http.StatusTooManyRequests,
},
}
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
c,
"api_error",
"Invalid API key",
"401",
http.StatusUnauthorized,
)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
service.SetOpsUpstreamError(c, tt.status, tt.message, "")
require.Equal(t, "upstream", phase)
require.False(t, isBusinessLimited)
require.Equal(t, "provider", errorOwner)
require.Equal(t, "upstream_http", errorSource)
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
c,
"api_error",
tt.message,
tt.code,
tt.status,
)
require.Equal(t, "upstream", phase)
require.False(t, isBusinessLimited)
require.Equal(t, "provider", errorOwner)
require.Equal(t, "upstream_http", errorSource)
})
}
}
func TestClassifyOpsUpstreamNoAvailableTextStillCountsForSLA(t *testing.T) {

View File

@ -89,7 +89,7 @@ func responsesInputToChatMessages(instructions string, inputRaw json.RawMessage)
return nil, fmt.Errorf("parse responses input item: %w", err)
}
role := rawString(item["role"])
role := chatCompletionsBridgeRole(rawString(item["role"]))
itemType := rawString(item["type"])
switch itemType {
case "function_call":
@ -130,9 +130,6 @@ func responsesInputToChatMessages(instructions string, inputRaw json.RawMessage)
continue
}
if role == "" {
role = "user"
}
content := item["content"]
if len(bytesTrimSpace(content)) == 0 {
if text := rawString(item["text"]); text != "" {
@ -152,6 +149,17 @@ func responsesInputToChatMessages(instructions string, inputRaw json.RawMessage)
return messages, nil
}
func chatCompletionsBridgeRole(role string) string {
trimmed := strings.TrimSpace(role)
if trimmed == "" {
return "user"
}
if strings.EqualFold(trimmed, "developer") {
return "system"
}
return role
}
func responsesContentToChatContent(raw json.RawMessage, role string) (json.RawMessage, error) {
raw = bytesTrimSpace(raw)
if len(raw) == 0 || string(raw) == "null" {

View File

@ -0,0 +1,82 @@
package apicompat
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestResponsesInputToChatMessages_DeveloperRoleMapsToSystem(t *testing.T) {
messages, err := responsesInputToChatMessages("", json.RawMessage(`[{"role":"developer","content":"follow project instructions"}]`))
require.NoError(t, err)
require.Len(t, messages, 1)
assert.Equal(t, "system", messages[0].Role)
assert.JSONEq(t, `"follow project instructions"`, string(messages[0].Content))
}
func TestResponsesInputToChatMessages_KeepsChatCompletionRoles(t *testing.T) {
input := json.RawMessage(`[
{"role":"system","content":"system message"},
{"role":"user","content":"user message"},
{"role":"assistant","content":"assistant message"},
{"role":"tool","content":"tool message"}
]`)
messages, err := responsesInputToChatMessages("", input)
require.NoError(t, err)
require.Len(t, messages, 4)
assert.Equal(t, []string{"system", "user", "assistant", "tool"}, chatMessageRoles(messages))
}
func TestResponsesInputToChatMessages_EmptyRoleFallsBackToUser(t *testing.T) {
messages, err := responsesInputToChatMessages("", json.RawMessage(`[{"role":"","content":"hello"}]`))
require.NoError(t, err)
require.Len(t, messages, 1)
assert.Equal(t, "user", messages[0].Role)
}
func TestResponsesInputToChatMessages_DeveloperRoleTrimAndCaseInsensitive(t *testing.T) {
input := json.RawMessage(`[
{"role":" Developer ","content":"one"},
{"role":"\tDEVELOPER\n","content":"two"}
]`)
messages, err := responsesInputToChatMessages("", input)
require.NoError(t, err)
require.Len(t, messages, 2)
assert.Equal(t, []string{"system", "system"}, chatMessageRoles(messages))
}
func TestResponsesToChatCompletionsRequest_InstructionsAndInputDeveloperRole(t *testing.T) {
req := &ResponsesRequest{
Model: "gpt-4o",
Instructions: "Use concise answers.",
Input: json.RawMessage(`[
{"role":"developer","content":[{"type":"input_text","text":"Prefer JSON."}]},
{"role":"user","content":"Hello"}
]`),
}
out, err := ResponsesToChatCompletionsRequest(req)
require.NoError(t, err)
require.Len(t, out.Messages, 3)
assert.Equal(t, []string{"system", "system", "user"}, chatMessageRoles(out.Messages))
assert.JSONEq(t, `"Use concise answers."`, string(out.Messages[0].Content))
assert.JSONEq(t, `"Prefer JSON."`, string(out.Messages[1].Content))
assert.JSONEq(t, `"Hello"`, string(out.Messages[2].Content))
}
func chatMessageRoles(messages []ChatMessage) []string {
roles := make([]string, 0, len(messages))
for _, message := range messages {
roles = append(roles, message.Role)
}
return roles
}

View File

@ -325,6 +325,10 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
if account == nil {
return nil
}
schedulable := account.Schedulable
if account.Status == service.StatusError {
schedulable = false
}
builder := r.client.Account.UpdateOneID(account.ID).
SetName(account.Name).
@ -337,7 +341,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
SetPriority(account.Priority).
SetStatus(account.Status).
SetErrorMessage(account.ErrorMessage).
SetSchedulable(account.Schedulable).
SetSchedulable(schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.RateMultiplier != nil {
@ -458,6 +462,7 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
return err
}
}
r.deleteSchedulerAccountSnapshot(ctx, id)
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
}
@ -724,6 +729,7 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
Where(dbaccount.IDEQ(id)).
SetStatus(service.StatusError).
SetErrorMessage(errorMsg).
SetSchedulable(false).
Save(ctx)
if err != nil {
return err
@ -757,6 +763,15 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
}
}
func (r *accountRepository) deleteSchedulerAccountSnapshot(ctx context.Context, accountID int64) {
if r == nil || r.schedulerCache == nil || accountID <= 0 {
return
}
if err := r.schedulerCache.DeleteAccount(ctx, accountID); err != nil {
logger.LegacyPrintf("repository.account", "[Scheduler] delete account snapshot failed: id=%d err=%v", accountID, err)
}
}
func (r *accountRepository) syncSchedulerAccountSnapshots(ctx context.Context, accountIDs []int64) {
if r == nil || r.schedulerCache == nil || len(accountIDs) == 0 {
return

View File

@ -23,6 +23,7 @@ type AccountRepoSuite struct {
type schedulerCacheRecorder struct {
setAccounts []*service.Account
deleteIDs []int64
accounts map[int64]*service.Account
}
@ -53,6 +54,10 @@ func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *servic
}
func (s *schedulerCacheRecorder) DeleteAccount(ctx context.Context, accountID int64) error {
s.deleteIDs = append(s.deleteIDs, accountID)
if s.accounts != nil {
delete(s.accounts, accountID)
}
return nil
}
@ -185,6 +190,27 @@ func (s *AccountRepoSuite) TestDelete() {
s.Require().Error(err, "expected error after delete")
}
func (s *AccountRepoSuite) TestDelete_RemovesSchedulerAccountSnapshot() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete-cache"})
cacheRecorder := &schedulerCacheRecorder{
accounts: map[int64]*service.Account{
account.ID: {
ID: account.ID,
Name: account.Name,
Status: service.StatusActive,
Schedulable: true,
},
},
}
s.repo.schedulerCache = cacheRecorder
err := s.repo.Delete(s.ctx, account.ID)
s.Require().NoError(err, "Delete")
s.Require().Equal([]int64{account.ID}, cacheRecorder.deleteIDs)
s.Require().NotContains(cacheRecorder.accounts, account.ID)
}
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-del"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-del"})
@ -729,7 +755,7 @@ func (s *AccountRepoSuite) TestUpdateLastUsed() {
// --- SetError ---
func (s *AccountRepoSuite) TestSetError() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-err", Status: service.StatusActive})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-err", Status: service.StatusActive, Schedulable: true})
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
@ -737,6 +763,22 @@ func (s *AccountRepoSuite) TestSetError() {
s.Require().NoError(err)
s.Require().Equal(service.StatusError, got.Status)
s.Require().Equal("something went wrong", got.ErrorMessage)
s.Require().False(got.Schedulable)
}
func (s *AccountRepoSuite) TestUpdateErrorStatusUnschedulesAccount() {
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-update-err", Status: service.StatusActive, Schedulable: true})
account.Status = service.StatusError
account.ErrorMessage = "token revoked"
account.Schedulable = true
s.Require().NoError(s.repo.Update(s.ctx, account))
got, err := s.repo.GetByID(s.ctx, account.ID)
s.Require().NoError(err)
s.Require().Equal(service.StatusError, got.Status)
s.Require().Equal("token revoked", got.ErrorMessage)
s.Require().False(got.Schedulable)
}
func (s *AccountRepoSuite) TestClearError_SyncSchedulerSnapshotOnRecovery() {

View File

@ -236,6 +236,91 @@ func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemC
return nil
}
func (r *redeemCodeRepository) BatchUpdate(ctx context.Context, ids []int64, fields service.RedeemCodeBatchUpdateFields) (int64, error) {
uniqueIDs := make([]int64, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
uniqueIDs = append(uniqueIDs, id)
}
if len(uniqueIDs) == 0 {
return 0, nil
}
if tx := dbent.TxFromContext(ctx); tx != nil {
return r.batchUpdate(ctx, tx.Client(), uniqueIDs, fields)
}
tx, err := r.client.Tx(ctx)
if err != nil {
return 0, err
}
txCtx := dbent.NewTxContext(ctx, tx)
defer func() { _ = tx.Rollback() }()
updated, err := r.batchUpdate(txCtx, tx.Client(), uniqueIDs, fields)
if err != nil {
return 0, err
}
if err := tx.Commit(); err != nil {
return 0, err
}
return updated, nil
}
func (r *redeemCodeRepository) batchUpdate(ctx context.Context, client *dbent.Client, ids []int64, fields service.RedeemCodeBatchUpdateFields) (int64, error) {
existing, err := client.RedeemCode.Query().
Where(redeemcode.IDIn(ids...)).
All(ctx)
if err != nil {
return 0, err
}
if len(existing) != len(ids) {
return 0, service.ErrRedeemCodeNotFound
}
if fields.TouchesUsedSensitiveFields() {
for _, code := range existing {
if code.Status == service.StatusUsed {
return 0, service.ErrRedeemCodeUsed
}
}
}
up := client.RedeemCode.Update().Where(redeemcode.IDIn(ids...))
if fields.Status != nil {
up.SetStatus(*fields.Status)
}
if fields.Notes != nil {
up.SetNotes(*fields.Notes)
}
if fields.ExpiresAt.Set {
if fields.ExpiresAt.Value != nil {
up.SetExpiresAt(*fields.ExpiresAt.Value)
} else {
up.ClearExpiresAt()
}
}
if fields.GroupID.Set {
if fields.GroupID.Value != nil {
up.SetGroupID(*fields.GroupID.Value)
} else {
up.ClearGroupID()
}
}
affected, err := up.Save(ctx)
if err != nil {
return 0, err
}
if affected != len(ids) {
return 0, service.ErrRedeemCodeNotFound
}
return int64(affected), nil
}
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
now := time.Now()
client := clientFromContext(ctx, r.client)

View File

@ -4,6 +4,7 @@ package repository
import (
"context"
"errors"
"testing"
"time"
@ -21,8 +22,8 @@ type RedeemCodeRepoSuite struct {
}
func (s *RedeemCodeRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.ctx = dbent.NewTxContext(context.Background(), tx)
s.client = tx.Client()
s.repo = NewRedeemCodeRepository(s.client).(*redeemCodeRepository)
}
@ -237,6 +238,123 @@ func (s *RedeemCodeRepoSuite) TestUpdate() {
s.Require().Equal(float64(50), got.Value)
}
func (s *RedeemCodeRepoSuite) TestBatchUpdate_PartialFieldsAndClear() {
group := s.createGroup(uniqueTestValue(s.T(), "batch-update-group"))
groupID := group.ID
expiresAt := time.Now().UTC().Add(2 * time.Hour)
status := service.StatusDisabled
notes := "batch note"
codeA := &service.RedeemCode{
Code: "BATCH-UP-A",
Type: service.RedeemTypeBalance,
Value: 10,
Status: service.StatusUnused,
Notes: "old",
}
codeB := &service.RedeemCode{
Code: "BATCH-UP-B",
Type: service.RedeemTypeBalance,
Value: 20,
Status: service.StatusUnused,
Notes: "old",
}
untouched := &service.RedeemCode{
Code: "BATCH-UP-C",
Type: service.RedeemTypeBalance,
Value: 30,
Status: service.StatusUnused,
Notes: "keep",
}
s.Require().NoError(s.repo.Create(s.ctx, codeA))
s.Require().NoError(s.repo.Create(s.ctx, codeB))
s.Require().NoError(s.repo.Create(s.ctx, untouched))
updated, err := s.repo.BatchUpdate(s.ctx, []int64{codeA.ID, codeB.ID}, service.RedeemCodeBatchUpdateFields{
Status: &status,
ExpiresAt: service.NullableTimeUpdate{Set: true, Value: &expiresAt},
Notes: &notes,
GroupID: service.NullableInt64Update{Set: true, Value: &groupID},
})
s.Require().NoError(err)
s.Require().Equal(int64(2), updated)
gotA, err := s.repo.GetByID(s.ctx, codeA.ID)
s.Require().NoError(err)
s.Require().Equal("BATCH-UP-A", gotA.Code)
s.Require().Equal(service.RedeemTypeBalance, gotA.Type)
s.Require().Equal(float64(10), gotA.Value)
s.Require().Equal(service.StatusDisabled, gotA.Status)
s.Require().Equal(notes, gotA.Notes)
s.Require().NotNil(gotA.ExpiresAt)
s.Require().WithinDuration(expiresAt, *gotA.ExpiresAt, time.Second)
s.Require().NotNil(gotA.GroupID)
s.Require().Equal(groupID, *gotA.GroupID)
gotB, err := s.repo.GetByID(s.ctx, codeB.ID)
s.Require().NoError(err)
s.Require().Equal(service.StatusDisabled, gotB.Status)
s.Require().Equal(notes, gotB.Notes)
gotUntouched, err := s.repo.GetByID(s.ctx, untouched.ID)
s.Require().NoError(err)
s.Require().Equal(service.StatusUnused, gotUntouched.Status)
s.Require().Equal("keep", gotUntouched.Notes)
s.Require().Nil(gotUntouched.ExpiresAt)
s.Require().Nil(gotUntouched.GroupID)
updated, err = s.repo.BatchUpdate(s.ctx, []int64{codeA.ID}, service.RedeemCodeBatchUpdateFields{
ExpiresAt: service.NullableTimeUpdate{Set: true},
GroupID: service.NullableInt64Update{Set: true},
})
s.Require().NoError(err)
s.Require().Equal(int64(1), updated)
gotA, err = s.repo.GetByID(s.ctx, codeA.ID)
s.Require().NoError(err)
s.Require().Nil(gotA.ExpiresAt)
s.Require().Nil(gotA.GroupID)
}
func (s *RedeemCodeRepoSuite) TestBatchUpdate_InvalidIDRollsBack() {
code := &service.RedeemCode{
Code: "BATCH-UP-ROLLBACK",
Type: service.RedeemTypeBalance,
Value: 10,
Status: service.StatusUnused,
Notes: "keep",
}
s.Require().NoError(s.repo.Create(s.ctx, code))
notes := "changed"
_, err := s.repo.BatchUpdate(s.ctx, []int64{code.ID, 999999}, service.RedeemCodeBatchUpdateFields{Notes: &notes})
s.Require().Error(err)
s.Require().True(errors.Is(err, service.ErrRedeemCodeNotFound))
got, getErr := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(getErr)
s.Require().Equal("keep", got.Notes)
}
func (s *RedeemCodeRepoSuite) TestBatchUpdate_UsedCodeRejectsSensitiveFields() {
code := &service.RedeemCode{
Code: "BATCH-UP-USED",
Type: service.RedeemTypeBalance,
Value: 10,
Status: service.StatusUsed,
}
s.Require().NoError(s.repo.Create(s.ctx, code))
status := service.StatusDisabled
_, err := s.repo.BatchUpdate(s.ctx, []int64{code.ID}, service.RedeemCodeBatchUpdateFields{Status: &status})
s.Require().Error(err)
s.Require().True(errors.Is(err, service.ErrRedeemCodeUsed))
got, getErr := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(getErr)
s.Require().Equal(service.StatusUsed, got.Status)
}
// --- Use ---
func (s *RedeemCodeRepoSuite) TestUse() {

View File

@ -757,6 +757,7 @@ func TestAPIContracts(t *testing.T) {
"site_logo": "",
"site_subtitle": "Subtitle",
"api_base_url": "https://api.example.com",
"api_key_acl_trust_forwarded_ip": false,
"contact_info": "support",
"doc_url": "https://docs.example.com",
"auth_source_default_email_balance": 0,
@ -862,6 +863,7 @@ func TestAPIContracts(t *testing.T) {
"payment_alipay_force_qrcode": false,
"balance_low_notify_enabled": false,
"account_quota_notify_enabled": false,
"subscription_expiry_notify_enabled": true,
"balance_low_notify_threshold": 0,
"balance_low_notify_recharge_url": "",
"account_quota_notify_emails": [],
@ -1014,6 +1016,7 @@ func TestAPIContracts(t *testing.T) {
"site_logo": "",
"site_subtitle": "Subscription to API Conversion Platform",
"api_base_url": "",
"api_key_acl_trust_forwarded_ip": false,
"contact_info": "",
"doc_url": "",
"home_content": "",
@ -1086,6 +1089,7 @@ func TestAPIContracts(t *testing.T) {
"payment_alipay_force_qrcode": false,
"balance_low_notify_enabled": false,
"account_quota_notify_enabled": false,
"subscription_expiry_notify_enabled": true,
"balance_low_notify_threshold": 0,
"balance_low_notify_recharge_url": "",
"account_quota_notify_emails": [],
@ -1254,7 +1258,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
@ -1852,6 +1856,10 @@ func (stubRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode)
return errors.New("not implemented")
}
func (stubRedeemCodeRepo) BatchUpdate(ctx context.Context, ids []int64, fields service.RedeemCodeBatchUpdateFields) (int64, error) {
return int64(len(ids)), nil
}
func (stubRedeemCodeRepo) Delete(ctx context.Context, id int64) error {
return errors.New("not implemented")
}

View File

@ -18,7 +18,6 @@ import (
"github.com/google/wire"
"github.com/redis/go-redis/v9"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
// ProviderSet 提供服务器层的依赖
@ -102,6 +101,16 @@ func ProvideRouter(
// ProvideHTTPServer 提供 HTTP 服务器
func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
httpHandler := http.Handler(router)
server := &http.Server{
Addr: cfg.Server.Address(),
Handler: httpHandler,
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
// 注意:不设置 WriteTimeout因为流式响应可能持续十几分钟
// 不设置 ReadTimeout因为大请求体可能需要较长时间读取
}
globalMaxSize := cfg.Server.MaxRequestBodySize
if globalMaxSize <= 0 {
@ -115,32 +124,31 @@ func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
// 根据配置决定是否启用 H2C
if cfg.Server.H2C.Enabled {
h2cConfig := cfg.Server.H2C
httpHandler = h2c.NewHandler(router, &http2.Server{
if err := http2.ConfigureServer(server, &http2.Server{
MaxConcurrentStreams: h2cConfig.MaxConcurrentStreams,
IdleTimeout: time.Duration(h2cConfig.IdleTimeout) * time.Second,
MaxReadFrameSize: uint32(h2cConfig.MaxReadFrameSize),
MaxUploadBufferPerConnection: int32(h2cConfig.MaxUploadBufferPerConnection),
MaxUploadBufferPerStream: int32(h2cConfig.MaxUploadBufferPerStream),
})
log.Printf("HTTP/2 Cleartext (h2c) enabled: max_concurrent_streams=%d, idle_timeout=%ds, max_read_frame_size=%d, max_upload_buffer_per_connection=%d, max_upload_buffer_per_stream=%d",
h2cConfig.MaxConcurrentStreams,
h2cConfig.IdleTimeout,
h2cConfig.MaxReadFrameSize,
h2cConfig.MaxUploadBufferPerConnection,
h2cConfig.MaxUploadBufferPerStream,
)
}); err != nil {
log.Printf("Failed to configure HTTP/2 Cleartext (h2c): %v", err)
} else {
protocols := new(http.Protocols)
protocols.SetHTTP1(true)
protocols.SetUnencryptedHTTP2(true)
server.Protocols = protocols
log.Printf("HTTP/2 Cleartext (h2c) enabled: max_concurrent_streams=%d, idle_timeout=%ds, max_read_frame_size=%d, max_upload_buffer_per_connection=%d, max_upload_buffer_per_stream=%d",
h2cConfig.MaxConcurrentStreams,
h2cConfig.IdleTimeout,
h2cConfig.MaxReadFrameSize,
h2cConfig.MaxUploadBufferPerConnection,
h2cConfig.MaxUploadBufferPerStream,
)
}
}
return &http.Server{
Addr: cfg.Server.Address(),
Handler: httpHandler,
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
// 注意:不设置 WriteTimeout因为流式响应可能持续十几分钟
// 不设置 ReadTimeout因为大请求体可能需要较长时间读取
}
server.Handler = httpHandler
return server
}
func derefInt64(p *int64) int64 {

View File

@ -90,6 +90,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
clientIP := ip.GetTrustedClientIP(c)
if cfg.TrustForwardedIPForAPIKeyACL() {
clientIP = ip.GetClientIP(c)
}
allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist)
if !allowed {
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonIPRestriction)

View File

@ -398,7 +398,7 @@ func TestAPIKeyAuthRejectsUnavailableGroup(t *testing.T) {
}
}
func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) {
func TestAPIKeyAuthIPRestrictionDoesNotTrustForwardedClientIPByDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
user := &service.User{
@ -460,6 +460,57 @@ func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T)
require.Equal(t, service.OpsClientBusinessLimitedReasonIPRestriction, businessLimitedReason)
}
func TestAPIKeyAuthIPRestrictionCanTrustForwardedClientIPForReverseProxy(t *testing.T) {
gin.SetMode(gin.TestMode)
user := &service.User{
ID: 7,
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
Concurrency: 3,
}
apiKey := &service.APIKey{
ID: 100,
UserID: user.ID,
Key: "test-key",
Status: service.StatusActive,
User: user,
IPWhitelist: []string{"1.2.3.4"},
}
apiKeyRepo := &stubApiKeyRepo{
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
if key != apiKey.Key {
return nil, service.ErrAPIKeyNotFound
}
clone := *apiKey
return &clone, nil
},
}
cfg := &config.Config{RunMode: config.RunModeSimple}
cfg.SetTrustForwardedIPForAPIKeyACL(true)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New()
require.NoError(t, router.SetTrustedProxies(nil))
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/t", nil)
req.RemoteAddr = "9.9.9.9:12345"
req.Header.Set("x-api-key", apiKey.Key)
req.Header.Set("X-Forwarded-For", "1.2.3.4")
req.Header.Set("X-Real-IP", "1.2.3.4")
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code)
}
func TestAPIKeyAuthTouchesLastUsedOnSuccess(t *testing.T) {
gin.SetMode(gin.TestMode)

View File

@ -4,6 +4,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
@ -31,7 +32,7 @@ func Logger() gin.HandlerFunc {
method := c.Request.Method
statusCode := c.Writer.Status()
clientIP := c.ClientIP()
clientIP := ip.GetClientIP(c)
protocol := c.Request.Proto
accountID, hasAccountID := c.Request.Context().Value(ctxkey.AccountID).(int64)
platform, _ := c.Request.Context().Value(ctxkey.Platform).(string)

View File

@ -180,6 +180,37 @@ func TestLogger_AccessLogIncludesCoreFields(t *testing.T) {
}
}
func TestLogger_AccessLogUsesForwardedClientIP(t *testing.T) {
gin.SetMode(gin.TestMode)
sink := initMiddlewareTestLogger(t)
r := gin.New()
r.Use(Logger())
r.GET("/api/test", func(c *gin.Context) {
c.Status(http.StatusOK)
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
req.RemoteAddr = "104.23.251.120:443"
req.Header.Set("CF-Connecting-IP", "203.0.113.42")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status=%d", w.Code)
}
for _, event := range sink.list() {
if event == nil || event.Message != "http request completed" {
continue
}
if got := event.Fields["client_ip"]; got != "203.0.113.42" {
t.Fatalf("client_ip=%q, want real forwarded ip", got)
}
return
}
t.Fatalf("access log event not found")
}
func TestLogger_HealthPathSkipped(t *testing.T) {
gin.SetMode(gin.TestMode)
sink := initMiddlewareTestLogger(t)

View File

@ -396,6 +396,7 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
codes.POST("/generate", h.Admin.Redeem.Generate)
codes.DELETE("/:id", h.Admin.Redeem.Delete)
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
codes.POST("/batch-update", h.Admin.Redeem.BatchUpdate)
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
}
}

View File

@ -511,7 +511,6 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co
// testOpenAIAccountConnection tests an OpenAI account's connection
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string, mode string) error {
ctx := c.Request.Context()
_ = prompt
mode = normalizeAccountTestMode(mode)
// Default to openai.DefaultTestModel for OpenAI testing
@ -572,14 +571,8 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
}
// 账号已被探测为不支持 Responses如 DeepSeek/Kimi 等)时,丢出明确提示。
// 账号本身可用(网关会走 CC 直转),仅测试入口需要补齐 CC SSE 处理逻辑。
// TODO实现 CC 格式的账号测试路径(需专门的 CC SSE handler
if !openai_compat.ShouldUseResponsesAPI(account.Extra) {
return s.sendErrorAndEnd(c,
"账号已被探测为不支持 OpenAI Responses API如 DeepSeek/Kimi 等三方兼容上游),"+
"账号本身可正常使用,但当前测试接口仅支持 Responses API 路径。请直接通过实际 API 调用验证。",
)
return s.testOpenAIChatCompletionsConnection(c, account, testModelID, prompt, normalizedBaseURL, authToken)
}
apiURL = buildOpenAIResponsesURL(normalizedBaseURL)
} else {
@ -654,6 +647,65 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
return s.processOpenAIStream(c, resp.Body)
}
// testOpenAIChatCompletionsConnection tests an OpenAI-compatible APIKey account
// through the raw /v1/chat/completions endpoint.
func (s *AccountTestService) testOpenAIChatCompletionsConnection(
c *gin.Context,
account *Account,
testModelID string,
prompt string,
normalizedBaseURL string,
authToken string,
) error {
ctx := c.Request.Context()
apiURL := buildOpenAIChatCompletionsURL(normalizedBaseURL)
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.Flush()
payload := createOpenAIChatCompletionsTestPayload(testModelID, prompt)
payloadBytes, _ := json.Marshal(payload)
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
s.sendEvent(c, TestEvent{Type: "status", Text: "正在通过 /v1/chat/completions 测试连接"})
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(payloadBytes))
if err != nil {
return s.sendErrorAndEnd(c, "Failed to create Chat Completions request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Authorization", "Bearer "+authToken)
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
if err != nil {
return s.sendErrorAndEnd(c, fmt.Sprintf("Chat Completions API (/v1/chat/completions) request failed: %s", err.Error()))
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode == http.StatusTooManyRequests {
s.reconcileOpenAI429State(ctx, account, resp.Header, body)
}
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
errMsg := fmt.Sprintf("Chat Completions authentication failed (401): %s", string(body))
_ = s.accountRepo.SetError(ctx, account.ID, errMsg)
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Chat Completions API (/v1/chat/completions) returned %d: %s", resp.StatusCode, string(body)))
}
return s.processOpenAIChatCompletionsStream(c, resp.Body)
}
// testOpenAICompactConnection probes /responses/compact and persists the
// resulting capability state on the account.
func (s *AccountTestService) testOpenAICompactConnection(c *gin.Context, account *Account, testModelID string) error {
@ -1256,6 +1308,24 @@ func createOpenAITestPayload(modelID string, isOAuth bool, prompt string) map[st
return payload
}
func createOpenAIChatCompletionsTestPayload(modelID string, prompt string) map[string]any {
testPrompt := strings.TrimSpace(prompt)
if testPrompt == "" {
testPrompt = "hi"
}
return map[string]any{
"model": modelID,
"messages": []map[string]any{
{
"role": "user",
"content": testPrompt,
},
},
"stream": true,
}
}
// processClaudeStream processes the SSE stream from Claude API
func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader) error {
reader := bufio.NewReader(body)
@ -1310,6 +1380,82 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader)
}
}
// processOpenAIChatCompletionsStream processes SSE chunks from the
// OpenAI-compatible Chat Completions API.
func (s *AccountTestService) processOpenAIChatCompletionsStream(c *gin.Context, body io.Reader) error {
reader := bufio.NewReader(body)
seenJSON := false
seenFinish := false
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
if seenFinish {
s.sendEvent(c, TestEvent{Type: "status", Text: "已通过 /v1/chat/completions 验证"})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
if seenJSON {
return s.sendErrorAndEnd(c, "Chat Completions stream from /v1/chat/completions ended before [DONE]")
}
return s.sendErrorAndEnd(c, "Invalid Chat Completions response from /v1/chat/completions: expected SSE JSON data")
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Chat Completions stream read error from /v1/chat/completions: %s", err.Error()))
}
line = strings.TrimSpace(line)
if line == "" || !sseDataPrefix.MatchString(line) {
continue
}
jsonStr := sseDataPrefix.ReplaceAllString(line, "")
if jsonStr == "[DONE]" {
s.sendEvent(c, TestEvent{Type: "status", Text: "已通过 /v1/chat/completions 验证"})
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
return nil
}
var data map[string]any
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
return s.sendErrorAndEnd(c, "Invalid Chat Completions response from /v1/chat/completions: expected JSON data")
}
seenJSON = true
if errData, ok := data["error"].(map[string]any); ok {
errorMsg := "Chat Completions API (/v1/chat/completions) returned an error"
if msg, ok := errData["message"].(string); ok && msg != "" {
errorMsg = msg
}
return s.sendErrorAndEnd(c, fmt.Sprintf("Chat Completions API (/v1/chat/completions) error: %s", errorMsg))
}
choices, ok := data["choices"].([]any)
if !ok {
continue
}
for _, choiceValue := range choices {
choice, ok := choiceValue.(map[string]any)
if !ok {
continue
}
if delta, ok := choice["delta"].(map[string]any); ok {
if text, ok := delta["content"].(string); ok && text != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: text})
}
}
if message, ok := choice["message"].(map[string]any); ok {
if text, ok := message["content"].(string); ok && text != "" {
s.sendEvent(c, TestEvent{Type: "content", Text: text})
}
}
if finishReason, ok := choice["finish_reason"].(string); ok && finishReason != "" {
seenFinish = true
}
}
}
}
// processOpenAIStream processes the SSE stream from OpenAI Responses API
func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error {
reader := bufio.NewReader(body)

View File

@ -12,10 +12,12 @@ import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/tidwall/gjson"
)
// --- shared test helpers ---
@ -333,3 +335,145 @@ func TestAccountTestService_OpenAI401SetsPermanentErrorOnly(t *testing.T) {
require.Zero(t, repo.clearedErrorID)
require.Nil(t, account.RateLimitResetAt)
}
func TestAccountTestService_OpenAIAPIKeyResponsesUnsupportedUsesChatCompletionsPath(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newTestContext()
upstreamBody := strings.Join([]string{
`data: {"id":"chatcmpl_test","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"pong"},"finish_reason":null}]}`,
"",
`data: {"id":"chatcmpl_test","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
"",
"data: [DONE]",
"",
}, "\n")
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader(upstreamBody)),
}}
svc := &AccountTestService{
httpUpstream: upstream,
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
}
account := &Account{
ID: 91,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": "https://compat-upstream.example/v1",
},
Extra: map[string]any{openai_compat.ExtraKeyResponsesSupported: false},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "hello", "")
require.NoError(t, err)
require.NotNil(t, upstream.lastReq)
require.Equal(t, "https://compat-upstream.example/v1/chat/completions", upstream.lastReq.URL.String())
require.Equal(t, "Bearer sk-test", upstream.lastReq.Header.Get("Authorization"))
require.Equal(t, "text/event-stream", upstream.lastReq.Header.Get("Accept"))
require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
require.True(t, gjson.GetBytes(upstream.lastBody, "stream").Bool())
require.Equal(t, "hello", gjson.GetBytes(upstream.lastBody, "messages.0.content").String())
require.False(t, gjson.GetBytes(upstream.lastBody, "input").Exists())
body := recorder.Body.String()
require.Contains(t, body, "pong")
require.Contains(t, body, "已通过 /v1/chat/completions 验证")
require.Contains(t, body, `"success":true`)
require.NotContains(t, body, "当前测试接口仅支持 Responses API 路径")
}
func TestAccountTestService_OpenAIChatCompletionsPathReturns4xx(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newTestContext()
upstream := &httpUpstreamRecorder{resp: newJSONResponse(http.StatusBadRequest, `{"error":{"message":"bad request"}}`)}
svc := &AccountTestService{
httpUpstream: upstream,
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
}
account := &Account{
ID: 92,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": "https://compat-upstream.example",
},
Extra: map[string]any{openai_compat.ExtraKeyResponsesSupported: false},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Equal(t, "https://compat-upstream.example/v1/chat/completions", upstream.lastReq.URL.String())
require.Contains(t, err.Error(), "Chat Completions API (/v1/chat/completions) returned 400")
require.Contains(t, recorder.Body.String(), "/v1/chat/completions")
require.NotContains(t, recorder.Body.String(), `"success":true`)
}
func TestAccountTestService_OpenAIChatCompletionsPathTimeout(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newTestContext()
upstream := &httpUpstreamRecorder{err: context.DeadlineExceeded}
svc := &AccountTestService{
httpUpstream: upstream,
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
}
account := &Account{
ID: 93,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": "https://compat-upstream.example",
},
Extra: map[string]any{openai_compat.ExtraKeyResponsesSupported: false},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Equal(t, "https://compat-upstream.example/v1/chat/completions", upstream.lastReq.URL.String())
require.Contains(t, err.Error(), "Chat Completions API (/v1/chat/completions) request failed")
require.Contains(t, err.Error(), context.DeadlineExceeded.Error())
require.Contains(t, recorder.Body.String(), "/v1/chat/completions")
require.NotContains(t, recorder.Body.String(), `"success":true`)
}
func TestAccountTestService_OpenAIChatCompletionsPathRejectsNonJSONStream(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newTestContext()
upstream := &httpUpstreamRecorder{resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
Body: io.NopCloser(strings.NewReader("data: not-json\n\n")),
}}
svc := &AccountTestService{
httpUpstream: upstream,
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
}
account := &Account{
ID: 94,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "sk-test",
"base_url": "https://compat-upstream.example",
},
Extra: map[string]any{openai_compat.ExtraKeyResponsesSupported: false},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
require.Error(t, err)
require.Equal(t, "https://compat-upstream.example/v1/chat/completions", upstream.lastReq.URL.String())
require.Contains(t, err.Error(), "Invalid Chat Completions response from /v1/chat/completions")
require.Contains(t, recorder.Body.String(), "/v1/chat/completions")
require.NotContains(t, recorder.Body.String(), `"success":true`)
}

View File

@ -531,6 +531,7 @@ type adminServiceImpl struct {
defaultSubAssigner DefaultSubscriptionAssigner
userSubRepo UserSubscriptionRepository
privacyClientFactory PrivacyClientFactory
runtimeBlocker AccountRuntimeBlocker
}
type userGroupRateBatchReader interface {
@ -556,6 +557,7 @@ func NewAdminService(
defaultSubAssigner DefaultSubscriptionAssigner,
userSubRepo UserSubscriptionRepository,
privacyClientFactory PrivacyClientFactory,
runtimeBlocker AccountRuntimeBlocker,
) AdminService {
return &adminServiceImpl{
userRepo: userRepo,
@ -575,6 +577,7 @@ func NewAdminService(
defaultSubAssigner: defaultSubAssigner,
userSubRepo: userSubRepo,
privacyClientFactory: privacyClientFactory,
runtimeBlocker: runtimeBlocker,
}
}
@ -2791,6 +2794,9 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac
if err := s.accountRepo.ClearTempUnschedulable(ctx, id); err != nil {
return nil, err
}
if s.runtimeBlocker != nil {
s.runtimeBlocker.ClearAccountSchedulingBlock(id)
}
return s.accountRepo.GetByID(ctx, id)
}

View File

@ -70,7 +70,8 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
TempUnschedulableReason: "missing refresh token",
},
}
svc := &adminServiceImpl{accountRepo: repo}
blocker := &runtimeBlockRecorder{}
svc := &adminServiceImpl{accountRepo: repo, runtimeBlocker: blocker}
updated, err := svc.ClearAccountError(context.Background(), 31)
require.NoError(t, err)
@ -83,4 +84,5 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
require.Nil(t, updated.RateLimitResetAt)
require.Nil(t, updated.TempUnschedulableUntil)
require.Empty(t, updated.TempUnschedulableReason)
require.Equal(t, []int64{31}, blocker.clearedIDs)
}

View File

@ -325,6 +325,12 @@ func (s *proxyRepoStub) ListAccountSummariesByProxyID(ctx context.Context, proxy
type redeemRepoStub struct {
deleteErrByID map[int64]error
deletedIDs []int64
batchUpdateIDs []int64
batchUpdateFields RedeemCodeBatchUpdateFields
batchUpdateResult int64
batchUpdateErr error
batchUpdateCalled bool
}
func (s *redeemRepoStub) Create(ctx context.Context, code *RedeemCode) error {
@ -347,6 +353,19 @@ func (s *redeemRepoStub) Update(ctx context.Context, code *RedeemCode) error {
panic("unexpected Update call")
}
func (s *redeemRepoStub) BatchUpdate(ctx context.Context, ids []int64, fields RedeemCodeBatchUpdateFields) (int64, error) {
s.batchUpdateCalled = true
s.batchUpdateIDs = append([]int64(nil), ids...)
s.batchUpdateFields = fields
if s.batchUpdateErr != nil {
return 0, s.batchUpdateErr
}
if s.batchUpdateResult != 0 {
return s.batchUpdateResult, nil
}
return int64(len(ids)), nil
}
func (s *redeemRepoStub) Delete(ctx context.Context, id int64) error {
s.deletedIDs = append(s.deletedIDs, id)
if s.deleteErrByID != nil {

View File

@ -49,7 +49,7 @@ func (s *AuthService) loginOrRegisterVerifiedEmailOAuth(
}
providerType := normalizeOAuthSignupSource(input.ProviderType)
if providerType != "github" && providerType != "google" {
if providerType != "github" && providerType != "google" && providerType != "oidc" {
return nil, nil, infraerrors.BadRequest("OAUTH_PROVIDER_INVALID", "oauth provider is invalid")
}
providerKey := strings.TrimSpace(input.ProviderKey)

View File

@ -59,6 +59,10 @@ func (s *redeemCodeRepoStub) Update(_ context.Context, code *RedeemCode) error {
return nil
}
func (s *redeemCodeRepoStub) BatchUpdate(context.Context, []int64, RedeemCodeBatchUpdateFields) (int64, error) {
panic("unexpected BatchUpdate call")
}
func (s *redeemCodeRepoStub) Delete(context.Context, int64) error {
panic("unexpected Delete call")
}

View File

@ -9,12 +9,16 @@ import (
"strings"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const defaultBedrockRegion = "us-east-1"
// featureKeyBedrockCCCompat is the key used in Channel.FeaturesConfig for Bedrock CC compatibility.
const featureKeyBedrockCCCompat = "bedrock_cc_compat"
var bedrockCrossRegionPrefixes = []string{"us.", "eu.", "apac.", "jp.", "au.", "us-gov.", "global."}
// BedrockCrossRegionPrefix 根据 AWS Region 返回 Bedrock 跨区域推理的模型 ID 前缀
@ -179,13 +183,16 @@ func BuildBedrockURL(region, modelID string, stream bool) string {
// 3. 移除 Bedrock 不支持的字段model, stream, output_format, output_config
// 4. 移除工具定义中的 custom 字段Claude Code 会发送 custom: {defer_loading: true}
// 5. 清理 cache_control 中 Bedrock 不支持的字段scope, ttl
// 6. 修复 thinking 字段兼容性Opus 4.7 仅支持 adaptiveenabled 需要 budget_tokens
// 7. 清理 tool_use.id / tool_use_id 中 Bedrock 不接受的字符
func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ([]byte, error) {
betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID)
return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens)
return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens, false)
}
// PrepareBedrockRequestBodyWithTokens prepares a Bedrock request using pre-resolved beta tokens.
func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string) ([]byte, error) {
// ccCompat 启用 CC 兼容模式时额外处理 thinking 类型转换和 tool_use.id 清理。
func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string, ccCompat bool) ([]byte, error) {
var err error
// 注入 anthropic_versionBedrock 要求)
@ -203,6 +210,7 @@ func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens
if err != nil {
return nil, fmt.Errorf("inject anthropic_beta: %w", err)
}
logger.LegacyPrintf("service.gateway", "[Bedrock] Injected beta tokens: %v (model=%s ccCompat=%v)", betaTokens, modelID, ccCompat)
}
// 移除 model 字段Bedrock 通过 URL 指定模型)
@ -235,6 +243,12 @@ func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens
// 清理 cache_control 中 Bedrock 不支持的字段
body = sanitizeBedrockCacheControl(body, modelID)
// CC 兼容模式:修复 CC 发送的 Bedrock 不兼容字段
if ccCompat {
body = sanitizeBedrockThinking(body, modelID)
body = sanitizeBedrockToolUseIDs(body)
}
return body, nil
}
@ -444,17 +458,17 @@ func parseAnthropicBetaHeader(header string) []string {
}
// bedrockSupportedBetaTokens 是 Bedrock Invoke 支持的 beta 头白名单
// 参考: litellm/litellm/llms/bedrock/common_utils.py (anthropic_beta_headers_config.json)
// 参考: AWS Bedrock 官方文档 + litellm anthropic_beta_headers_config.json
// 更新策略: 当 AWS Bedrock 新增支持的 beta token 时需同步更新此白名单
var bedrockSupportedBetaTokens = map[string]bool{
"computer-use-2025-01-24": true,
"computer-use-2025-11-24": true,
"context-1m-2025-08-07": true,
"context-management-2025-06-27": true,
"compact-2026-01-12": true,
"interleaved-thinking-2025-05-14": true,
"tool-search-tool-2025-10-19": true,
"tool-examples-2025-10-29": true,
"computer-use-2025-01-24": true,
"computer-use-2025-11-24": true,
"context-1m-2025-08-07": true,
// "context-management-2025-06-27": false, // 无官方文档支持
"compact-2026-01-12": true, // 官方支持,仅 InvokeModel APIOpus 4.6+
// "interleaved-thinking-2025-05-14": false, // 无官方文档支持
"tool-search-tool-2025-10-19": true,
"tool-examples-2025-10-29": true,
}
// bedrockBetaTokenTransforms 定义 Bedrock Invoke 特有的 beta 头转换规则
@ -482,11 +496,8 @@ func autoInjectBedrockBetaTokens(tokens []string, body []byte, modelID string) [
}
}
// 检测 thinking / interleaved thinking
// 请求体中有 "thinking" 字段 → 需要 interleaved-thinking beta
if gjson.GetBytes(body, "thinking").Exists() {
inject("interleaved-thinking-2025-05-14")
}
// 注意thinking 字段不再自动注入 interleaved-thinking-2025-05-14
// 因为该 beta token 未在 AWS Bedrock 官方文档中确认支持
// 检测 computer_use 工具
// tools 中有 type="computer_20xxxxxx" 的工具 → 需要 computer-use beta
@ -605,3 +616,156 @@ func filterBedrockBetaTokens(tokens []string) []string {
return result
}
// bedrockToolUseIDRe 匹配 Bedrock 允许的 tool_use ID 字符(字母、数字、下划线、连字符)
var bedrockToolUseIDRe = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
// isBedrockOpus47OrNewer 判断 Bedrock 模型 ID 是否为 Claude Opus 4.7 或更新版本
// Opus 4.7 仅支持 thinking.type: "adaptive",不支持 "enabled"
func isBedrockOpus47OrNewer(modelID string) bool {
lower := strings.ToLower(modelID)
if !strings.Contains(lower, "opus") {
return false
}
matches := claudeVersionRe.FindStringSubmatch(lower)
if matches == nil {
return false
}
major, _ := strconv.Atoi(matches[1])
minor, _ := strconv.Atoi(matches[2])
return major > 4 || (major == 4 && minor >= 7)
}
const defaultThinkingBudgetTokens = 10000
// sanitizeBedrockThinking 修复 thinking 字段的 Bedrock 兼容性问题:
// - Opus 4.7+: 仅支持 "adaptive",将 "enabled" 转换为 "adaptive" 并移除 budget_tokens
// - 其他模型: "enabled" 必须带 budget_tokens缺失时补充默认值
func sanitizeBedrockThinking(body []byte, modelID string) []byte {
thinking := gjson.GetBytes(body, "thinking")
if !thinking.Exists() || !thinking.IsObject() {
return body
}
thinkingType := thinking.Get("type").String()
if thinkingType == "" {
return body
}
if isBedrockOpus47OrNewer(modelID) {
if thinkingType == "enabled" {
body, _ = sjson.SetBytes(body, "thinking.type", "adaptive")
body, _ = sjson.DeleteBytes(body, "thinking.budget_tokens")
}
return body
}
if thinkingType == "enabled" && !thinking.Get("budget_tokens").Exists() {
body, _ = sjson.SetBytes(body, "thinking.budget_tokens", defaultThinkingBudgetTokens)
}
return body
}
// sanitizeBedrockToolUseIDs 清理 messages 中 tool_use.id 和 tool_result.tool_use_id
// 的非法字符。Bedrock 要求 ID 匹配 '^[a-zA-Z0-9_-]+$'。
func sanitizeBedrockToolUseIDs(body []byte) []byte {
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
for mi, msg := range messages.Array() {
content := msg.Get("content")
if !content.Exists() || !content.IsArray() {
continue
}
for ci, block := range content.Array() {
switch block.Get("type").String() {
case "tool_use":
body = sanitizeIDField(body, block.Get("id").String(), fmt.Sprintf("messages.%d.content.%d.id", mi, ci))
case "tool_result":
body = sanitizeIDField(body, block.Get("tool_use_id").String(), fmt.Sprintf("messages.%d.content.%d.tool_use_id", mi, ci))
}
}
}
return body
}
func sanitizeIDField(body []byte, id, path string) []byte {
if id == "" {
return body
}
sanitized := bedrockToolUseIDRe.ReplaceAllString(id, "_")
if sanitized != id {
body, _ = sjson.SetBytes(body, path, sanitized)
}
return body
}
const defaultCCMaxTokens = 81920
// sanitizeBedrockCCFields 处理 Claude Code 发送的 Bedrock 不兼容字段:
// - 移除 service_tierAnthropic API 专有Bedrock 不支持)
// - 移除 interface_geoAnthropic API 专有Bedrock 不支持)
// - 移除 context_managementAnthropic API 专有Bedrock 不支持CC v2.1.87+ 默认携带)
// - 注入 max_tokens 默认值 81920CC 可能省略Bedrock 要求必须提供)
// - 注入 anthropic_versionCC 通过 HTTP 头发送Bedrock 需要放在请求体中)
func sanitizeBedrockCCFields(body []byte) []byte {
if gjson.GetBytes(body, "service_tier").Exists() {
body, _ = sjson.DeleteBytes(body, "service_tier")
}
if gjson.GetBytes(body, "interface_geo").Exists() {
body, _ = sjson.DeleteBytes(body, "interface_geo")
}
if gjson.GetBytes(body, "context_management").Exists() {
body, _ = sjson.DeleteBytes(body, "context_management")
}
if !gjson.GetBytes(body, "max_tokens").Exists() {
body, _ = sjson.SetBytes(body, "max_tokens", defaultCCMaxTokens)
}
if !gjson.GetBytes(body, "anthropic_version").Exists() {
body, _ = sjson.SetBytes(body, "anthropic_version", "bedrock-2023-05-31")
}
return body
}
// sanitizeBedrockCCBetaTokens 清理请求体中的 anthropic_beta 字段,只保留 Bedrock 支持的 beta token
// CC 可能在请求体中注入了 Bedrock 不支持的 beta token如 prompt-caching 等),导致 ValidationException
func sanitizeBedrockCCBetaTokens(body []byte, modelID string) []byte {
betaField := gjson.GetBytes(body, "anthropic_beta")
if !betaField.Exists() {
return body
}
var tokens []string
if betaField.IsArray() {
for _, t := range betaField.Array() {
if t.Type == gjson.String {
tokens = append(tokens, t.String())
}
}
}
originalTokens := append([]string(nil), tokens...) // 保存原始 tokens 用于日志
// 复用现有的 Bedrock beta token 过滤逻辑(自动注入 + 白名单过滤 + 转换)
// 即使 tokens 为空,也要执行自动注入(根据 body 内容补充必要的 beta token
tokens = autoInjectBedrockBetaTokens(tokens, body, modelID)
tokens = filterBedrockBetaTokens(tokens)
if len(tokens) == 0 {
// 所有 token 都被过滤掉,删除 anthropic_beta 字段
body, _ = sjson.DeleteBytes(body, "anthropic_beta")
logger.LegacyPrintf("service.gateway", "[Bedrock CC Compat] Removed all beta tokens: original=%v", originalTokens)
} else {
// 更新为过滤后的 token 列表
body, _ = sjson.SetBytes(body, "anthropic_beta", tokens)
if len(originalTokens) > 0 {
logger.LegacyPrintf("service.gateway", "[Bedrock CC Compat] Filtered beta tokens: original=%v final=%v", originalTokens, tokens)
} else {
logger.LegacyPrintf("service.gateway", "[Bedrock CC Compat] Auto-injected beta tokens: %v", tokens)
}
}
return body
}

View File

@ -216,7 +216,7 @@ func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) {
]
}`
betaHeader := "interleaved-thinking-2025-05-14, context-1m-2025-08-07, compact-2026-01-12"
betaHeader := "context-1m-2025-08-07, compact-2026-01-12"
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", betaHeader)
require.NoError(t, err)
@ -228,10 +228,9 @@ func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) {
// anthropic_beta 应包含所有 beta tokens
betaArr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, betaArr, 3)
assert.Equal(t, "interleaved-thinking-2025-05-14", betaArr[0].String())
assert.Equal(t, "context-1m-2025-08-07", betaArr[1].String())
assert.Equal(t, "compact-2026-01-12", betaArr[2].String())
require.Len(t, betaArr, 2)
assert.Equal(t, "context-1m-2025-08-07", betaArr[0].String())
assert.Equal(t, "compact-2026-01-12", betaArr[1].String())
// output_format 应被移除schema 内联到最后一条 user message
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
@ -264,29 +263,29 @@ func TestPrepareBedrockRequestBody_BetaHeader(t *testing.T) {
})
t.Run("single beta token", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14")
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 1)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
assert.Equal(t, "context-1m-2025-08-07", arr[0].String())
})
t.Run("multiple beta tokens with spaces", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14 , context-1m-2025-08-07 ")
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07 , compact-2026-01-12 ")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 2)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
assert.Equal(t, "context-1m-2025-08-07", arr[0].String())
assert.Equal(t, "compact-2026-01-12", arr[1].String())
})
t.Run("json array beta header", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["interleaved-thinking-2025-05-14","context-1m-2025-08-07"]`)
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["context-1m-2025-08-07","compact-2026-01-12"]`)
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 2)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
assert.Equal(t, "context-1m-2025-08-07", arr[0].String())
assert.Equal(t, "compact-2026-01-12", arr[1].String())
})
}
@ -301,15 +300,15 @@ func TestParseAnthropicBetaHeader(t *testing.T) {
func TestFilterBedrockBetaTokens(t *testing.T) {
t.Run("supported tokens pass through", func(t *testing.T) {
tokens := []string{"interleaved-thinking-2025-05-14", "context-1m-2025-08-07", "compact-2026-01-12"}
tokens := []string{"context-1m-2025-08-07", "compact-2026-01-12", "computer-use-2025-11-24"}
result := filterBedrockBetaTokens(tokens)
assert.Equal(t, tokens, result)
})
t.Run("unsupported tokens are filtered out", func(t *testing.T) {
tokens := []string{"interleaved-thinking-2025-05-14", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"}
tokens := []string{"context-1m-2025-08-07", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"}
result := filterBedrockBetaTokens(tokens)
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
assert.Equal(t, []string{"context-1m-2025-08-07"}, result)
})
t.Run("advanced-tool-use transforms to tool-search-tool", func(t *testing.T) {
@ -361,11 +360,11 @@ func TestPrepareBedrockRequestBody_BetaFiltering(t *testing.T) {
t.Run("unsupported beta tokens are filtered", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
"interleaved-thinking-2025-05-14, output-128k-2025-02-19, files-api-2025-04-14")
"compact-2026-01-12, output-128k-2025-02-19, files-api-2025-04-14")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 1)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
assert.Equal(t, "compact-2026-01-12", arr[0].String())
})
t.Run("advanced-tool-use transformed in full pipeline", func(t *testing.T) {
@ -498,22 +497,17 @@ func TestResolveBedrockModelID(t *testing.T) {
}
func TestAutoInjectBedrockBetaTokens(t *testing.T) {
t.Run("inject interleaved-thinking when thinking present", func(t *testing.T) {
t.Run("no auto-inject for thinking (interleaved-thinking not supported)", func(t *testing.T) {
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
// interleaved-thinking-2025-05-14 已从白名单移除,不应自动注入
assert.Empty(t, result)
})
t.Run("no duplicate when already present", func(t *testing.T) {
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens([]string{"interleaved-thinking-2025-05-14"}, body, "us.anthropic.claude-opus-4-6-v1")
count := 0
for _, t := range result {
if t == "interleaved-thinking-2025-05-14" {
count++
}
}
assert.Equal(t, 1, count)
result := autoInjectBedrockBetaTokens([]string{"context-1m-2025-08-07"}, body, "us.anthropic.claude-opus-4-6-v1")
assert.Equal(t, []string{"context-1m-2025-08-07"}, result)
})
t.Run("inject computer-use when computer tool present", func(t *testing.T) {
@ -574,7 +568,8 @@ func TestAutoInjectBedrockBetaTokens(t *testing.T) {
result := autoInjectBedrockBetaTokens(existing, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "context-1m-2025-08-07")
assert.Contains(t, result, "compact-2026-01-12")
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
// interleaved-thinking 不再自动注入
assert.NotContains(t, result, "interleaved-thinking-2025-05-14")
})
}
@ -588,27 +583,21 @@ func TestResolveBedrockBetaTokens(t *testing.T) {
t.Run("unsupported client beta tokens are filtered out", func(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
result := ResolveBedrockBetaTokens("interleaved-thinking-2025-05-14,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1")
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
result := ResolveBedrockBetaTokens("context-1m-2025-08-07,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1")
assert.Equal(t, []string{"context-1m-2025-08-07"}, result)
})
}
func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) {
t.Run("thinking in body auto-injects beta without header", func(t *testing.T) {
t.Run("thinking in body does not auto-inject beta (not supported)", func(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
found := false
for _, v := range arr {
if v.String() == "interleaved-thinking-2025-05-14" {
found = true
}
}
assert.True(t, found, "interleaved-thinking should be auto-injected")
// interleaved-thinking 已从白名单移除,不应自动注入
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
})
t.Run("header tokens merged with auto-injected tokens", func(t *testing.T) {
t.Run("header tokens preserved without auto-injection", func(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07")
require.NoError(t, err)
@ -618,7 +607,8 @@ func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) {
names[i] = v.String()
}
assert.Contains(t, names, "context-1m-2025-08-07")
assert.Contains(t, names, "interleaved-thinking-2025-05-14")
// interleaved-thinking 不再自动注入
assert.NotContains(t, names, "interleaved-thinking-2025-05-14")
})
}
@ -657,3 +647,363 @@ func TestAdjustBedrockModelRegionPrefix(t *testing.T) {
})
}
}
func TestIsBedrockOpus47OrNewer(t *testing.T) {
tests := []struct {
modelID string
expect bool
}{
{"us.anthropic.claude-opus-4-7-v1", true},
{"us.anthropic.claude-opus-4-6-v1", false},
{"us.anthropic.claude-opus-4-5-20251101-v1:0", false},
{"us.anthropic.claude-opus-5-0-v1", true},
// Sonnet 4.7 is not Opus → false
{"us.anthropic.claude-sonnet-4-7-v1", false},
{"us.anthropic.claude-sonnet-4-6", false},
// Haiku is not Opus
{"us.anthropic.claude-haiku-4-5-20251001-v1:0", false},
// Non-Claude models
{"amazon.nova-pro-v1", false},
}
for _, tt := range tests {
t.Run(tt.modelID, func(t *testing.T) {
assert.Equal(t, tt.expect, isBedrockOpus47OrNewer(tt.modelID))
})
}
}
func TestSanitizeBedrockThinking(t *testing.T) {
t.Run("opus 4.7 converts enabled to adaptive", func(t *testing.T) {
input := `{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[]}`
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1")
assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String())
assert.False(t, gjson.GetBytes(result, "thinking.budget_tokens").Exists())
})
t.Run("opus 4.7 keeps adaptive unchanged", func(t *testing.T) {
input := `{"thinking":{"type":"adaptive"},"messages":[]}`
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1")
assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String())
})
t.Run("opus 4.6 enabled without budget_tokens gets default", func(t *testing.T) {
input := `{"thinking":{"type":"enabled"},"messages":[]}`
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-6-v1")
assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String())
assert.Equal(t, int64(defaultThinkingBudgetTokens), gjson.GetBytes(result, "thinking.budget_tokens").Int())
})
t.Run("opus 4.6 enabled with budget_tokens unchanged", func(t *testing.T) {
input := `{"thinking":{"type":"enabled","budget_tokens":20000},"messages":[]}`
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-6-v1")
assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String())
assert.Equal(t, int64(20000), gjson.GetBytes(result, "thinking.budget_tokens").Int())
})
t.Run("no thinking field unchanged", func(t *testing.T) {
input := `{"messages":[]}`
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1")
assert.JSONEq(t, input, string(result))
})
t.Run("sonnet 4.6 enabled without budget_tokens gets default", func(t *testing.T) {
input := `{"thinking":{"type":"enabled"},"messages":[]}`
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-sonnet-4-6")
assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String())
assert.Equal(t, int64(defaultThinkingBudgetTokens), gjson.GetBytes(result, "thinking.budget_tokens").Int())
})
}
func TestSanitizeBedrockToolUseIDs(t *testing.T) {
t.Run("clean IDs unchanged", func(t *testing.T) {
input := `{"messages":[{"role":"assistant","content":[{"type":"tool_use","id":"toolu_01AbCdEf","name":"bash","input":{}}]}]}`
result := sanitizeBedrockToolUseIDs([]byte(input))
assert.Equal(t, "toolu_01AbCdEf", gjson.GetBytes(result, "messages.0.content.0.id").String())
})
t.Run("dots in tool_use ID replaced with underscores", func(t *testing.T) {
input := `{"messages":[{"role":"assistant","content":[{"type":"tool_use","id":"toolu.01.Ab","name":"bash","input":{}}]}]}`
result := sanitizeBedrockToolUseIDs([]byte(input))
assert.Equal(t, "toolu_01_Ab", gjson.GetBytes(result, "messages.0.content.0.id").String())
})
t.Run("special chars in tool_use ID sanitized", func(t *testing.T) {
input := `{"messages":[{"role":"assistant","content":[{"type":"tool_use","id":"toolu:01@Ab#Cd","name":"bash","input":{}}]}]}`
result := sanitizeBedrockToolUseIDs([]byte(input))
id := gjson.GetBytes(result, "messages.0.content.0.id").String()
assert.Regexp(t, `^[a-zA-Z0-9_-]+$`, id)
})
t.Run("tool_result tool_use_id sanitized", func(t *testing.T) {
input := `{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu.01.Ab","content":"ok"}]}]}`
result := sanitizeBedrockToolUseIDs([]byte(input))
assert.Equal(t, "toolu_01_Ab", gjson.GetBytes(result, "messages.0.content.0.tool_use_id").String())
})
t.Run("mixed clean and dirty IDs", func(t *testing.T) {
input := `{"messages":[
{"role":"assistant","content":[{"type":"tool_use","id":"clean_id-123","name":"a","input":{}}]},
{"role":"user","content":[{"type":"tool_result","tool_use_id":"dirty.id@456","content":"ok"}]},
{"role":"assistant","content":[{"type":"tool_use","id":"also.dirty","name":"b","input":{}}]}
]}`
result := sanitizeBedrockToolUseIDs([]byte(input))
assert.Equal(t, "clean_id-123", gjson.GetBytes(result, "messages.0.content.0.id").String())
assert.Equal(t, "dirty_id_456", gjson.GetBytes(result, "messages.1.content.0.tool_use_id").String())
assert.Equal(t, "also_dirty", gjson.GetBytes(result, "messages.2.content.0.id").String())
})
t.Run("no messages unchanged", func(t *testing.T) {
input := `{"system":[{"type":"text","text":"hi"}]}`
result := sanitizeBedrockToolUseIDs([]byte(input))
assert.JSONEq(t, input, string(result))
})
t.Run("string content skipped", func(t *testing.T) {
input := `{"messages":[{"role":"user","content":"plain text"}]}`
result := sanitizeBedrockToolUseIDs([]byte(input))
assert.JSONEq(t, input, string(result))
})
t.Run("empty ID skipped", func(t *testing.T) {
input := `{"messages":[{"role":"assistant","content":[{"type":"tool_use","id":"","name":"a","input":{}}]}]}`
result := sanitizeBedrockToolUseIDs([]byte(input))
assert.Equal(t, "", gjson.GetBytes(result, "messages.0.content.0.id").String())
})
}
func TestSanitizeBedrockThinking_EdgeCases(t *testing.T) {
t.Run("opus 4.7 enabled without budget_tokens converts to adaptive", func(t *testing.T) {
input := `{"thinking":{"type":"enabled"},"messages":[]}`
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1")
assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String())
assert.False(t, gjson.GetBytes(result, "thinking.budget_tokens").Exists())
})
t.Run("thinking type disabled unchanged", func(t *testing.T) {
input := `{"thinking":{"type":"disabled"},"messages":[]}`
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1")
assert.Equal(t, "disabled", gjson.GetBytes(result, "thinking.type").String())
})
t.Run("thinking type empty string unchanged", func(t *testing.T) {
input := `{"thinking":{"type":""},"messages":[]}`
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1")
assert.JSONEq(t, input, string(result))
})
t.Run("thinking is not an object unchanged", func(t *testing.T) {
input := `{"thinking":true,"messages":[]}`
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1")
assert.JSONEq(t, input, string(result))
})
t.Run("opus 4.7 adaptive with budget_tokens preserved", func(t *testing.T) {
input := `{"thinking":{"type":"adaptive","budget_tokens":5000},"messages":[]}`
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1")
assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String())
assert.Equal(t, int64(5000), gjson.GetBytes(result, "thinking.budget_tokens").Int())
})
// Forward() passes parsed.Model (standard names like "claude-opus-4-7")
t.Run("standard model name opus 4.7 converts enabled to adaptive", func(t *testing.T) {
input := `{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[]}`
result := sanitizeBedrockThinking([]byte(input), "claude-opus-4-7")
assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String())
assert.False(t, gjson.GetBytes(result, "thinking.budget_tokens").Exists())
})
t.Run("standard model name opus 4.6 keeps enabled", func(t *testing.T) {
input := `{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[]}`
result := sanitizeBedrockThinking([]byte(input), "claude-opus-4-6")
assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String())
assert.Equal(t, int64(10000), gjson.GetBytes(result, "thinking.budget_tokens").Int())
})
}
func TestIsBedrockOpus47OrNewer_EdgeCases(t *testing.T) {
tests := []struct {
modelID string
expect bool
}{
{"anthropic.claude-opus-4-7-v1", true},
{"us.anthropic.claude-opus-4-7-20270101-v1:0", true},
{"", false},
// Forward() passes parsed.Model (standard names), not Bedrock IDs
{"claude-opus-4-7", true},
{"claude-opus-4-6", false},
{"claude-sonnet-4-7", false},
}
for _, tt := range tests {
t.Run(tt.modelID, func(t *testing.T) {
assert.Equal(t, tt.expect, isBedrockOpus47OrNewer(tt.modelID))
})
}
}
func TestPrepareBedrockRequestBodyWithTokens_CCCompat(t *testing.T) {
input := `{
"model":"claude-opus-4-6",
"stream":true,
"max_tokens":16384,
"thinking":{"type":"enabled"},
"messages":[
{"role":"assistant","content":[{"type":"tool_use","id":"toolu.01.Ab","name":"bash","input":{}}]},
{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu.01.Ab","content":"ok"}]}
]
}`
t.Run("ccCompat=false skips thinking and toolUseID sanitization", func(t *testing.T) {
result, err := PrepareBedrockRequestBodyWithTokens([]byte(input), "us.anthropic.claude-opus-4-6-v1", nil, false)
require.NoError(t, err)
assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String())
assert.False(t, gjson.GetBytes(result, "thinking.budget_tokens").Exists())
assert.Equal(t, "toolu.01.Ab", gjson.GetBytes(result, "messages.0.content.0.id").String())
})
t.Run("ccCompat=true applies thinking fix and toolUseID sanitization (opus 4.6)", func(t *testing.T) {
result, err := PrepareBedrockRequestBodyWithTokens([]byte(input), "us.anthropic.claude-opus-4-6-v1", nil, true)
require.NoError(t, err)
assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String())
assert.Equal(t, int64(defaultThinkingBudgetTokens), gjson.GetBytes(result, "thinking.budget_tokens").Int())
assert.Equal(t, "toolu_01_Ab", gjson.GetBytes(result, "messages.0.content.0.id").String())
assert.Equal(t, "toolu_01_Ab", gjson.GetBytes(result, "messages.1.content.0.tool_use_id").String())
})
t.Run("ccCompat=true converts thinking to adaptive for opus 4.7", func(t *testing.T) {
result, err := PrepareBedrockRequestBodyWithTokens([]byte(input), "us.anthropic.claude-opus-4-7-v1", nil, true)
require.NoError(t, err)
assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String())
assert.False(t, gjson.GetBytes(result, "thinking.budget_tokens").Exists())
assert.Equal(t, "toolu_01_Ab", gjson.GetBytes(result, "messages.0.content.0.id").String())
})
}
func TestSanitizeBedrockCCFields(t *testing.T) {
t.Run("removes service_tier and interface_geo", func(t *testing.T) {
body := []byte(`{"model":"claude-opus-4-6","service_tier":"standard","interface_geo":"us","messages":[]}`)
result := sanitizeBedrockCCFields(body)
assert.False(t, gjson.GetBytes(result, "service_tier").Exists())
assert.False(t, gjson.GetBytes(result, "interface_geo").Exists())
assert.True(t, gjson.GetBytes(result, "messages").Exists())
})
t.Run("removes context_management", func(t *testing.T) {
body := []byte(`{"model":"claude-opus-4-6","context_management":{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]},"messages":[]}`)
result := sanitizeBedrockCCFields(body)
assert.False(t, gjson.GetBytes(result, "context_management").Exists())
assert.True(t, gjson.GetBytes(result, "messages").Exists())
})
t.Run("injects max_tokens when missing", func(t *testing.T) {
body := []byte(`{"model":"claude-opus-4-6","messages":[]}`)
result := sanitizeBedrockCCFields(body)
assert.Equal(t, int64(defaultCCMaxTokens), gjson.GetBytes(result, "max_tokens").Int())
})
t.Run("preserves existing max_tokens", func(t *testing.T) {
body := []byte(`{"model":"claude-opus-4-6","max_tokens":4096,"messages":[]}`)
result := sanitizeBedrockCCFields(body)
assert.Equal(t, int64(4096), gjson.GetBytes(result, "max_tokens").Int())
})
t.Run("injects anthropic_version when missing", func(t *testing.T) {
body := []byte(`{"model":"claude-opus-4-6","messages":[]}`)
result := sanitizeBedrockCCFields(body)
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
})
t.Run("preserves existing anthropic_version", func(t *testing.T) {
body := []byte(`{"model":"claude-opus-4-6","anthropic_version":"bedrock-2023-05-31","messages":[]}`)
result := sanitizeBedrockCCFields(body)
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
})
t.Run("no-op when fields already clean", func(t *testing.T) {
body := []byte(`{"model":"claude-opus-4-6","max_tokens":81920,"anthropic_version":"bedrock-2023-05-31","messages":[]}`)
result := sanitizeBedrockCCFields(body)
assert.Equal(t, int64(defaultCCMaxTokens), gjson.GetBytes(result, "max_tokens").Int())
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
assert.False(t, gjson.GetBytes(result, "service_tier").Exists())
assert.False(t, gjson.GetBytes(result, "interface_geo").Exists())
assert.False(t, gjson.GetBytes(result, "context_management").Exists())
})
t.Run("full CC request sanitization", func(t *testing.T) {
body := []byte(`{
"model":"claude-opus-4-6",
"service_tier":"standard",
"interface_geo":"us",
"context_management":{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]},
"messages":[{"role":"user","content":"hello"}],
"thinking":{"type":"enabled"}
}`)
result := sanitizeBedrockCCFields(body)
assert.False(t, gjson.GetBytes(result, "service_tier").Exists())
assert.False(t, gjson.GetBytes(result, "interface_geo").Exists())
assert.False(t, gjson.GetBytes(result, "context_management").Exists())
assert.Equal(t, int64(defaultCCMaxTokens), gjson.GetBytes(result, "max_tokens").Int())
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String())
})
}
func TestSanitizeBedrockCCBetaTokens(t *testing.T) {
t.Run("filters unsupported beta tokens", func(t *testing.T) {
input := `{"anthropic_beta":["prompt-caching-2024-07-31","context-1m-2025-08-07","unsupported-feature"],"messages":[]}`
result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6")
beta := gjson.GetBytes(result, "anthropic_beta")
assert.True(t, beta.Exists())
assert.True(t, beta.IsArray())
tokens := beta.Array()
assert.Equal(t, 1, len(tokens))
assert.Equal(t, "context-1m-2025-08-07", tokens[0].String())
})
t.Run("removes anthropic_beta if all tokens filtered", func(t *testing.T) {
input := `{"anthropic_beta":["prompt-caching-2024-07-31","unsupported-feature"],"messages":[]}`
result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6")
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
})
t.Run("thinking alone does not auto-inject beta tokens", func(t *testing.T) {
input := `{"anthropic_beta":[],"thinking":{"type":"enabled"},"messages":[]}`
result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6")
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
})
t.Run("auto-injects computer-use beta token", func(t *testing.T) {
input := `{"anthropic_beta":[],"tools":[{"type":"computer_20250124","name":"computer"}],"messages":[]}`
result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6")
beta := gjson.GetBytes(result, "anthropic_beta")
assert.True(t, beta.Exists())
tokens := beta.Array()
assert.Equal(t, 1, len(tokens))
assert.Equal(t, "computer-use-2025-11-24", tokens[0].String())
})
t.Run("transforms advanced-tool-use to tool-search-tool", func(t *testing.T) {
input := `{"anthropic_beta":["advanced-tool-use-2025-11-20"],"messages":[]}`
result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6")
beta := gjson.GetBytes(result, "anthropic_beta")
tokens := beta.Array()
assert.Equal(t, 2, len(tokens)) // tool-search-tool + tool-examples (auto-associated)
assert.Contains(t, []string{tokens[0].String(), tokens[1].String()}, "tool-search-tool-2025-10-19")
assert.Contains(t, []string{tokens[0].String(), tokens[1].String()}, "tool-examples-2025-10-29")
})
t.Run("no-op when anthropic_beta not present", func(t *testing.T) {
input := `{"messages":[]}`
result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6")
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
})
t.Run("preserves supported beta tokens", func(t *testing.T) {
input := `{"anthropic_beta":["computer-use-2025-11-24","context-1m-2025-08-07"],"messages":[]}`
result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6")
beta := gjson.GetBytes(result, "anthropic_beta")
tokens := beta.Array()
assert.Equal(t, 2, len(tokens))
assert.Contains(t, []string{tokens[0].String(), tokens[1].String()}, "computer-use-2025-11-24")
assert.Contains(t, []string{tokens[0].String(), tokens[1].String()}, "context-1m-2025-08-07")
})
}

View File

@ -248,6 +248,17 @@ func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
return ok && enabled
}
// IsBedrockCCCompatEnabled 返回该渠道是否启用了 Bedrock CC 兼容模式。
// 一旦启用,该渠道下所有请求都会应用 CC 兼容转换,不区分账号 platform。
func (c *Channel) IsBedrockCCCompatEnabled(platform string) bool {
if c == nil || c.FeaturesConfig == nil {
return false
}
// 直接检查 bedrock_cc_compat 开关,不再检查 platform 子字段
enabled, ok := c.FeaturesConfig[featureKeyBedrockCCCompat].(bool)
return ok && enabled
}
// deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution.
func deepCopyFeaturesConfig(src map[string]any) map[string]any {
dst := make(map[string]any, len(src))

View File

@ -0,0 +1,73 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestChannel_IsBedrockCCCompatEnabled_Enabled(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyBedrockCCCompat: true,
},
}
require.True(t, c.IsBedrockCCCompatEnabled("bedrock"))
}
func TestChannel_IsBedrockCCCompatEnabled_AppliesToAllPlatforms(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyBedrockCCCompat: true,
},
}
require.True(t, c.IsBedrockCCCompatEnabled("anthropic"))
require.True(t, c.IsBedrockCCCompatEnabled("openai"))
require.True(t, c.IsBedrockCCCompatEnabled(""))
}
func TestChannel_IsBedrockCCCompatEnabled_Disabled(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyBedrockCCCompat: false,
},
}
require.False(t, c.IsBedrockCCCompatEnabled("bedrock"))
}
func TestChannel_IsBedrockCCCompatEnabled_NilFeaturesConfig(t *testing.T) {
c := &Channel{FeaturesConfig: nil}
require.False(t, c.IsBedrockCCCompatEnabled("bedrock"))
}
func TestChannel_IsBedrockCCCompatEnabled_NilChannel(t *testing.T) {
var c *Channel
require.False(t, c.IsBedrockCCCompatEnabled("bedrock"))
}
func TestChannel_IsBedrockCCCompatEnabled_WrongType(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyBedrockCCCompat: "yes",
},
}
require.False(t, c.IsBedrockCCCompatEnabled("bedrock"))
}
func TestChannel_IsBedrockCCCompatEnabled_OldMapFormat(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
featureKeyBedrockCCCompat: map[string]any{"bedrock": true},
},
}
require.False(t, c.IsBedrockCCCompatEnabled("bedrock"))
}
func TestChannel_IsBedrockCCCompatEnabled_MissingKey(t *testing.T) {
c := &Channel{
FeaturesConfig: map[string]any{
"other_feature": true,
},
}
require.False(t, c.IsBedrockCCCompatEnabled("bedrock"))
}

View File

@ -281,9 +281,54 @@ func callProvider(ctx context.Context, provider, endpoint, apiKey, model, prompt
if err != nil {
return "", "", status, err
}
if provider == MonitorProviderOpenAI && apiMode == MonitorAPIModeResponses {
return extractOpenAIResponsesText(respBytes), string(respBytes), status, nil
}
return gjson.GetBytes(respBytes, adapter.textPath).String(), string(respBytes), status, nil
}
// extractOpenAIResponsesText 聚合 Responses API 的最终 assistant 文本。
// Responses 的 output 数组顺序由模型决定reasoning / tool-call item 可能排在 message 前面,
// 因此不能假设文本永远在 output.0.content.0.text。
func extractOpenAIResponsesText(respBytes []byte) string {
if text := gjson.GetBytes(respBytes, "output_text").String(); strings.TrimSpace(text) != "" {
return text
}
var texts []string
outputs := gjson.GetBytes(respBytes, "output")
if outputs.IsArray() {
outputs.ForEach(func(_, output gjson.Result) bool {
outputType := output.Get("type").String()
if outputType != "" && outputType != "message" {
return true
}
content := output.Get("content")
if !content.IsArray() {
return true
}
content.ForEach(func(_, block gjson.Result) bool {
blockType := block.Get("type").String()
if blockType != "" && blockType != "output_text" {
return true
}
if text := block.Get("text").String(); strings.TrimSpace(text) != "" {
texts = append(texts, text)
}
return true
})
return true
})
}
if len(texts) > 0 {
return strings.Join(texts, "")
}
return gjson.GetBytes(respBytes, providerOpenAIResponsesAdapter.textPath).String()
}
// mergeHeaders 把用户自定义 headers 合并到 adapter 默认 headers 上。
// 用户值覆盖默认命中黑名单hop-by-hop / 由 http.Client 自管的)的 key 静默丢弃。
func mergeHeaders(base map[string]string, opts *CheckOptions) map[string]string {

View File

@ -60,10 +60,11 @@ func setupFakeAnthropic(t *testing.T, handler *captureHandler) string {
}
type openAICaptureHandler struct {
lastBody map[string]any
lastHeaders http.Header
lastPath string
status int
lastBody map[string]any
lastHeaders http.Header
lastPath string
status int
responsesLeadingReasoning bool
}
func (h *openAICaptureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@ -82,10 +83,23 @@ func (h *openAICaptureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
answer := answerFromOpenAIRequest(parsed)
if h.lastPath == providerOpenAIResponsesPath {
output := []map[string]any{}
if h.responsesLeadingReasoning {
output = append(output, map[string]any{
"type": "reasoning",
"summary": []any{},
})
}
output = append(output, map[string]any{
"type": "message",
"status": "completed",
"role": "assistant",
"content": []map[string]any{
{"type": "output_text", "text": answer},
},
})
_ = json.NewEncoder(w).Encode(map[string]any{
"output": []map[string]any{{
"content": []map[string]any{{"type": "output_text", "text": answer}},
}},
"output": output,
})
return
}
@ -212,6 +226,22 @@ func TestRunCheckForModel_OpenAIResponses_DefaultRequest(t *testing.T) {
}
}
func TestRunCheckForModel_OpenAIResponses_SkipsLeadingReasoningItem(t *testing.T) {
h := &openAICaptureHandler{responsesLeadingReasoning: true}
endpoint := setupFakeOpenAI(t, h)
res := runCheckForModel(context.Background(), MonitorProviderOpenAI, endpoint, "sk-openai", "gpt-5.5", &CheckOptions{
APIMode: MonitorAPIModeResponses,
})
if res.Status != MonitorStatusOperational {
t.Fatalf("responses request should find text after leading reasoning item, got status=%s message=%q", res.Status, res.Message)
}
if h.lastPath != providerOpenAIResponsesPath {
t.Fatalf("expected responses path %q, got %q", providerOpenAIResponsesPath, h.lastPath)
}
}
func TestRunCheckForModel_OpenAIResponsesReplaceMissingInstructionsFailsLocally(t *testing.T) {
h := &openAICaptureHandler{}
endpoint := setupFakeOpenAI(t, h)

View File

@ -3,13 +3,17 @@ package service
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"golang.org/x/sync/singleflight"
)
// ConcurrencyCache 定义并发控制的缓存接口
@ -79,18 +83,50 @@ func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error
}
const (
// Default extra wait slots beyond concurrency limit
// 默认等待队列额外槽位
defaultExtraWaitSlots = 20
defaultAccountLoadBatchCacheTTL = 200 * time.Millisecond
accountLoadBatchFetchTimeout = 3 * time.Second
maxAccountLoadBatchCacheEntries = 256
)
// ConcurrencyService manages concurrent request limiting for accounts and users
// ConcurrencyService 管理账号和用户的并发限制。
type ConcurrencyService struct {
cache ConcurrencyCache
accountLoadCacheTTL atomic.Int64
accountLoadCacheMu sync.RWMutex
accountLoadCache map[string]cachedAccountLoadBatch
accountLoadGroup singleflight.Group
}
// NewConcurrencyService creates a new ConcurrencyService
type cachedAccountLoadBatch struct {
loadMap map[int64]*AccountLoadInfo
expiresAt time.Time
}
// NewConcurrencyService 创建并发控制服务。
func NewConcurrencyService(cache ConcurrencyCache) *ConcurrencyService {
return &ConcurrencyService{cache: cache}
svc := &ConcurrencyService{
cache: cache,
accountLoadCache: make(map[string]cachedAccountLoadBatch),
}
svc.SetAccountLoadBatchCacheTTL(defaultAccountLoadBatchCacheTTL)
return svc
}
// SetAccountLoadBatchCacheTTL 设置账号负载批量读取的极短 TTL 缓存;非正数表示禁用缓存。
func (s *ConcurrencyService) SetAccountLoadBatchCacheTTL(ttl time.Duration) {
if s == nil {
return
}
s.accountLoadCacheTTL.Store(int64(ttl))
if ttl <= 0 {
s.accountLoadCacheMu.Lock()
s.accountLoadCache = make(map[string]cachedAccountLoadBatch)
s.accountLoadCacheMu.Unlock()
}
}
// AcquireResult represents the result of acquiring a concurrency slot
@ -284,12 +320,140 @@ func CalculateMaxWait(userConcurrency int) int {
return userConcurrency + defaultExtraWaitSlots
}
// GetAccountsLoadBatch returns load info for multiple accounts.
// GetAccountsLoadBatch 批量获取账号负载信息。
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
return s.getAccountsLoadBatch(ctx, accounts, true)
}
// GetAccountsLoadBatchFresh 绕过极短 TTL 缓存,用于抢槽失败后的实时刷新兜底。
func (s *ConcurrencyService) GetAccountsLoadBatchFresh(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
return s.getAccountsLoadBatch(ctx, accounts, false)
}
func (s *ConcurrencyService) getAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency, allowCache bool) (map[int64]*AccountLoadInfo, error) {
if len(accounts) == 0 {
return map[int64]*AccountLoadInfo{}, nil
}
if s.cache == nil {
return map[int64]*AccountLoadInfo{}, nil
}
return s.cache.GetAccountsLoadBatch(ctx, accounts)
ttl := time.Duration(s.accountLoadCacheTTL.Load())
if !allowCache || ttl <= 0 {
return s.fetchAccountsLoadBatch(ctx, accounts)
}
key := accountLoadBatchCacheKey(accounts)
if cached, ok := s.getCachedAccountLoadBatch(key, time.Now()); ok {
return cached, nil
}
value, err, _ := s.accountLoadGroup.Do(key, func() (any, error) {
now := time.Now()
if cached, ok := s.getCachedAccountLoadBatch(key, now); ok {
return cached, nil
}
loadMap, fetchErr := s.fetchAccountsLoadBatch(ctx, accounts)
if fetchErr != nil {
return nil, fetchErr
}
cached := cloneAccountLoadMap(loadMap)
s.storeCachedAccountLoadBatch(key, cached, now.Add(ttl))
return cached, nil
})
if err != nil {
return nil, err
}
loadMap, _ := value.(map[int64]*AccountLoadInfo)
if loadMap == nil {
return map[int64]*AccountLoadInfo{}, nil
}
return loadMap, nil
}
func (s *ConcurrencyService) fetchAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if s.cache == nil {
return map[int64]*AccountLoadInfo{}, nil
}
baseCtx := context.Background()
if ctx != nil {
baseCtx = context.WithoutCancel(ctx)
}
redisCtx, cancel := context.WithTimeout(baseCtx, accountLoadBatchFetchTimeout)
defer cancel()
return s.cache.GetAccountsLoadBatch(redisCtx, accounts)
}
func (s *ConcurrencyService) getCachedAccountLoadBatch(key string, now time.Time) (map[int64]*AccountLoadInfo, bool) {
s.accountLoadCacheMu.RLock()
cached, ok := s.accountLoadCache[key]
s.accountLoadCacheMu.RUnlock()
if !ok {
return nil, false
}
if !now.Before(cached.expiresAt) {
s.accountLoadCacheMu.Lock()
if current, exists := s.accountLoadCache[key]; exists && !now.Before(current.expiresAt) {
delete(s.accountLoadCache, key)
}
s.accountLoadCacheMu.Unlock()
return nil, false
}
return cached.loadMap, true
}
func (s *ConcurrencyService) storeCachedAccountLoadBatch(key string, loadMap map[int64]*AccountLoadInfo, expiresAt time.Time) {
s.accountLoadCacheMu.Lock()
if s.accountLoadCache == nil {
s.accountLoadCache = make(map[string]cachedAccountLoadBatch)
}
if len(s.accountLoadCache) >= maxAccountLoadBatchCacheEntries {
now := time.Now()
for cacheKey, cached := range s.accountLoadCache {
if !now.Before(cached.expiresAt) {
delete(s.accountLoadCache, cacheKey)
}
}
for len(s.accountLoadCache) >= maxAccountLoadBatchCacheEntries {
for cacheKey := range s.accountLoadCache {
delete(s.accountLoadCache, cacheKey)
break
}
}
}
s.accountLoadCache[key] = cachedAccountLoadBatch{
loadMap: loadMap,
expiresAt: expiresAt,
}
s.accountLoadCacheMu.Unlock()
}
func accountLoadBatchCacheKey(accounts []AccountWithConcurrency) string {
hash := sha256.New()
var buf [16]byte
for _, account := range accounts {
binary.LittleEndian.PutUint64(buf[:8], uint64(account.ID))
binary.LittleEndian.PutUint64(buf[8:], uint64(int64(account.MaxConcurrency)))
_, _ = hash.Write(buf[:])
}
sum := hash.Sum(nil)
return strconv.Itoa(len(accounts)) + ":" + hex.EncodeToString(sum)
}
func cloneAccountLoadMap(loadMap map[int64]*AccountLoadInfo) map[int64]*AccountLoadInfo {
if len(loadMap) == 0 {
return map[int64]*AccountLoadInfo{}
}
clone := make(map[int64]*AccountLoadInfo, len(loadMap))
for accountID, loadInfo := range loadMap {
if loadInfo == nil {
clone[accountID] = nil
continue
}
copied := *loadInfo
clone[accountID] = &copied
}
return clone
}
// GetUsersLoadBatch returns load info for multiple users.

View File

@ -7,7 +7,9 @@ import (
"errors"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
@ -32,6 +34,7 @@ type stubConcurrencyCacheForTest struct {
// 记录调用
releasedAccountIDs []int64
releasedRequestIDs []string
loadBatchCalls atomic.Int64
}
var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil)
@ -82,6 +85,7 @@ func (c *stubConcurrencyCacheForTest) DecrementWaitCount(_ context.Context, _ in
return nil
}
func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
c.loadBatchCalls.Add(1)
return c.loadBatch, c.loadBatchErr
}
func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
@ -237,6 +241,47 @@ func TestGetAccountsLoadBatch_NilCache(t *testing.T) {
require.Empty(t, result)
}
func TestGetAccountsLoadBatch_UsesShortTTLCache(t *testing.T) {
cache := &stubConcurrencyCacheForTest{
loadBatch: map[int64]*AccountLoadInfo{
1: {AccountID: 1, CurrentConcurrency: 1, LoadRate: 20},
},
}
svc := NewConcurrencyService(cache)
svc.SetAccountLoadBatchCacheTTL(time.Second)
accounts := []AccountWithConcurrency{{ID: 1, MaxConcurrency: 5}}
first, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
require.NoError(t, err)
require.Equal(t, 1, first[int64(1)].CurrentConcurrency)
cache.loadBatch[1] = &AccountLoadInfo{AccountID: 1, CurrentConcurrency: 4, LoadRate: 80}
second, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
require.NoError(t, err)
require.Equal(t, 1, second[int64(1)].CurrentConcurrency)
require.Equal(t, int64(1), cache.loadBatchCalls.Load())
}
func TestGetAccountsLoadBatchFresh_BypassesShortTTLCache(t *testing.T) {
cache := &stubConcurrencyCacheForTest{
loadBatch: map[int64]*AccountLoadInfo{
1: {AccountID: 1, CurrentConcurrency: 1, LoadRate: 20},
},
}
svc := NewConcurrencyService(cache)
svc.SetAccountLoadBatchCacheTTL(time.Second)
accounts := []AccountWithConcurrency{{ID: 1, MaxConcurrency: 5}}
_, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
require.NoError(t, err)
cache.loadBatch[1] = &AccountLoadInfo{AccountID: 1, CurrentConcurrency: 4, LoadRate: 80}
fresh, err := svc.GetAccountsLoadBatchFresh(context.Background(), accounts)
require.NoError(t, err)
require.Equal(t, 4, fresh[int64(1)].CurrentConcurrency)
require.Equal(t, int64(2), cache.loadBatchCalls.Load())
}
func TestIncrementWaitCount_Success(t *testing.T) {
cache := &stubConcurrencyCacheForTest{waitAllowed: true}
svc := NewConcurrencyService(cache)

View File

@ -44,6 +44,10 @@ const (
ContentModerationKeywordModeKeywordAndAPI = "keyword_and_api"
ContentModerationKeywordModeAPIOnly = "api_only"
ContentModerationModelFilterAll = "all"
ContentModerationModelFilterInclude = "include"
ContentModerationModelFilterExclude = "exclude"
ContentModerationProtocolAnthropicMessages = "anthropic_messages"
ContentModerationProtocolOpenAIResponses = "openai_responses"
ContentModerationProtocolOpenAIChat = "openai_chat_completions"
@ -80,6 +84,8 @@ const (
maxContentModerationTestImageDataURLBytes = 12 * 1024 * 1024
maxContentModerationBlockedKeywords = 10000
maxContentModerationBlockedKeywordRunes = 200
maxContentModerationModelFilterModels = 1000
maxContentModerationModelFilterRunes = 200
contentModerationCleanupInterval = 24 * time.Hour
contentModerationCleanupTimeout = 30 * time.Minute
@ -127,32 +133,33 @@ func ContentModerationCategories() []string {
}
type ContentModerationConfig struct {
Enabled bool `json:"enabled"`
Mode string `json:"mode"`
BaseURL string `json:"base_url"`
Model string `json:"model"`
APIKey string `json:"api_key,omitempty"`
APIKeys []string `json:"api_keys,omitempty"`
TimeoutMS int `json:"timeout_ms"`
SampleRate int `json:"sample_rate"`
AllGroups bool `json:"all_groups"`
GroupIDs []int64 `json:"group_ids"`
RecordNonHits bool `json:"record_non_hits"`
Thresholds map[string]float64 `json:"thresholds"`
WorkerCount int `json:"worker_count"`
QueueSize int `json:"queue_size"`
BlockStatus int `json:"block_status"`
BlockMessage string `json:"block_message"`
EmailOnHit bool `json:"email_on_hit"`
AutoBanEnabled bool `json:"auto_ban_enabled"`
BanThreshold int `json:"ban_threshold"`
ViolationWindowHours int `json:"violation_window_hours"`
RetryCount int `json:"retry_count"`
HitRetentionDays int `json:"hit_retention_days"`
NonHitRetentionDays int `json:"non_hit_retention_days"`
PreHashCheckEnabled bool `json:"pre_hash_check_enabled"`
BlockedKeywords []string `json:"blocked_keywords"`
KeywordBlockingMode string `json:"keyword_blocking_mode"`
Enabled bool `json:"enabled"`
Mode string `json:"mode"`
BaseURL string `json:"base_url"`
Model string `json:"model"`
APIKey string `json:"api_key,omitempty"`
APIKeys []string `json:"api_keys,omitempty"`
TimeoutMS int `json:"timeout_ms"`
SampleRate int `json:"sample_rate"`
AllGroups bool `json:"all_groups"`
GroupIDs []int64 `json:"group_ids"`
RecordNonHits bool `json:"record_non_hits"`
Thresholds map[string]float64 `json:"thresholds"`
WorkerCount int `json:"worker_count"`
QueueSize int `json:"queue_size"`
BlockStatus int `json:"block_status"`
BlockMessage string `json:"block_message"`
EmailOnHit bool `json:"email_on_hit"`
AutoBanEnabled bool `json:"auto_ban_enabled"`
BanThreshold int `json:"ban_threshold"`
ViolationWindowHours int `json:"violation_window_hours"`
RetryCount int `json:"retry_count"`
HitRetentionDays int `json:"hit_retention_days"`
NonHitRetentionDays int `json:"non_hit_retention_days"`
PreHashCheckEnabled bool `json:"pre_hash_check_enabled"`
BlockedKeywords []string `json:"blocked_keywords"`
KeywordBlockingMode string `json:"keyword_blocking_mode"`
ModelFilter ContentModerationModelFilter `json:"model_filter"`
}
type ContentModerationConfigView struct {
@ -184,6 +191,7 @@ type ContentModerationConfigView struct {
PreHashCheckEnabled bool `json:"pre_hash_check_enabled"`
BlockedKeywords []string `json:"blocked_keywords"`
KeywordBlockingMode string `json:"keyword_blocking_mode"`
ModelFilter ContentModerationModelFilter `json:"model_filter"`
}
type ContentModerationAPIKeyStatus struct {
@ -227,34 +235,40 @@ type ContentModerationTestAuditResult struct {
}
type UpdateContentModerationConfigInput struct {
Enabled *bool `json:"enabled"`
Mode *string `json:"mode"`
BaseURL *string `json:"base_url"`
Model *string `json:"model"`
APIKey *string `json:"api_key"`
APIKeys *[]string `json:"api_keys"`
APIKeysMode string `json:"api_keys_mode"`
DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"`
ClearAPIKey bool `json:"clear_api_key"`
TimeoutMS *int `json:"timeout_ms"`
SampleRate *int `json:"sample_rate"`
AllGroups *bool `json:"all_groups"`
GroupIDs *[]int64 `json:"group_ids"`
RecordNonHits *bool `json:"record_non_hits"`
WorkerCount *int `json:"worker_count"`
QueueSize *int `json:"queue_size"`
BlockStatus *int `json:"block_status"`
BlockMessage *string `json:"block_message"`
EmailOnHit *bool `json:"email_on_hit"`
AutoBanEnabled *bool `json:"auto_ban_enabled"`
BanThreshold *int `json:"ban_threshold"`
ViolationWindowHours *int `json:"violation_window_hours"`
RetryCount *int `json:"retry_count"`
HitRetentionDays *int `json:"hit_retention_days"`
NonHitRetentionDays *int `json:"non_hit_retention_days"`
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
BlockedKeywords *[]string `json:"blocked_keywords"`
KeywordBlockingMode *string `json:"keyword_blocking_mode"`
Enabled *bool `json:"enabled"`
Mode *string `json:"mode"`
BaseURL *string `json:"base_url"`
Model *string `json:"model"`
APIKey *string `json:"api_key"`
APIKeys *[]string `json:"api_keys"`
APIKeysMode string `json:"api_keys_mode"`
DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"`
ClearAPIKey bool `json:"clear_api_key"`
TimeoutMS *int `json:"timeout_ms"`
SampleRate *int `json:"sample_rate"`
AllGroups *bool `json:"all_groups"`
GroupIDs *[]int64 `json:"group_ids"`
RecordNonHits *bool `json:"record_non_hits"`
WorkerCount *int `json:"worker_count"`
QueueSize *int `json:"queue_size"`
BlockStatus *int `json:"block_status"`
BlockMessage *string `json:"block_message"`
EmailOnHit *bool `json:"email_on_hit"`
AutoBanEnabled *bool `json:"auto_ban_enabled"`
BanThreshold *int `json:"ban_threshold"`
ViolationWindowHours *int `json:"violation_window_hours"`
RetryCount *int `json:"retry_count"`
HitRetentionDays *int `json:"hit_retention_days"`
NonHitRetentionDays *int `json:"non_hit_retention_days"`
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
BlockedKeywords *[]string `json:"blocked_keywords"`
KeywordBlockingMode *string `json:"keyword_blocking_mode"`
ModelFilter *ContentModerationModelFilter `json:"model_filter"`
}
type ContentModerationModelFilter struct {
Type string `json:"type"`
Models []string `json:"models"`
}
type ContentModerationCheckInput struct {
@ -581,6 +595,9 @@ func (s *ContentModerationService) UpdateConfig(ctx context.Context, input Updat
if input.KeywordBlockingMode != nil {
cfg.KeywordBlockingMode = strings.TrimSpace(*input.KeywordBlockingMode)
}
if input.ModelFilter != nil {
cfg.ModelFilter = *input.ModelFilter
}
if input.AllGroups != nil {
cfg.AllGroups = *input.AllGroups
}
@ -719,7 +736,8 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer
"error", err)
return allow, nil
}
inScope := cfg.includesGroup(input.GroupID)
inGroupScope := cfg.includesGroup(input.GroupID)
inModelScope := cfg.includesModel(input.Model)
slog.Info("content_moderation.config_loaded",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
@ -733,7 +751,10 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer
"mode", cfg.Mode,
"all_groups", cfg.AllGroups,
"configured_group_ids", cfg.GroupIDs,
"in_scope", inScope,
"in_group_scope", inGroupScope,
"model_filter_type", cfg.ModelFilter.Type,
"configured_models", cfg.ModelFilter.Models,
"in_model_scope", inModelScope,
"sample_rate", cfg.SampleRate,
"api_key_count", len(cfg.apiKeys()),
"pre_hash_check_enabled", cfg.PreHashCheckEnabled,
@ -756,7 +777,7 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer
"protocol", input.Protocol)
return allow, nil
}
if !inScope {
if !inGroupScope {
slog.Info("content_moderation.skip_group_out_of_scope",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
@ -768,6 +789,19 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer
"configured_group_ids", cfg.GroupIDs)
return allow, nil
}
if !inModelScope {
slog.Info("content_moderation.skip_model_out_of_scope",
"user_id", input.UserID,
"api_key_id", input.APIKeyID,
"group_id", contentModerationLogGroupID(input.GroupID),
"group_name", input.GroupName,
"endpoint", input.Endpoint,
"protocol", input.Protocol,
"model", input.Model,
"model_filter_type", cfg.ModelFilter.Type,
"configured_models", cfg.ModelFilter.Models)
return allow, nil
}
content := ExtractContentModerationInput(input.Protocol, input.Body)
if content.IsEmpty() {
slog.Info("content_moderation.skip_empty_input",
@ -1025,6 +1059,9 @@ func (s *ContentModerationService) worker(id int) {
if !cfg.includesGroup(task.input.GroupID) {
return
}
if !cfg.includesModel(task.input.Model) {
return
}
s.asyncActive.Add(1)
defer s.asyncActive.Add(-1)
queueDelay := int(time.Since(task.enqueuedAt).Milliseconds())
@ -1270,6 +1307,9 @@ func (s *ContentModerationService) validateConfig(ctx context.Context, cfg *Cont
if cfg.BlockStatus < 400 || cfg.BlockStatus > 599 {
return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_BLOCK_STATUS", "拦截 HTTP 状态码必须在 400-599 之间")
}
if cfg.ModelFilter.Type != ContentModerationModelFilterAll && len(cfg.ModelFilter.Models) == 0 {
return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_MODEL_FILTER", "指定或排除模型时至少需要配置 1 个模型")
}
if !cfg.AllGroups && len(cfg.GroupIDs) > 0 && s.groupRepo != nil {
for _, groupID := range cfg.GroupIDs {
if _, err := s.groupRepo.GetByIDLite(ctx, groupID); err != nil {
@ -1590,6 +1630,10 @@ func defaultContentModerationConfig() *ContentModerationConfig {
PreHashCheckEnabled: false,
BlockedKeywords: []string{},
KeywordBlockingMode: ContentModerationKeywordModeKeywordAndAPI,
ModelFilter: ContentModerationModelFilter{
Type: ContentModerationModelFilterAll,
Models: []string{},
},
}
}
@ -1670,6 +1714,7 @@ func (cfg *ContentModerationConfig) normalize() {
cfg.Thresholds = mergeContentModerationThresholds(ContentModerationDefaultThresholds(), cfg.Thresholds)
cfg.BlockedKeywords = normalizeBlockedKeywords(cfg.BlockedKeywords)
cfg.KeywordBlockingMode = normalizeKeywordBlockingMode(cfg.KeywordBlockingMode)
cfg.ModelFilter = normalizeContentModerationModelFilter(cfg.ModelFilter)
}
func (cfg *ContentModerationConfig) includesGroup(groupID *int64) bool {
@ -1687,6 +1732,21 @@ func (cfg *ContentModerationConfig) includesGroup(groupID *int64) bool {
return false
}
func (cfg *ContentModerationConfig) includesModel(model string) bool {
if cfg == nil {
return true
}
filter := normalizeContentModerationModelFilter(cfg.ModelFilter)
switch filter.Type {
case ContentModerationModelFilterInclude:
return contentModerationModelListContains(filter.Models, model)
case ContentModerationModelFilterExclude:
return !contentModerationModelListContains(filter.Models, model)
default:
return true
}
}
func contentModerationLogGroupID(groupID *int64) int64 {
if groupID == nil {
return 0
@ -1848,6 +1908,7 @@ func (s *ContentModerationService) configView(cfg *ContentModerationConfig) *Con
PreHashCheckEnabled: cfg.PreHashCheckEnabled,
BlockedKeywords: append([]string(nil), cfg.BlockedKeywords...),
KeywordBlockingMode: cfg.KeywordBlockingMode,
ModelFilter: cloneContentModerationModelFilter(cfg.ModelFilter),
}
}
@ -2125,6 +2186,73 @@ func normalizeKeywordBlockingMode(mode string) string {
}
}
func normalizeContentModerationModelFilter(filter ContentModerationModelFilter) ContentModerationModelFilter {
out := ContentModerationModelFilter{
Type: normalizeContentModerationModelFilterType(filter.Type),
Models: normalizeContentModerationModelNames(filter.Models),
}
if out.Type == ContentModerationModelFilterAll {
out.Models = []string{}
}
return out
}
func cloneContentModerationModelFilter(filter ContentModerationModelFilter) ContentModerationModelFilter {
normalized := normalizeContentModerationModelFilter(filter)
normalized.Models = append([]string(nil), normalized.Models...)
return normalized
}
func normalizeContentModerationModelFilterType(filterType string) string {
switch strings.ToLower(strings.TrimSpace(filterType)) {
case ContentModerationModelFilterInclude:
return ContentModerationModelFilterInclude
case ContentModerationModelFilterExclude:
return ContentModerationModelFilterExclude
case ContentModerationModelFilterAll:
return ContentModerationModelFilterAll
default:
return ContentModerationModelFilterAll
}
}
func normalizeContentModerationModelNames(models []string) []string {
if len(models) == 0 {
return []string{}
}
out := make([]string, 0, len(models))
seen := make(map[string]struct{}, len(models))
for _, raw := range models {
model := trimRunes(strings.TrimSpace(raw), maxContentModerationModelFilterRunes)
if model == "" {
continue
}
key := strings.ToLower(model)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, model)
if len(out) >= maxContentModerationModelFilterModels {
break
}
}
return out
}
func contentModerationModelListContains(models []string, model string) bool {
model = strings.ToLower(strings.TrimSpace(model))
if model == "" {
return false
}
for _, candidate := range models {
if strings.ToLower(strings.TrimSpace(candidate)) == model {
return true
}
}
return false
}
func matchBlockedKeyword(text string, keywords []string) (string, bool) {
if text == "" || len(keywords) == 0 {
return "", false

View File

@ -48,44 +48,44 @@ func collectLastRoleMessage(messages gjson.Result, role string, parts *[]string,
if !messages.IsArray() {
return
}
var lastParts []string
var lastImages []string
messages.ForEach(func(_, msg gjson.Result) bool {
if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == role {
var candidate []string
var candidateImages []string
collectContentValue(msg.Get("content"), &candidate, &candidateImages)
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
lastParts = candidate
lastImages = candidateImages
}
}
return true
})
*parts = append(*parts, lastParts...)
*images = append(*images, lastImages...)
array := messages.Array()
if len(array) == 0 {
return
}
last := array[len(array)-1]
if strings.ToLower(strings.TrimSpace(last.Get("role").String())) != role {
return
}
var candidate []string
var candidateImages []string
collectContentValue(last.Get("content"), &candidate, &candidateImages)
if normalizeContentModerationText(strings.Join(candidate, "\n")) == "" && len(candidateImages) == 0 {
return
}
*parts = append(*parts, candidate...)
*images = append(*images, candidateImages...)
}
func collectLastAnthropicUserMessage(messages gjson.Result, parts *[]string, images *[]string) {
if !messages.IsArray() {
return
}
var lastParts []string
var lastImages []string
messages.ForEach(func(_, msg gjson.Result) bool {
if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == "user" {
var candidate []string
var candidateImages []string
collectAnthropicUserContentValue(msg.Get("content"), &candidate, &candidateImages)
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
lastParts = candidate
lastImages = candidateImages
}
}
return true
})
*parts = append(*parts, lastParts...)
*images = append(*images, lastImages...)
array := messages.Array()
if len(array) == 0 {
return
}
last := array[len(array)-1]
if strings.ToLower(strings.TrimSpace(last.Get("role").String())) != "user" {
return
}
var candidate []string
var candidateImages []string
collectAnthropicUserContentValue(last.Get("content"), &candidate, &candidateImages)
if normalizeContentModerationText(strings.Join(candidate, "\n")) == "" && len(candidateImages) == 0 {
return
}
*parts = append(*parts, candidate...)
*images = append(*images, candidateImages...)
}
func collectAnthropicUserContentValue(value gjson.Result, parts *[]string, images *[]string) {
@ -128,18 +128,17 @@ func collectLastResponsesInput(input gjson.Result, parts *[]string, images *[]st
case input.Type == gjson.String:
addModerationText(parts, input.String())
case input.IsArray():
var last gjson.Result
input.ForEach(func(_, item gjson.Result) bool {
if isResponsesUserTextItem(item) {
last = item
}
return true
})
if last.Exists() {
collectContentValue(last.Get("content"), parts, images)
if last.Get("type").String() == "input_text" || last.Get("text").Exists() {
collectContentValue(last, parts, images)
}
array := input.Array()
if len(array) == 0 {
return
}
last := array[len(array)-1]
if !isResponsesUserTextItem(last) {
return
}
collectContentValue(last.Get("content"), parts, images)
if last.Get("type").String() == "input_text" || last.Get("text").Exists() {
collectContentValue(last, parts, images)
}
case input.IsObject():
if isResponsesUserTextItem(input) {
@ -176,29 +175,29 @@ func collectLastGeminiContent(contents gjson.Result, parts *[]string, images *[]
if !contents.IsArray() {
return
}
var lastParts []string
var lastImages []string
contents.ForEach(func(_, content gjson.Result) bool {
role := strings.ToLower(strings.TrimSpace(content.Get("role").String()))
if role == "" || role == "user" {
var candidate []string
var candidateImages []string
if arr := content.Get("parts"); arr.IsArray() {
arr.ForEach(func(_, part gjson.Result) bool {
addModerationText(&candidate, part.Get("text").String())
addGeminiModerationImage(&candidateImages, part)
return true
})
}
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
lastParts = candidate
lastImages = candidateImages
}
}
return true
})
*parts = append(*parts, lastParts...)
*images = append(*images, lastImages...)
array := contents.Array()
if len(array) == 0 {
return
}
last := array[len(array)-1]
role := strings.ToLower(strings.TrimSpace(last.Get("role").String()))
if role != "" && role != "user" {
return
}
var candidate []string
var candidateImages []string
if arr := last.Get("parts"); arr.IsArray() {
arr.ForEach(func(_, part gjson.Result) bool {
addModerationText(&candidate, part.Get("text").String())
addGeminiModerationImage(&candidateImages, part)
return true
})
}
if normalizeContentModerationText(strings.Join(candidate, "\n")) == "" && len(candidateImages) == 0 {
return
}
*parts = append(*parts, candidate...)
*images = append(*images, candidateImages...)
}
func collectContentValue(value gjson.Result, parts *[]string, images *[]string) {

View File

@ -0,0 +1,179 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// 当数组末尾不是用户消息时典型场景Agent 工具循环结束于 tool/assistant
// 应直接跳过审计——不再回溯查找历史中的某条用户消息。
func TestExtractContentModerationInput_AnthropicAgentToolLoopSkipsAudit(t *testing.T) {
body := []byte(`{
"messages": [
{"role":"user","content":"调用一下天气工具"},
{"role":"assistant","content":[{"type":"tool_use","id":"tool_1","name":"weather","input":{}}]},
{"role":"user","content":[{"type":"tool_result","tool_use_id":"tool_1","content":"晴 25 度"}]}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body)
require.Empty(t, input.Text)
require.Empty(t, input.Images)
}
func TestExtractContentModerationInput_AnthropicFirstTurnExtractsUser(t *testing.T) {
body := []byte(`{
"messages": [
{"role":"user","content":"Q1"}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body)
require.Equal(t, "Q1", input.Text)
}
func TestExtractContentModerationInput_AnthropicMultiTurnExtractsLatestUser(t *testing.T) {
body := []byte(`{
"messages": [
{"role":"user","content":"Q1"},
{"role":"assistant","content":"A1"},
{"role":"user","content":"Q2"}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body)
require.Equal(t, "Q2", input.Text)
}
func TestExtractContentModerationInput_AnthropicStreamResendExtractsResend(t *testing.T) {
body := []byte(`{
"messages": [
{"role":"user","content":"原问题"},
{"role":"assistant","content":"部分回答……"},
{"role":"user","content":"重发"}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body)
require.Equal(t, "重发", input.Text)
}
func TestExtractContentModerationInput_OpenAIChatAgentToolLoopSkipsAudit(t *testing.T) {
body := []byte(`{
"messages": [
{"role":"system","content":"sys"},
{"role":"user","content":"列出我的订单"},
{"role":"assistant","content":null,"tool_calls":[{"id":"call_1","type":"function","function":{"name":"orders","arguments":"{}"}}]},
{"role":"tool","tool_call_id":"call_1","content":"[]"}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIChat, body)
require.Empty(t, input.Text)
require.Empty(t, input.Images)
}
func TestExtractContentModerationInput_OpenAIChatMultiTurnExtractsLatestUser(t *testing.T) {
body := []byte(`{
"messages": [
{"role":"user","content":"Q1"},
{"role":"assistant","content":"A1"},
{"role":"user","content":"Q2"}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIChat, body)
require.Equal(t, "Q2", input.Text)
}
func TestExtractContentModerationInput_GeminiAgentToolLoopSkipsAudit(t *testing.T) {
body := []byte(`{
"contents": [
{"role":"user","parts":[{"text":"查询天气"}]},
{"role":"model","parts":[{"functionCall":{"name":"weather","args":{}}}]},
{"role":"user","parts":[{"functionResponse":{"name":"weather","response":{"temp":25}}}]}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolGemini, body)
require.Empty(t, input.Text)
require.Empty(t, input.Images)
}
func TestExtractContentModerationInput_GeminiFirstTurnExtractsUser(t *testing.T) {
body := []byte(`{
"contents": [
{"role":"user","parts":[{"text":"你好"}]}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolGemini, body)
require.Equal(t, "你好", input.Text)
}
func TestExtractContentModerationInput_GeminiMultiTurnExtractsLatestUser(t *testing.T) {
body := []byte(`{
"contents": [
{"role":"user","parts":[{"text":"Q1"}]},
{"role":"model","parts":[{"text":"A1"}]},
{"role":"user","parts":[{"text":"Q2"}]}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolGemini, body)
require.Equal(t, "Q2", input.Text)
}
func TestExtractContentModerationInput_ResponsesAgentToolLoopSkipsAudit(t *testing.T) {
body := []byte(`{
"input":[
{"type":"message","role":"user","content":[{"type":"input_text","text":"运行测试"}]},
{"type":"function_call","call_id":"call_1","name":"run_tests","arguments":"{}"},
{"type":"function_call_output","call_id":"call_1","output":"all passed"}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIResponses, body)
require.Empty(t, input.Text)
require.Empty(t, input.Images)
}
func TestExtractContentModerationInput_ResponsesLastUserMessageExtracted(t *testing.T) {
body := []byte(`{
"input":[
{"type":"message","role":"user","content":[{"type":"input_text","text":"first"}]},
{"type":"message","role":"assistant","content":[{"type":"output_text","text":"answer"}]},
{"type":"message","role":"user","content":[{"type":"input_text","text":"latest"}]}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIResponses, body)
require.Equal(t, "latest", input.Text)
}
func TestExtractContentModerationInput_ResponsesLastIsAssistantSkipped(t *testing.T) {
body := []byte(`{
"input":[
{"type":"message","role":"user","content":[{"type":"input_text","text":"q1"}]},
{"type":"message","role":"assistant","content":[{"type":"output_text","text":"a1"}]}
]
}`)
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIResponses, body)
require.Empty(t, input.Text)
require.Empty(t, input.Images)
}

View File

@ -530,6 +530,147 @@ func TestNormalizeKeywordBlockingMode_UnknownFallsBackToDefault(t *testing.T) {
require.Equal(t, ContentModerationKeywordModeAPIOnly, normalizeKeywordBlockingMode("api_only"))
}
func TestContentModerationCheck_ModelFilterAllAuditsEveryModel(t *testing.T) {
cfg := defaultContentModerationModelFilterTestConfig()
cfg.ModelFilter = ContentModerationModelFilter{Type: ContentModerationModelFilterAll}
svc, repo := newContentModerationModelFilterTestService(t, cfg)
for _, model := range []string{"gpt-5.5", "gpt-5.4"} {
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
Model: model,
Protocol: ContentModerationProtocolOpenAIChat,
Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`),
})
require.NoError(t, err)
require.True(t, decision.Blocked)
require.Equal(t, ContentModerationActionKeywordBlock, decision.Action)
}
require.Len(t, repo.logs, 2)
}
func TestContentModerationCheck_ModelFilterIncludeOnlyAuditsListedModels(t *testing.T) {
cfg := defaultContentModerationModelFilterTestConfig()
cfg.ModelFilter = ContentModerationModelFilter{Type: ContentModerationModelFilterInclude, Models: []string{"gpt-5.5"}}
svc, repo := newContentModerationModelFilterTestService(t, cfg)
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
Model: "gpt-5.5",
Protocol: ContentModerationProtocolOpenAIChat,
Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`),
})
require.NoError(t, err)
require.True(t, decision.Blocked)
require.Equal(t, ContentModerationActionKeywordBlock, decision.Action)
decision, err = svc.Check(context.Background(), ContentModerationCheckInput{
Model: "gpt-5.4",
Protocol: ContentModerationProtocolOpenAIChat,
Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`),
})
require.NoError(t, err)
require.True(t, decision.Allowed)
require.False(t, decision.Blocked)
require.Equal(t, ContentModerationActionAllow, decision.Action)
require.Len(t, repo.logs, 1)
require.Equal(t, "gpt-5.5", repo.logs[0].Model)
}
func TestContentModerationCheck_ModelFilterExcludeSkipsListedModels(t *testing.T) {
cfg := defaultContentModerationModelFilterTestConfig()
cfg.ModelFilter = ContentModerationModelFilter{Type: ContentModerationModelFilterExclude, Models: []string{"gpt-5.4"}}
svc, repo := newContentModerationModelFilterTestService(t, cfg)
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
Model: "gpt-5.5",
Protocol: ContentModerationProtocolOpenAIChat,
Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`),
})
require.NoError(t, err)
require.True(t, decision.Blocked)
require.Equal(t, ContentModerationActionKeywordBlock, decision.Action)
decision, err = svc.Check(context.Background(), ContentModerationCheckInput{
Model: "gpt-5.4",
Protocol: ContentModerationProtocolOpenAIChat,
Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`),
})
require.NoError(t, err)
require.True(t, decision.Allowed)
require.False(t, decision.Blocked)
require.Equal(t, ContentModerationActionAllow, decision.Action)
require.Len(t, repo.logs, 1)
require.Equal(t, "gpt-5.5", repo.logs[0].Model)
}
func TestContentModerationLoadConfig_LegacyConfigDefaultsModelFilterToAll(t *testing.T) {
raw := `{"enabled":true,"mode":"pre_block","base_url":"https://api.openai.com","model":"omni-moderation-latest","blocked_keywords":["secret-token"]}`
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{
SettingKeyContentModerationConfig: raw,
}},
nil,
nil,
nil,
nil,
nil,
nil,
)
cfg, err := svc.loadConfig(context.Background())
require.NoError(t, err)
require.Equal(t, ContentModerationModelFilterAll, cfg.ModelFilter.Type)
require.Empty(t, cfg.ModelFilter.Models)
require.True(t, cfg.includesModel("gpt-5.5"))
require.True(t, cfg.includesModel("gpt-5.4"))
}
func TestContentModerationCheck_ModelFilterUsesRequestedModelNotBodyModel(t *testing.T) {
cfg := defaultContentModerationModelFilterTestConfig()
cfg.ModelFilter = ContentModerationModelFilter{Type: ContentModerationModelFilterInclude, Models: []string{"gpt-5.5"}}
svc, repo := newContentModerationModelFilterTestService(t, cfg)
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
Model: "gpt-5.5",
Protocol: ContentModerationProtocolOpenAIChat,
Body: []byte(`{"model":"mapped-upstream-model","messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`),
})
require.NoError(t, err)
require.True(t, decision.Blocked)
require.Equal(t, ContentModerationActionKeywordBlock, decision.Action)
require.Len(t, repo.logs, 1)
require.Equal(t, "gpt-5.5", repo.logs[0].Model)
}
func defaultContentModerationModelFilterTestConfig() *ContentModerationConfig {
cfg := defaultContentModerationConfig()
cfg.Enabled = true
cfg.Mode = ContentModerationModePreBlock
cfg.BlockedKeywords = []string{"secret-token"}
return cfg
}
func newContentModerationModelFilterTestService(t *testing.T, cfg *ContentModerationConfig) (*ContentModerationService, *contentModerationTestRepo) {
t.Helper()
rawCfg, err := json.Marshal(cfg)
require.NoError(t, err)
repo := &contentModerationTestRepo{}
svc := NewContentModerationService(
&contentModerationTestSettingRepo{values: map[string]string{
SettingKeyRiskControlEnabled: "true",
SettingKeyContentModerationConfig: string(rawCfg),
}},
repo,
&contentModerationTestHashCache{},
nil,
nil,
nil,
nil,
)
return svc, repo
}
func TestContentModerationUpdateConfig_AppendsAndDeletesAPIKeys(t *testing.T) {
cfg := defaultContentModerationConfig()
cfg.APIKeys = []string{"sk-old-a", "sk-old-b"}

View File

@ -132,6 +132,9 @@ const (
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
// API Key IP 访问控制设置
SettingKeyAPIKeyACLTrustForwardedIP = "api_key_acl_trust_forwarded_ip" // API Key IP 白/黑名单是否信任转发 IP
// TOTP 双因素认证设置
SettingKeyTotpEnabled = "totp_enabled" // 是否启用 TOTP 2FA 功能
@ -406,12 +409,15 @@ const (
// 用于避免 Cloudflare 对浏览器型 UA 的质询拦截。
SettingKeyOpenAICodexUserAgent = "openai_codex_user_agent"
// Balance Low Notification
// 余额不足提醒
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值USD
SettingKeyBalanceLowNotifyRechargeURL = "balance_low_notify_recharge_url" // 充值页面 URL
// Account Quota Notification
// 订阅到期提醒
SettingKeySubscriptionExpiryNotifyEnabled = "subscription_expiry_notify_enabled" // 订阅到期提醒全局开关,默认开启
// 账号限额通知
SettingKeyAccountQuotaNotifyEnabled = "account_quota_notify_enabled" // 全局开关
SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表JSON 数组)

View File

@ -5739,6 +5739,31 @@ func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header,
}
}
// ApplyBedrockCCCompat 应用 Bedrock CC 兼容转换(渠道级模型映射后调用)
// 清理 Anthropic API 专有字段、注入 Bedrock 必需字段、修复 thinking/tool_use ID
func (s *GatewayService) ApplyBedrockCCCompat(ctx context.Context, body []byte, model string, account *Account, groupID *int64) []byte {
if !s.isBedrockCCCompatEnabled(ctx, account, groupID) {
return body
}
body = sanitizeBedrockCCFields(body)
body = sanitizeBedrockThinking(body, model)
body = sanitizeBedrockToolUseIDs(body)
body = sanitizeBedrockCCBetaTokens(body, model)
return body
}
// isBedrockCCCompatEnabled 检查渠道是否启用了 Bedrock CC 兼容模式
func (s *GatewayService) isBedrockCCCompatEnabled(ctx context.Context, account *Account, groupID *int64) bool {
if groupID == nil || s.channelService == nil {
return false
}
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
if err != nil || ch == nil {
return false
}
return ch.IsBedrockCCCompatEnabled(account.Platform)
}
// forwardBedrock 转发请求到 AWS Bedrock
func (s *GatewayService) forwardBedrock(
ctx context.Context,
@ -5771,7 +5796,7 @@ func (s *GatewayService) forwardBedrock(
return nil, err
}
bedrockBody, err := PrepareBedrockRequestBodyWithTokens(body, mappedModel, betaTokens)
bedrockBody, err := PrepareBedrockRequestBodyWithTokens(body, mappedModel, betaTokens, false)
if err != nil {
return nil, fmt.Errorf("prepare bedrock request body: %w", err)
}

View File

@ -126,17 +126,17 @@ func TestResolveBedrockBetaTokensForRequest_FiltersAfterBedrockTransform(t *test
}
}
// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking 验证:
// 管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token
// 但请求体包含 thinking 字段 → 自动注入后应被 block。
func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking(t *testing.T) {
// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedComputerUse 验证:
// 管理员 block 了 computer-use,客户端不在 header 中带该 token
// 但请求体包含 computer_use 工具 → 自动注入后应被 block。
func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedComputerUse(t *testing.T) {
settings := &BetaPolicySettings{
Rules: []BetaPolicyRule{
{
BetaToken: "interleaved-thinking-2025-05-14",
BetaToken: "computer-use-2025-11-24",
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "thinking is blocked",
ErrorMessage: "computer use is blocked",
},
},
}
@ -155,18 +155,18 @@ func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking(t *te
}
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
// header 中不带 beta token但 body 中有 thinking 字段
// header 中不带 beta token但 body 中有 computer_use 工具
_, err = svc.resolveBedrockBetaTokensForRequest(
context.Background(),
account,
"", // 空 header
[]byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`),
[]byte(`{"tools":[{"type":"computer_20250124","name":"computer"}],"messages":[{"role":"user","content":"hi"}]}`),
"us.anthropic.claude-opus-4-6-v1",
)
if err == nil {
t.Fatal("expected body-injected interleaved-thinking to be blocked")
t.Fatal("expected body-injected computer-use to be blocked")
}
if err.Error() != "thinking is blocked" {
if err.Error() != "computer use is blocked" {
t.Fatalf("unexpected error: %v", err)
}
}
@ -222,10 +222,10 @@ func TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches(t *test
settings := &BetaPolicySettings{
Rules: []BetaPolicyRule{
{
BetaToken: "computer-use-2025-11-24",
BetaToken: "context-1m-2025-08-07",
Action: BetaPolicyActionBlock,
Scope: BetaPolicyScopeAll,
ErrorMessage: "computer use is blocked",
ErrorMessage: "context is blocked",
},
},
}
@ -244,12 +244,12 @@ func TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches(t *test
}
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
// body 中有 thinking会注入 interleaved-thinking但 block 规则只针对 computer-use
// body 中有 computer_use 工具(会注入 computer-use token但 block 规则只针对 context-1m
tokens, err := svc.resolveBedrockBetaTokensForRequest(
context.Background(),
account,
"",
[]byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`),
[]byte(`{"tools":[{"type":"computer_20250124","name":"computer"}],"messages":[{"role":"user","content":"hi"}]}`),
"us.anthropic.claude-opus-4-6-v1",
)
if err != nil {
@ -257,11 +257,11 @@ func TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches(t *test
}
found := false
for _, token := range tokens {
if token == "interleaved-thinking-2025-05-14" {
if token == "computer-use-2025-11-24" {
found = true
}
}
if !found {
t.Fatal("expected interleaved-thinking token to be present")
t.Fatal("expected computer-use token to be present")
}
}

View File

@ -0,0 +1,169 @@
package service
import (
"context"
"net/http"
"time"
)
const (
openAIAccountStateUpdateTimeout = 5 * time.Second
openAIOAuth429FallbackCooldown = 5 * time.Second
openAIStopSchedulingBridgeCooldown = 2 * time.Minute
openAIOAuth429StormWindow = 10 * time.Second
openAIOAuth429StormThreshold = 20
openAIOAuth429StormMaxAccountSwitches = 1
)
func openAIAccountStateContext(ctx context.Context) (context.Context, context.CancelFunc) {
base := context.Background()
if ctx != nil {
base = context.WithoutCancel(ctx)
}
return context.WithTimeout(base, openAIAccountStateUpdateTimeout)
}
func isOpenAIOAuthAccount(account *Account) bool {
return account != nil && account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth
}
func isOpenAIAccount(account *Account) bool {
return account != nil && account.Platform == PlatformOpenAI
}
func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) bool {
stateCtx, cancel := openAIAccountStateContext(ctx)
defer cancel()
if statusCode == http.StatusTooManyRequests {
s.markOpenAIOAuth429RateLimited(stateCtx, account, headers, responseBody)
}
if s == nil || account == nil || s.rateLimitService == nil {
return false
}
shouldDisable := s.rateLimitService.HandleUpstreamError(stateCtx, account, statusCode, headers, responseBody)
if shouldDisable {
s.BlockAccountScheduling(account, time.Time{}, "upstream_disable")
}
return shouldDisable
}
func (s *OpenAIGatewayService) markOpenAIOAuth429RateLimited(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
if s == nil || !isOpenAIOAuthAccount(account) {
return
}
s.recordOpenAIOAuth429()
cooldownUntil := time.Now().Add(openAIOAuth429FallbackCooldown)
if s.rateLimitService != nil {
if resetAt := s.rateLimitService.calculateOpenAI429ResetTime(headers); resetAt != nil && resetAt.After(time.Now()) {
cooldownUntil = *resetAt
} else if resetUnix := parseOpenAIRateLimitResetTime(responseBody); resetUnix != nil {
if resetAt := time.Unix(*resetUnix, 0); resetAt.After(time.Now()) {
cooldownUntil = resetAt
}
} else if cooldown, ok := s.rateLimitService.get429FallbackCooldown(ctx, account); ok && cooldown > 0 {
cooldownUntil = time.Now().Add(cooldown)
}
}
s.BlockAccountScheduling(account, cooldownUntil, "429")
}
func (s *OpenAIGatewayService) BlockAccountScheduling(account *Account, until time.Time, reason string) {
if s == nil || !isOpenAIAccount(account) {
return
}
now := time.Now()
blockUntil := until
if blockUntil.IsZero() || !blockUntil.After(now) {
blockUntil = now.Add(openAIStopSchedulingBridgeCooldown)
}
for {
current, loaded := s.openaiAccountRuntimeBlockUntil.Load(account.ID)
if !loaded {
actual, stored := s.openaiAccountRuntimeBlockUntil.LoadOrStore(account.ID, blockUntil)
if !stored {
return
}
current = actual
}
currentUntil, ok := current.(time.Time)
if !ok || currentUntil.IsZero() {
if s.openaiAccountRuntimeBlockUntil.CompareAndSwap(account.ID, current, blockUntil) {
return
}
continue
}
if currentUntil.After(blockUntil) {
return
}
if s.openaiAccountRuntimeBlockUntil.CompareAndSwap(account.ID, current, blockUntil) {
return
}
}
}
func (s *OpenAIGatewayService) ClearAccountSchedulingBlock(accountID int64) {
if s == nil || accountID <= 0 {
return
}
s.openaiAccountRuntimeBlockUntil.Delete(accountID)
}
func (s *OpenAIGatewayService) isOpenAIAccountRuntimeBlocked(account *Account) bool {
if s == nil || !isOpenAIAccount(account) {
return false
}
value, ok := s.openaiAccountRuntimeBlockUntil.Load(account.ID)
if !ok {
return false
}
cooldownUntil, ok := value.(time.Time)
if !ok || cooldownUntil.IsZero() {
s.openaiAccountRuntimeBlockUntil.Delete(account.ID)
return false
}
if time.Now().Before(cooldownUntil) {
return true
}
s.openaiAccountRuntimeBlockUntil.Delete(account.ID)
return false
}
func (s *OpenAIGatewayService) recordOpenAIOAuth429() {
if s == nil {
return
}
now := time.Now()
windowStart := s.openaiOAuth429WindowStartUnixNano.Load()
if windowStart == 0 || now.Sub(time.Unix(0, windowStart)) >= openAIOAuth429StormWindow {
if s.openaiOAuth429WindowStartUnixNano.CompareAndSwap(windowStart, now.UnixNano()) {
s.openaiOAuth429WindowCount.Store(1)
return
}
}
s.openaiOAuth429WindowCount.Add(1)
}
func (s *OpenAIGatewayService) isOpenAIOAuth429Storm() bool {
if s == nil {
return false
}
windowStart := s.openaiOAuth429WindowStartUnixNano.Load()
if windowStart == 0 || time.Since(time.Unix(0, windowStart)) >= openAIOAuth429StormWindow {
return false
}
return s.openaiOAuth429WindowCount.Load() >= openAIOAuth429StormThreshold
}
func (s *OpenAIGatewayService) ShouldStopOpenAIOAuth429Failover(account *Account, statusCode int, failedSwitches int) bool {
if statusCode != http.StatusTooManyRequests || failedSwitches < openAIOAuth429StormMaxAccountSwitches {
return false
}
if !isOpenAIOAuthAccount(account) {
return false
}
return s.isOpenAIOAuth429Storm()
}

View File

@ -0,0 +1,101 @@
//go:build unit
package service
import (
"context"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestOpenAI429FastPath_MarksOAuthAccountCoolingDown(t *testing.T) {
svc := &OpenAIGatewayService{}
account := &Account{ID: 42, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
apiKeyAccount := &Account{ID: 43, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
shouldDisable := svc.handleOpenAIAccountUpstreamError(context.Background(), account, http.StatusTooManyRequests, http.Header{}, nil)
apiKeyShouldDisable := svc.handleOpenAIAccountUpstreamError(context.Background(), apiKeyAccount, http.StatusTooManyRequests, http.Header{}, nil)
require.False(t, shouldDisable)
require.False(t, apiKeyShouldDisable)
require.True(t, svc.isOpenAIAccountRuntimeBlocked(account))
require.False(t, svc.isOpenAIAccountRuntimeBlocked(apiKeyAccount))
}
func TestOpenAIRuntimeBlock_AppliesToOpenAIAPIKeyWhenRateLimitServiceStopsScheduling(t *testing.T) {
svc := &OpenAIGatewayService{}
account := &Account{ID: 44, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
svc.BlockAccountScheduling(account, time.Time{}, "custom_error_code")
require.True(t, svc.isOpenAIAccountRuntimeBlocked(account))
}
func TestOpenAIRuntimeBlock_DoesNotApplyToOtherPlatforms(t *testing.T) {
svc := &OpenAIGatewayService{}
account := &Account{ID: 45, Platform: PlatformGemini, Type: AccountTypeOAuth}
svc.BlockAccountScheduling(account, time.Time{}, "custom_error_code")
require.False(t, svc.isOpenAIAccountRuntimeBlocked(account))
}
func TestOpenAIRuntimeBlocker_IgnoresNonOpenAIFromRateLimitService(t *testing.T) {
gateway := &OpenAIGatewayService{}
repo := &rateLimitAccountRepoStub{}
rateLimitService := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
rateLimitService.SetAccountRuntimeBlocker(gateway)
account := &Account{ID: 45, Platform: PlatformGemini, Type: AccountTypeOAuth}
shouldDisable := rateLimitService.HandleUpstreamError(context.Background(), account, http.StatusForbidden, http.Header{}, []byte("forbidden"))
require.True(t, shouldDisable)
require.False(t, gateway.isOpenAIAccountRuntimeBlocked(account))
}
func TestOpenAIRuntimeBlock_DoesNotShortenExistingBlock(t *testing.T) {
svc := &OpenAIGatewayService{}
account := &Account{ID: 46, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
longUntil := time.Now().Add(10 * time.Minute)
svc.BlockAccountScheduling(account, longUntil, "oauth_401")
svc.BlockAccountScheduling(account, time.Time{}, "upstream_disable")
value, ok := svc.openaiAccountRuntimeBlockUntil.Load(account.ID)
require.True(t, ok)
actualUntil, ok := value.(time.Time)
require.True(t, ok)
require.WithinDuration(t, longUntil, actualUntil, time.Second)
}
func TestOpenAIRuntimeBlock_ClearAccountSchedulingBlock(t *testing.T) {
svc := &OpenAIGatewayService{}
account := &Account{ID: 47, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
svc.BlockAccountScheduling(account, time.Now().Add(time.Minute), "429")
require.True(t, svc.isOpenAIAccountRuntimeBlocked(account))
svc.ClearAccountSchedulingBlock(account.ID)
require.False(t, svc.isOpenAIAccountRuntimeBlocked(account))
}
func TestShouldStopOpenAIOAuth429Failover_OnlyDuringStorm(t *testing.T) {
svc := &OpenAIGatewayService{}
account := &Account{ID: 42, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
apiKeyAccount := &Account{ID: 43, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 1))
for i := 0; i < openAIOAuth429StormThreshold; i++ {
svc.recordOpenAIOAuth429()
}
require.True(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 1))
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(apiKeyAccount, http.StatusTooManyRequests, 1))
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusInternalServerError, 1))
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 0))
}

View File

@ -92,6 +92,16 @@ type openAIAccountSchedulerMetrics struct {
loadSkewMilliTotal atomic.Int64
}
type openAIAccountLoadPlan struct {
allCandidates []openAIAccountCandidateScore
candidates []openAIAccountCandidateScore
staleSnapshotCompactRetry []openAIAccountCandidateScore
selectionOrder []openAIAccountCandidateScore
candidateCount int
topK int
loadSkew float64
}
func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) {
if m == nil {
return
@ -360,7 +370,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {
if acquireErr == nil && result != nil && result.Acquired {
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
return &AccountSelectionResult{
Account: account,
@ -586,6 +596,234 @@ func buildOpenAIWeightedSelectionOrder(
return order
}
func (s *defaultOpenAIAccountScheduler) buildOpenAIAccountLoadPlan(
req OpenAIAccountScheduleRequest,
filtered []*Account,
loadMap map[int64]*AccountLoadInfo,
) openAIAccountLoadPlan {
allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered))
for _, account := range filtered {
loadInfo := loadMap[account.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: account.ID}
}
errorRate, ttft, hasTTFT := 0.0, 0.0, false
if s.stats != nil {
errorRate, ttft, hasTTFT = s.stats.snapshot(account.ID)
}
allCandidates = append(allCandidates, openAIAccountCandidateScore{
account: account,
loadInfo: loadInfo,
errorRate: errorRate,
ttft: ttft,
hasTTFT: hasTTFT,
})
}
candidates := allCandidates
staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates))
if req.RequireCompact {
candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates))
for _, candidate := range allCandidates {
if openAICompactSupportTier(candidate.account) == 0 {
staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate)
continue
}
candidates = append(candidates, candidate)
}
}
plan := openAIAccountLoadPlan{
allCandidates: allCandidates,
candidates: candidates,
staleSnapshotCompactRetry: staleSnapshotCompactRetry,
candidateCount: len(candidates),
}
if len(candidates) == 0 {
plan.selectionOrder = s.buildOpenAISelectionOrder(req, plan)
return plan
}
minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority
maxWaiting := 1
loadRateSum := 0.0
loadRateSumSquares := 0.0
minTTFT, maxTTFT := 0.0, 0.0
hasTTFTSample := false
for _, candidate := range candidates {
if candidate.account.Priority < minPriority {
minPriority = candidate.account.Priority
}
if candidate.account.Priority > maxPriority {
maxPriority = candidate.account.Priority
}
if candidate.loadInfo.WaitingCount > maxWaiting {
maxWaiting = candidate.loadInfo.WaitingCount
}
if candidate.hasTTFT && candidate.ttft > 0 {
if !hasTTFTSample {
minTTFT, maxTTFT = candidate.ttft, candidate.ttft
hasTTFTSample = true
} else {
if candidate.ttft < minTTFT {
minTTFT = candidate.ttft
}
if candidate.ttft > maxTTFT {
maxTTFT = candidate.ttft
}
}
}
loadRate := float64(candidate.loadInfo.LoadRate)
loadRateSum += loadRate
loadRateSumSquares += loadRate * loadRate
}
plan.loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
weights := s.service.openAIWSSchedulerWeights()
for i := range candidates {
item := &candidates[i]
priorityFactor := 1.0
if maxPriority > minPriority {
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
}
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
errorFactor := 1 - clamp01(item.errorRate)
ttftFactor := 0.5
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
}
quotaFactor := item.account.GetQuotaRemainingFraction()
item.score = weights.Priority*priorityFactor +
weights.Load*loadFactor +
weights.Queue*queueFactor +
weights.ErrorRate*errorFactor +
weights.TTFT*ttftFactor +
weights.Quota*quotaFactor
}
plan.candidates = candidates
plan.topK = s.service.openAIWSLBTopK()
if plan.topK > len(candidates) {
plan.topK = len(candidates)
}
if plan.topK <= 0 {
plan.topK = 1
}
plan.selectionOrder = s.buildOpenAISelectionOrder(req, plan)
return plan
}
func (s *defaultOpenAIAccountScheduler) buildOpenAISelectionOrder(
req OpenAIAccountScheduleRequest,
plan openAIAccountLoadPlan,
) []openAIAccountCandidateScore {
buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
if len(pool) == 0 || plan.topK <= 0 {
return nil
}
groupTopK := plan.topK
if groupTopK > len(pool) {
groupTopK = len(pool)
}
ranked := selectTopKOpenAICandidates(pool, groupTopK)
return buildOpenAIWeightedSelectionOrder(ranked, req)
}
if req.RequireCompact {
supported := make([]openAIAccountCandidateScore, 0, len(plan.candidates))
unknown := make([]openAIAccountCandidateScore, 0, len(plan.candidates))
for _, candidate := range plan.candidates {
switch openAICompactSupportTier(candidate.account) {
case 2:
supported = append(supported, candidate)
case 1:
unknown = append(unknown, candidate)
}
}
selectionOrder := make([]openAIAccountCandidateScore, 0, len(plan.allCandidates))
selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...)
selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...)
if len(plan.staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil {
selectionOrder = append(selectionOrder, sortOpenAICompactRetryCandidates(plan.staleSnapshotCompactRetry)...)
}
return selectionOrder
}
return buildSelectionOrder(plan.candidates)
}
func sortOpenAICompactRetryCandidates(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
if len(pool) == 0 {
return nil
}
ordered := append([]openAIAccountCandidateScore(nil), pool...)
sort.SliceStable(ordered, func(i, j int) bool {
a, b := ordered[i], ordered[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount {
return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
return ordered
}
func (s *defaultOpenAIAccountScheduler) tryAcquireOpenAISelectionOrder(
ctx context.Context,
req OpenAIAccountScheduleRequest,
selectionOrder []openAIAccountCandidateScore,
) (*AccountSelectionResult, bool, error) {
compactBlocked := false
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
compactBlocked = true
continue
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if acquireErr != nil {
return nil, compactBlocked, acquireErr
}
if result != nil && result.Acquired {
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
}
return &AccountSelectionResult{
Account: fresh,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, compactBlocked, nil
}
}
return nil, compactBlocked, nil
}
func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
ctx context.Context,
req OpenAIAccountScheduleRequest,
@ -616,8 +854,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
if !account.IsSchedulable() || !account.IsOpenAI() {
continue
}
if s.service.isOpenAIAccountRuntimeBlocked(account) {
continue
}
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
if schedGroup != nil && schedGroup.RequirePrivacySet && !account.IsPrivacySet() {
s.service.BlockAccountScheduling(account, time.Time{}, "privacy_not_set")
_ = s.service.accountRepo.SetError(ctx, account.ID,
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
continue
@ -645,214 +887,47 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}
}
allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered))
for _, account := range filtered {
loadInfo := loadMap[account.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: account.ID}
}
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
allCandidates = append(allCandidates, openAIAccountCandidateScore{
account: account,
loadInfo: loadInfo,
errorRate: errorRate,
ttft: ttft,
hasTTFT: hasTTFT,
})
plan := s.buildOpenAIAccountLoadPlan(req, filtered, loadMap)
candidateCount := plan.candidateCount
topK := plan.topK
loadSkew := plan.loadSkew
selectionOrder := plan.selectionOrder
if req.RequireCompact && len(plan.candidates) == 0 && len(plan.staleSnapshotCompactRetry) == 0 {
return nil, 0, 0, 0, ErrNoAvailableCompactAccounts
}
// Compact 模式下把明确不支持 compact 的账号拆出,仅在 schedulerSnapshot 启用
// 时作为最后兜底snapshot 可能已陈旧)。
candidates := allCandidates
staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates))
if req.RequireCompact {
candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates))
for _, candidate := range allCandidates {
if openAICompactSupportTier(candidate.account) == 0 {
staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate)
continue
}
candidates = append(candidates, candidate)
}
if len(candidates) == 0 && len(staleSnapshotCompactRetry) == 0 {
return nil, 0, 0, 0, ErrNoAvailableCompactAccounts
}
}
candidateCount := len(candidates)
loadSkew := 0.0
if len(candidates) > 0 {
minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority
maxWaiting := 1
loadRateSum := 0.0
loadRateSumSquares := 0.0
minTTFT, maxTTFT := 0.0, 0.0
hasTTFTSample := false
for _, candidate := range candidates {
if candidate.account.Priority < minPriority {
minPriority = candidate.account.Priority
}
if candidate.account.Priority > maxPriority {
maxPriority = candidate.account.Priority
}
if candidate.loadInfo.WaitingCount > maxWaiting {
maxWaiting = candidate.loadInfo.WaitingCount
}
if candidate.hasTTFT && candidate.ttft > 0 {
if !hasTTFTSample {
minTTFT, maxTTFT = candidate.ttft, candidate.ttft
hasTTFTSample = true
} else {
if candidate.ttft < minTTFT {
minTTFT = candidate.ttft
}
if candidate.ttft > maxTTFT {
maxTTFT = candidate.ttft
}
}
}
loadRate := float64(candidate.loadInfo.LoadRate)
loadRateSum += loadRate
loadRateSumSquares += loadRate * loadRate
}
loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
weights := s.service.openAIWSSchedulerWeights()
for i := range candidates {
item := &candidates[i]
priorityFactor := 1.0
if maxPriority > minPriority {
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
}
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
errorFactor := 1 - clamp01(item.errorRate)
ttftFactor := 0.5
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
}
quotaFactor := item.account.GetQuotaRemainingFraction()
item.score = weights.Priority*priorityFactor +
weights.Load*loadFactor +
weights.Queue*queueFactor +
weights.ErrorRate*errorFactor +
weights.TTFT*ttftFactor +
weights.Quota*quotaFactor
}
if s.service.openAIWSP2CEnabled() {
return s.selectByPowerOfTwo(ctx, req, candidates, loadSkew)
}
}
topK := 0
if len(candidates) > 0 {
topK = s.service.openAIWSLBTopK()
if topK > len(candidates) {
topK = len(candidates)
}
if topK <= 0 {
topK = 1
}
}
buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
if len(pool) == 0 || topK <= 0 {
return nil
}
groupTopK := topK
if groupTopK > len(pool) {
groupTopK = len(pool)
}
ranked := selectTopKOpenAICandidates(pool, groupTopK)
return buildOpenAIWeightedSelectionOrder(ranked, req)
}
sortCompactRetryCandidates := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
if len(pool) == 0 {
return nil
}
ordered := append([]openAIAccountCandidateScore(nil), pool...)
sort.SliceStable(ordered, func(i, j int) bool {
a, b := ordered[i], ordered[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount {
return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
return ordered
}
selectionOrder := make([]openAIAccountCandidateScore, 0, len(allCandidates))
if req.RequireCompact {
supported := make([]openAIAccountCandidateScore, 0, len(candidates))
unknown := make([]openAIAccountCandidateScore, 0, len(candidates))
for _, candidate := range candidates {
switch openAICompactSupportTier(candidate.account) {
case 2:
supported = append(supported, candidate)
case 1:
unknown = append(unknown, candidate)
}
}
if len(supported) == 0 && len(unknown) == 0 && s.service.schedulerSnapshot == nil {
return nil, candidateCount, topK, loadSkew, ErrNoAvailableCompactAccounts
}
selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...)
selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...)
if len(staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil {
selectionOrder = append(selectionOrder, sortCompactRetryCandidates(staleSnapshotCompactRetry)...)
}
} else {
selectionOrder = buildSelectionOrder(candidates)
if s.service.openAIWSP2CEnabled() && len(plan.candidates) > 0 {
return s.selectByPowerOfTwo(ctx, req, plan.candidates, loadSkew)
}
if len(selectionOrder) == 0 {
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(allCandidates) > 0)
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(plan.allCandidates) > 0)
}
compactBlocked := false
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
continue
}
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
compactBlocked = true
continue
}
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if acquireErr != nil {
return nil, candidateCount, topK, loadSkew, acquireErr
}
if result != nil && result.Acquired {
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
result, compactBlocked, acquireErr := s.tryAcquireOpenAISelectionOrder(ctx, req, selectionOrder)
if acquireErr != nil {
return nil, candidateCount, topK, loadSkew, acquireErr
}
if result != nil {
return result, candidateCount, topK, loadSkew, nil
}
if s.service.concurrencyService != nil {
if freshLoadMap, loadErr := s.service.concurrencyService.GetAccountsLoadBatchFresh(ctx, loadReq); loadErr == nil {
freshPlan := s.buildOpenAIAccountLoadPlan(req, filtered, freshLoadMap)
if len(freshPlan.selectionOrder) > 0 {
freshResult, freshCompactBlocked, freshAcquireErr := s.tryAcquireOpenAISelectionOrder(ctx, req, freshPlan.selectionOrder)
if freshAcquireErr != nil {
return nil, candidateCount, topK, loadSkew, freshAcquireErr
}
if freshResult != nil {
return freshResult, freshPlan.candidateCount, freshPlan.topK, freshPlan.loadSkew, nil
}
compactBlocked = compactBlocked || freshCompactBlocked
selectionOrder = freshPlan.selectionOrder
candidateCount = freshPlan.candidateCount
topK = freshPlan.topK
loadSkew = freshPlan.loadSkew
}
return &AccountSelectionResult{
Account: fresh,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, candidateCount, topK, loadSkew, nil
}
}
@ -899,6 +974,9 @@ func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.C
if account == nil {
return false
}
if s != nil && s.service != nil && s.service.isOpenAIAccountRuntimeBlocked(account) {
return false
}
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
return false
}

View File

@ -276,9 +276,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,

View File

@ -206,9 +206,7 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,

View File

@ -337,9 +337,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,

View File

@ -187,9 +187,7 @@ func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions(
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.rateLimitService != nil {
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
return nil, &UpstreamFailoverError{
StatusCode: resp.StatusCode,
ResponseBody: respBody,

View File

@ -22,6 +22,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
@ -354,6 +355,9 @@ type OpenAIGatewayService struct {
openaiAccountStats *openAIAccountRuntimeStats
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
openaiAccountRuntimeBlockUntil sync.Map // key: int64(accountID), value: time.Time
openaiOAuth429WindowStartUnixNano atomic.Int64
openaiOAuth429WindowCount atomic.Int64
openaiWSRetryMetrics openAIWSRetryMetrics
responseHeaderFilter *responseheaders.CompiledHeaderFilter
codexSnapshotThrottle *accountWriteThrottle
@ -417,6 +421,12 @@ func NewOpenAIGatewayService(
responseHeaderFilter: compileResponseHeaderFilter(cfg),
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
}
if rateLimitService != nil {
rateLimitService.SetAccountRuntimeBlocker(svc)
}
if openAITokenProvider != nil {
openAITokenProvider.SetAccountRuntimeBlocker(svc)
}
svc.logOpenAIWSModeBootstrap()
return svc
}
@ -962,7 +972,7 @@ func appendCodexCLIOnlyRejectedRequestFields(fields []zap.Field, c *gin.Context,
zap.String("request_path", strings.TrimSpace(req.URL.Path)),
zap.String("request_query", strings.TrimSpace(req.URL.RawQuery)),
zap.String("request_host", strings.TrimSpace(req.Host)),
zap.String("request_client_ip", strings.TrimSpace(c.ClientIP())),
zap.String("request_client_ip", strings.TrimSpace(ip.GetClientIP(c))),
zap.String("request_remote_addr", strings.TrimSpace(req.RemoteAddr)),
zap.String("request_user_agent", strings.TrimSpace(req.Header.Get("User-Agent"))),
zap.String("request_content_type", strings.TrimSpace(req.Header.Get("Content-Type"))),
@ -1381,13 +1391,18 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked)
}
hydrated, err := s.hydrateSelectedAccount(ctx, selected)
if err != nil {
return nil, err
}
// 4. 设置粘性会话绑定
// Set sticky session binding
if sessionHash != "" {
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL)
}
return s.hydrateSelectedAccount(ctx, selected)
return hydrated, nil
}
// tryStickySessionHit 尝试从粘性会话获取账号。
@ -1430,6 +1445,10 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
if !isOpenAIAccountEligibleForRequest(account, requestedModel, false) {
return nil
}
if s.isOpenAIAccountRuntimeBlocked(account) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
return nil
}
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
@ -1575,8 +1594,8 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
return nil, err
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
if err == nil && result != nil && result.Acquired {
return s.newAcquiredSelectionResult(ctx, account, result.ReleaseFunc)
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
@ -1627,13 +1646,19 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
if account == nil {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else if s.isOpenAIAccountRuntimeBlocked(account) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) {
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
} else {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
if err == nil && result != nil && result.Acquired {
selection, selectErr := s.newAcquiredSelectionResult(ctx, account, result.ReleaseFunc)
if selectErr != nil {
return nil, selectErr
}
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
return selection, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
@ -1665,6 +1690,9 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
if !acc.IsSchedulable() {
continue
}
if s.isOpenAIAccountRuntimeBlocked(acc) {
continue
}
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
@ -1687,6 +1715,92 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
})
}
tryAcquireFromLoadMap := func(loadMap map[int64]*AccountLoadInfo) (*AccountSelectionResult, bool, error) {
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
}
if loadInfo.LoadRate < 100 {
available = append(available, accountWithLoad{
account: acc,
loadInfo: loadInfo,
})
}
}
if len(available) == 0 {
return nil, false, nil
}
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
shuffleWithinSortGroups(available)
selectionOrder := make([]accountWithLoad, 0, len(available))
if requireCompact {
appendTier := func(out []accountWithLoad, tier int) []accountWithLoad {
for _, item := range available {
if openAICompactSupportTier(item.account) == tier {
out = append(out, item)
}
}
return out
}
selectionOrder = appendTier(selectionOrder, 2)
selectionOrder = appendTier(selectionOrder, 1)
// tier 0 候选作为兜底追加DB recheck 时若发现 cache tier 0 实际
// 已升级为 1/2探测刚跑完cache 尚未刷新),仍可正常命中。
selectionOrder = appendTier(selectionOrder, 0)
} else {
selectionOrder = append(selectionOrder, available...)
}
for _, item := range selectionOrder {
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false)
if fresh == nil {
continue
}
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
if fresh == nil {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if err == nil && result != nil && result.Acquired {
selection, selectErr := s.newAcquiredSelectionResult(ctx, fresh, result.ReleaseFunc)
if selectErr != nil {
return nil, true, selectErr
}
if sessionHash != "" {
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
}
return selection, true, nil
}
}
return nil, true, nil
}
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
ordered := append([]*Account(nil), candidates...)
@ -1707,87 +1821,28 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if err == nil && result.Acquired {
if err == nil && result != nil && result.Acquired {
selection, selectErr := s.newAcquiredSelectionResult(ctx, fresh, result.ReleaseFunc)
if selectErr != nil {
return nil, selectErr
}
if sessionHash != "" {
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
}
return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil)
return selection, nil
}
}
} else {
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
}
if loadInfo.LoadRate < 100 {
available = append(available, accountWithLoad{
account: acc,
loadInfo: loadInfo,
})
}
}
if len(available) > 0 {
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
shuffleWithinSortGroups(available)
selectionOrder := make([]accountWithLoad, 0, len(available))
if requireCompact {
appendTier := func(out []accountWithLoad, tier int) []accountWithLoad {
for _, item := range available {
if openAICompactSupportTier(item.account) == tier {
out = append(out, item)
}
}
return out
}
selectionOrder = appendTier(selectionOrder, 2)
selectionOrder = appendTier(selectionOrder, 1)
// tier 0 候选作为兜底追加DB recheck 时若发现 cache tier 0 实际
// 已升级为 1/2探测刚跑完cache 尚未刷新),仍可正常命中。
selectionOrder = appendTier(selectionOrder, 0)
} else {
selectionOrder = append(selectionOrder, available...)
}
for _, item := range selectionOrder {
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false)
if fresh == nil {
continue
}
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
if fresh == nil {
continue
}
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
continue
}
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
}
return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil)
if selection, attempted, selectErr := tryAcquireFromLoadMap(loadMap); selectErr != nil {
return nil, selectErr
} else if selection != nil {
return selection, nil
} else if attempted {
if freshLoadMap, loadErr := s.concurrencyService.GetAccountsLoadBatchFresh(ctx, accountLoads); loadErr == nil {
if selection, _, selectErr := tryAcquireFromLoadMap(freshLoadMap); selectErr != nil {
return nil, selectErr
} else if selection != nil {
return selection, nil
}
}
}
@ -1868,6 +1923,9 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
if !isOpenAIAccountEligibleForRequest(fresh, requestedModel, requireCompact) {
return nil
}
if s.isOpenAIAccountRuntimeBlocked(fresh) {
return nil
}
return fresh
}
@ -1889,6 +1947,9 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
if !isOpenAIAccountEligibleForRequest(latest, requestedModel, requireCompact) {
return nil
}
if s.isOpenAIAccountRuntimeBlocked(latest) {
return nil
}
return latest
}
@ -1935,6 +1996,14 @@ func (s *OpenAIGatewayService) newSelectionResult(ctx context.Context, account *
}, nil
}
func (s *OpenAIGatewayService) newAcquiredSelectionResult(ctx context.Context, account *Account, release func()) (*AccountSelectionResult, error) {
selection, err := s.newSelectionResult(ctx, account, true, release, nil)
if err != nil && release != nil {
release()
}
return selection, err
}
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
if s.cfg != nil {
return s.cfg.Gateway.Scheduling
@ -1996,7 +2065,7 @@ func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode i
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
// Forward forwards request to OpenAI API
@ -3278,9 +3347,7 @@ func (s *OpenAIGatewayService) handleFailoverErrorResponsePassthrough(
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
if s.rateLimitService != nil {
_ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
@ -3321,12 +3388,9 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
if s.rateLimitService != nil {
// Passthrough mode preserves the raw upstream error response, but runtime
// account state still needs to be updated so sticky routing can stop
// reusing a freshly rate-limited account.
_ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
// 透传模式保留原始上游错误响应,但运行态账号状态仍需更新,
// 避免粘性路由继续复用刚被限流的账号。
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
@ -4075,10 +4139,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(
}
// Handle upstream error (mark account status)
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
shouldDisable := s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
kind := "http_error"
if shouldDisable {
kind = "failover"
@ -4210,12 +4271,9 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse(
}
// Track rate limits and decide whether to trigger secondary failover.
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
)
}
shouldDisable := s.handleOpenAIAccountUpstreamError(
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
)
kind := "http_error"
if shouldDisable {
kind = "failover"

View File

@ -131,8 +131,10 @@ func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil))
c.Request.RemoteAddr = "172.18.0.1:54321"
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")
c.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("X-Real-IP", "203.0.113.42")
c.Request.Header.Set("OpenAI-Beta", "assistants=v2")
body := []byte(`{"model":"gpt-5.2","stream":false,"prompt_cache_key":"pc-123","access_token":"secret-token","input":[{"type":"text","text":"hello"}]}`)
@ -146,6 +148,8 @@ func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) {
require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown"))
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.2"))
require.True(t, logSink.ContainsFieldValue("request_query", "trace=1"))
require.True(t, logSink.ContainsFieldValue("request_client_ip", "203.0.113.42"))
require.True(t, logSink.ContainsFieldValue("request_remote_addr", "172.18.0.1:54321"))
require.True(t, logSink.ContainsFieldValue("request_prompt_cache_key_sha256", hashSensitiveValueForLog("pc-123")))
require.True(t, logSink.ContainsFieldValue("request_headers", "openai-beta"))
require.True(t, logSink.ContainsField("request_body_size"))

View File

@ -29,6 +29,69 @@ type openAIResponsesImageResult struct {
Model string
}
type OpenAIImagesUpstreamError struct {
StatusCode int
ErrorType string
Code string
Message string
Param string
UpstreamRequestID string
}
func (e *OpenAIImagesUpstreamError) Error() string {
if e == nil {
return ""
}
code := strings.TrimSpace(e.Code)
if code == "" {
code = strings.TrimSpace(e.ErrorType)
}
message := strings.TrimSpace(e.Message)
if code != "" && message != "" {
return fmt.Sprintf("openai images upstream error: %s: %s", code, message)
}
if message != "" {
return "openai images upstream error: " + message
}
if code != "" {
return "openai images upstream error: " + code
}
return "openai images upstream error"
}
func (e *OpenAIImagesUpstreamError) clientStatusCode() int {
if e == nil {
return http.StatusBadGateway
}
if e.StatusCode > 0 {
return e.StatusCode
}
return http.StatusBadGateway
}
func (e *OpenAIImagesUpstreamError) clientErrorType() string {
if e == nil {
return "upstream_error"
}
if trimmed := strings.TrimSpace(e.ErrorType); trimmed != "" {
return trimmed
}
return "upstream_error"
}
func (e *OpenAIImagesUpstreamError) clientMessage() string {
if e == nil {
return "Upstream request failed"
}
if trimmed := strings.TrimSpace(e.Message); trimmed != "" {
return trimmed
}
if trimmed := strings.TrimSpace(e.Code); trimmed != "" {
return trimmed
}
return "Upstream request failed"
}
func openAIResponsesImageResultKey(itemID string, result openAIResponsesImageResult) string {
if strings.TrimSpace(result.Result) != "" {
return strings.TrimSpace(result.OutputFormat) + "|" + strings.TrimSpace(result.Result)
@ -465,6 +528,57 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
return nil, createdAt, usageRaw, openAIResponsesImageResult{}, foundFinal, nil
}
func extractOpenAIImagesUpstreamError(body []byte) *OpenAIImagesUpstreamError {
var upstreamErr *OpenAIImagesUpstreamError
forEachOpenAISSEDataPayload(string(body), func(payload []byte) {
if upstreamErr != nil || !gjson.ValidBytes(payload) {
return
}
upstreamErr = openAIImagesUpstreamErrorFromSSEPayload(payload)
})
return upstreamErr
}
func openAIImagesUpstreamErrorFromSSEPayload(payload []byte) *OpenAIImagesUpstreamError {
if !gjson.ValidBytes(payload) {
return nil
}
switch gjson.GetBytes(payload, "type").String() {
case "error":
return openAIImagesUpstreamErrorFromGJSON(gjson.GetBytes(payload, "error"), "")
case "response.failed":
response := gjson.GetBytes(payload, "response")
return openAIImagesUpstreamErrorFromGJSON(response.Get("error"), response.Get("id").String())
default:
return nil
}
}
func openAIImagesUpstreamErrorFromGJSON(errorObj gjson.Result, upstreamRequestID string) *OpenAIImagesUpstreamError {
if !errorObj.Exists() {
return nil
}
code := strings.TrimSpace(errorObj.Get("code").String())
errType := strings.TrimSpace(errorObj.Get("type").String())
message := strings.TrimSpace(errorObj.Get("message").String())
param := strings.TrimSpace(errorObj.Get("param").String())
statusCode := http.StatusBadGateway
if strings.EqualFold(code, "moderation_blocked") || strings.EqualFold(errType, "image_generation_user_error") {
statusCode = http.StatusBadRequest
}
if message == "" {
message = "Upstream request failed"
}
return &OpenAIImagesUpstreamError{
StatusCode: statusCode,
ErrorType: errType,
Code: code,
Message: sanitizeUpstreamErrorMessage(message),
Param: param,
UpstreamRequestID: strings.TrimSpace(upstreamRequestID),
}
}
func buildOpenAIImagesAPIResponse(
results []openAIResponsesImageResult,
createdAt int64,
@ -531,6 +645,41 @@ func buildOpenAIImagesStreamErrorBody(message string) []byte {
return body
}
func buildOpenAIImagesStreamErrorBodyFromUpstream(err *OpenAIImagesUpstreamError) []byte {
if err == nil {
return buildOpenAIImagesStreamErrorBody("")
}
body := buildOpenAIImagesStreamErrorBody(err.clientMessage())
body, _ = sjson.SetBytes(body, "error.type", err.clientErrorType())
if code := strings.TrimSpace(err.Code); code != "" {
body, _ = sjson.SetBytes(body, "error.code", code)
}
if param := strings.TrimSpace(err.Param); param != "" {
body, _ = sjson.SetBytes(body, "error.param", param)
}
return body
}
func writeOpenAIImagesUpstreamErrorResponse(c *gin.Context, err *OpenAIImagesUpstreamError) bool {
if c == nil || c.Writer == nil || c.Writer.Written() || err == nil {
return false
}
errorObj := gin.H{
"type": err.clientErrorType(),
"message": err.clientMessage(),
}
if code := strings.TrimSpace(err.Code); code != "" {
errorObj["code"] = code
}
if param := strings.TrimSpace(err.Param); param != "" {
errorObj["param"] = param
}
c.JSON(err.clientStatusCode(), gin.H{
"error": errorObj,
})
return true
}
func (s *OpenAIGatewayService) writeOpenAIImagesStreamEvent(c *gin.Context, flusher http.Flusher, eventName string, payload []byte) error {
if strings.TrimSpace(eventName) != "" {
if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil {
@ -588,6 +737,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
return OpenAIUsage{}, 0, nil, err
}
if len(results) == 0 {
if upstreamErr := extractOpenAIImagesUpstreamError(body); upstreamErr != nil {
setOpsUpstreamError(c, upstreamErr.clientStatusCode(), upstreamErr.clientMessage(), "")
writeOpenAIImagesUpstreamErrorResponse(c, upstreamErr)
return OpenAIUsage{}, 0, nil, upstreamErr
}
return OpenAIUsage{}, 0, nil, fmt.Errorf("upstream did not return image output")
}
if strings.TrimSpace(firstMeta.Model) == "" {
@ -742,6 +896,16 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
imageCount = len(emitted)
imageOutputSizes = openAIResponsesImageResultSizes(finalResults)
processDataDone = true
case "error", "response.failed":
if upstreamErr := openAIImagesUpstreamErrorFromSSEPayload(dataBytes); upstreamErr != nil {
if !clientDisconnected {
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBodyFromUpstream(upstreamErr))
}
setOpsUpstreamError(c, upstreamErr.clientStatusCode(), upstreamErr.clientMessage(), "")
processDataErr = upstreamErr
processDataDone = true
return
}
}
}

View File

@ -553,6 +553,59 @@ func TestOpenAIGatewayServiceForwardImages_OAuthPassesNAndReturnsAllImages(t *te
require.Equal(t, "draw a cat 3", gjson.Get(rec.Body.String(), "data.2.revised_prompt").String())
}
func TestOpenAIGatewayServiceForwardImages_OAuthNonStreamModerationBlockedReturnsClientError(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw blocked image","response_format":"b64_json"}`)
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = req
c.Set("api_key", &APIKey{ID: 42})
svc := &OpenAIGatewayService{}
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
require.NoError(t, err)
svc.httpUpstream = &httpUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/event-stream"},
"X-Request-Id": []string{"req_img_blocked"},
},
Body: io.NopCloser(strings.NewReader(
"data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000020}}\n\n" +
"data: {\"type\":\"error\",\"error\":{\"type\":\"image_generation_user_error\",\"code\":\"moderation_blocked\",\"message\":\"Your request was rejected by the safety system. safety_violations=[sexual].\"}}\n\n" +
"data: {\"type\":\"response.failed\",\"response\":{\"id\":\"resp_blocked\",\"status\":\"failed\",\"error\":{\"type\":\"image_generation_user_error\",\"code\":\"moderation_blocked\",\"message\":\"Your request was rejected by the safety system. safety_violations=[sexual].\"}}}\n\n",
)),
},
}
account := &Account{
ID: 1,
Name: "openai-oauth",
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token-123",
},
}
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
require.Nil(t, result)
var upstreamErr *OpenAIImagesUpstreamError
require.ErrorAs(t, err, &upstreamErr)
require.Equal(t, http.StatusBadRequest, upstreamErr.StatusCode)
require.Equal(t, "moderation_blocked", upstreamErr.Code)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Equal(t, "image_generation_user_error", gjson.Get(rec.Body.String(), "error.type").String())
require.Equal(t, "moderation_blocked", gjson.Get(rec.Body.String(), "error.code").String())
require.Contains(t, gjson.Get(rec.Body.String(), "error.message").String(), "safety system")
}
func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) {
gin.SetMode(gin.TestMode)
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","response_format":"b64_json"}`)

View File

@ -80,6 +80,7 @@ type OpenAITokenProvider struct {
accountRepo AccountRepository
tokenCache OpenAITokenCache
openAIOAuthService *OpenAIOAuthService
runtimeBlocker AccountRuntimeBlocker
metrics *openAITokenRuntimeMetricsStore
refreshAPI *OAuthRefreshAPI
executor OAuthRefreshExecutor
@ -111,6 +112,10 @@ func (p *OpenAITokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
p.refreshPolicy = policy
}
func (p *OpenAITokenProvider) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) {
p.runtimeBlocker = blocker
}
func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics {
if p == nil {
return OpenAITokenRuntimeMetrics{}
@ -275,6 +280,9 @@ func (p *OpenAITokenProvider) disableAccountMissingRefreshToken(account *Account
if p == nil || p.accountRepo == nil || account == nil {
return
}
if p.runtimeBlocker != nil {
p.runtimeBlocker.BlockAccountScheduling(account, time.Time{}, "missing_refresh_token")
}
bgCtx := context.Background()
if err := p.accountRepo.SetError(bgCtx, account.ID, reason); err != nil {
slog.Warn("openai_token_provider.set_error_failed",

View File

@ -952,6 +952,8 @@ func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T)
cache.getErr = errors.New("simulated cache miss")
provider := NewOpenAITokenProvider(repo, cache, nil)
blocker := &runtimeBlockRecorder{}
provider.SetAccountRuntimeBlocker(blocker)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
@ -960,4 +962,7 @@ func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T)
require.Equal(t, 1, repo.setErrorCalls, "account should be disabled via SetError exactly once")
require.Contains(t, repo.lastErrorMsg, "refresh_token is missing")
require.Len(t, blocker.accounts, 1)
require.Equal(t, account.ID, blocker.accounts[0].ID)
require.Equal(t, "missing_refresh_token", blocker.reasons[0])
}

View File

@ -4091,7 +4091,7 @@ func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Contex
if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
return
}
s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody)
s.handleOpenAIAccountUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody)
}
func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {

View File

@ -115,6 +115,10 @@ func (r *paymentOrderLifecycleRedeemRepo) Update(context.Context, *RedeemCode) e
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) BatchUpdate(context.Context, []int64, RedeemCodeBatchUpdateFields) (int64, error) {
panic("unexpected call")
}
func (r *paymentOrderLifecycleRedeemRepo) Delete(context.Context, int64) error {
panic("unexpected call")
}

View File

@ -28,10 +28,16 @@ type RateLimitService struct {
openAI403CounterCache OpenAI403CounterCache
settingService *SettingService
tokenCacheInvalidator TokenCacheInvalidator
runtimeBlocker AccountRuntimeBlocker
usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
}
type AccountRuntimeBlocker interface {
BlockAccountScheduling(account *Account, until time.Time, reason string)
ClearAccountSchedulingBlock(accountID int64)
}
// SuccessfulTestRecoveryResult 表示测试成功后恢复了哪些运行时状态。
type SuccessfulTestRecoveryResult struct {
ClearedError bool
@ -98,6 +104,24 @@ func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvali
s.tokenCacheInvalidator = invalidator
}
func (s *RateLimitService) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) {
s.runtimeBlocker = blocker
}
func (s *RateLimitService) notifyAccountSchedulingBlocked(account *Account, until time.Time, reason string) {
if s == nil || s.runtimeBlocker == nil || account == nil {
return
}
s.runtimeBlocker.BlockAccountScheduling(account, until, reason)
}
func (s *RateLimitService) notifyAccountSchedulingBlockCleared(accountID int64) {
if s == nil || s.runtimeBlocker == nil || accountID <= 0 {
return
}
s.runtimeBlocker.ClearAccountSchedulingBlock(accountID)
}
// ErrorPolicyResult 表示错误策略检查的结果
type ErrorPolicyResult int
@ -240,6 +264,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
cooldownMinutes = 10
}
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
s.notifyAccountSchedulingBlocked(account, until, "oauth_401")
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, msg); err != nil {
slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
}
@ -678,6 +703,7 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
// handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
s.notifyAccountSchedulingBlocked(account, time.Time{}, "auth_error")
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err)
return
@ -758,6 +784,7 @@ func (s *RateLimitService) handleOpenAI403(ctx context.Context, account *Account
until := time.Now().Add(time.Duration(openAI403CooldownMinutesDefault) * time.Minute)
reason := fmt.Sprintf("OpenAI 403 temporary cooldown (%d/%d): %s", count, openAI403DisableThreshold, msg)
s.notifyAccountSchedulingBlocked(account, until, "openai_403_temp")
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
slog.Warn("openai_403_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
s.handleAuthError(ctx, account, msg)
@ -823,6 +850,7 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac
// handleCustomErrorCode 处理自定义错误码,停止账号调度
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
s.notifyAccountSchedulingBlocked(account, time.Time{}, "custom_error_code")
if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil {
slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err)
return
@ -838,6 +866,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
persistOpenAI429PlanType(ctx, s.accountRepo, account, responseBody)
s.persistOpenAICodexSnapshot(ctx, account, headers)
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
s.notifyAccountSchedulingBlocked(account, *resetAt, "429")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
@ -849,6 +878,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 2. Anthropic 平台:尝试解析 per-window 头5h / 7d选择实际触发的窗口
if result := calculateAnthropic429ResetTime(headers); result != nil {
s.notifyAccountSchedulingBlocked(account, result.resetAt, "429")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
@ -878,6 +908,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 尝试解析 OpenAI 的 usage_limit_reached 错误
if resetAt := parseOpenAIRateLimitResetTime(responseBody); resetAt != nil {
resetTime := time.Unix(*resetAt, 0)
s.notifyAccountSchedulingBlocked(account, resetTime, "429")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
@ -889,6 +920,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 尝试解析 Gemini 格式(用于其他平台)
if resetAt := ParseGeminiRateLimitResetTime(responseBody); resetAt != nil {
resetTime := time.Unix(*resetAt, 0)
s.notifyAccountSchedulingBlocked(account, resetTime, "429")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
@ -924,6 +956,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
resetAt := time.Unix(ts, 0)
// 标记限流状态
s.notifyAccountSchedulingBlocked(account, resetAt, "429")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
@ -948,6 +981,7 @@ func (s *RateLimitService) apply429FallbackRateLimit(ctx context.Context, accoun
resetAt := time.Now().Add(cooldown)
slog.Warn("rate_limit_429_fallback_used", "account_id", account.ID, "platform", account.Platform, "reason", reason, "using_default", cooldown.String())
s.notifyAccountSchedulingBlocked(account, resetAt, "429_fallback")
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
}
@ -1291,6 +1325,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
}
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
s.notifyAccountSchedulingBlocked(account, until, "529")
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
slog.Warn("overload_set_failed", "account_id", account.ID, "error", err)
return
@ -1420,6 +1455,7 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
}
}
s.ResetOpenAI403Counter(ctx, accountID)
s.notifyAccountSchedulingBlockCleared(accountID)
return nil
}
@ -1460,6 +1496,9 @@ func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID in
}
if result.ClearedError || result.ClearedRateLimit {
s.ResetOpenAI403Counter(ctx, accountID)
if result.ClearedError && !result.ClearedRateLimit {
s.notifyAccountSchedulingBlockCleared(accountID)
}
}
return result, nil
@ -1484,6 +1523,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
if err := s.accountRepo.ClearModelRateLimits(ctx, accountID); err != nil {
slog.Warn("clear_model_rate_limits_on_temp_unsched_reset_failed", "account_id", accountID, "error", err)
}
s.notifyAccountSchedulingBlockCleared(accountID)
return nil
}
@ -1694,6 +1734,7 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
reason = strings.TrimSpace(state.ErrorMessage)
}
s.notifyAccountSchedulingBlocked(account, until, "temp_unschedulable")
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err)
return false
@ -1798,6 +1839,7 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
reason = state.ErrorMessage
}
s.notifyAccountSchedulingBlocked(account, until, "stream_timeout_temp_unschedulable")
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err)
return false
@ -1824,6 +1866,7 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, account *Account, model string) bool {
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
s.notifyAccountSchedulingBlocked(account, time.Time{}, "stream_timeout_error")
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err)
return false

View File

@ -6,16 +6,36 @@ import (
"context"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type runtimeBlockRecorder struct {
accounts []*Account
until []time.Time
reasons []string
clearedIDs []int64
}
func (r *runtimeBlockRecorder) BlockAccountScheduling(account *Account, until time.Time, reason string) {
r.accounts = append(r.accounts, account)
r.until = append(r.until, until)
r.reasons = append(r.reasons, reason)
}
func (r *runtimeBlockRecorder) ClearAccountSchedulingBlock(accountID int64) {
r.clearedIDs = append(r.clearedIDs, accountID)
}
func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
counter := &openAI403CounterCacheStub{counts: []int64{1}}
blocker := &runtimeBlockRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetOpenAI403CounterCache(counter)
service.SetAccountRuntimeBlocker(blocker)
account := &Account{
ID: 301,
Platform: PlatformOpenAI,
@ -35,6 +55,10 @@ func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable
require.Equal(t, 1, repo.tempCalls)
require.Contains(t, repo.lastTempReason, "temporary edge rejection")
require.Contains(t, repo.lastTempReason, "(1/3)")
require.Len(t, blocker.accounts, 1)
require.Equal(t, account.ID, blocker.accounts[0].ID)
require.Equal(t, "openai_403_temp", blocker.reasons[0])
require.True(t, blocker.until[0].After(time.Now()))
}
func TestRateLimitService_HandleUpstreamError_OpenAI403ThresholdDisables(t *testing.T) {

View File

@ -219,7 +219,9 @@ func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLi
},
}
cache := &tempUnschedCacheRecorder{}
blocker := &runtimeBlockRecorder{}
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache)
svc.SetAccountRuntimeBlocker(blocker)
result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 42)
require.NoError(t, err)
@ -234,6 +236,7 @@ func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLi
require.Equal(t, 1, repo.clearModelRateLimitCalls)
require.Equal(t, 1, repo.clearTempUnschedCalls)
require.Equal(t, []int64{42}, cache.deletedIDs)
require.Equal(t, []int64{42}, blocker.clearedIDs)
}
func TestRateLimitService_RecoverAccountAfterSuccessfulTest_NoRecoverableStateIsNoop(t *testing.T) {

View File

@ -54,6 +54,7 @@ type RedeemCodeRepository interface {
GetByID(ctx context.Context, id int64) (*RedeemCode, error)
GetByCode(ctx context.Context, code string) (*RedeemCode, error)
Update(ctx context.Context, code *RedeemCode) error
BatchUpdate(ctx context.Context, ids []int64, fields RedeemCodeBatchUpdateFields) (int64, error)
Delete(ctx context.Context, id int64) error
Use(ctx context.Context, id, userID int64) error
@ -82,6 +83,54 @@ type RedeemCodeResponse struct {
CreatedAt time.Time `json:"created_at"`
}
type NullableTimeUpdate struct {
Set bool
Value *time.Time
}
type NullableInt64Update struct {
Set bool
Value *int64
}
type RedeemCodeBatchUpdateFields struct {
Status *string
ExpiresAt NullableTimeUpdate
Notes *string
GroupID NullableInt64Update
// Core fields are intentionally modeled only so service validation can
// reject payloads that try to mutate redemption value semantics in bulk.
Type *string
Value *float64
}
func (f RedeemCodeBatchUpdateFields) HasChanges() bool {
return f.Status != nil ||
f.ExpiresAt.Set ||
f.Notes != nil ||
f.GroupID.Set ||
f.Type != nil ||
f.Value != nil
}
func (f RedeemCodeBatchUpdateFields) HasCoreFieldChanges() bool {
return f.Type != nil || f.Value != nil
}
func (f RedeemCodeBatchUpdateFields) TouchesUsedSensitiveFields() bool {
return f.Status != nil || f.ExpiresAt.Set || f.GroupID.Set
}
type RedeemCodeBatchUpdateInput struct {
IDs []int64
Fields RedeemCodeBatchUpdateFields
}
type RedeemCodeBatchUpdateResult struct {
Updated int64 `json:"updated"`
}
// RedeemService 兑换码服务
type RedeemService struct {
redeemRepo RedeemCodeRepository
@ -218,6 +267,61 @@ func (s *RedeemService) CreateCode(ctx context.Context, code *RedeemCode) error
return nil
}
func (s *RedeemService) BatchUpdate(ctx context.Context, input *RedeemCodeBatchUpdateInput) (*RedeemCodeBatchUpdateResult, error) {
if input == nil {
return nil, infraerrors.BadRequest("REDEEM_CODE_BATCH_UPDATE_INVALID", "batch update input is required")
}
if len(input.IDs) == 0 {
return nil, infraerrors.BadRequest("REDEEM_CODE_BATCH_UPDATE_IDS_REQUIRED", "ids are required")
}
if !input.Fields.HasChanges() {
return nil, infraerrors.BadRequest("REDEEM_CODE_BATCH_UPDATE_EMPTY", "at least one field must be selected")
}
if input.Fields.HasCoreFieldChanges() {
return nil, infraerrors.BadRequest("REDEEM_CODE_CORE_FIELDS_IMMUTABLE", "type and value cannot be batch updated")
}
ids := make([]int64, 0, len(input.IDs))
seen := make(map[int64]struct{}, len(input.IDs))
for _, id := range input.IDs {
if id <= 0 {
return nil, infraerrors.BadRequest("REDEEM_CODE_BATCH_UPDATE_INVALID_ID", "ids must be positive")
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
ids = append(ids, id)
}
if len(ids) == 0 {
return nil, infraerrors.BadRequest("REDEEM_CODE_BATCH_UPDATE_IDS_REQUIRED", "ids are required")
}
if input.Fields.Status != nil {
switch *input.Fields.Status {
case StatusUnused, StatusDisabled:
default:
return nil, infraerrors.BadRequest("REDEEM_CODE_STATUS_INVALID", "status must be unused or disabled")
}
}
if input.Fields.ExpiresAt.Set && input.Fields.ExpiresAt.Value != nil {
expiresAt := input.Fields.ExpiresAt.Value.UTC()
if !expiresAt.After(time.Now().UTC()) {
return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRES_AT_INVALID", "expires_at must be in the future")
}
input.Fields.ExpiresAt.Value = &expiresAt
}
if input.Fields.GroupID.Set && input.Fields.GroupID.Value != nil && *input.Fields.GroupID.Value <= 0 {
return nil, infraerrors.BadRequest("REDEEM_CODE_GROUP_ID_INVALID", "group_id must be positive")
}
updated, err := s.redeemRepo.BatchUpdate(ctx, ids, input.Fields)
if err != nil {
return nil, err
}
return &RedeemCodeBatchUpdateResult{Updated: updated}, nil
}
// checkRedeemRateLimit 检查用户兑换错误次数是否超限
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
if s.cache == nil {

View File

@ -0,0 +1,75 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
func TestRedeemService_BatchUpdate_PartialFields(t *testing.T) {
status := StatusDisabled
notes := "maintenance window"
expiresAt := time.Now().UTC().Add(24 * time.Hour)
repo := &redeemRepoStub{}
svc := &RedeemService{redeemRepo: repo}
result, err := svc.BatchUpdate(context.Background(), &RedeemCodeBatchUpdateInput{
IDs: []int64{1, 2, 2},
Fields: RedeemCodeBatchUpdateFields{
Status: &status,
ExpiresAt: NullableTimeUpdate{Set: true, Value: &expiresAt},
Notes: &notes,
},
})
require.NoError(t, err)
require.Equal(t, int64(2), result.Updated)
require.True(t, repo.batchUpdateCalled)
require.Equal(t, []int64{1, 2}, repo.batchUpdateIDs)
require.Equal(t, &status, repo.batchUpdateFields.Status)
require.True(t, repo.batchUpdateFields.ExpiresAt.Set)
require.WithinDuration(t, expiresAt, *repo.batchUpdateFields.ExpiresAt.Value, time.Second)
require.Equal(t, &notes, repo.batchUpdateFields.Notes)
require.False(t, repo.batchUpdateFields.GroupID.Set)
require.Nil(t, repo.batchUpdateFields.Type)
require.Nil(t, repo.batchUpdateFields.Value)
}
func TestRedeemService_BatchUpdate_RejectsInvalidID(t *testing.T) {
repo := &redeemRepoStub{}
svc := &RedeemService{redeemRepo: repo}
notes := "bad id"
result, err := svc.BatchUpdate(context.Background(), &RedeemCodeBatchUpdateInput{
IDs: []int64{1, 0},
Fields: RedeemCodeBatchUpdateFields{Notes: &notes},
})
require.Nil(t, result)
require.Error(t, err)
require.True(t, infraerrors.IsBadRequest(err))
require.False(t, repo.batchUpdateCalled)
}
func TestRedeemService_BatchUpdate_RejectsCoreFieldsForUsedCodes(t *testing.T) {
repo := &redeemRepoStub{}
svc := &RedeemService{redeemRepo: repo}
newValue := 100.0
result, err := svc.BatchUpdate(context.Background(), &RedeemCodeBatchUpdateInput{
IDs: []int64{42},
Fields: RedeemCodeBatchUpdateFields{
Value: &newValue,
},
})
require.Nil(t, result)
require.Error(t, err)
require.True(t, infraerrors.IsBadRequest(err))
require.False(t, repo.batchUpdateCalled)
}

View File

@ -26,12 +26,17 @@ func IsRegistrationEmailSuffixAllowed(email string, whitelist []string) bool {
if len(whitelist) == 0 {
return true
}
suffix := RegistrationEmailSuffix(email)
if suffix == "" {
_, domain, ok := splitEmailForPolicy(email)
if !ok {
return false
}
suffix := "@" + domain
for _, allowed := range whitelist {
if suffix == allowed {
allowed = strings.ToLower(strings.TrimSpace(allowed))
if strings.HasPrefix(allowed, "@") && suffix == allowed {
return true
}
if strings.HasPrefix(allowed, "*.") && registrationEmailDomainMatchesWildcard(domain, allowed) {
return true
}
}
@ -98,6 +103,14 @@ func normalizeRegistrationEmailSuffix(raw string) (string, error) {
return "", nil
}
if strings.HasPrefix(value, "*.") {
domain := strings.TrimPrefix(value, "*.")
if !isValidRegistrationEmailDomain(domain) {
return "", fmt.Errorf("invalid email suffix: %q", raw)
}
return "*." + domain, nil
}
domain := value
if strings.Contains(value, "@") {
if !strings.HasPrefix(value, "@") || strings.Count(value, "@") != 1 {
@ -106,13 +119,27 @@ func normalizeRegistrationEmailSuffix(raw string) (string, error) {
domain = strings.TrimPrefix(value, "@")
}
if domain == "" || strings.Contains(domain, "@") || !registrationEmailDomainPattern.MatchString(domain) {
if !isValidRegistrationEmailDomain(domain) {
return "", fmt.Errorf("invalid email suffix: %q", raw)
}
return "@" + domain, nil
}
func isValidRegistrationEmailDomain(domain string) bool {
return domain != "" &&
!strings.Contains(domain, "@") &&
registrationEmailDomainPattern.MatchString(domain)
}
func registrationEmailDomainMatchesWildcard(domain string, allowed string) bool {
base := strings.TrimPrefix(allowed, "*.")
if !isValidRegistrationEmailDomain(base) {
return false
}
return domain == base || strings.HasSuffix(domain, "."+base)
}
func splitEmailForPolicy(raw string) (local string, domain string, ok bool) {
email := strings.ToLower(strings.TrimSpace(raw))
local, domain, found := strings.Cut(email, "@")

View File

@ -9,23 +9,36 @@ import (
)
func TestNormalizeRegistrationEmailSuffixWhitelist(t *testing.T) {
got, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"example.com", "@EXAMPLE.COM", " @foo.bar "})
got, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"example.com", "@EXAMPLE.COM", " @foo.bar ", "*.EDU.CN"})
require.NoError(t, err)
require.Equal(t, []string{"@example.com", "@foo.bar"}, got)
require.Equal(t, []string{"@example.com", "@foo.bar", "*.edu.cn"}, got)
}
func TestNormalizeRegistrationEmailSuffixWhitelist_Invalid(t *testing.T) {
_, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"@invalid_domain"})
require.Error(t, err)
for _, item := range []string{"@invalid_domain", "*.", "*", "*.@", "*.foo"} {
t.Run(item, func(t *testing.T) {
_, err := NormalizeRegistrationEmailSuffixWhitelist([]string{item})
require.Error(t, err)
})
}
}
func TestParseRegistrationEmailSuffixWhitelist(t *testing.T) {
got := ParseRegistrationEmailSuffixWhitelist(`["example.com","@foo.bar","@invalid_domain"]`)
require.Equal(t, []string{"@example.com", "@foo.bar"}, got)
got := ParseRegistrationEmailSuffixWhitelist(`["example.com","@foo.bar","*.EDU.CN","@invalid_domain","*.foo"]`)
require.Equal(t, []string{"@example.com", "@foo.bar", "*.edu.cn"}, got)
}
func TestIsRegistrationEmailSuffixAllowed(t *testing.T) {
require.True(t, IsRegistrationEmailSuffixAllowed("user@example.com", []string{"@example.com"}))
require.False(t, IsRegistrationEmailSuffixAllowed("user@sub.example.com", []string{"@example.com"}))
require.True(t, IsRegistrationEmailSuffixAllowed("user@qq.com", []string{"@qq.com"}))
require.False(t, IsRegistrationEmailSuffixAllowed("user@sub.qq.com", []string{"@qq.com"}))
require.True(t, IsRegistrationEmailSuffixAllowed("student@cs.edu.cn", []string{"*.edu.cn"}))
require.True(t, IsRegistrationEmailSuffixAllowed("student@edu.cn", []string{"*.edu.cn"}))
require.False(t, IsRegistrationEmailSuffixAllowed("student@foo.cn", []string{"*.edu.cn"}))
require.True(t, IsRegistrationEmailSuffixAllowed("user@a.com", []string{"@a.com", "*.b.cn"}))
require.True(t, IsRegistrationEmailSuffixAllowed("user@school.b.cn", []string{"@a.com", "*.b.cn"}))
require.True(t, IsRegistrationEmailSuffixAllowed("user@b.cn", []string{"@a.com", "*.b.cn"}))
require.False(t, IsRegistrationEmailSuffixAllowed("user@c.cn", []string{"@a.com", "*.b.cn"}))
require.True(t, IsRegistrationEmailSuffixAllowed("user@any.com", []string{}))
}

View File

@ -114,6 +114,31 @@ func TestOpenAISelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedul
}
}
func TestOpenAINewAcquiredSelectionResult_ReleasesSlotWhenHydrationFails(t *testing.T) {
cache := &snapshotHydrationCache{
accounts: map[int64]*Account{},
}
schedulerSnapshot := NewSchedulerSnapshotService(cache, nil, stubOpenAIAccountRepo{}, nil, nil)
svc := &OpenAIGatewayService{
schedulerSnapshot: schedulerSnapshot,
}
releaseCalls := 0
selection, err := svc.newAcquiredSelectionResult(context.Background(), &Account{ID: 1001}, func() {
releaseCalls++
})
if err == nil {
t.Fatalf("expected hydration error")
}
if selection != nil {
t.Fatalf("expected nil selection on hydration error")
}
if releaseCalls != 1 {
t.Fatalf("expected release to be called once, got %d", releaseCalls)
}
}
func TestGatewaySelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedulerSnapshot(t *testing.T) {
cache := &snapshotHydrationCache{
snapshot: []*Account{

View File

@ -597,6 +597,23 @@ func (s *SettingService) SetProxyRepository(repo ProxyRepository) {
s.proxyRepo = repo
}
func (s *SettingService) LoadAPIKeyACLTrustForwardedIPSetting(ctx context.Context) error {
if s == nil || s.cfg == nil || s.settingRepo == nil {
return nil
}
value, err := s.settingRepo.GetValue(ctx, SettingKeyAPIKeyACLTrustForwardedIP)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
s.cfg.SetTrustForwardedIPForAPIKeyACL(s.cfg.Security.TrustForwardedIPForAPIKeyACL)
return nil
}
return fmt.Errorf("get api key acl forwarded ip setting: %w", err)
}
enabled := value == "true"
s.cfg.SetTrustForwardedIPForAPIKeyACL(enabled)
return nil
}
// GetAllSettings 获取所有系统设置
func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) {
settings, err := s.settingRepo.GetAll(ctx)
@ -633,6 +650,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyLoginAgreementDocuments,
SettingKeyTurnstileEnabled,
SettingKeyTurnstileSiteKey,
SettingKeyAPIKeyACLTrustForwardedIP,
SettingKeySiteName,
SettingKeySiteLogo,
SettingKeySiteSubtitle,
@ -1568,6 +1586,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
if settings.TurnstileSecretKey != "" {
updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
}
updates[SettingKeyAPIKeyACLTrustForwardedIP] = strconv.FormatBool(settings.APIKeyACLTrustForwardedIP)
// LinuxDo Connect OAuth 登录
updates[SettingKeyLinuxDoConnectEnabled] = strconv.FormatBool(settings.LinuxDoConnectEnabled)
@ -1777,10 +1796,11 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
updates[SettingPaymentVisibleMethodWxpayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodWxpayEnabled)
updates[openAIAdvancedSchedulerSettingKey] = strconv.FormatBool(settings.OpenAIAdvancedSchedulerEnabled)
// Balance low notification
// 余额、订阅到期与账号限额通知
updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled)
updates[SettingKeyBalanceLowNotifyThreshold] = strconv.FormatFloat(settings.BalanceLowNotifyThreshold, 'f', 8, 64)
updates[SettingKeyBalanceLowNotifyRechargeURL] = settings.BalanceLowNotifyRechargeURL
updates[SettingKeySubscriptionExpiryNotifyEnabled] = strconv.FormatBool(settings.SubscriptionExpiryNotifyEnabled)
updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled)
updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails)
@ -1867,6 +1887,9 @@ func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
enabled: settings.OpenAIAdvancedSchedulerEnabled,
expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
})
if s.cfg != nil {
s.cfg.SetTrustForwardedIPForAPIKeyACL(settings.APIKeyACLTrustForwardedIP)
}
if s.onUpdate != nil {
s.onUpdate() // Invalidate cache after settings update
}
@ -2463,6 +2486,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyLoginAgreementMode: defaultLoginAgreementMode,
SettingKeyLoginAgreementUpdatedAt: defaultLoginAgreementDate,
SettingKeyLoginAgreementDocuments: loginAgreementDocumentsJSON,
SettingKeyAPIKeyACLTrustForwardedIP: "false",
SettingKeySiteName: "Sub2API",
SettingKeySiteLogo: "",
SettingKeyPurchaseSubscriptionEnabled: "false",
@ -2622,6 +2646,12 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
if loginAgreementUpdatedAt == "" {
loginAgreementUpdatedAt = defaultLoginAgreementDate
}
apiKeyACLTrustForwardedIP := false
if value, ok := settings[SettingKeyAPIKeyACLTrustForwardedIP]; ok {
apiKeyACLTrustForwardedIP = value == "true"
} else if s != nil && s.cfg != nil {
apiKeyACLTrustForwardedIP = s.cfg.Security.TrustForwardedIPForAPIKeyACL
}
result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: emailVerifyEnabled,
@ -2644,6 +2674,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "",
APIKeyACLTrustForwardedIP: apiKeyACLTrustForwardedIP,
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
@ -3131,14 +3162,15 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result.PaymentVisibleMethodWxpayEnabled = settings[SettingPaymentVisibleMethodWxpayEnabled] == "true"
result.OpenAIAdvancedSchedulerEnabled = settings[openAIAdvancedSchedulerSettingKey] == "true"
// Balance low notification
// 余额、订阅到期与账号限额通知
result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 {
result.BalanceLowNotifyThreshold = v
}
result.BalanceLowNotifyRechargeURL = settings[SettingKeyBalanceLowNotifyRechargeURL]
result.SubscriptionExpiryNotifyEnabled = !isFalseSettingValue(settings[SettingKeySubscriptionExpiryNotifyEnabled])
// Account quota notification
// 账号限额通知
result.AccountQuotaNotifyEnabled = settings[SettingKeyAccountQuotaNotifyEnabled] == "true"
if raw := strings.TrimSpace(settings[SettingKeyAccountQuotaNotifyEmails]); raw != "" {
result.AccountQuotaNotifyEmails = ParseNotifyEmails(raw)

View File

@ -53,14 +53,14 @@ func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelis
values: map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
SettingKeyRegistrationEmailSuffixWhitelist: `["@EXAMPLE.com"," @foo.bar ","@invalid_domain",""]`,
SettingKeyRegistrationEmailSuffixWhitelist: `["@EXAMPLE.com"," @foo.bar ","*.EDU.CN","@invalid_domain",""]`,
},
}
svc := NewSettingService(repo, &config.Config{})
settings, err := svc.GetPublicSettings(context.Background())
require.NoError(t, err)
require.Equal(t, []string{"@example.com", "@foo.bar"}, settings.RegistrationEmailSuffixWhitelist)
require.Equal(t, []string{"@example.com", "@foo.bar", "*.edu.cn"}, settings.RegistrationEmailSuffixWhitelist)
}
func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) {

View File

@ -213,10 +213,10 @@ func TestSettingService_UpdateSettings_RegistrationEmailSuffixWhitelist_Normaliz
svc := NewSettingService(repo, &config.Config{})
err := svc.UpdateSettings(context.Background(), &SystemSettings{
RegistrationEmailSuffixWhitelist: []string{"example.com", "@EXAMPLE.com", " @foo.bar "},
RegistrationEmailSuffixWhitelist: []string{"example.com", "@EXAMPLE.com", " @foo.bar ", "*.EDU.CN"},
})
require.NoError(t, err)
require.Equal(t, `["@example.com","@foo.bar"]`, repo.updates[SettingKeyRegistrationEmailSuffixWhitelist])
require.Equal(t, `["@example.com","@foo.bar","*.edu.cn"]`, repo.updates[SettingKeyRegistrationEmailSuffixWhitelist])
}
func TestSettingService_UpdateSettings_RegistrationEmailSuffixWhitelist_Invalid(t *testing.T) {
@ -290,6 +290,30 @@ func TestSettingService_UpdateSettings_AntigravityUserAgentVersion(t *testing.T)
require.Equal(t, "1.23.2", repo.updates[SettingKeyAntigravityUserAgentVersion])
}
func TestSettingService_UpdateSettings_APIKeyACLTrustForwardedIPRefreshesConfig(t *testing.T) {
repo := &settingUpdateRepoStub{}
cfg := &config.Config{}
svc := NewSettingService(repo, cfg)
err := svc.UpdateSettings(context.Background(), &SystemSettings{
APIKeyACLTrustForwardedIP: true,
})
require.NoError(t, err)
require.Equal(t, "true", repo.updates[SettingKeyAPIKeyACLTrustForwardedIP])
require.True(t, cfg.Security.TrustForwardedIPForAPIKeyACL)
require.True(t, cfg.TrustForwardedIPForAPIKeyACL())
}
func TestSettingService_ParseSettings_APIKeyACLTrustForwardedIPFallsBackToConfigWhenMissing(t *testing.T) {
cfg := &config.Config{}
cfg.Security.TrustForwardedIPForAPIKeyACL = true
svc := NewSettingService(&settingUpdateRepoStub{}, cfg)
got := svc.parseSettings(map[string]string{})
require.True(t, got.APIKeyACLTrustForwardedIP)
}
func TestSettingService_GetAntigravityUserAgentVersion_Precedence(t *testing.T) {
t.Run("后台设置优先", func(t *testing.T) {
svc := NewSettingService(&settingAntigravityUARepoStub{values: map[string]string{

View File

@ -38,6 +38,7 @@ type SystemSettings struct {
TurnstileSiteKey string
TurnstileSecretKey string
TurnstileSecretKeyConfigured bool
APIKeyACLTrustForwardedIP bool
// LinuxDo Connect OAuth 登录
LinuxDoConnectEnabled bool
@ -204,15 +205,18 @@ type SystemSettings struct {
PaymentVisibleMethodAlipayEnabled bool
PaymentVisibleMethodWxpayEnabled bool
// OpenAI account scheduling
// OpenAI 账号调度
OpenAIAdvancedSchedulerEnabled bool
// Balance low notification
// 余额不足提醒
BalanceLowNotifyEnabled bool
BalanceLowNotifyThreshold float64
BalanceLowNotifyRechargeURL string
// Account quota notification
// 订阅到期提醒
SubscriptionExpiryNotifyEnabled bool
// 账号限额通知
AccountQuotaNotifyEnabled bool
AccountQuotaNotifyEmails []NotifyEmailEntry
}

View File

@ -2,6 +2,7 @@ package service
import (
"context"
"errors"
"fmt"
"log"
"strconv"
@ -14,6 +15,7 @@ import (
// SubscriptionExpiryService periodically updates expired subscription status.
type SubscriptionExpiryService struct {
userSubRepo UserSubscriptionRepository
settingRepo SettingRepository
notificationEmailService *NotificationEmailService
interval time.Duration
stopCh chan struct{}
@ -29,6 +31,10 @@ func NewSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, interv
}
}
func (s *SubscriptionExpiryService) SetSettingRepository(settingRepo SettingRepository) {
s.settingRepo = settingRepo
}
func (s *SubscriptionExpiryService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) {
s.notificationEmailService = notificationEmailService
}
@ -84,6 +90,9 @@ func (s *SubscriptionExpiryService) sendExpiryReminders(ctx context.Context) {
if s == nil || s.userSubRepo == nil || s.notificationEmailService == nil {
return
}
if !s.expiryReminderEnabled(ctx) {
return
}
for page := 1; ; page++ {
subs, pag, err := s.userSubRepo.List(ctx, pagination.PaginationParams{Page: page, PageSize: 200}, nil, nil, SubscriptionStatusActive, "", "expires_at", "asc")
if err != nil {
@ -99,6 +108,21 @@ func (s *SubscriptionExpiryService) sendExpiryReminders(ctx context.Context) {
}
}
func (s *SubscriptionExpiryService) expiryReminderEnabled(ctx context.Context) bool {
if s == nil || s.settingRepo == nil {
return true
}
value, err := s.settingRepo.GetValue(ctx, SettingKeySubscriptionExpiryNotifyEnabled)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return true
}
log.Printf("[SubscriptionExpiry] Read expiry reminder switch failed: %v", err)
return false
}
return !isFalseSettingValue(value)
}
func (s *SubscriptionExpiryService) sendExpiryReminderIfDue(ctx context.Context, sub *UserSubscription) {
if sub == nil || sub.User == nil || sub.Group == nil || sub.User.Email == "" {
return

View File

@ -0,0 +1,164 @@
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type subscriptionExpiryRepoStub struct {
listCalls int
}
func (r *subscriptionExpiryRepoStub) Create(context.Context, *UserSubscription) error {
return nil
}
func (r *subscriptionExpiryRepoStub) GetByID(context.Context, int64) (*UserSubscription, error) {
return nil, ErrSubscriptionNotFound
}
func (r *subscriptionExpiryRepoStub) GetByUserIDAndGroupID(context.Context, int64, int64) (*UserSubscription, error) {
return nil, ErrSubscriptionNotFound
}
func (r *subscriptionExpiryRepoStub) GetActiveByUserIDAndGroupID(context.Context, int64, int64) (*UserSubscription, error) {
return nil, ErrSubscriptionNotFound
}
func (r *subscriptionExpiryRepoStub) Update(context.Context, *UserSubscription) error {
return nil
}
func (r *subscriptionExpiryRepoStub) Delete(context.Context, int64) error {
return nil
}
func (r *subscriptionExpiryRepoStub) ListByUserID(context.Context, int64) ([]UserSubscription, error) {
return nil, nil
}
func (r *subscriptionExpiryRepoStub) ListActiveByUserID(context.Context, int64) ([]UserSubscription, error) {
return nil, nil
}
func (r *subscriptionExpiryRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) {
return nil, nil, nil
}
func (r *subscriptionExpiryRepoStub) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) {
r.listCalls++
return nil, &pagination.PaginationResult{Page: 1, Pages: 1}, nil
}
func (r *subscriptionExpiryRepoStub) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) {
return false, nil
}
func (r *subscriptionExpiryRepoStub) ExtendExpiry(context.Context, int64, time.Time) error {
return nil
}
func (r *subscriptionExpiryRepoStub) UpdateStatus(context.Context, int64, string) error {
return nil
}
func (r *subscriptionExpiryRepoStub) UpdateNotes(context.Context, int64, string) error {
return nil
}
func (r *subscriptionExpiryRepoStub) ActivateWindows(context.Context, int64, time.Time) error {
return nil
}
func (r *subscriptionExpiryRepoStub) ResetDailyUsage(context.Context, int64, time.Time) error {
return nil
}
func (r *subscriptionExpiryRepoStub) ResetWeeklyUsage(context.Context, int64, time.Time) error {
return nil
}
func (r *subscriptionExpiryRepoStub) ResetMonthlyUsage(context.Context, int64, time.Time) error {
return nil
}
func (r *subscriptionExpiryRepoStub) IncrementUsage(context.Context, int64, float64) error {
return nil
}
func (r *subscriptionExpiryRepoStub) BatchUpdateExpiredStatus(context.Context) (int64, error) {
return 0, nil
}
type subscriptionExpirySettingRepoStub struct {
values map[string]string
err error
}
func (r *subscriptionExpirySettingRepoStub) Get(context.Context, string) (*Setting, error) {
return nil, ErrSettingNotFound
}
func (r *subscriptionExpirySettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
if r.err != nil {
return "", r.err
}
value, ok := r.values[key]
if !ok {
return "", ErrSettingNotFound
}
return value, nil
}
func (r *subscriptionExpirySettingRepoStub) Set(context.Context, string, string) error {
return nil
}
func (r *subscriptionExpirySettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
return nil, nil
}
func (r *subscriptionExpirySettingRepoStub) SetMultiple(context.Context, map[string]string) error {
return nil
}
func (r *subscriptionExpirySettingRepoStub) GetAll(context.Context) (map[string]string, error) {
return nil, nil
}
func (r *subscriptionExpirySettingRepoStub) Delete(context.Context, string) error {
return nil
}
func TestSubscriptionExpiryService_ExpiryReminderEnabledDefaultsToTrue(t *testing.T) {
svc := NewSubscriptionExpiryService(nil, time.Minute)
svc.SetSettingRepository(&subscriptionExpirySettingRepoStub{values: map[string]string{}})
require.True(t, svc.expiryReminderEnabled(context.Background()))
}
func TestSubscriptionExpiryService_ExpiryReminderDisabledSkipsSubscriptionScan(t *testing.T) {
repo := &subscriptionExpiryRepoStub{}
settingRepo := &subscriptionExpirySettingRepoStub{
values: map[string]string{SettingKeySubscriptionExpiryNotifyEnabled: "false"},
}
svc := NewSubscriptionExpiryService(repo, time.Minute)
svc.SetSettingRepository(settingRepo)
svc.SetNotificationEmailService(NewNotificationEmailService(settingRepo, nil))
svc.sendExpiryReminders(context.Background())
require.Zero(t, repo.listCalls)
}
func TestSubscriptionExpiryService_ExpiryReminderSettingReadErrorFailsClosed(t *testing.T) {
svc := NewSubscriptionExpiryService(nil, time.Minute)
svc.SetSettingRepository(&subscriptionExpirySettingRepoStub{err: errors.New("db down")})
require.False(t, svc.expiryReminderEnabled(context.Background()))
}

View File

@ -27,6 +27,7 @@ type TokenRefreshService struct {
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存
refreshAPI *OAuthRefreshAPI // 统一刷新 API
runtimeBlocker AccountRuntimeBlocker
// OpenAI privacy: 刷新成功后检查并设置 training opt-out
privacyClientFactory PrivacyClientFactory
@ -100,6 +101,24 @@ func (s *TokenRefreshService) SetRefreshPolicy(policy BackgroundRefreshPolicy) {
s.refreshPolicy = policy
}
func (s *TokenRefreshService) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) {
s.runtimeBlocker = blocker
}
func (s *TokenRefreshService) notifyAccountSchedulingBlocked(account *Account, until time.Time, reason string) {
if s == nil || s.runtimeBlocker == nil || account == nil {
return
}
s.runtimeBlocker.BlockAccountScheduling(account, until, reason)
}
func (s *TokenRefreshService) notifyAccountSchedulingBlockCleared(accountID int64) {
if s == nil || s.runtimeBlocker == nil || accountID <= 0 {
return
}
s.runtimeBlocker.ClearAccountSchedulingBlock(accountID)
}
// Start 启动后台刷新服务
func (s *TokenRefreshService) Start() {
if !s.cfg.Enabled {
@ -284,6 +303,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
// 不可重试错误invalid_grant/invalid_client 等)直接标记 error 状态并返回
if isNonRetryableRefreshError(err) {
errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err)
s.notifyAccountSchedulingBlocked(account, time.Time{}, "token_refresh_non_retryable")
if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil {
slog.Error("token_refresh.set_error_status_failed",
"account_id", account.ID,
@ -327,6 +347,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
// 设置临时不可调度 10 分钟(不标记 error保持 status=active 让下个刷新周期能继续尝试)
until := time.Now().Add(tokenRefreshTempUnschedDuration)
reason := fmt.Sprintf("token refresh retry exhausted: %v", lastErr)
s.notifyAccountSchedulingBlocked(account, until, "token_refresh_retry_exhausted")
if setErr := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); setErr != nil {
slog.Warn("token_refresh.set_temp_unschedulable_failed",
"account_id", account.ID,
@ -355,6 +376,7 @@ func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *A
)
} else {
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
s.notifyAccountSchedulingBlockCleared(account.ID)
}
}
// 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景)
@ -366,6 +388,7 @@ func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *A
)
} else {
slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID)
s.notifyAccountSchedulingBlockCleared(account.ID)
}
// 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态
if s.tempUnschedCache != nil {
@ -417,11 +440,12 @@ func isNonRetryableRefreshError(err error) bool {
}
msg := strings.ToLower(err.Error())
nonRetryable := []string{
"invalid_grant", // refresh_token 已失效
"invalid_client", // 客户端配置错误
"unauthorized_client", // 客户端未授权
"access_denied", // 访问被拒绝
"missing_project_id", // 缺少 project_id
"invalid_grant", // refresh_token 已失效
"refresh_token_reused", // OpenAI refresh_token 已被使用,必须重新授权
"invalid_client", // 客户端配置错误
"unauthorized_client", // 客户端未授权
"access_denied", // 访问被拒绝
"missing_project_id", // 缺少 project_id
"no refresh token available",
}
for _, needle := range nonRetryable {

View File

@ -532,6 +532,7 @@ func TestIsNonRetryableRefreshError(t *testing.T) {
{name: "network_error", err: errors.New("network timeout"), expected: false},
{name: "invalid_grant", err: errors.New("invalid_grant"), expected: true},
{name: "invalid_client", err: errors.New("invalid_client"), expected: true},
{name: "refresh_token_reused", err: errors.New(`OPENAI_OAUTH_TOKEN_REFRESH_FAILED: token refresh failed: status 401, body: {"error":{"code":"refresh_token_reused"}}`), expected: true},
{name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true},
{name: "access_denied", err: errors.New("access_denied"), expected: true},
{name: "no_refresh_token", err: errors.New("no refresh token available"), expected: true},

View File

@ -62,6 +62,7 @@ func ProvideTokenRefreshService(
privacyClientFactory PrivacyClientFactory,
proxyRepo ProxyRepository,
refreshAPI *OAuthRefreshAPI,
runtimeBlocker AccountRuntimeBlocker,
) *TokenRefreshService {
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
// 注入 OpenAI privacy opt-out 依赖
@ -70,6 +71,7 @@ func ProvideTokenRefreshService(
svc.SetRefreshAPI(refreshAPI)
// 调用侧显式注入后台刷新策略,避免策略漂移
svc.SetRefreshPolicy(DefaultBackgroundRefreshPolicy())
svc.SetAccountRuntimeBlocker(runtimeBlocker)
svc.Start()
return svc
}
@ -154,8 +156,9 @@ func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpirySe
}
// ProvideSubscriptionExpiryService creates and starts SubscriptionExpiryService.
func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, notificationEmailService *NotificationEmailService) *SubscriptionExpiryService {
func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, settingRepo SettingRepository, notificationEmailService *NotificationEmailService) *SubscriptionExpiryService {
svc := NewSubscriptionExpiryService(userSubRepo, time.Minute)
svc.SetSettingRepository(settingRepo)
svc.SetNotificationEmailService(notificationEmailService)
svc.Start()
return svc
@ -185,6 +188,7 @@ func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountReposi
logger.LegacyPrintf("service.concurrency", "Warning: startup cleanup stale process slots failed: %v", err)
}
if cfg != nil {
svc.SetAccountLoadBatchCacheTTL(time.Duration(cfg.Gateway.Scheduling.LoadBatchCacheTTLMS) * time.Millisecond)
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
}
return svc
@ -400,6 +404,9 @@ func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupReposit
svc := NewSettingService(settingRepo, cfg)
svc.SetDefaultSubscriptionGroupReader(groupRepo)
svc.SetProxyRepository(proxyRepo)
if err := svc.LoadAPIKeyACLTrustForwardedIPSetting(context.Background()); err != nil {
logger.LegacyPrintf("service.setting", "Warning: load api key acl forwarded ip setting failed: %v", err)
}
antigravity.SetUserAgentVersionResolver(svc.GetAntigravityUserAgentVersion)
return svc
}
@ -455,6 +462,7 @@ var ProviderSet = wire.NewSet(
NewRPMTokenBucketService,
NewGatewayService,
NewOpenAIGatewayService,
wire.Bind(new(AccountRuntimeBlocker), new(*OpenAIGatewayService)),
NewOAuthService,
NewOpenAIOAuthService,
NewGeminiOAuthService,

View File

@ -0,0 +1,12 @@
-- 修复user_provider_default_grants 表的 provider_type check 约束
-- 与 auth_identities / auth_identity_channels / pending_auth_sessions 保持一致,
-- 否则启用了 auth_source_default_{github,google,dingtalk}_grant_on_first_bind
-- 之后OAuth 首次绑定流程会因约束违反而失败。
-- 参见 migrations 135、136 漏改本表。
ALTER TABLE user_provider_default_grants
DROP CONSTRAINT IF EXISTS user_provider_default_grants_provider_type_check;
ALTER TABLE user_provider_default_grants
ADD CONSTRAINT user_provider_default_grants_provider_type_check
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google', 'dingtalk'));

View File

@ -0,0 +1,4 @@
-- 订阅到期提醒邮件开关,默认保持历史行为:开启。
INSERT INTO settings (key, value, updated_at)
VALUES ('subscription_expiry_notify_enabled', 'true', NOW())
ON CONFLICT (key) DO NOTHING;

View File

@ -405,6 +405,9 @@ gateway:
# Enable batch load calculation for scheduling
# 启用调度批量负载计算
load_batch_enabled: true
# Tiny in-process TTL for batch load reads in milliseconds (0 disables)
# 调度批量负载读取的进程内短缓存 TTL毫秒0 表示禁用)
load_batch_cache_ttl_ms: 200
# Slot cleanup interval (duration)
# 并发槽位清理周期(时间段)
slot_cleanup_interval: 30s

View File

@ -57,5 +57,10 @@
"vite-plugin-checker": "^0.9.1",
"vitest": "^2.1.9",
"vue-tsc": "^2.2.0"
},
"pnpm": {
"overrides": {
"js-cookie": "3.0.7"
}
}
}

View File

@ -4,6 +4,9 @@ settings:
autoInstallPeers: true
excludeLinksFromLockfile: false
overrides:
js-cookie: 3.0.7
importers:
.:
@ -2882,9 +2885,9 @@ packages:
engines: {node: '>=14'}
hasBin: true
js-cookie@3.0.5:
resolution: {integrity: sha512-cEiJEAEoIbWfCZYKWhVwFuvPX1gETRYPw6LlaTKoxD3s2AkXzkCjnp6h0V77ozyqj0jakteJ4YqDJT830+lVGw==}
engines: {node: '>=14'}
js-cookie@3.0.7:
resolution: {integrity: sha512-z/wZZgDrkNV1eA0ULjM/F9/50Ya8fbzgKneSpoPsXSGd0KnpdtHfOZWK+GcwLk+EZbS4F9RBhU+K2RgzuDaItw==}
engines: {node: '>=20'}
js-tokens@4.0.0:
resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==}
@ -6365,7 +6368,7 @@ snapshots:
'@types/js-cookie': 3.0.6
dayjs: 1.11.20
intersection-observer: 0.12.2
js-cookie: 3.0.5
js-cookie: 3.0.7
lodash: 4.18.1
react: 19.2.3
react-dom: 19.2.3(react@19.2.3)
@ -7656,10 +7659,10 @@ snapshots:
config-chain: 1.1.13
editorconfig: 1.0.4
glob: 10.5.0
js-cookie: 3.0.5
js-cookie: 3.0.7
nopt: 7.2.1
js-cookie@3.0.5: {}
js-cookie@3.0.7: {}
js-tokens@4.0.0: {}

View File

@ -7,6 +7,7 @@ import { apiClient } from '../client'
import type {
RedeemCode,
GenerateRedeemCodesRequest,
BatchUpdateRedeemCodeFields,
RedeemCodeType,
PaginatedResponse
} from '@/types'
@ -23,7 +24,7 @@ export async function list(
pageSize: number = 20,
filters?: {
type?: RedeemCodeType
status?: 'active' | 'used' | 'expired' | 'unused'
status?: 'active' | 'used' | 'expired' | 'unused' | 'disabled'
search?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
@ -118,6 +119,26 @@ export async function batchDelete(ids: number[]): Promise<{
return data
}
/**
* Batch update selected redeem code fields
* @param ids - Array of redeem code IDs
* @param fields - Field collection to update
* @returns Updated count
*/
export async function batchUpdate(
ids: number[],
fields: BatchUpdateRedeemCodeFields
): Promise<{
updated: number
message: string
}> {
const { data } = await apiClient.post<{
updated: number
message: string
}>('/admin/redeem-codes/batch-update', { ids, fields })
return data
}
/**
* Expire redeem code
* @param id - Redeem code ID
@ -158,7 +179,7 @@ export async function getStats(): Promise<{
*/
export async function exportCodes(filters?: {
type?: RedeemCodeType
status?: 'used' | 'expired' | 'unused'
status?: 'used' | 'expired' | 'unused' | 'disabled'
search?: string
sort_by?: string
sort_order?: 'asc' | 'desc'
@ -176,6 +197,7 @@ export const redeemAPI = {
generate,
delete: deleteCode,
batchDelete,
batchUpdate,
expire,
getStats,
exportCodes

View File

@ -2,6 +2,12 @@ import { apiClient } from '../client'
export type ModerationMode = 'off' | 'observe' | 'pre_block'
export type KeywordBlockingMode = 'keyword_only' | 'keyword_and_api' | 'api_only'
export type ContentModerationModelFilterType = 'all' | 'include' | 'exclude'
export interface ContentModerationModelFilter {
type: ContentModerationModelFilterType
models: string[]
}
export interface ContentModerationConfig {
enabled: boolean
@ -32,6 +38,7 @@ export interface ContentModerationConfig {
pre_hash_check_enabled: boolean
blocked_keywords: string[]
keyword_blocking_mode: KeywordBlockingMode
model_filter: ContentModerationModelFilter
}
export type ContentModerationAPIKeyStatusValue = 'unknown' | 'ok' | 'error' | 'frozen'
@ -105,6 +112,7 @@ export interface UpdateContentModerationConfig {
pre_hash_check_enabled?: boolean
blocked_keywords?: string[]
keyword_blocking_mode?: KeywordBlockingMode
model_filter?: ContentModerationModelFilter
}
export interface ContentModerationRuntimeStatus {

Some files were not shown because too many files have changed in this diff Show More