chore: merge upstream Wei-Shaw/sub2api latest (v0.1.130+)
Upstream features: bedrock CC compat, email whitelist wildcard, content moderation per-model toggle, redeem code batch update, OIDC verified-email fast path, subscription expiry email, cache hit rate fix, audit dedup, js-cookie security fix, x/net vulnerability fix, OpenAI account cooldown optimization, reverse proxy client IP fix, API key ACL trusted forwarded IP. Local additions preserved: rpmTokenBucketService, quotaFactor scoring, P2C scheduler selection.
This commit is contained in:
commit
e938be5f3f
@ -114,6 +114,12 @@ Register now via <a href="https://pateway.ai/?ch=1tsfr51">this link</a> to recei
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://runapi.co/register?aff=fu2E"><img src="assets/partners/logos/runapi.png" alt="RunAPI" width="150"></a></td>
|
||||
<td>Thanks to RunAPI for sponsoring this project! <a href="https://runapi.co/register?aff=fu2E">RunAPI</a> is an efficient and stable API platform and OpenRouter alternative. With one API Key, you can access 150+ popular models including OpenAI, Claude, Gemini, DeepSeek, and Grok, with pricing as low as 10% of the original rate. It is highly stable and seamlessly compatible with tools such as Claude Code and OpenClaw.
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
|
||||
## Ecosystem
|
||||
|
||||
@ -112,6 +112,13 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
|
||||
<td>感谢 PPToken.org 赞助本项目! <a href="https://api.pptoken.org/register?promo=SUB2API">PPToken.org</a> 主打 GPT 系列模型 API 中转服务,支持 Codex、Claude Code、OpenAI 兼容客户端及 Gemini CLI 等工具接入。充值 1:1,1 元=1 美元额度;GPT 模型最低 0.16 倍倍率,综合成本约为官方价格的 0.22 折,最快首字 Token 约 1 秒,适合开发者低成本、高响应速度接入 GPT 模型能力。技术支持: 7×24 小时真人响应(不是机器人),群内@技术,10 分钟内有回复 。赞助商福利:前 200 名用户通过 <a href="https://api.pptoken.org/register?promo=SUB2API">[专属注册链接]</a> 注册,输入优惠码 `SUB2API`,可领取 Codex / Claude Code 免费试用额度,无门槛、不绑卡。
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://runapi.co/register?aff=fu2E"><img src="assets/partners/logos/runapi.png" alt="RunAPI" width="150"></a></td>
|
||||
<td>感谢 RunAPI 赞助本项目! <a href="https://runapi.co/register?aff=fu2E">RunAPI</a> 是高效稳定的API OpenRouter平替平台,一个 API Key 即可访问 OpenAI、Claude、Gemini、DeepSeek、Grok 等 150+ 主流模型,低至 1 折,极其稳定,可以无缝兼容 Claude Code、OpenClaw 等工具。
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
|
||||
## 生态项目
|
||||
|
||||
@ -113,6 +113,12 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://runapi.co/register?aff=fu2E"><img src="assets/partners/logos/runapi.png" alt="RunAPI" width="150"></a></td>
|
||||
<td>RunAPI のご支援に感謝します!<a href="https://runapi.co/register?aff=fu2E">RunAPI</a> は効率的で安定した API プラットフォームで、OpenRouter の代替として利用できます。1つの API キーで OpenAI、Claude、Gemini、DeepSeek、Grok など 150以上の主要モデルにアクセスでき、価格は最低 10% から。非常に安定しており、Claude Code や OpenClaw などのツールとシームレスに互換します。
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
|
||||
## エコシステム
|
||||
|
||||
BIN
assets/partners/logos/runapi.png
Normal file
BIN
assets/partners/logos/runapi.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 13 KiB |
@ -1 +1 @@
|
||||
0.1.129
|
||||
0.1.130
|
||||
|
||||
@ -24,8 +24,6 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/web"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
|
||||
//go:embed VERSION
|
||||
@ -116,11 +114,16 @@ func runSetupServer() {
|
||||
log.Printf("Setup wizard available at http://%s", addr)
|
||||
log.Println("Complete the setup wizard to configure Sub2API")
|
||||
|
||||
protocols := new(http.Protocols)
|
||||
protocols.SetHTTP1(true)
|
||||
protocols.SetUnencryptedHTTP2(true)
|
||||
|
||||
server := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: h2c.NewHandler(r, &http2.Server{}),
|
||||
Handler: r,
|
||||
ReadHeaderTimeout: 30 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
Protocols: protocols,
|
||||
}
|
||||
|
||||
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
|
||||
@ -113,23 +113,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
|
||||
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
|
||||
privacyClientFactory := providePrivacyClientFactory()
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory)
|
||||
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||
rpmCache := repository.NewRPMCache(redisClient)
|
||||
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
|
||||
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
|
||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
|
||||
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
|
||||
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
|
||||
driveClient := repository.NewGeminiDriveClient()
|
||||
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig)
|
||||
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
|
||||
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
|
||||
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
|
||||
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
|
||||
@ -138,6 +133,30 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
|
||||
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
openAIOAuthClient := repository.NewOpenAIOAuthClient()
|
||||
openAIOAuthService := service.NewOpenAIOAuthService(proxyRepository, openAIOAuthClient)
|
||||
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||
channelRepository := repository.NewChannelRepository(db)
|
||||
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||
notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService)
|
||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
|
||||
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory, openAIGatewayService)
|
||||
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
|
||||
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
|
||||
rpmCache := repository.NewRPMCache(redisClient)
|
||||
groupCapacityService := service.NewGroupCapacityService(accountRepository, groupRepository, concurrencyService, sessionLimitCache, rpmCache)
|
||||
groupHandler := admin.NewGroupHandler(adminService, dashboardService, groupCapacityService)
|
||||
claudeOAuthClient := repository.NewClaudeOAuthClient()
|
||||
oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient)
|
||||
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
|
||||
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
|
||||
driveClient := repository.NewGeminiDriveClient()
|
||||
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, driveClient, configConfig)
|
||||
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
|
||||
claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream)
|
||||
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
|
||||
usageCache := service.NewUsageCache()
|
||||
@ -146,12 +165,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
||||
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
||||
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService)
|
||||
oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache)
|
||||
geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI)
|
||||
claudeTokenProvider := service.ProvideClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService, oAuthRefreshAPI)
|
||||
gatewayCache := repository.NewGatewayCache(redisClient)
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
|
||||
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||
@ -176,25 +191,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
|
||||
promoHandler := admin.NewPromoHandler(promoService)
|
||||
opsRepository := repository.NewOpsRepository(db)
|
||||
usageBillingRepository := repository.NewUsageBillingRepository(client, db)
|
||||
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
|
||||
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
billingService := service.NewBillingService(configConfig, pricingService)
|
||||
identityService := service.NewIdentityService(identityCache)
|
||||
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
|
||||
digestSessionStore := service.NewDigestSessionStore()
|
||||
channelRepository := repository.NewChannelRepository(db)
|
||||
channelService := service.NewChannelService(channelRepository, groupRepository, apiKeyAuthCacheInvalidator, pricingService)
|
||||
modelPricingResolver := service.NewModelPricingResolver(channelService, billingService)
|
||||
notificationEmailService := service.NewNotificationEmailService(settingRepository, emailService)
|
||||
balanceNotifyService := service.ProvideBalanceNotifyService(emailService, settingRepository, accountRepository, notificationEmailService)
|
||||
rpmTokenBucketService := service.NewRPMTokenBucketService()
|
||||
gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver, balanceNotifyService, rpmTokenBucketService)
|
||||
openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oAuthRefreshAPI)
|
||||
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService, balanceNotifyService, settingService)
|
||||
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
|
||||
opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository)
|
||||
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink)
|
||||
@ -271,9 +271,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
|
||||
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig, channelMonitorService, settingRepository, opsService)
|
||||
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI)
|
||||
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig, tempUnschedCache, privacyClientFactory, proxyRepository, oAuthRefreshAPI, openAIGatewayService)
|
||||
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository, notificationEmailService)
|
||||
subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository, settingRepository, notificationEmailService)
|
||||
scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig)
|
||||
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
|
||||
windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository)
|
||||
|
||||
@ -15,6 +15,7 @@ require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0
|
||||
github.com/coder/websocket v1.8.14
|
||||
github.com/dgraph-io/ristretto v0.2.0
|
||||
github.com/docker/docker v28.5.2+incompatible
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||
github.com/google/uuid v1.6.0
|
||||
@ -41,11 +42,11 @@ require (
|
||||
github.com/wechatpay-apiv3/wechatpay-go v0.2.21
|
||||
github.com/zeromicro/go-zero v1.9.4
|
||||
go.uber.org/zap v1.24.0
|
||||
golang.org/x/crypto v0.50.0
|
||||
golang.org/x/crypto v0.51.0
|
||||
golang.org/x/image v0.39.0
|
||||
golang.org/x/net v0.53.0
|
||||
golang.org/x/net v0.55.0
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/term v0.42.0
|
||||
golang.org/x/term v0.43.0
|
||||
google.golang.org/protobuf v1.36.10
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
@ -87,7 +88,6 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/docker v28.5.2+incompatible // indirect
|
||||
github.com/docker/go-connections v0.6.0 // indirect
|
||||
github.com/docker/go-units v0.5.0 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
@ -107,7 +107,6 @@ require (
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/subcommands v1.2.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
|
||||
@ -175,10 +174,10 @@ require (
|
||||
go.uber.org/multierr v1.9.0 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/mod v0.34.0 // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
golang.org/x/text v0.36.0 // indirect
|
||||
golang.org/x/tools v0.43.0 // indirect
|
||||
golang.org/x/mod v0.35.0 // indirect
|
||||
golang.org/x/sys v0.45.0 // indirect
|
||||
golang.org/x/text v0.37.0 // indirect
|
||||
golang.org/x/tools v0.44.0 // indirect
|
||||
google.golang.org/grpc v1.75.1 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
modernc.org/libc v1.67.6 // indirect
|
||||
|
||||
@ -108,8 +108,6 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v28.5.1+incompatible h1:Bm8DchhSD2J6PsFzxC35TZo4TLGR2PdW/E69rU45NhM=
|
||||
github.com/docker/docker v28.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM=
|
||||
github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94=
|
||||
@ -166,8 +164,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
|
||||
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
|
||||
@ -222,8 +218,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
@ -257,8 +251,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
@ -288,8 +280,6 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
|
||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
@ -322,8 +312,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||
@ -415,16 +403,16 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI=
|
||||
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
golang.org/x/image v0.39.0 h1:skVYidAEVKgn8lZ602XO75asgXBgLj9G/FE3RbuPFww=
|
||||
golang.org/x/image v0.39.0/go.mod h1:sIbmppfU+xFLPIG0FoVUTvyBMmgng1/XAMhQ2ft0hpA=
|
||||
golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI=
|
||||
golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY=
|
||||
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
|
||||
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
|
||||
golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM=
|
||||
golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU=
|
||||
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
|
||||
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@ -436,16 +424,16 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
|
||||
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
|
||||
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
|
||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.43.0 h1:12BdW9CeB3Z+J/I/wj34VMl8X+fEXBxVR90JeMX5E7s=
|
||||
golang.org/x/tools v0.43.0/go.mod h1:uHkMso649BX2cZK6+RpuIPXS3ho2hZo4FVwfoy1vIk0=
|
||||
golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c=
|
||||
golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU=
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
@ -581,11 +582,35 @@ type CORSConfig struct {
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"`
|
||||
ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"`
|
||||
CSP CSPConfig `mapstructure:"csp"`
|
||||
ProxyFallback ProxyFallbackConfig `mapstructure:"proxy_fallback"`
|
||||
ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"`
|
||||
URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"`
|
||||
ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"`
|
||||
CSP CSPConfig `mapstructure:"csp"`
|
||||
ProxyFallback ProxyFallbackConfig `mapstructure:"proxy_fallback"`
|
||||
ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"`
|
||||
TrustForwardedIPForAPIKeyACL bool `mapstructure:"trust_forwarded_ip_for_api_key_acl"`
|
||||
trustForwardedIPForAPIKeyACLLive *atomic.Bool `mapstructure:"-"`
|
||||
}
|
||||
|
||||
func (c *Config) TrustForwardedIPForAPIKeyACL() bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
live := c.Security.trustForwardedIPForAPIKeyACLLive
|
||||
if live == nil {
|
||||
return c.Security.TrustForwardedIPForAPIKeyACL
|
||||
}
|
||||
return live.Load()
|
||||
}
|
||||
|
||||
func (c *Config) SetTrustForwardedIPForAPIKeyACL(enabled bool) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.Security.TrustForwardedIPForAPIKeyACL = enabled
|
||||
if c.Security.trustForwardedIPForAPIKeyACLLive == nil {
|
||||
c.Security.trustForwardedIPForAPIKeyACLLive = &atomic.Bool{}
|
||||
}
|
||||
c.Security.trustForwardedIPForAPIKeyACLLive.Store(enabled)
|
||||
}
|
||||
|
||||
type URLAllowlistConfig struct {
|
||||
@ -1083,7 +1108,8 @@ type GatewaySchedulingConfig struct {
|
||||
FallbackSelectionMode string `mapstructure:"fallback_selection_mode"`
|
||||
|
||||
// 负载计算
|
||||
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
|
||||
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
|
||||
LoadBatchCacheTTLMS int `mapstructure:"load_batch_cache_ttl_ms"`
|
||||
// 快照桶读取时的 MGET 分块大小
|
||||
SnapshotMGetChunkSize int `mapstructure:"snapshot_mget_chunk_size"`
|
||||
// 快照重建时的缓存写入分块大小
|
||||
@ -1466,6 +1492,7 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
|
||||
cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed)
|
||||
cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove)
|
||||
cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy)
|
||||
cfg.SetTrustForwardedIPForAPIKeyACL(cfg.Security.TrustForwardedIPForAPIKeyACL)
|
||||
cfg.Log.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level))
|
||||
cfg.Log.Format = strings.ToLower(strings.TrimSpace(cfg.Log.Format))
|
||||
cfg.Log.ServiceName = strings.TrimSpace(cfg.Log.ServiceName)
|
||||
@ -1610,6 +1637,7 @@ func setDefaults() {
|
||||
viper.SetDefault("security.csp.enabled", true)
|
||||
viper.SetDefault("security.csp.policy", DefaultCSPPolicy)
|
||||
viper.SetDefault("security.proxy_probe.insecure_skip_verify", false)
|
||||
viper.SetDefault("security.trust_forwarded_ip_for_api_key_acl", false)
|
||||
|
||||
// Security - disable direct fallback on proxy error
|
||||
viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false)
|
||||
@ -1905,6 +1933,7 @@ func setDefaults() {
|
||||
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
|
||||
viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
|
||||
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
|
||||
viper.SetDefault("gateway.scheduling.load_batch_cache_ttl_ms", 200)
|
||||
viper.SetDefault("gateway.scheduling.snapshot_mget_chunk_size", 128)
|
||||
viper.SetDefault("gateway.scheduling.snapshot_write_chunk_size", 256)
|
||||
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
|
||||
@ -2766,6 +2795,9 @@ func (c *Config) Validate() error {
|
||||
if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 {
|
||||
return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive")
|
||||
}
|
||||
if c.Gateway.Scheduling.LoadBatchCacheTTLMS < 0 {
|
||||
return fmt.Errorf("gateway.scheduling.load_batch_cache_ttl_ms must be non-negative")
|
||||
}
|
||||
if c.Gateway.Scheduling.SnapshotMGetChunkSize <= 0 {
|
||||
return fmt.Errorf("gateway.scheduling.snapshot_mget_chunk_size must be positive")
|
||||
}
|
||||
|
||||
@ -73,6 +73,9 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) {
|
||||
if !cfg.Gateway.Scheduling.LoadBatchEnabled {
|
||||
t.Fatalf("LoadBatchEnabled = false, want true")
|
||||
}
|
||||
if cfg.Gateway.Scheduling.LoadBatchCacheTTLMS != 200 {
|
||||
t.Fatalf("LoadBatchCacheTTLMS = %d, want 200", cfg.Gateway.Scheduling.LoadBatchCacheTTLMS)
|
||||
}
|
||||
if cfg.Gateway.Scheduling.SlotCleanupInterval != 30*time.Second {
|
||||
t.Fatalf("SlotCleanupInterval = %v, want 30s", cfg.Gateway.Scheduling.SlotCleanupInterval)
|
||||
}
|
||||
@ -1415,6 +1418,11 @@ func TestValidateConfigErrors(t *testing.T) {
|
||||
mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
|
||||
wantErr: "gateway.scheduling.sticky_session_max_waiting",
|
||||
},
|
||||
{
|
||||
name: "gateway scheduling load batch cache ttl",
|
||||
mutate: func(c *Config) { c.Gateway.Scheduling.LoadBatchCacheTTLMS = -1 },
|
||||
wantErr: "gateway.scheduling.load_batch_cache_ttl_ms",
|
||||
},
|
||||
{
|
||||
name: "gateway scheduling outbox poll",
|
||||
mutate: func(c *Config) { c.Gateway.Scheduling.OutboxPollIntervalSeconds = 0 },
|
||||
|
||||
@ -20,34 +20,35 @@ func NewContentModerationHandler(svc *service.ContentModerationService) *Content
|
||||
}
|
||||
|
||||
type contentModerationConfigRequest struct {
|
||||
Enabled *bool `json:"enabled"`
|
||||
Mode *string `json:"mode"`
|
||||
BaseURL *string `json:"base_url"`
|
||||
Model *string `json:"model"`
|
||||
APIKey *string `json:"api_key"`
|
||||
APIKeys *[]string `json:"api_keys"`
|
||||
APIKeysMode string `json:"api_keys_mode"`
|
||||
DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"`
|
||||
ClearAPIKey bool `json:"clear_api_key"`
|
||||
TimeoutMS *int `json:"timeout_ms"`
|
||||
SampleRate *int `json:"sample_rate"`
|
||||
AllGroups *bool `json:"all_groups"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
RecordNonHits *bool `json:"record_non_hits"`
|
||||
WorkerCount *int `json:"worker_count"`
|
||||
QueueSize *int `json:"queue_size"`
|
||||
BlockStatus *int `json:"block_status"`
|
||||
BlockMessage *string `json:"block_message"`
|
||||
EmailOnHit *bool `json:"email_on_hit"`
|
||||
AutoBanEnabled *bool `json:"auto_ban_enabled"`
|
||||
BanThreshold *int `json:"ban_threshold"`
|
||||
ViolationWindowHours *int `json:"violation_window_hours"`
|
||||
RetryCount *int `json:"retry_count"`
|
||||
HitRetentionDays *int `json:"hit_retention_days"`
|
||||
NonHitRetentionDays *int `json:"non_hit_retention_days"`
|
||||
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
|
||||
BlockedKeywords *[]string `json:"blocked_keywords"`
|
||||
KeywordBlockingMode *string `json:"keyword_blocking_mode"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
Mode *string `json:"mode"`
|
||||
BaseURL *string `json:"base_url"`
|
||||
Model *string `json:"model"`
|
||||
APIKey *string `json:"api_key"`
|
||||
APIKeys *[]string `json:"api_keys"`
|
||||
APIKeysMode string `json:"api_keys_mode"`
|
||||
DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"`
|
||||
ClearAPIKey bool `json:"clear_api_key"`
|
||||
TimeoutMS *int `json:"timeout_ms"`
|
||||
SampleRate *int `json:"sample_rate"`
|
||||
AllGroups *bool `json:"all_groups"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
RecordNonHits *bool `json:"record_non_hits"`
|
||||
WorkerCount *int `json:"worker_count"`
|
||||
QueueSize *int `json:"queue_size"`
|
||||
BlockStatus *int `json:"block_status"`
|
||||
BlockMessage *string `json:"block_message"`
|
||||
EmailOnHit *bool `json:"email_on_hit"`
|
||||
AutoBanEnabled *bool `json:"auto_ban_enabled"`
|
||||
BanThreshold *int `json:"ban_threshold"`
|
||||
ViolationWindowHours *int `json:"violation_window_hours"`
|
||||
RetryCount *int `json:"retry_count"`
|
||||
HitRetentionDays *int `json:"hit_retention_days"`
|
||||
NonHitRetentionDays *int `json:"non_hit_retention_days"`
|
||||
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
|
||||
BlockedKeywords *[]string `json:"blocked_keywords"`
|
||||
KeywordBlockingMode *string `json:"keyword_blocking_mode"`
|
||||
ModelFilter *service.ContentModerationModelFilter `json:"model_filter"`
|
||||
}
|
||||
|
||||
type contentModerationAPIKeyTestRequest struct {
|
||||
@ -107,6 +108,7 @@ func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) {
|
||||
PreHashCheckEnabled: req.PreHashCheckEnabled,
|
||||
BlockedKeywords: req.BlockedKeywords,
|
||||
KeywordBlockingMode: req.KeywordBlockingMode,
|
||||
ModelFilter: req.ModelFilter,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
|
||||
@ -307,6 +307,51 @@ func (h *RedeemHandler) BatchDelete(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// BatchUpdate handles batch updating redeem codes
|
||||
// POST /api/v1/admin/redeem-codes/batch-update
|
||||
func (h *RedeemHandler) BatchUpdate(c *gin.Context) {
|
||||
if h.redeemService == nil {
|
||||
response.InternalError(c, "redeem service not configured")
|
||||
return
|
||||
}
|
||||
|
||||
var req dto.BatchUpdateRedeemCodesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.redeemService.BatchUpdate(c.Request.Context(), &service.RedeemCodeBatchUpdateInput{
|
||||
IDs: req.IDs,
|
||||
Fields: redeemBatchUpdateFieldsFromDTO(req.Fields),
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"updated": result.Updated,
|
||||
"message": "Redeem codes updated successfully",
|
||||
})
|
||||
}
|
||||
|
||||
func redeemBatchUpdateFieldsFromDTO(in dto.BatchUpdateRedeemCodeFields) service.RedeemCodeBatchUpdateFields {
|
||||
out := service.RedeemCodeBatchUpdateFields{
|
||||
Status: in.Status,
|
||||
Notes: in.Notes,
|
||||
Type: in.Type,
|
||||
Value: in.Value,
|
||||
}
|
||||
if in.ExpiresAt.Set {
|
||||
out.ExpiresAt = service.NullableTimeUpdate{Set: true, Value: in.ExpiresAt.Value}
|
||||
}
|
||||
if in.GroupID.Set {
|
||||
out.GroupID = service.NullableInt64Update{Set: true, Value: in.GroupID.Value}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Expire handles expiring a redeem code
|
||||
// POST /api/v1/admin/redeem-codes/:id/expire
|
||||
func (h *RedeemHandler) Expire(c *gin.Context) {
|
||||
|
||||
@ -142,6 +142,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
|
||||
APIKeyACLTrustForwardedIP: settings.APIKeyACLTrustForwardedIP,
|
||||
LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
|
||||
@ -264,6 +265,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
BalanceLowNotifyEnabled: settings.BalanceLowNotifyEnabled,
|
||||
BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold,
|
||||
BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL,
|
||||
SubscriptionExpiryNotifyEnabled: settings.SubscriptionExpiryNotifyEnabled,
|
||||
AccountQuotaNotifyEnabled: settings.AccountQuotaNotifyEnabled,
|
||||
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(settings.AccountQuotaNotifyEmails),
|
||||
PaymentEnabled: paymentCfg.Enabled,
|
||||
@ -399,6 +401,9 @@ type UpdateSettingsRequest struct {
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKey string `json:"turnstile_secret_key"`
|
||||
|
||||
// API Key IP 访问控制设置
|
||||
APIKeyACLTrustForwardedIP *bool `json:"api_key_acl_trust_forwarded_ip"`
|
||||
|
||||
// LinuxDo Connect OAuth 登录
|
||||
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
|
||||
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
|
||||
@ -582,12 +587,13 @@ type UpdateSettingsRequest struct {
|
||||
// OpenAI account scheduling
|
||||
OpenAIAdvancedSchedulerEnabled *bool `json:"openai_advanced_scheduler_enabled"`
|
||||
|
||||
// Balance low notification
|
||||
BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
|
||||
BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
|
||||
BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"`
|
||||
AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"`
|
||||
AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"`
|
||||
// 余额不足提醒
|
||||
BalanceLowNotifyEnabled *bool `json:"balance_low_notify_enabled"`
|
||||
BalanceLowNotifyThreshold *float64 `json:"balance_low_notify_threshold"`
|
||||
BalanceLowNotifyRechargeURL *string `json:"balance_low_notify_recharge_url"`
|
||||
SubscriptionExpiryNotifyEnabled *bool `json:"subscription_expiry_notify_enabled"`
|
||||
AccountQuotaNotifyEnabled *bool `json:"account_quota_notify_enabled"`
|
||||
AccountQuotaNotifyEmails *[]dto.NotifyEmailEntry `json:"account_quota_notify_emails"`
|
||||
|
||||
// Payment configuration (integrated into settings, full replace)
|
||||
PaymentEnabled *bool `json:"payment_enabled"`
|
||||
@ -1432,28 +1438,34 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
FrontendURL: req.FrontendURL,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
LoginAgreementEnabled: req.LoginAgreementEnabled,
|
||||
LoginAgreementMode: loginAgreementMode,
|
||||
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
|
||||
LoginAgreementDocuments: loginAgreementDocuments,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
FrontendURL: req.FrontendURL,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
LoginAgreementEnabled: req.LoginAgreementEnabled,
|
||||
LoginAgreementMode: loginAgreementMode,
|
||||
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
|
||||
LoginAgreementDocuments: loginAgreementDocuments,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
APIKeyACLTrustForwardedIP: func() bool {
|
||||
if req.APIKeyACLTrustForwardedIP != nil {
|
||||
return *req.APIKeyACLTrustForwardedIP
|
||||
}
|
||||
return previousSettings.APIKeyACLTrustForwardedIP
|
||||
}(),
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
@ -1669,6 +1681,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
return previousSettings.BalanceLowNotifyRechargeURL
|
||||
}(),
|
||||
SubscriptionExpiryNotifyEnabled: func() bool {
|
||||
if req.SubscriptionExpiryNotifyEnabled != nil {
|
||||
return *req.SubscriptionExpiryNotifyEnabled
|
||||
}
|
||||
return previousSettings.SubscriptionExpiryNotifyEnabled
|
||||
}(),
|
||||
AccountQuotaNotifyEnabled: func() bool {
|
||||
if req.AccountQuotaNotifyEnabled != nil {
|
||||
return *req.AccountQuotaNotifyEnabled
|
||||
@ -1869,6 +1887,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
|
||||
APIKeyACLTrustForwardedIP: updatedSettings.APIKeyACLTrustForwardedIP,
|
||||
LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
|
||||
@ -1989,6 +2008,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
BalanceLowNotifyEnabled: updatedSettings.BalanceLowNotifyEnabled,
|
||||
BalanceLowNotifyThreshold: updatedSettings.BalanceLowNotifyThreshold,
|
||||
BalanceLowNotifyRechargeURL: updatedSettings.BalanceLowNotifyRechargeURL,
|
||||
SubscriptionExpiryNotifyEnabled: updatedSettings.SubscriptionExpiryNotifyEnabled,
|
||||
AccountQuotaNotifyEnabled: updatedSettings.AccountQuotaNotifyEnabled,
|
||||
AccountQuotaNotifyEmails: dto.NotifyEmailEntriesFromService(updatedSettings.AccountQuotaNotifyEmails),
|
||||
PaymentEnabled: updatedPaymentCfg.Enabled,
|
||||
@ -2145,6 +2165,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if req.TurnstileSecretKey != "" {
|
||||
changed = append(changed, "turnstile_secret_key")
|
||||
}
|
||||
if before.APIKeyACLTrustForwardedIP != after.APIKeyACLTrustForwardedIP {
|
||||
changed = append(changed, "api_key_acl_trust_forwarded_ip")
|
||||
}
|
||||
if before.LinuxDoConnectEnabled != after.LinuxDoConnectEnabled {
|
||||
changed = append(changed, "linuxdo_connect_enabled")
|
||||
}
|
||||
@ -2454,7 +2477,7 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.OpenAIAdvancedSchedulerEnabled != after.OpenAIAdvancedSchedulerEnabled {
|
||||
changed = append(changed, "openai_advanced_scheduler_enabled")
|
||||
}
|
||||
// Balance & quota notification
|
||||
// 余额、订阅到期与账号限额通知
|
||||
if before.BalanceLowNotifyEnabled != after.BalanceLowNotifyEnabled {
|
||||
changed = append(changed, "balance_low_notify_enabled")
|
||||
}
|
||||
@ -2464,6 +2487,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.BalanceLowNotifyRechargeURL != after.BalanceLowNotifyRechargeURL {
|
||||
changed = append(changed, "balance_low_notify_recharge_url")
|
||||
}
|
||||
if before.SubscriptionExpiryNotifyEnabled != after.SubscriptionExpiryNotifyEnabled {
|
||||
changed = append(changed, "subscription_expiry_notify_enabled")
|
||||
}
|
||||
if before.AccountQuotaNotifyEnabled != after.AccountQuotaNotifyEnabled {
|
||||
changed = append(changed, "account_quota_notify_enabled")
|
||||
}
|
||||
@ -3472,6 +3498,8 @@ func emailTemplateEventOptionsToDTO(events []service.NotificationEmailEventInfo)
|
||||
Value: event.Event,
|
||||
Label: event.Label,
|
||||
Description: event.Description,
|
||||
Category: event.Category,
|
||||
Optional: event.Optional,
|
||||
})
|
||||
}
|
||||
return items
|
||||
|
||||
@ -2492,6 +2492,10 @@ func (r *oauthPendingFlowRedeemCodeRepo) Update(ctx context.Context, code *servi
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) BatchUpdate(context.Context, []int64, service.RedeemCodeBatchUpdateFields) (int64, error) {
|
||||
panic("unexpected BatchUpdate call")
|
||||
}
|
||||
|
||||
func (r *oauthPendingFlowRedeemCodeRepo) Delete(context.Context, int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
@ -32,6 +33,13 @@ func findCookie(cookies []*http.Cookie, name string) *http.Cookie {
|
||||
return nil
|
||||
}
|
||||
|
||||
func requireCookieCleared(t *testing.T, recorder *httptest.ResponseRecorder, name string) {
|
||||
t.Helper()
|
||||
cookie := findCookie(recorder.Result().Cookies(), name)
|
||||
require.NotNil(t, cookie)
|
||||
require.Equal(t, -1, cookie.MaxAge)
|
||||
}
|
||||
|
||||
func decodeCookieValueForTest(t *testing.T, value string) string {
|
||||
t.Helper()
|
||||
decoded, err := decodeCookieValue(value)
|
||||
@ -40,6 +48,13 @@ func decodeCookieValueForTest(t *testing.T, value string) string {
|
||||
}
|
||||
|
||||
func assertOAuthRedirectError(t *testing.T, location string, errorCode string, errorMessage string) {
|
||||
t.Helper()
|
||||
values := parseOAuthRedirectFragment(t, location)
|
||||
require.Equal(t, errorCode, values.Get("error"))
|
||||
require.Equal(t, errorMessage, values.Get("error_message"))
|
||||
}
|
||||
|
||||
func parseOAuthRedirectFragment(t *testing.T, location string) url.Values {
|
||||
t.Helper()
|
||||
require.NotEmpty(t, location)
|
||||
|
||||
@ -52,6 +67,5 @@ func assertOAuthRedirectError(t *testing.T, location string, errorCode string, e
|
||||
}
|
||||
values, err := url.ParseQuery(rawValues)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, errorCode, values.Get("error"))
|
||||
require.Equal(t, errorMessage, values.Get("error_message"))
|
||||
return values
|
||||
}
|
||||
|
||||
@ -454,6 +454,24 @@ func (h *AuthHandler) OIDCOAuthCallback(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 快捷路径:当上游返回已验证邮箱、部署不要求额外确认且本地没有同邮箱账号时,
|
||||
// 直接信任上游身份完成注册/登录,避免展示 choice 页。
|
||||
if compatEmailUser == nil &&
|
||||
strings.TrimSpace(compatEmail) != "" &&
|
||||
emailVerified != nil && *emailVerified {
|
||||
if handled := h.tryOIDCVerifiedEmailFastPath(
|
||||
c,
|
||||
frontendCallback,
|
||||
redirectTo,
|
||||
identityRef,
|
||||
compatEmail,
|
||||
username,
|
||||
upstreamClaims,
|
||||
); handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if h.isForceEmailOnThirdPartySignup(c.Request.Context()) {
|
||||
if err := h.createOIDCOAuthChoicePendingSession(
|
||||
c,
|
||||
@ -1190,3 +1208,70 @@ func oidcClearCookie(c *gin.Context, name string, secure bool) {
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
// tryOIDCVerifiedEmailFastPath 在 OIDC 上游已返回已验证邮箱时尝试跳过 choice/pending 页。
|
||||
// 返回 true 表示已经写出重定向响应;返回 false 表示调用方应继续回退到常规 choice 流程。
|
||||
func (h *AuthHandler) tryOIDCVerifiedEmailFastPath(
|
||||
c *gin.Context,
|
||||
frontendCallback string,
|
||||
redirectTo string,
|
||||
identity service.PendingAuthIdentityKey,
|
||||
compatEmail string,
|
||||
username string,
|
||||
upstreamClaims map[string]any,
|
||||
) bool {
|
||||
if h == nil || h.authService == nil || h.settingSvc == nil {
|
||||
return false
|
||||
}
|
||||
ctx := c.Request.Context()
|
||||
if h.isForceEmailOnThirdPartySignup(ctx) {
|
||||
return false
|
||||
}
|
||||
if h.settingSvc.IsInvitationCodeEnabled(ctx) {
|
||||
return false
|
||||
}
|
||||
if err := h.ensureBackendModeAllowsNewUserLogin(ctx); err != nil {
|
||||
log.Printf("[OIDC OAuth] verified-email fast path blocked by backend mode: reason=%s", infraerrors.Reason(err))
|
||||
clearOAuthPendingSessionCookie(c, isRequestHTTPS(c))
|
||||
clearOAuthPendingBrowserCookie(c, isRequestHTTPS(c))
|
||||
redirectOAuthError(c, frontendCallback, "login_blocked", infraerrors.Reason(err), infraerrors.Message(err))
|
||||
return true
|
||||
}
|
||||
|
||||
verifiedEmail := strings.TrimSpace(strings.ToLower(compatEmail))
|
||||
upstreamMetadata := make(map[string]any, len(upstreamClaims)+1)
|
||||
for k, v := range upstreamClaims {
|
||||
upstreamMetadata[k] = v
|
||||
}
|
||||
if syntheticEmail := pendingSessionStringValue(upstreamClaims, "email"); syntheticEmail != "" && !strings.EqualFold(syntheticEmail, verifiedEmail) {
|
||||
upstreamMetadata["synthetic_email"] = syntheticEmail
|
||||
}
|
||||
upstreamMetadata["email"] = verifiedEmail
|
||||
input := service.EmailOAuthIdentityInput{
|
||||
ProviderType: strings.TrimSpace(identity.ProviderType),
|
||||
ProviderKey: strings.TrimSpace(identity.ProviderKey),
|
||||
ProviderSubject: strings.TrimSpace(identity.ProviderSubject),
|
||||
Email: verifiedEmail,
|
||||
EmailVerified: true,
|
||||
Username: strings.TrimSpace(username),
|
||||
DisplayName: pendingSessionStringValue(upstreamClaims, "suggested_display_name"),
|
||||
AvatarURL: pendingSessionStringValue(upstreamClaims, "suggested_avatar_url"),
|
||||
UpstreamMetadata: upstreamMetadata,
|
||||
}
|
||||
tokenPair, _, err := h.authService.LoginOrRegisterVerifiedEmailOAuthWithInvitation(ctx, input, "", "")
|
||||
if err != nil {
|
||||
log.Printf("[OIDC OAuth] verified-email fast path skipped: reason=%s", infraerrors.Reason(err))
|
||||
return false
|
||||
}
|
||||
|
||||
fragment := url.Values{}
|
||||
fragment.Set("access_token", tokenPair.AccessToken)
|
||||
fragment.Set("refresh_token", tokenPair.RefreshToken)
|
||||
fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn))
|
||||
fragment.Set("token_type", "Bearer")
|
||||
fragment.Set("redirect", redirectTo)
|
||||
clearOAuthPendingSessionCookie(c, isRequestHTTPS(c))
|
||||
clearOAuthPendingBrowserCookie(c, isRequestHTTPS(c))
|
||||
redirectWithFragment(c, frontendCallback, fragment)
|
||||
return true
|
||||
}
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -926,6 +927,232 @@ func TestCompleteOIDCOAuthRegistrationRejectsIdentityOwnershipConflictBeforeUser
|
||||
require.Nil(t, storedSession.ConsumedAt)
|
||||
}
|
||||
|
||||
func TestTryOIDCVerifiedEmailFastPathCreatesUserAndIdentity(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, false)
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
ctx := context.Background()
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil)
|
||||
|
||||
identity := service.PendingAuthIdentityKey{
|
||||
ProviderType: "oidc",
|
||||
ProviderKey: "https://issuer.example.com",
|
||||
ProviderSubject: "fast-path-subject",
|
||||
}
|
||||
completed := handler.tryOIDCVerifiedEmailFastPath(
|
||||
c,
|
||||
"/auth/oidc/callback",
|
||||
"/dashboard",
|
||||
identity,
|
||||
"fastpath@example.com",
|
||||
"fastpath_user",
|
||||
map[string]any{
|
||||
"suggested_display_name": "Fast Path",
|
||||
"suggested_avatar_url": "",
|
||||
},
|
||||
)
|
||||
require.True(t, completed)
|
||||
require.Equal(t, http.StatusFound, recorder.Code)
|
||||
|
||||
location := recorder.Header().Get("Location")
|
||||
require.Contains(t, location, "/auth/oidc/callback")
|
||||
require.Contains(t, location, "access_token=")
|
||||
require.Contains(t, location, "refresh_token=")
|
||||
require.Contains(t, location, "token_type=Bearer")
|
||||
|
||||
user, err := client.User.Query().Where(dbuser.EmailEQ("fastpath@example.com")).Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fastpath_user", user.Username)
|
||||
require.Equal(t, "oidc", user.SignupSource)
|
||||
|
||||
identityRecord, err := client.AuthIdentity.Query().Where(
|
||||
authidentity.ProviderTypeEQ("oidc"),
|
||||
authidentity.ProviderKeyEQ("https://issuer.example.com"),
|
||||
authidentity.ProviderSubjectEQ("fast-path-subject"),
|
||||
authidentity.UserIDEQ(user.ID),
|
||||
).Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fastpath@example.com", identityRecord.Metadata["email"])
|
||||
require.Equal(t, true, identityRecord.Metadata["email_verified"])
|
||||
|
||||
pendingCount, err := client.PendingAuthSession.Query().Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, pendingCount)
|
||||
}
|
||||
|
||||
func TestOIDCOAuthCallbackVerifiedEmailFastPathIssuesTokenWithoutPendingSession(t *testing.T) {
|
||||
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
|
||||
Subject: "oidc-fast-callback-subject",
|
||||
PreferredUsername: "oidc_fast_callback",
|
||||
DisplayName: "OIDC Fast Callback",
|
||||
AvatarURL: "https://cdn.example/oidc-fast.png",
|
||||
Email: "oidc-fast-callback@example.com",
|
||||
EmailVerified: true,
|
||||
})
|
||||
defer cleanup()
|
||||
|
||||
handler, client := newOIDCOAuthHandlerAndClientWithSettings(t, false, cfg, nil)
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-fast-callback", nil)
|
||||
req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-fast-callback"))
|
||||
req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
|
||||
req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-fast-callback"))
|
||||
req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-fast-callback-subject"))
|
||||
req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
|
||||
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-fast-callback"))
|
||||
c.Request = req
|
||||
|
||||
handler.OIDCOAuthCallback(c)
|
||||
|
||||
require.Equal(t, http.StatusFound, recorder.Code)
|
||||
location := recorder.Header().Get("Location")
|
||||
require.Contains(t, location, "/auth/oidc/callback#")
|
||||
require.Contains(t, location, "access_token=")
|
||||
require.Contains(t, location, "refresh_token=")
|
||||
require.Contains(t, location, "token_type=Bearer")
|
||||
fragmentValues := parseOAuthRedirectFragment(t, location)
|
||||
require.Equal(t, "/dashboard", fragmentValues.Get("redirect"))
|
||||
requireCookieCleared(t, recorder, oauthPendingSessionCookieName)
|
||||
requireCookieCleared(t, recorder, oauthPendingBrowserCookieName)
|
||||
|
||||
ctx := context.Background()
|
||||
user, err := client.User.Query().Where(dbuser.EmailEQ("oidc-fast-callback@example.com")).Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "oidc_fast_callback", user.Username)
|
||||
require.Equal(t, "oidc", user.SignupSource)
|
||||
|
||||
identity, err := client.AuthIdentity.Query().Where(
|
||||
authidentity.ProviderTypeEQ("oidc"),
|
||||
authidentity.ProviderKeyEQ(cfg.IssuerURL),
|
||||
authidentity.ProviderSubjectEQ("oidc-fast-callback-subject"),
|
||||
authidentity.UserIDEQ(user.ID),
|
||||
).Only(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "oidc-fast-callback@example.com", identity.Metadata["email"])
|
||||
require.Equal(t, true, identity.Metadata["email_verified"])
|
||||
require.Equal(t, "OIDC Fast Callback", identity.Metadata["suggested_display_name"])
|
||||
require.NotEqual(t, identity.Metadata["email"], identity.Metadata["synthetic_email"])
|
||||
|
||||
pendingCount, err := client.PendingAuthSession.Query().Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, pendingCount)
|
||||
}
|
||||
|
||||
func TestOIDCOAuthCallbackVerifiedEmailFastPathBackendModeBlocksBeforeUserCreation(t *testing.T) {
|
||||
cfg, cleanup := newOIDCTestProvider(t, oidcProviderFixture{
|
||||
Subject: "oidc-fast-backend-mode-subject",
|
||||
PreferredUsername: "oidc_backend_mode",
|
||||
DisplayName: "OIDC Backend Mode",
|
||||
Email: "oidc-backend-mode@example.com",
|
||||
EmailVerified: true,
|
||||
})
|
||||
defer cleanup()
|
||||
|
||||
handler, client := newOIDCOAuthHandlerAndClientWithSettings(t, false, cfg, map[string]string{
|
||||
service.SettingKeyBackendModeEnabled: "true",
|
||||
})
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback?code=oidc-code&state=state-backend-mode", nil)
|
||||
req.AddCookie(encodedCookie(oidcOAuthStateCookieName, "state-backend-mode"))
|
||||
req.AddCookie(encodedCookie(oidcOAuthRedirectCookie, "/dashboard"))
|
||||
req.AddCookie(encodedCookie(oidcOAuthVerifierCookie, "verifier-backend-mode"))
|
||||
req.AddCookie(encodedCookie(oidcOAuthNonceCookie, "nonce-oidc-fast-backend-mode-subject"))
|
||||
req.AddCookie(encodedCookie(oidcOAuthIntentCookieName, oauthIntentLogin))
|
||||
req.AddCookie(encodedCookie(oauthPendingBrowserCookieName, "browser-backend-mode"))
|
||||
c.Request = req
|
||||
|
||||
handler.OIDCOAuthCallback(c)
|
||||
|
||||
require.Equal(t, http.StatusFound, recorder.Code)
|
||||
assertOAuthRedirectError(t, recorder.Header().Get("Location"), "login_blocked", "BACKEND_MODE_ADMIN_ONLY")
|
||||
requireCookieCleared(t, recorder, oauthPendingSessionCookieName)
|
||||
requireCookieCleared(t, recorder, oauthPendingBrowserCookieName)
|
||||
|
||||
ctx := context.Background()
|
||||
userCount, err := client.User.Query().Where(dbuser.EmailEQ("oidc-backend-mode@example.com")).Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, userCount)
|
||||
identityCount, err := client.AuthIdentity.Query().
|
||||
Where(authidentity.ProviderSubjectEQ("oidc-fast-backend-mode-subject")).
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, identityCount)
|
||||
pendingCount, err := client.PendingAuthSession.Query().Count(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, pendingCount)
|
||||
}
|
||||
|
||||
func TestTryOIDCVerifiedEmailFastPathSkippedWhenInvitationCodeRequired(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandler(t, true)
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil)
|
||||
|
||||
identity := service.PendingAuthIdentityKey{
|
||||
ProviderType: "oidc",
|
||||
ProviderKey: "https://issuer.example.com",
|
||||
ProviderSubject: "fast-path-skipped-invitation",
|
||||
}
|
||||
completed := handler.tryOIDCVerifiedEmailFastPath(
|
||||
c,
|
||||
"/auth/oidc/callback",
|
||||
"/dashboard",
|
||||
identity,
|
||||
"invite-only@example.com",
|
||||
"invite_only_user",
|
||||
map[string]any{},
|
||||
)
|
||||
require.False(t, completed)
|
||||
require.NotEqual(t, http.StatusFound, recorder.Code)
|
||||
|
||||
userCount, err := client.User.Query().Where(dbuser.EmailEQ("invite-only@example.com")).Count(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, userCount)
|
||||
}
|
||||
|
||||
func TestTryOIDCVerifiedEmailFastPathSkippedWhenForceEmailEnabled(t *testing.T) {
|
||||
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||
settingValues: map[string]string{
|
||||
service.SettingKeyForceEmailOnThirdPartySignup: "true",
|
||||
},
|
||||
})
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/oauth/oidc/callback", nil)
|
||||
|
||||
identity := service.PendingAuthIdentityKey{
|
||||
ProviderType: "oidc",
|
||||
ProviderKey: "https://issuer.example.com",
|
||||
ProviderSubject: "fast-path-skipped-force-email",
|
||||
}
|
||||
completed := handler.tryOIDCVerifiedEmailFastPath(
|
||||
c,
|
||||
"/auth/oidc/callback",
|
||||
"/dashboard",
|
||||
identity,
|
||||
"force-email@example.com",
|
||||
"force_email_user",
|
||||
map[string]any{},
|
||||
)
|
||||
require.False(t, completed)
|
||||
|
||||
userCount, err := client.User.Query().Where(dbuser.EmailEQ("force-email@example.com")).Count(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, userCount)
|
||||
}
|
||||
|
||||
type oidcProviderFixture struct {
|
||||
Subject string
|
||||
PreferredUsername string
|
||||
@ -957,6 +1184,49 @@ func newOIDCOAuthHandlerAndClient(t *testing.T, invitationEnabled bool, oauthCfg
|
||||
return handler, client
|
||||
}
|
||||
|
||||
func newOIDCOAuthHandlerAndClientWithSettings(
|
||||
t *testing.T,
|
||||
invitationEnabled bool,
|
||||
oauthCfg config.OIDCConnectConfig,
|
||||
settingValues map[string]string,
|
||||
) (*AuthHandler, *dbent.Client) {
|
||||
t.Helper()
|
||||
|
||||
values := map[string]string{
|
||||
service.SettingKeyOIDCConnectEnabled: "true",
|
||||
service.SettingKeyOIDCConnectProviderName: strings.TrimSpace(firstNonEmpty(oauthCfg.ProviderName, "OIDC")),
|
||||
service.SettingKeyOIDCConnectClientID: oauthCfg.ClientID,
|
||||
service.SettingKeyOIDCConnectClientSecret: oauthCfg.ClientSecret,
|
||||
service.SettingKeyOIDCConnectIssuerURL: oauthCfg.IssuerURL,
|
||||
service.SettingKeyOIDCConnectAuthorizeURL: oauthCfg.AuthorizeURL,
|
||||
service.SettingKeyOIDCConnectTokenURL: oauthCfg.TokenURL,
|
||||
service.SettingKeyOIDCConnectUserInfoURL: oauthCfg.UserInfoURL,
|
||||
service.SettingKeyOIDCConnectJWKSURL: oauthCfg.JWKSURL,
|
||||
service.SettingKeyOIDCConnectScopes: oauthCfg.Scopes,
|
||||
service.SettingKeyOIDCConnectRedirectURL: oauthCfg.RedirectURL,
|
||||
service.SettingKeyOIDCConnectFrontendRedirectURL: oauthCfg.FrontendRedirectURL,
|
||||
service.SettingKeyOIDCConnectTokenAuthMethod: oauthCfg.TokenAuthMethod,
|
||||
service.SettingKeyOIDCConnectUsePKCE: boolSettingValue(oauthCfg.UsePKCE),
|
||||
service.SettingKeyOIDCConnectValidateIDToken: boolSettingValue(oauthCfg.ValidateIDToken),
|
||||
service.SettingKeyOIDCConnectAllowedSigningAlgs: oauthCfg.AllowedSigningAlgs,
|
||||
service.SettingKeyOIDCConnectClockSkewSeconds: "120",
|
||||
service.SettingKeyOIDCConnectRequireEmailVerified: boolSettingValue(oauthCfg.RequireEmailVerified),
|
||||
}
|
||||
for key, value := range settingValues {
|
||||
values[key] = value
|
||||
}
|
||||
|
||||
handler, client := newOAuthPendingFlowTestHandlerWithDependencies(t, oauthPendingFlowTestHandlerOptions{
|
||||
invitationEnabled: invitationEnabled,
|
||||
settingValues: values,
|
||||
})
|
||||
if handler.cfg == nil {
|
||||
handler.cfg = &config.Config{}
|
||||
}
|
||||
handler.cfg.OIDC = oauthCfg
|
||||
return handler, client
|
||||
}
|
||||
|
||||
func newOIDCTestProvider(t *testing.T, fixture oidcProviderFixture) (config.OIDCConnectConfig, func()) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@ -50,6 +50,7 @@ type SystemSettings struct {
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"`
|
||||
APIKeyACLTrustForwardedIP bool `json:"api_key_acl_trust_forwarded_ip"`
|
||||
|
||||
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
|
||||
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
|
||||
@ -222,12 +223,13 @@ type SystemSettings struct {
|
||||
// Force Alipay mobile clients to use QR code payment instead of mobile redirect
|
||||
PaymentAlipayForceQRCode bool `json:"payment_alipay_force_qrcode"`
|
||||
|
||||
// Balance low notification
|
||||
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
|
||||
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
|
||||
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
|
||||
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
|
||||
AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"`
|
||||
// 余额、订阅到期与账号限额通知
|
||||
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
|
||||
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
|
||||
BalanceLowNotifyRechargeURL string `json:"balance_low_notify_recharge_url"`
|
||||
SubscriptionExpiryNotifyEnabled bool `json:"subscription_expiry_notify_enabled"`
|
||||
AccountQuotaNotifyEnabled bool `json:"account_quota_notify_enabled"`
|
||||
AccountQuotaNotifyEmails []NotifyEmailEntry `json:"account_quota_notify_emails"`
|
||||
|
||||
// Channel Monitor feature switch
|
||||
ChannelMonitorEnabled bool `json:"channel_monitor_enabled"`
|
||||
@ -378,11 +380,13 @@ type OpenAIFastPolicySettings struct {
|
||||
Rules []OpenAIFastPolicyRule `json:"rules"`
|
||||
}
|
||||
|
||||
// EmailTemplateEventOption describes an editable notification email event.
|
||||
// EmailTemplateEventOption 描述可编辑的通知邮件事件。
|
||||
type EmailTemplateEventOption struct {
|
||||
Value string `json:"value"`
|
||||
Label string `json:"label,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Category string `json:"category,omitempty"`
|
||||
Optional bool `json:"optional,omitempty"`
|
||||
}
|
||||
|
||||
// EmailTemplateSummary is shown in the admin email template list.
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
@ -359,6 +361,59 @@ type AdminRedeemCode struct {
|
||||
Notes string `json:"notes"`
|
||||
}
|
||||
|
||||
type NullableTimeField struct {
|
||||
Set bool
|
||||
Value *time.Time
|
||||
}
|
||||
|
||||
func (f *NullableTimeField) UnmarshalJSON(data []byte) error {
|
||||
f.Set = true
|
||||
if bytes.Equal(data, []byte("null")) {
|
||||
f.Value = nil
|
||||
return nil
|
||||
}
|
||||
var value time.Time
|
||||
if err := json.Unmarshal(data, &value); err != nil {
|
||||
return err
|
||||
}
|
||||
f.Value = &value
|
||||
return nil
|
||||
}
|
||||
|
||||
type NullableInt64Field struct {
|
||||
Set bool
|
||||
Value *int64
|
||||
}
|
||||
|
||||
func (f *NullableInt64Field) UnmarshalJSON(data []byte) error {
|
||||
f.Set = true
|
||||
if bytes.Equal(data, []byte("null")) {
|
||||
f.Value = nil
|
||||
return nil
|
||||
}
|
||||
var value int64
|
||||
if err := json.Unmarshal(data, &value); err != nil {
|
||||
return err
|
||||
}
|
||||
f.Value = &value
|
||||
return nil
|
||||
}
|
||||
|
||||
type BatchUpdateRedeemCodeFields struct {
|
||||
Status *string `json:"status,omitempty"`
|
||||
ExpiresAt NullableTimeField `json:"expires_at,omitempty"`
|
||||
Notes *string `json:"notes,omitempty"`
|
||||
GroupID NullableInt64Field `json:"group_id,omitempty"`
|
||||
|
||||
Type *string `json:"type,omitempty"`
|
||||
Value *float64 `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
type BatchUpdateRedeemCodesRequest struct {
|
||||
IDs []int64 `json:"ids" binding:"required,min=1"`
|
||||
Fields BatchUpdateRedeemCodeFields `json:"fields" binding:"required"`
|
||||
}
|
||||
|
||||
// UsageLog 是普通用户接口使用的 usage log DTO(不包含管理员字段)。
|
||||
type UsageLog struct {
|
||||
ID int64 `json:"id"`
|
||||
|
||||
@ -785,8 +785,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if channelMapping.Mapped {
|
||||
parsedReq.Model = channelMapping.MappedModel
|
||||
parsedReq.Body = h.gatewayService.ReplaceModelInBody(parsedReq.Body, channelMapping.MappedModel)
|
||||
body = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
// Bedrock CC 兼容:渠道模型映射后,清理 Anthropic API 专有字段、注入 Bedrock 必需字段
|
||||
parsedReq.Body = h.gatewayService.ApplyBedrockCCCompat(c.Request.Context(), parsedReq.Body, parsedReq.Model, account, apiKey.GroupID)
|
||||
body = parsedReq.Body
|
||||
|
||||
// 转发请求 - 根据账号平台分流
|
||||
c.Set("parsed_request", parsedReq)
|
||||
|
||||
@ -179,12 +179,16 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
|
||||
result, err := func() (*service.OpenAIForwardResult, error) {
|
||||
defer func() {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
}()
|
||||
return h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
|
||||
}()
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
@ -236,6 +240,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
|
||||
@ -333,11 +333,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
|
||||
result, err := func() (*service.OpenAIForwardResult, error) {
|
||||
defer func() {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
}()
|
||||
return h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
|
||||
}()
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
@ -389,6 +393,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
reqLog.Warn("openai.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
@ -722,12 +730,16 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
if channelMappingMsg.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMappingMsg.MappedModel)
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
|
||||
result, err := func() (*service.OpenAIForwardResult, error) {
|
||||
defer func() {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
}()
|
||||
return h.gatewayService.ForwardAsAnthropic(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
|
||||
}()
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
@ -775,6 +787,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
|
||||
h.handleAnthropicFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
reqLog.Warn("openai_messages.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
|
||||
@ -195,11 +195,15 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
result, err := h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel)
|
||||
result, err := func() (*service.OpenAIForwardResult, error) {
|
||||
defer func() {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
}()
|
||||
return h.gatewayService.ForwardImages(c.Request.Context(), c, account, body, parsed, channelMapping.MappedModel)
|
||||
}()
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
@ -217,6 +221,18 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
var imageUpstreamErr *service.OpenAIImagesUpstreamError
|
||||
if errors.As(err, &imageUpstreamErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
reqLog.Warn("openai.images.upstream_user_error",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("status_code", imageUpstreamErr.StatusCode),
|
||||
zap.String("error_type", imageUpstreamErr.ErrorType),
|
||||
zap.String("error_code", imageUpstreamErr.Code),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
@ -246,6 +262,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
if h.gatewayService.ShouldStopOpenAIOAuth429Failover(account, failoverErr.StatusCode, switchCount) {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
reqLog.Warn("openai.images.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
|
||||
@ -41,13 +41,18 @@ const (
|
||||
opsErrInsufficientQuota = "insufficient_quota"
|
||||
|
||||
// 上游错误码常量 — 错误分类 (normalizeOpsErrorType / classifyOpsPhase / classifyOpsIsBusinessLimited)
|
||||
opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE"
|
||||
opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED"
|
||||
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
|
||||
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
|
||||
opsCodeUserInactive = "USER_INACTIVE"
|
||||
opsCodeInvalidAPIKey = "INVALID_API_KEY"
|
||||
opsCodeAPIKeyRequired = "API_KEY_REQUIRED"
|
||||
opsCodeInsufficientBalance = "INSUFFICIENT_BALANCE"
|
||||
opsCodeUsageLimitExceeded = "USAGE_LIMIT_EXCEEDED"
|
||||
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
|
||||
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
|
||||
opsCodeUserInactive = "USER_INACTIVE"
|
||||
opsCodeInvalidAPIKey = "INVALID_API_KEY"
|
||||
opsCodeAPIKeyRequired = "API_KEY_REQUIRED"
|
||||
opsCodeAPIKeyExpired = "API_KEY_EXPIRED"
|
||||
opsCodeAPIKeyDisabled = "API_KEY_DISABLED"
|
||||
opsCodeUserNotFound = "USER_NOT_FOUND"
|
||||
opsCodeAPIKeyQuotaExhausted = "API_KEY_QUOTA_EXHAUSTED"
|
||||
opsCodeAPIKeyQueryDeprecated = "api_key_in_query_deprecated"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -1091,8 +1096,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
||||
if isOpsClientAuthError(code, msg) {
|
||||
return "auth"
|
||||
}
|
||||
switch strings.TrimSpace(code) {
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||
if isOpsLocalBusinessLimitError(code, msg) {
|
||||
return "request"
|
||||
}
|
||||
|
||||
@ -1151,8 +1155,10 @@ func classifyOpsErrorLog(c *gin.Context, errType, message, code string, status i
|
||||
if routingCapacityLimited {
|
||||
phase = "routing"
|
||||
}
|
||||
localClientAuthError := !upstreamError && phase == "auth" && isOpsClientAuthError(code, strings.ToLower(message))
|
||||
isBusinessLimited = routingCapacityLimited || clientBusinessLimited || classifyOpsIsBusinessLimited(errType, phase, code, status, message, localClientAuthError)
|
||||
msg := strings.ToLower(message)
|
||||
localClientAuthError := !upstreamError && phase == "auth" && isOpsClientAuthError(code, msg)
|
||||
localBusinessLimited := !upstreamError && classifyOpsIsBusinessLimited(errType, phase, code, status, message, localClientAuthError)
|
||||
isBusinessLimited = routingCapacityLimited || (clientBusinessLimited && !upstreamError) || localBusinessLimited
|
||||
errorOwner = classifyOpsErrorOwner(phase, message)
|
||||
errorSource = classifyOpsErrorSource(phase, message)
|
||||
return phase, isBusinessLimited, errorOwner, errorSource
|
||||
@ -1162,8 +1168,7 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa
|
||||
if len(localClientAuthError) > 0 && localClientAuthError[0] {
|
||||
return true
|
||||
}
|
||||
switch strings.TrimSpace(code) {
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
|
||||
if isOpsLocalBusinessLimitError(code, strings.ToLower(message)) {
|
||||
return true
|
||||
}
|
||||
if phase == "billing" || phase == "concurrency" {
|
||||
@ -1180,10 +1185,45 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa
|
||||
|
||||
func isOpsClientAuthError(code string, msg string) bool {
|
||||
switch strings.TrimSpace(code) {
|
||||
case opsCodeInvalidAPIKey, opsCodeAPIKeyRequired:
|
||||
case opsCodeInvalidAPIKey,
|
||||
opsCodeAPIKeyRequired,
|
||||
opsCodeAPIKeyExpired,
|
||||
opsCodeAPIKeyDisabled,
|
||||
opsCodeUserNotFound,
|
||||
opsCodeUserInactive:
|
||||
return true
|
||||
}
|
||||
return strings.Contains(msg, "invalid api key") || strings.Contains(msg, "api key is required")
|
||||
return strings.Contains(msg, "invalid api key") ||
|
||||
strings.Contains(msg, "api key is required") ||
|
||||
strings.Contains(msg, "api key is disabled") ||
|
||||
strings.Contains(msg, "user associated with api key not found") ||
|
||||
strings.Contains(msg, "user account is not active")
|
||||
}
|
||||
|
||||
func isOpsLocalBusinessLimitError(code string, msg string) bool {
|
||||
switch strings.TrimSpace(code) {
|
||||
case opsCodeInsufficientBalance,
|
||||
opsCodeUsageLimitExceeded,
|
||||
opsCodeSubscriptionNotFound,
|
||||
opsCodeSubscriptionInvalid,
|
||||
opsCodeAPIKeyQuotaExhausted,
|
||||
opsCodeAPIKeyQueryDeprecated:
|
||||
return true
|
||||
}
|
||||
return strings.Contains(msg, "api key in query parameter is deprecated") ||
|
||||
strings.Contains(msg, "query parameter api_key is deprecated") ||
|
||||
strings.Contains(msg, "no active subscription found for this group") ||
|
||||
strings.Contains(msg, opsErrInsufficientBalance) ||
|
||||
strings.Contains(msg, "insufficient account balance") ||
|
||||
strings.Contains(msg, "api key group platform is not gemini") ||
|
||||
strings.Contains(msg, "api key 额度已用完") ||
|
||||
strings.Contains(msg, "api key 5小时限额已用完") ||
|
||||
strings.Contains(msg, "api key 日限额已用完") ||
|
||||
strings.Contains(msg, "api key 7天限额已用完") ||
|
||||
strings.Contains(msg, "daily usage limit exceeded") ||
|
||||
strings.Contains(msg, "weekly usage limit exceeded") ||
|
||||
strings.Contains(msg, "monthly usage limit exceeded") ||
|
||||
strings.Contains(msg, "requests-per-minute limit exceeded")
|
||||
}
|
||||
|
||||
func hasOpsUpstreamErrorContext(c *gin.Context) bool {
|
||||
|
||||
@ -260,6 +260,34 @@ func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) {
|
||||
code: "API_KEY_REQUIRED",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "expired local API key",
|
||||
errType: "api_error",
|
||||
message: "API key 已过期",
|
||||
code: "API_KEY_EXPIRED",
|
||||
status: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "disabled local API key",
|
||||
errType: "api_error",
|
||||
message: "API key is disabled",
|
||||
code: "API_KEY_DISABLED",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "local API key user missing",
|
||||
errType: "api_error",
|
||||
message: "User associated with API key not found",
|
||||
code: "USER_NOT_FOUND",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "inactive local API key user",
|
||||
errType: "api_error",
|
||||
message: "User account is not active",
|
||||
code: "USER_INACTIVE",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "google invalid API key",
|
||||
errType: "api_error",
|
||||
@ -274,6 +302,27 @@ func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) {
|
||||
code: "401",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "google disabled API key",
|
||||
errType: "api_error",
|
||||
message: "API key is disabled",
|
||||
code: "401",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "google local API key user missing",
|
||||
errType: "api_error",
|
||||
message: "User associated with API key not found",
|
||||
code: "401",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "google inactive local API key user",
|
||||
errType: "api_error",
|
||||
message: "User account is not active",
|
||||
code: "401",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@ -294,6 +343,126 @@ func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyOpsLocalBusinessLimitErrorsExcludedFromSLA(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
message string
|
||||
code string
|
||||
status int
|
||||
wantErrType string
|
||||
wantPhase string
|
||||
}{
|
||||
{
|
||||
name: "standard API key quota exhausted",
|
||||
errType: "api_error",
|
||||
message: "API key 额度已用完",
|
||||
code: "API_KEY_QUOTA_EXHAUSTED",
|
||||
status: http.StatusTooManyRequests,
|
||||
wantErrType: "api_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "standard query API key deprecated",
|
||||
errType: "api_error",
|
||||
message: "API key in query parameter is deprecated. Please use Authorization header instead.",
|
||||
code: "api_key_in_query_deprecated",
|
||||
status: http.StatusBadRequest,
|
||||
wantErrType: "api_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "google query API key deprecated",
|
||||
errType: "api_error",
|
||||
message: "Query parameter api_key is deprecated. Use Authorization header or key instead.",
|
||||
code: "400",
|
||||
status: http.StatusBadRequest,
|
||||
wantErrType: "api_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "google no active subscription",
|
||||
errType: "api_error",
|
||||
message: "No active subscription found for this group",
|
||||
code: "403",
|
||||
status: http.StatusForbidden,
|
||||
wantErrType: "api_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "google insufficient account balance",
|
||||
errType: "api_error",
|
||||
message: "Insufficient account balance",
|
||||
code: "403",
|
||||
status: http.StatusForbidden,
|
||||
wantErrType: "api_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "gateway billing cache insufficient balance",
|
||||
errType: "billing_error",
|
||||
message: "insufficient balance",
|
||||
code: "",
|
||||
status: http.StatusForbidden,
|
||||
wantErrType: "billing_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "gemini group platform mismatch",
|
||||
errType: "api_error",
|
||||
message: "API key group platform is not gemini",
|
||||
code: "400",
|
||||
status: http.StatusBadRequest,
|
||||
wantErrType: "api_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "gateway API key 5h rate limit",
|
||||
errType: "api_error",
|
||||
message: "api key 5小时限额已用完",
|
||||
code: "rate_limit_exceeded",
|
||||
status: http.StatusTooManyRequests,
|
||||
wantErrType: "api_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "gateway group RPM limit",
|
||||
errType: "api_error",
|
||||
message: "group requests-per-minute limit exceeded",
|
||||
code: "rate_limit_exceeded",
|
||||
status: http.StatusTooManyRequests,
|
||||
wantErrType: "api_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
{
|
||||
name: "google subscription daily limit",
|
||||
errType: "api_error",
|
||||
message: "daily usage limit exceeded",
|
||||
code: "429",
|
||||
status: http.StatusTooManyRequests,
|
||||
wantErrType: "api_error",
|
||||
wantPhase: "request",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
errType := normalizeOpsErrorType(tt.errType, tt.code)
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, tt.message, tt.code, tt.status)
|
||||
|
||||
require.Equal(t, tt.wantErrType, errType)
|
||||
require.Equal(t, tt.wantPhase, phase)
|
||||
require.True(t, isBusinessLimited)
|
||||
require.Equal(t, "client", errorOwner)
|
||||
require.Equal(t, "client_request", errorSource)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyOpsIPRestrictionAccessDeniedExcludedFromSLA(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
@ -372,23 +541,71 @@ func TestClassifyOpsUnmarkedNoAvailableTextStillCountsForSLA(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClassifyOpsUpstreamAuthTextStillCountsForSLA(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
service.SetOpsUpstreamError(c, http.StatusUnauthorized, "Invalid API key", "")
|
||||
tests := []struct {
|
||||
name string
|
||||
message string
|
||||
code string
|
||||
status int
|
||||
}{
|
||||
{
|
||||
name: "invalid API key",
|
||||
message: "Invalid API key",
|
||||
code: "401",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "disabled API key",
|
||||
message: "API key is disabled",
|
||||
code: "API_KEY_DISABLED",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "gemini group platform mismatch",
|
||||
message: "API key group platform is not gemini",
|
||||
code: "400",
|
||||
status: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "provider balance error",
|
||||
message: "Insufficient account balance",
|
||||
code: "INSUFFICIENT_BALANCE",
|
||||
status: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "provider subscription error",
|
||||
message: "No active subscription found for this group",
|
||||
code: "SUBSCRIPTION_NOT_FOUND",
|
||||
status: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "provider quota error",
|
||||
message: "api key 额度已用完",
|
||||
code: "API_KEY_QUOTA_EXHAUSTED",
|
||||
status: http.StatusTooManyRequests,
|
||||
},
|
||||
}
|
||||
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
|
||||
c,
|
||||
"api_error",
|
||||
"Invalid API key",
|
||||
"401",
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
service.SetOpsUpstreamError(c, tt.status, tt.message, "")
|
||||
|
||||
require.Equal(t, "upstream", phase)
|
||||
require.False(t, isBusinessLimited)
|
||||
require.Equal(t, "provider", errorOwner)
|
||||
require.Equal(t, "upstream_http", errorSource)
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
|
||||
c,
|
||||
"api_error",
|
||||
tt.message,
|
||||
tt.code,
|
||||
tt.status,
|
||||
)
|
||||
|
||||
require.Equal(t, "upstream", phase)
|
||||
require.False(t, isBusinessLimited)
|
||||
require.Equal(t, "provider", errorOwner)
|
||||
require.Equal(t, "upstream_http", errorSource)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyOpsUpstreamNoAvailableTextStillCountsForSLA(t *testing.T) {
|
||||
|
||||
@ -89,7 +89,7 @@ func responsesInputToChatMessages(instructions string, inputRaw json.RawMessage)
|
||||
return nil, fmt.Errorf("parse responses input item: %w", err)
|
||||
}
|
||||
|
||||
role := rawString(item["role"])
|
||||
role := chatCompletionsBridgeRole(rawString(item["role"]))
|
||||
itemType := rawString(item["type"])
|
||||
switch itemType {
|
||||
case "function_call":
|
||||
@ -130,9 +130,6 @@ func responsesInputToChatMessages(instructions string, inputRaw json.RawMessage)
|
||||
continue
|
||||
}
|
||||
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
content := item["content"]
|
||||
if len(bytesTrimSpace(content)) == 0 {
|
||||
if text := rawString(item["text"]); text != "" {
|
||||
@ -152,6 +149,17 @@ func responsesInputToChatMessages(instructions string, inputRaw json.RawMessage)
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func chatCompletionsBridgeRole(role string) string {
|
||||
trimmed := strings.TrimSpace(role)
|
||||
if trimmed == "" {
|
||||
return "user"
|
||||
}
|
||||
if strings.EqualFold(trimmed, "developer") {
|
||||
return "system"
|
||||
}
|
||||
return role
|
||||
}
|
||||
|
||||
func responsesContentToChatContent(raw json.RawMessage, role string) (json.RawMessage, error) {
|
||||
raw = bytesTrimSpace(raw)
|
||||
if len(raw) == 0 || string(raw) == "null" {
|
||||
|
||||
@ -0,0 +1,82 @@
|
||||
package apicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestResponsesInputToChatMessages_DeveloperRoleMapsToSystem(t *testing.T) {
|
||||
messages, err := responsesInputToChatMessages("", json.RawMessage(`[{"role":"developer","content":"follow project instructions"}]`))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 1)
|
||||
|
||||
assert.Equal(t, "system", messages[0].Role)
|
||||
assert.JSONEq(t, `"follow project instructions"`, string(messages[0].Content))
|
||||
}
|
||||
|
||||
func TestResponsesInputToChatMessages_KeepsChatCompletionRoles(t *testing.T) {
|
||||
input := json.RawMessage(`[
|
||||
{"role":"system","content":"system message"},
|
||||
{"role":"user","content":"user message"},
|
||||
{"role":"assistant","content":"assistant message"},
|
||||
{"role":"tool","content":"tool message"}
|
||||
]`)
|
||||
|
||||
messages, err := responsesInputToChatMessages("", input)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 4)
|
||||
|
||||
assert.Equal(t, []string{"system", "user", "assistant", "tool"}, chatMessageRoles(messages))
|
||||
}
|
||||
|
||||
func TestResponsesInputToChatMessages_EmptyRoleFallsBackToUser(t *testing.T) {
|
||||
messages, err := responsesInputToChatMessages("", json.RawMessage(`[{"role":"","content":"hello"}]`))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 1)
|
||||
|
||||
assert.Equal(t, "user", messages[0].Role)
|
||||
}
|
||||
|
||||
func TestResponsesInputToChatMessages_DeveloperRoleTrimAndCaseInsensitive(t *testing.T) {
|
||||
input := json.RawMessage(`[
|
||||
{"role":" Developer ","content":"one"},
|
||||
{"role":"\tDEVELOPER\n","content":"two"}
|
||||
]`)
|
||||
|
||||
messages, err := responsesInputToChatMessages("", input)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 2)
|
||||
|
||||
assert.Equal(t, []string{"system", "system"}, chatMessageRoles(messages))
|
||||
}
|
||||
|
||||
func TestResponsesToChatCompletionsRequest_InstructionsAndInputDeveloperRole(t *testing.T) {
|
||||
req := &ResponsesRequest{
|
||||
Model: "gpt-4o",
|
||||
Instructions: "Use concise answers.",
|
||||
Input: json.RawMessage(`[
|
||||
{"role":"developer","content":[{"type":"input_text","text":"Prefer JSON."}]},
|
||||
{"role":"user","content":"Hello"}
|
||||
]`),
|
||||
}
|
||||
|
||||
out, err := ResponsesToChatCompletionsRequest(req)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, out.Messages, 3)
|
||||
|
||||
assert.Equal(t, []string{"system", "system", "user"}, chatMessageRoles(out.Messages))
|
||||
assert.JSONEq(t, `"Use concise answers."`, string(out.Messages[0].Content))
|
||||
assert.JSONEq(t, `"Prefer JSON."`, string(out.Messages[1].Content))
|
||||
assert.JSONEq(t, `"Hello"`, string(out.Messages[2].Content))
|
||||
}
|
||||
|
||||
func chatMessageRoles(messages []ChatMessage) []string {
|
||||
roles := make([]string, 0, len(messages))
|
||||
for _, message := range messages {
|
||||
roles = append(roles, message.Role)
|
||||
}
|
||||
return roles
|
||||
}
|
||||
@ -325,6 +325,10 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
if account == nil {
|
||||
return nil
|
||||
}
|
||||
schedulable := account.Schedulable
|
||||
if account.Status == service.StatusError {
|
||||
schedulable = false
|
||||
}
|
||||
|
||||
builder := r.client.Account.UpdateOneID(account.ID).
|
||||
SetName(account.Name).
|
||||
@ -337,7 +341,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
SetPriority(account.Priority).
|
||||
SetStatus(account.Status).
|
||||
SetErrorMessage(account.ErrorMessage).
|
||||
SetSchedulable(account.Schedulable).
|
||||
SetSchedulable(schedulable).
|
||||
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
|
||||
|
||||
if account.RateMultiplier != nil {
|
||||
@ -458,6 +462,7 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
r.deleteSchedulerAccountSnapshot(ctx, id)
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
|
||||
}
|
||||
@ -724,6 +729,7 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
SetStatus(service.StatusError).
|
||||
SetErrorMessage(errorMsg).
|
||||
SetSchedulable(false).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -757,6 +763,15 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
|
||||
}
|
||||
}
|
||||
|
||||
func (r *accountRepository) deleteSchedulerAccountSnapshot(ctx context.Context, accountID int64) {
|
||||
if r == nil || r.schedulerCache == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
if err := r.schedulerCache.DeleteAccount(ctx, accountID); err != nil {
|
||||
logger.LegacyPrintf("repository.account", "[Scheduler] delete account snapshot failed: id=%d err=%v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *accountRepository) syncSchedulerAccountSnapshots(ctx context.Context, accountIDs []int64) {
|
||||
if r == nil || r.schedulerCache == nil || len(accountIDs) == 0 {
|
||||
return
|
||||
|
||||
@ -23,6 +23,7 @@ type AccountRepoSuite struct {
|
||||
|
||||
type schedulerCacheRecorder struct {
|
||||
setAccounts []*service.Account
|
||||
deleteIDs []int64
|
||||
accounts map[int64]*service.Account
|
||||
}
|
||||
|
||||
@ -53,6 +54,10 @@ func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *servic
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) DeleteAccount(ctx context.Context, accountID int64) error {
|
||||
s.deleteIDs = append(s.deleteIDs, accountID)
|
||||
if s.accounts != nil {
|
||||
delete(s.accounts, accountID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -185,6 +190,27 @@ func (s *AccountRepoSuite) TestDelete() {
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete_RemovesSchedulerAccountSnapshot() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete-cache"})
|
||||
cacheRecorder := &schedulerCacheRecorder{
|
||||
accounts: map[int64]*service.Account{
|
||||
account.ID: {
|
||||
ID: account.ID,
|
||||
Name: account.Name,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
err := s.repo.Delete(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
s.Require().Equal([]int64{account.ID}, cacheRecorder.deleteIDs)
|
||||
s.Require().NotContains(cacheRecorder.accounts, account.ID)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-del"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-del"})
|
||||
@ -729,7 +755,7 @@ func (s *AccountRepoSuite) TestUpdateLastUsed() {
|
||||
// --- SetError ---
|
||||
|
||||
func (s *AccountRepoSuite) TestSetError() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-err", Status: service.StatusActive})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-err", Status: service.StatusActive, Schedulable: true})
|
||||
|
||||
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
|
||||
|
||||
@ -737,6 +763,22 @@ func (s *AccountRepoSuite) TestSetError() {
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.StatusError, got.Status)
|
||||
s.Require().Equal("something went wrong", got.ErrorMessage)
|
||||
s.Require().False(got.Schedulable)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateErrorStatusUnschedulesAccount() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-update-err", Status: service.StatusActive, Schedulable: true})
|
||||
account.Status = service.StatusError
|
||||
account.ErrorMessage = "token revoked"
|
||||
account.Schedulable = true
|
||||
|
||||
s.Require().NoError(s.repo.Update(s.ctx, account))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.StatusError, got.Status)
|
||||
s.Require().Equal("token revoked", got.ErrorMessage)
|
||||
s.Require().False(got.Schedulable)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestClearError_SyncSchedulerSnapshotOnRecovery() {
|
||||
|
||||
@ -236,6 +236,91 @@ func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemC
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) BatchUpdate(ctx context.Context, ids []int64, fields service.RedeemCodeBatchUpdateFields) (int64, error) {
|
||||
uniqueIDs := make([]int64, 0, len(ids))
|
||||
seen := make(map[int64]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
uniqueIDs = append(uniqueIDs, id)
|
||||
}
|
||||
if len(uniqueIDs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
return r.batchUpdate(ctx, tx.Client(), uniqueIDs, fields)
|
||||
}
|
||||
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
updated, err := r.batchUpdate(txCtx, tx.Client(), uniqueIDs, fields)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) batchUpdate(ctx context.Context, client *dbent.Client, ids []int64, fields service.RedeemCodeBatchUpdateFields) (int64, error) {
|
||||
existing, err := client.RedeemCode.Query().
|
||||
Where(redeemcode.IDIn(ids...)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(existing) != len(ids) {
|
||||
return 0, service.ErrRedeemCodeNotFound
|
||||
}
|
||||
if fields.TouchesUsedSensitiveFields() {
|
||||
for _, code := range existing {
|
||||
if code.Status == service.StatusUsed {
|
||||
return 0, service.ErrRedeemCodeUsed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
up := client.RedeemCode.Update().Where(redeemcode.IDIn(ids...))
|
||||
if fields.Status != nil {
|
||||
up.SetStatus(*fields.Status)
|
||||
}
|
||||
if fields.Notes != nil {
|
||||
up.SetNotes(*fields.Notes)
|
||||
}
|
||||
if fields.ExpiresAt.Set {
|
||||
if fields.ExpiresAt.Value != nil {
|
||||
up.SetExpiresAt(*fields.ExpiresAt.Value)
|
||||
} else {
|
||||
up.ClearExpiresAt()
|
||||
}
|
||||
}
|
||||
if fields.GroupID.Set {
|
||||
if fields.GroupID.Value != nil {
|
||||
up.SetGroupID(*fields.GroupID.Value)
|
||||
} else {
|
||||
up.ClearGroupID()
|
||||
}
|
||||
}
|
||||
|
||||
affected, err := up.Save(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if affected != len(ids) {
|
||||
return 0, service.ErrRedeemCodeNotFound
|
||||
}
|
||||
return int64(affected), nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
|
||||
now := time.Now()
|
||||
client := clientFromContext(ctx, r.client)
|
||||
|
||||
@ -4,6 +4,7 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -21,8 +22,8 @@ type RedeemCodeRepoSuite struct {
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.ctx = dbent.NewTxContext(context.Background(), tx)
|
||||
s.client = tx.Client()
|
||||
s.repo = NewRedeemCodeRepository(s.client).(*redeemCodeRepository)
|
||||
}
|
||||
@ -237,6 +238,123 @@ func (s *RedeemCodeRepoSuite) TestUpdate() {
|
||||
s.Require().Equal(float64(50), got.Value)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestBatchUpdate_PartialFieldsAndClear() {
|
||||
group := s.createGroup(uniqueTestValue(s.T(), "batch-update-group"))
|
||||
groupID := group.ID
|
||||
expiresAt := time.Now().UTC().Add(2 * time.Hour)
|
||||
status := service.StatusDisabled
|
||||
notes := "batch note"
|
||||
|
||||
codeA := &service.RedeemCode{
|
||||
Code: "BATCH-UP-A",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 10,
|
||||
Status: service.StatusUnused,
|
||||
Notes: "old",
|
||||
}
|
||||
codeB := &service.RedeemCode{
|
||||
Code: "BATCH-UP-B",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 20,
|
||||
Status: service.StatusUnused,
|
||||
Notes: "old",
|
||||
}
|
||||
untouched := &service.RedeemCode{
|
||||
Code: "BATCH-UP-C",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 30,
|
||||
Status: service.StatusUnused,
|
||||
Notes: "keep",
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, codeA))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, codeB))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, untouched))
|
||||
|
||||
updated, err := s.repo.BatchUpdate(s.ctx, []int64{codeA.ID, codeB.ID}, service.RedeemCodeBatchUpdateFields{
|
||||
Status: &status,
|
||||
ExpiresAt: service.NullableTimeUpdate{Set: true, Value: &expiresAt},
|
||||
Notes: ¬es,
|
||||
GroupID: service.NullableInt64Update{Set: true, Value: &groupID},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(int64(2), updated)
|
||||
|
||||
gotA, err := s.repo.GetByID(s.ctx, codeA.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("BATCH-UP-A", gotA.Code)
|
||||
s.Require().Equal(service.RedeemTypeBalance, gotA.Type)
|
||||
s.Require().Equal(float64(10), gotA.Value)
|
||||
s.Require().Equal(service.StatusDisabled, gotA.Status)
|
||||
s.Require().Equal(notes, gotA.Notes)
|
||||
s.Require().NotNil(gotA.ExpiresAt)
|
||||
s.Require().WithinDuration(expiresAt, *gotA.ExpiresAt, time.Second)
|
||||
s.Require().NotNil(gotA.GroupID)
|
||||
s.Require().Equal(groupID, *gotA.GroupID)
|
||||
|
||||
gotB, err := s.repo.GetByID(s.ctx, codeB.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.StatusDisabled, gotB.Status)
|
||||
s.Require().Equal(notes, gotB.Notes)
|
||||
|
||||
gotUntouched, err := s.repo.GetByID(s.ctx, untouched.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.StatusUnused, gotUntouched.Status)
|
||||
s.Require().Equal("keep", gotUntouched.Notes)
|
||||
s.Require().Nil(gotUntouched.ExpiresAt)
|
||||
s.Require().Nil(gotUntouched.GroupID)
|
||||
|
||||
updated, err = s.repo.BatchUpdate(s.ctx, []int64{codeA.ID}, service.RedeemCodeBatchUpdateFields{
|
||||
ExpiresAt: service.NullableTimeUpdate{Set: true},
|
||||
GroupID: service.NullableInt64Update{Set: true},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(int64(1), updated)
|
||||
|
||||
gotA, err = s.repo.GetByID(s.ctx, codeA.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(gotA.ExpiresAt)
|
||||
s.Require().Nil(gotA.GroupID)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestBatchUpdate_InvalidIDRollsBack() {
|
||||
code := &service.RedeemCode{
|
||||
Code: "BATCH-UP-ROLLBACK",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 10,
|
||||
Status: service.StatusUnused,
|
||||
Notes: "keep",
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, code))
|
||||
notes := "changed"
|
||||
|
||||
_, err := s.repo.BatchUpdate(s.ctx, []int64{code.ID, 999999}, service.RedeemCodeBatchUpdateFields{Notes: ¬es})
|
||||
s.Require().Error(err)
|
||||
s.Require().True(errors.Is(err, service.ErrRedeemCodeNotFound))
|
||||
|
||||
got, getErr := s.repo.GetByID(s.ctx, code.ID)
|
||||
s.Require().NoError(getErr)
|
||||
s.Require().Equal("keep", got.Notes)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestBatchUpdate_UsedCodeRejectsSensitiveFields() {
|
||||
code := &service.RedeemCode{
|
||||
Code: "BATCH-UP-USED",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 10,
|
||||
Status: service.StatusUsed,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, code))
|
||||
status := service.StatusDisabled
|
||||
|
||||
_, err := s.repo.BatchUpdate(s.ctx, []int64{code.ID}, service.RedeemCodeBatchUpdateFields{Status: &status})
|
||||
s.Require().Error(err)
|
||||
s.Require().True(errors.Is(err, service.ErrRedeemCodeUsed))
|
||||
|
||||
got, getErr := s.repo.GetByID(s.ctx, code.ID)
|
||||
s.Require().NoError(getErr)
|
||||
s.Require().Equal(service.StatusUsed, got.Status)
|
||||
}
|
||||
|
||||
// --- Use ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUse() {
|
||||
|
||||
@ -757,6 +757,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"site_logo": "",
|
||||
"site_subtitle": "Subtitle",
|
||||
"api_base_url": "https://api.example.com",
|
||||
"api_key_acl_trust_forwarded_ip": false,
|
||||
"contact_info": "support",
|
||||
"doc_url": "https://docs.example.com",
|
||||
"auth_source_default_email_balance": 0,
|
||||
@ -862,6 +863,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"payment_alipay_force_qrcode": false,
|
||||
"balance_low_notify_enabled": false,
|
||||
"account_quota_notify_enabled": false,
|
||||
"subscription_expiry_notify_enabled": true,
|
||||
"balance_low_notify_threshold": 0,
|
||||
"balance_low_notify_recharge_url": "",
|
||||
"account_quota_notify_emails": [],
|
||||
@ -1014,6 +1016,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"site_logo": "",
|
||||
"site_subtitle": "Subscription to API Conversion Platform",
|
||||
"api_base_url": "",
|
||||
"api_key_acl_trust_forwarded_ip": false,
|
||||
"contact_info": "",
|
||||
"doc_url": "",
|
||||
"home_content": "",
|
||||
@ -1086,6 +1089,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"payment_alipay_force_qrcode": false,
|
||||
"balance_low_notify_enabled": false,
|
||||
"account_quota_notify_enabled": false,
|
||||
"subscription_expiry_notify_enabled": true,
|
||||
"balance_low_notify_threshold": 0,
|
||||
"balance_low_notify_recharge_url": "",
|
||||
"account_quota_notify_emails": [],
|
||||
@ -1254,7 +1258,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
@ -1852,6 +1856,10 @@ func (stubRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode)
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (stubRedeemCodeRepo) BatchUpdate(ctx context.Context, ids []int64, fields service.RedeemCodeBatchUpdateFields) (int64, error) {
|
||||
return int64(len(ids)), nil
|
||||
}
|
||||
|
||||
func (stubRedeemCodeRepo) Delete(ctx context.Context, id int64) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
@ -18,7 +18,6 @@ import (
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
|
||||
// ProviderSet 提供服务器层的依赖
|
||||
@ -102,6 +101,16 @@ func ProvideRouter(
|
||||
// ProvideHTTPServer 提供 HTTP 服务器
|
||||
func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
|
||||
httpHandler := http.Handler(router)
|
||||
server := &http.Server{
|
||||
Addr: cfg.Server.Address(),
|
||||
Handler: httpHandler,
|
||||
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
|
||||
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
|
||||
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
|
||||
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
|
||||
// 注意:不设置 WriteTimeout,因为流式响应可能持续十几分钟
|
||||
// 不设置 ReadTimeout,因为大请求体可能需要较长时间读取
|
||||
}
|
||||
|
||||
globalMaxSize := cfg.Server.MaxRequestBodySize
|
||||
if globalMaxSize <= 0 {
|
||||
@ -115,32 +124,31 @@ func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server {
|
||||
// 根据配置决定是否启用 H2C
|
||||
if cfg.Server.H2C.Enabled {
|
||||
h2cConfig := cfg.Server.H2C
|
||||
httpHandler = h2c.NewHandler(router, &http2.Server{
|
||||
if err := http2.ConfigureServer(server, &http2.Server{
|
||||
MaxConcurrentStreams: h2cConfig.MaxConcurrentStreams,
|
||||
IdleTimeout: time.Duration(h2cConfig.IdleTimeout) * time.Second,
|
||||
MaxReadFrameSize: uint32(h2cConfig.MaxReadFrameSize),
|
||||
MaxUploadBufferPerConnection: int32(h2cConfig.MaxUploadBufferPerConnection),
|
||||
MaxUploadBufferPerStream: int32(h2cConfig.MaxUploadBufferPerStream),
|
||||
})
|
||||
log.Printf("HTTP/2 Cleartext (h2c) enabled: max_concurrent_streams=%d, idle_timeout=%ds, max_read_frame_size=%d, max_upload_buffer_per_connection=%d, max_upload_buffer_per_stream=%d",
|
||||
h2cConfig.MaxConcurrentStreams,
|
||||
h2cConfig.IdleTimeout,
|
||||
h2cConfig.MaxReadFrameSize,
|
||||
h2cConfig.MaxUploadBufferPerConnection,
|
||||
h2cConfig.MaxUploadBufferPerStream,
|
||||
)
|
||||
}); err != nil {
|
||||
log.Printf("Failed to configure HTTP/2 Cleartext (h2c): %v", err)
|
||||
} else {
|
||||
protocols := new(http.Protocols)
|
||||
protocols.SetHTTP1(true)
|
||||
protocols.SetUnencryptedHTTP2(true)
|
||||
server.Protocols = protocols
|
||||
log.Printf("HTTP/2 Cleartext (h2c) enabled: max_concurrent_streams=%d, idle_timeout=%ds, max_read_frame_size=%d, max_upload_buffer_per_connection=%d, max_upload_buffer_per_stream=%d",
|
||||
h2cConfig.MaxConcurrentStreams,
|
||||
h2cConfig.IdleTimeout,
|
||||
h2cConfig.MaxReadFrameSize,
|
||||
h2cConfig.MaxUploadBufferPerConnection,
|
||||
h2cConfig.MaxUploadBufferPerStream,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Server{
|
||||
Addr: cfg.Server.Address(),
|
||||
Handler: httpHandler,
|
||||
// ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击
|
||||
ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second,
|
||||
// IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源
|
||||
IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second,
|
||||
// 注意:不设置 WriteTimeout,因为流式响应可能持续十几分钟
|
||||
// 不设置 ReadTimeout,因为大请求体可能需要较长时间读取
|
||||
}
|
||||
server.Handler = httpHandler
|
||||
return server
|
||||
}
|
||||
|
||||
func derefInt64(p *int64) int64 {
|
||||
|
||||
@ -90,6 +90,9 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
// 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制
|
||||
if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 {
|
||||
clientIP := ip.GetTrustedClientIP(c)
|
||||
if cfg.TrustForwardedIPForAPIKeyACL() {
|
||||
clientIP = ip.GetClientIP(c)
|
||||
}
|
||||
allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist)
|
||||
if !allowed {
|
||||
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonIPRestriction)
|
||||
|
||||
@ -398,7 +398,7 @@ func TestAPIKeyAuthRejectsUnavailableGroup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) {
|
||||
func TestAPIKeyAuthIPRestrictionDoesNotTrustForwardedClientIPByDefault(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
@ -460,6 +460,57 @@ func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T)
|
||||
require.Equal(t, service.OpsClientBusinessLimitedReasonIPRestriction, businessLimitedReason)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthIPRestrictionCanTrustForwardedClientIPForReverseProxy(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
IPWhitelist: []string{"1.2.3.4"},
|
||||
}
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
cfg.SetTrustForwardedIPForAPIKeyACL(true)
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
require.NoError(t, router.SetTrustedProxies(nil))
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.RemoteAddr = "9.9.9.9:12345"
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
req.Header.Set("X-Forwarded-For", "1.2.3.4")
|
||||
req.Header.Set("X-Real-IP", "1.2.3.4")
|
||||
req.Header.Set("CF-Connecting-IP", "1.2.3.4")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthTouchesLastUsedOnSuccess(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
@ -31,7 +32,7 @@ func Logger() gin.HandlerFunc {
|
||||
|
||||
method := c.Request.Method
|
||||
statusCode := c.Writer.Status()
|
||||
clientIP := c.ClientIP()
|
||||
clientIP := ip.GetClientIP(c)
|
||||
protocol := c.Request.Proto
|
||||
accountID, hasAccountID := c.Request.Context().Value(ctxkey.AccountID).(int64)
|
||||
platform, _ := c.Request.Context().Value(ctxkey.Platform).(string)
|
||||
|
||||
@ -180,6 +180,37 @@ func TestLogger_AccessLogIncludesCoreFields(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_AccessLogUsesForwardedClientIP(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
sink := initMiddlewareTestLogger(t)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(Logger())
|
||||
r.GET("/api/test", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/test", nil)
|
||||
req.RemoteAddr = "104.23.251.120:443"
|
||||
req.Header.Set("CF-Connecting-IP", "203.0.113.42")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("status=%d", w.Code)
|
||||
}
|
||||
|
||||
for _, event := range sink.list() {
|
||||
if event == nil || event.Message != "http request completed" {
|
||||
continue
|
||||
}
|
||||
if got := event.Fields["client_ip"]; got != "203.0.113.42" {
|
||||
t.Fatalf("client_ip=%q, want real forwarded ip", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
t.Fatalf("access log event not found")
|
||||
}
|
||||
|
||||
func TestLogger_HealthPathSkipped(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
sink := initMiddlewareTestLogger(t)
|
||||
|
||||
@ -396,6 +396,7 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
codes.POST("/generate", h.Admin.Redeem.Generate)
|
||||
codes.DELETE("/:id", h.Admin.Redeem.Delete)
|
||||
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
|
||||
codes.POST("/batch-update", h.Admin.Redeem.BatchUpdate)
|
||||
codes.POST("/:id/expire", h.Admin.Redeem.Expire)
|
||||
}
|
||||
}
|
||||
|
||||
@ -511,7 +511,6 @@ func (s *AccountTestService) testBedrockAccountConnection(c *gin.Context, ctx co
|
||||
// testOpenAIAccountConnection tests an OpenAI account's connection
|
||||
func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account *Account, modelID string, prompt string, mode string) error {
|
||||
ctx := c.Request.Context()
|
||||
_ = prompt
|
||||
mode = normalizeAccountTestMode(mode)
|
||||
|
||||
// Default to openai.DefaultTestModel for OpenAI testing
|
||||
@ -572,14 +571,8 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Invalid base URL: %s", err.Error()))
|
||||
}
|
||||
// 账号已被探测为不支持 Responses(如 DeepSeek/Kimi 等)时,丢出明确提示。
|
||||
// 账号本身可用(网关会走 CC 直转),仅测试入口需要补齐 CC SSE 处理逻辑。
|
||||
// TODO:实现 CC 格式的账号测试路径(需专门的 CC SSE handler)。
|
||||
if !openai_compat.ShouldUseResponsesAPI(account.Extra) {
|
||||
return s.sendErrorAndEnd(c,
|
||||
"账号已被探测为不支持 OpenAI Responses API(如 DeepSeek/Kimi 等三方兼容上游),"+
|
||||
"账号本身可正常使用,但当前测试接口仅支持 Responses API 路径。请直接通过实际 API 调用验证。",
|
||||
)
|
||||
return s.testOpenAIChatCompletionsConnection(c, account, testModelID, prompt, normalizedBaseURL, authToken)
|
||||
}
|
||||
apiURL = buildOpenAIResponsesURL(normalizedBaseURL)
|
||||
} else {
|
||||
@ -654,6 +647,65 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
return s.processOpenAIStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// testOpenAIChatCompletionsConnection tests an OpenAI-compatible APIKey account
|
||||
// through the raw /v1/chat/completions endpoint.
|
||||
func (s *AccountTestService) testOpenAIChatCompletionsConnection(
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
testModelID string,
|
||||
prompt string,
|
||||
normalizedBaseURL string,
|
||||
authToken string,
|
||||
) error {
|
||||
ctx := c.Request.Context()
|
||||
apiURL := buildOpenAIChatCompletionsURL(normalizedBaseURL)
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.Flush()
|
||||
|
||||
payload := createOpenAIChatCompletionsTestPayload(testModelID, prompt)
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
s.sendEvent(c, TestEvent{Type: "test_start", Model: testModelID})
|
||||
s.sendEvent(c, TestEvent{Type: "status", Text: "正在通过 /v1/chat/completions 测试连接"})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Failed to create Chat Completions request")
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Authorization", "Bearer "+authToken)
|
||||
|
||||
proxyURL := ""
|
||||
if account.ProxyID != nil && account.Proxy != nil {
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Chat Completions API (/v1/chat/completions) request failed: %s", err.Error()))
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
s.reconcileOpenAI429State(ctx, account, resp.Header, body)
|
||||
}
|
||||
if resp.StatusCode == http.StatusUnauthorized && s.accountRepo != nil {
|
||||
errMsg := fmt.Sprintf("Chat Completions authentication failed (401): %s", string(body))
|
||||
_ = s.accountRepo.SetError(ctx, account.ID, errMsg)
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Chat Completions API (/v1/chat/completions) returned %d: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
return s.processOpenAIChatCompletionsStream(c, resp.Body)
|
||||
}
|
||||
|
||||
// testOpenAICompactConnection probes /responses/compact and persists the
|
||||
// resulting capability state on the account.
|
||||
func (s *AccountTestService) testOpenAICompactConnection(c *gin.Context, account *Account, testModelID string) error {
|
||||
@ -1256,6 +1308,24 @@ func createOpenAITestPayload(modelID string, isOAuth bool, prompt string) map[st
|
||||
return payload
|
||||
}
|
||||
|
||||
func createOpenAIChatCompletionsTestPayload(modelID string, prompt string) map[string]any {
|
||||
testPrompt := strings.TrimSpace(prompt)
|
||||
if testPrompt == "" {
|
||||
testPrompt = "hi"
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"model": modelID,
|
||||
"messages": []map[string]any{
|
||||
{
|
||||
"role": "user",
|
||||
"content": testPrompt,
|
||||
},
|
||||
},
|
||||
"stream": true,
|
||||
}
|
||||
}
|
||||
|
||||
// processClaudeStream processes the SSE stream from Claude API
|
||||
func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader) error {
|
||||
reader := bufio.NewReader(body)
|
||||
@ -1310,6 +1380,82 @@ func (s *AccountTestService) processClaudeStream(c *gin.Context, body io.Reader)
|
||||
}
|
||||
}
|
||||
|
||||
// processOpenAIChatCompletionsStream processes SSE chunks from the
|
||||
// OpenAI-compatible Chat Completions API.
|
||||
func (s *AccountTestService) processOpenAIChatCompletionsStream(c *gin.Context, body io.Reader) error {
|
||||
reader := bufio.NewReader(body)
|
||||
seenJSON := false
|
||||
seenFinish := false
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
if seenFinish {
|
||||
s.sendEvent(c, TestEvent{Type: "status", Text: "已通过 /v1/chat/completions 验证"})
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
if seenJSON {
|
||||
return s.sendErrorAndEnd(c, "Chat Completions stream from /v1/chat/completions ended before [DONE]")
|
||||
}
|
||||
return s.sendErrorAndEnd(c, "Invalid Chat Completions response from /v1/chat/completions: expected SSE JSON data")
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Chat Completions stream read error from /v1/chat/completions: %s", err.Error()))
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || !sseDataPrefix.MatchString(line) {
|
||||
continue
|
||||
}
|
||||
|
||||
jsonStr := sseDataPrefix.ReplaceAllString(line, "")
|
||||
if jsonStr == "[DONE]" {
|
||||
s.sendEvent(c, TestEvent{Type: "status", Text: "已通过 /v1/chat/completions 验证"})
|
||||
s.sendEvent(c, TestEvent{Type: "test_complete", Success: true})
|
||||
return nil
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &data); err != nil {
|
||||
return s.sendErrorAndEnd(c, "Invalid Chat Completions response from /v1/chat/completions: expected JSON data")
|
||||
}
|
||||
seenJSON = true
|
||||
|
||||
if errData, ok := data["error"].(map[string]any); ok {
|
||||
errorMsg := "Chat Completions API (/v1/chat/completions) returned an error"
|
||||
if msg, ok := errData["message"].(string); ok && msg != "" {
|
||||
errorMsg = msg
|
||||
}
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Chat Completions API (/v1/chat/completions) error: %s", errorMsg))
|
||||
}
|
||||
|
||||
choices, ok := data["choices"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, choiceValue := range choices {
|
||||
choice, ok := choiceValue.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if delta, ok := choice["delta"].(map[string]any); ok {
|
||||
if text, ok := delta["content"].(string); ok && text != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||
}
|
||||
}
|
||||
if message, ok := choice["message"].(map[string]any); ok {
|
||||
if text, ok := message["content"].(string); ok && text != "" {
|
||||
s.sendEvent(c, TestEvent{Type: "content", Text: text})
|
||||
}
|
||||
}
|
||||
if finishReason, ok := choice["finish_reason"].(string); ok && finishReason != "" {
|
||||
seenFinish = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processOpenAIStream processes the SSE stream from OpenAI Responses API
|
||||
func (s *AccountTestService) processOpenAIStream(c *gin.Context, body io.Reader) error {
|
||||
reader := bufio.NewReader(body)
|
||||
|
||||
@ -12,10 +12,12 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// --- shared test helpers ---
|
||||
@ -333,3 +335,145 @@ func TestAccountTestService_OpenAI401SetsPermanentErrorOnly(t *testing.T) {
|
||||
require.Zero(t, repo.clearedErrorID)
|
||||
require.Nil(t, account.RateLimitResetAt)
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAIAPIKeyResponsesUnsupportedUsesChatCompletionsPath(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, recorder := newTestContext()
|
||||
|
||||
upstreamBody := strings.Join([]string{
|
||||
`data: {"id":"chatcmpl_test","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"pong"},"finish_reason":null}]}`,
|
||||
"",
|
||||
`data: {"id":"chatcmpl_test","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
|
||||
"",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamBody)),
|
||||
}}
|
||||
svc := &AccountTestService{
|
||||
httpUpstream: upstream,
|
||||
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
|
||||
}
|
||||
account := &Account{
|
||||
ID: 91,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://compat-upstream.example/v1",
|
||||
},
|
||||
Extra: map[string]any{openai_compat.ExtraKeyResponsesSupported: false},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "hello", "")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, upstream.lastReq)
|
||||
require.Equal(t, "https://compat-upstream.example/v1/chat/completions", upstream.lastReq.URL.String())
|
||||
require.Equal(t, "Bearer sk-test", upstream.lastReq.Header.Get("Authorization"))
|
||||
require.Equal(t, "text/event-stream", upstream.lastReq.Header.Get("Accept"))
|
||||
require.Equal(t, "gpt-5.4", gjson.GetBytes(upstream.lastBody, "model").String())
|
||||
require.True(t, gjson.GetBytes(upstream.lastBody, "stream").Bool())
|
||||
require.Equal(t, "hello", gjson.GetBytes(upstream.lastBody, "messages.0.content").String())
|
||||
require.False(t, gjson.GetBytes(upstream.lastBody, "input").Exists())
|
||||
body := recorder.Body.String()
|
||||
require.Contains(t, body, "pong")
|
||||
require.Contains(t, body, "已通过 /v1/chat/completions 验证")
|
||||
require.Contains(t, body, `"success":true`)
|
||||
require.NotContains(t, body, "当前测试接口仅支持 Responses API 路径")
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAIChatCompletionsPathReturns4xx(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, recorder := newTestContext()
|
||||
|
||||
upstream := &httpUpstreamRecorder{resp: newJSONResponse(http.StatusBadRequest, `{"error":{"message":"bad request"}}`)}
|
||||
svc := &AccountTestService{
|
||||
httpUpstream: upstream,
|
||||
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
|
||||
}
|
||||
account := &Account{
|
||||
ID: 92,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://compat-upstream.example",
|
||||
},
|
||||
Extra: map[string]any{openai_compat.ExtraKeyResponsesSupported: false},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "https://compat-upstream.example/v1/chat/completions", upstream.lastReq.URL.String())
|
||||
require.Contains(t, err.Error(), "Chat Completions API (/v1/chat/completions) returned 400")
|
||||
require.Contains(t, recorder.Body.String(), "/v1/chat/completions")
|
||||
require.NotContains(t, recorder.Body.String(), `"success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAIChatCompletionsPathTimeout(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, recorder := newTestContext()
|
||||
|
||||
upstream := &httpUpstreamRecorder{err: context.DeadlineExceeded}
|
||||
svc := &AccountTestService{
|
||||
httpUpstream: upstream,
|
||||
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
|
||||
}
|
||||
account := &Account{
|
||||
ID: 93,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://compat-upstream.example",
|
||||
},
|
||||
Extra: map[string]any{openai_compat.ExtraKeyResponsesSupported: false},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "https://compat-upstream.example/v1/chat/completions", upstream.lastReq.URL.String())
|
||||
require.Contains(t, err.Error(), "Chat Completions API (/v1/chat/completions) request failed")
|
||||
require.Contains(t, err.Error(), context.DeadlineExceeded.Error())
|
||||
require.Contains(t, recorder.Body.String(), "/v1/chat/completions")
|
||||
require.NotContains(t, recorder.Body.String(), `"success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAIChatCompletionsPathRejectsNonJSONStream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, recorder := newTestContext()
|
||||
|
||||
upstream := &httpUpstreamRecorder{resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
Body: io.NopCloser(strings.NewReader("data: not-json\n\n")),
|
||||
}}
|
||||
svc := &AccountTestService{
|
||||
httpUpstream: upstream,
|
||||
cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
|
||||
}
|
||||
account := &Account{
|
||||
ID: 94,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://compat-upstream.example",
|
||||
},
|
||||
Extra: map[string]any{openai_compat.ExtraKeyResponsesSupported: false},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4", "", "")
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "https://compat-upstream.example/v1/chat/completions", upstream.lastReq.URL.String())
|
||||
require.Contains(t, err.Error(), "Invalid Chat Completions response from /v1/chat/completions")
|
||||
require.Contains(t, recorder.Body.String(), "/v1/chat/completions")
|
||||
require.NotContains(t, recorder.Body.String(), `"success":true`)
|
||||
}
|
||||
|
||||
@ -531,6 +531,7 @@ type adminServiceImpl struct {
|
||||
defaultSubAssigner DefaultSubscriptionAssigner
|
||||
userSubRepo UserSubscriptionRepository
|
||||
privacyClientFactory PrivacyClientFactory
|
||||
runtimeBlocker AccountRuntimeBlocker
|
||||
}
|
||||
|
||||
type userGroupRateBatchReader interface {
|
||||
@ -556,6 +557,7 @@ func NewAdminService(
|
||||
defaultSubAssigner DefaultSubscriptionAssigner,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
privacyClientFactory PrivacyClientFactory,
|
||||
runtimeBlocker AccountRuntimeBlocker,
|
||||
) AdminService {
|
||||
return &adminServiceImpl{
|
||||
userRepo: userRepo,
|
||||
@ -575,6 +577,7 @@ func NewAdminService(
|
||||
defaultSubAssigner: defaultSubAssigner,
|
||||
userSubRepo: userSubRepo,
|
||||
privacyClientFactory: privacyClientFactory,
|
||||
runtimeBlocker: runtimeBlocker,
|
||||
}
|
||||
}
|
||||
|
||||
@ -2791,6 +2794,9 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac
|
||||
if err := s.accountRepo.ClearTempUnschedulable(ctx, id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.runtimeBlocker != nil {
|
||||
s.runtimeBlocker.ClearAccountSchedulingBlock(id)
|
||||
}
|
||||
return s.accountRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
|
||||
@ -70,7 +70,8 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
|
||||
TempUnschedulableReason: "missing refresh token",
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
blocker := &runtimeBlockRecorder{}
|
||||
svc := &adminServiceImpl{accountRepo: repo, runtimeBlocker: blocker}
|
||||
|
||||
updated, err := svc.ClearAccountError(context.Background(), 31)
|
||||
require.NoError(t, err)
|
||||
@ -83,4 +84,5 @@ func TestAdminService_ClearAccountError_AlsoClearsRecoverableRuntimeState(t *tes
|
||||
require.Nil(t, updated.RateLimitResetAt)
|
||||
require.Nil(t, updated.TempUnschedulableUntil)
|
||||
require.Empty(t, updated.TempUnschedulableReason)
|
||||
require.Equal(t, []int64{31}, blocker.clearedIDs)
|
||||
}
|
||||
|
||||
@ -325,6 +325,12 @@ func (s *proxyRepoStub) ListAccountSummariesByProxyID(ctx context.Context, proxy
|
||||
type redeemRepoStub struct {
|
||||
deleteErrByID map[int64]error
|
||||
deletedIDs []int64
|
||||
|
||||
batchUpdateIDs []int64
|
||||
batchUpdateFields RedeemCodeBatchUpdateFields
|
||||
batchUpdateResult int64
|
||||
batchUpdateErr error
|
||||
batchUpdateCalled bool
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) Create(ctx context.Context, code *RedeemCode) error {
|
||||
@ -347,6 +353,19 @@ func (s *redeemRepoStub) Update(ctx context.Context, code *RedeemCode) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) BatchUpdate(ctx context.Context, ids []int64, fields RedeemCodeBatchUpdateFields) (int64, error) {
|
||||
s.batchUpdateCalled = true
|
||||
s.batchUpdateIDs = append([]int64(nil), ids...)
|
||||
s.batchUpdateFields = fields
|
||||
if s.batchUpdateErr != nil {
|
||||
return 0, s.batchUpdateErr
|
||||
}
|
||||
if s.batchUpdateResult != 0 {
|
||||
return s.batchUpdateResult, nil
|
||||
}
|
||||
return int64(len(ids)), nil
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
s.deletedIDs = append(s.deletedIDs, id)
|
||||
if s.deleteErrByID != nil {
|
||||
|
||||
@ -49,7 +49,7 @@ func (s *AuthService) loginOrRegisterVerifiedEmailOAuth(
|
||||
}
|
||||
|
||||
providerType := normalizeOAuthSignupSource(input.ProviderType)
|
||||
if providerType != "github" && providerType != "google" {
|
||||
if providerType != "github" && providerType != "google" && providerType != "oidc" {
|
||||
return nil, nil, infraerrors.BadRequest("OAUTH_PROVIDER_INVALID", "oauth provider is invalid")
|
||||
}
|
||||
providerKey := strings.TrimSpace(input.ProviderKey)
|
||||
|
||||
@ -59,6 +59,10 @@ func (s *redeemCodeRepoStub) Update(_ context.Context, code *RedeemCode) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *redeemCodeRepoStub) BatchUpdate(context.Context, []int64, RedeemCodeBatchUpdateFields) (int64, error) {
|
||||
panic("unexpected BatchUpdate call")
|
||||
}
|
||||
|
||||
func (s *redeemCodeRepoStub) Delete(context.Context, int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
@ -9,12 +9,16 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const defaultBedrockRegion = "us-east-1"
|
||||
|
||||
// featureKeyBedrockCCCompat is the key used in Channel.FeaturesConfig for Bedrock CC compatibility.
|
||||
const featureKeyBedrockCCCompat = "bedrock_cc_compat"
|
||||
|
||||
var bedrockCrossRegionPrefixes = []string{"us.", "eu.", "apac.", "jp.", "au.", "us-gov.", "global."}
|
||||
|
||||
// BedrockCrossRegionPrefix 根据 AWS Region 返回 Bedrock 跨区域推理的模型 ID 前缀
|
||||
@ -179,13 +183,16 @@ func BuildBedrockURL(region, modelID string, stream bool) string {
|
||||
// 3. 移除 Bedrock 不支持的字段(model, stream, output_format, output_config)
|
||||
// 4. 移除工具定义中的 custom 字段(Claude Code 会发送 custom: {defer_loading: true})
|
||||
// 5. 清理 cache_control 中 Bedrock 不支持的字段(scope, ttl)
|
||||
// 6. 修复 thinking 字段兼容性(Opus 4.7 仅支持 adaptive,enabled 需要 budget_tokens)
|
||||
// 7. 清理 tool_use.id / tool_use_id 中 Bedrock 不接受的字符
|
||||
func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ([]byte, error) {
|
||||
betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID)
|
||||
return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens)
|
||||
return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens, false)
|
||||
}
|
||||
|
||||
// PrepareBedrockRequestBodyWithTokens prepares a Bedrock request using pre-resolved beta tokens.
|
||||
func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string) ([]byte, error) {
|
||||
// ccCompat 启用 CC 兼容模式时额外处理 thinking 类型转换和 tool_use.id 清理。
|
||||
func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string, ccCompat bool) ([]byte, error) {
|
||||
var err error
|
||||
|
||||
// 注入 anthropic_version(Bedrock 要求)
|
||||
@ -203,6 +210,7 @@ func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("inject anthropic_beta: %w", err)
|
||||
}
|
||||
logger.LegacyPrintf("service.gateway", "[Bedrock] Injected beta tokens: %v (model=%s ccCompat=%v)", betaTokens, modelID, ccCompat)
|
||||
}
|
||||
|
||||
// 移除 model 字段(Bedrock 通过 URL 指定模型)
|
||||
@ -235,6 +243,12 @@ func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens
|
||||
// 清理 cache_control 中 Bedrock 不支持的字段
|
||||
body = sanitizeBedrockCacheControl(body, modelID)
|
||||
|
||||
// CC 兼容模式:修复 CC 发送的 Bedrock 不兼容字段
|
||||
if ccCompat {
|
||||
body = sanitizeBedrockThinking(body, modelID)
|
||||
body = sanitizeBedrockToolUseIDs(body)
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
@ -444,17 +458,17 @@ func parseAnthropicBetaHeader(header string) []string {
|
||||
}
|
||||
|
||||
// bedrockSupportedBetaTokens 是 Bedrock Invoke 支持的 beta 头白名单
|
||||
// 参考: litellm/litellm/llms/bedrock/common_utils.py (anthropic_beta_headers_config.json)
|
||||
// 参考: AWS Bedrock 官方文档 + litellm anthropic_beta_headers_config.json
|
||||
// 更新策略: 当 AWS Bedrock 新增支持的 beta token 时需同步更新此白名单
|
||||
var bedrockSupportedBetaTokens = map[string]bool{
|
||||
"computer-use-2025-01-24": true,
|
||||
"computer-use-2025-11-24": true,
|
||||
"context-1m-2025-08-07": true,
|
||||
"context-management-2025-06-27": true,
|
||||
"compact-2026-01-12": true,
|
||||
"interleaved-thinking-2025-05-14": true,
|
||||
"tool-search-tool-2025-10-19": true,
|
||||
"tool-examples-2025-10-29": true,
|
||||
"computer-use-2025-01-24": true,
|
||||
"computer-use-2025-11-24": true,
|
||||
"context-1m-2025-08-07": true,
|
||||
// "context-management-2025-06-27": false, // 无官方文档支持
|
||||
"compact-2026-01-12": true, // 官方支持,仅 InvokeModel API(Opus 4.6+)
|
||||
// "interleaved-thinking-2025-05-14": false, // 无官方文档支持
|
||||
"tool-search-tool-2025-10-19": true,
|
||||
"tool-examples-2025-10-29": true,
|
||||
}
|
||||
|
||||
// bedrockBetaTokenTransforms 定义 Bedrock Invoke 特有的 beta 头转换规则
|
||||
@ -482,11 +496,8 @@ func autoInjectBedrockBetaTokens(tokens []string, body []byte, modelID string) [
|
||||
}
|
||||
}
|
||||
|
||||
// 检测 thinking / interleaved thinking
|
||||
// 请求体中有 "thinking" 字段 → 需要 interleaved-thinking beta
|
||||
if gjson.GetBytes(body, "thinking").Exists() {
|
||||
inject("interleaved-thinking-2025-05-14")
|
||||
}
|
||||
// 注意:thinking 字段不再自动注入 interleaved-thinking-2025-05-14
|
||||
// 因为该 beta token 未在 AWS Bedrock 官方文档中确认支持
|
||||
|
||||
// 检测 computer_use 工具
|
||||
// tools 中有 type="computer_20xxxxxx" 的工具 → 需要 computer-use beta
|
||||
@ -605,3 +616,156 @@ func filterBedrockBetaTokens(tokens []string) []string {
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// bedrockToolUseIDRe 匹配 Bedrock 允许的 tool_use ID 字符(字母、数字、下划线、连字符)
|
||||
var bedrockToolUseIDRe = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
|
||||
|
||||
// isBedrockOpus47OrNewer 判断 Bedrock 模型 ID 是否为 Claude Opus 4.7 或更新版本
|
||||
// Opus 4.7 仅支持 thinking.type: "adaptive",不支持 "enabled"
|
||||
func isBedrockOpus47OrNewer(modelID string) bool {
|
||||
lower := strings.ToLower(modelID)
|
||||
if !strings.Contains(lower, "opus") {
|
||||
return false
|
||||
}
|
||||
matches := claudeVersionRe.FindStringSubmatch(lower)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
major, _ := strconv.Atoi(matches[1])
|
||||
minor, _ := strconv.Atoi(matches[2])
|
||||
return major > 4 || (major == 4 && minor >= 7)
|
||||
}
|
||||
|
||||
const defaultThinkingBudgetTokens = 10000
|
||||
|
||||
// sanitizeBedrockThinking 修复 thinking 字段的 Bedrock 兼容性问题:
|
||||
// - Opus 4.7+: 仅支持 "adaptive",将 "enabled" 转换为 "adaptive" 并移除 budget_tokens
|
||||
// - 其他模型: "enabled" 必须带 budget_tokens,缺失时补充默认值
|
||||
func sanitizeBedrockThinking(body []byte, modelID string) []byte {
|
||||
thinking := gjson.GetBytes(body, "thinking")
|
||||
if !thinking.Exists() || !thinking.IsObject() {
|
||||
return body
|
||||
}
|
||||
|
||||
thinkingType := thinking.Get("type").String()
|
||||
if thinkingType == "" {
|
||||
return body
|
||||
}
|
||||
|
||||
if isBedrockOpus47OrNewer(modelID) {
|
||||
if thinkingType == "enabled" {
|
||||
body, _ = sjson.SetBytes(body, "thinking.type", "adaptive")
|
||||
body, _ = sjson.DeleteBytes(body, "thinking.budget_tokens")
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
if thinkingType == "enabled" && !thinking.Get("budget_tokens").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "thinking.budget_tokens", defaultThinkingBudgetTokens)
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// sanitizeBedrockToolUseIDs 清理 messages 中 tool_use.id 和 tool_result.tool_use_id
|
||||
// 的非法字符。Bedrock 要求 ID 匹配 '^[a-zA-Z0-9_-]+$'。
|
||||
func sanitizeBedrockToolUseIDs(body []byte) []byte {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return body
|
||||
}
|
||||
for mi, msg := range messages.Array() {
|
||||
content := msg.Get("content")
|
||||
if !content.Exists() || !content.IsArray() {
|
||||
continue
|
||||
}
|
||||
for ci, block := range content.Array() {
|
||||
switch block.Get("type").String() {
|
||||
case "tool_use":
|
||||
body = sanitizeIDField(body, block.Get("id").String(), fmt.Sprintf("messages.%d.content.%d.id", mi, ci))
|
||||
case "tool_result":
|
||||
body = sanitizeIDField(body, block.Get("tool_use_id").String(), fmt.Sprintf("messages.%d.content.%d.tool_use_id", mi, ci))
|
||||
}
|
||||
}
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func sanitizeIDField(body []byte, id, path string) []byte {
|
||||
if id == "" {
|
||||
return body
|
||||
}
|
||||
sanitized := bedrockToolUseIDRe.ReplaceAllString(id, "_")
|
||||
if sanitized != id {
|
||||
body, _ = sjson.SetBytes(body, path, sanitized)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
const defaultCCMaxTokens = 81920
|
||||
|
||||
// sanitizeBedrockCCFields 处理 Claude Code 发送的 Bedrock 不兼容字段:
|
||||
// - 移除 service_tier(Anthropic API 专有,Bedrock 不支持)
|
||||
// - 移除 interface_geo(Anthropic API 专有,Bedrock 不支持)
|
||||
// - 移除 context_management(Anthropic API 专有,Bedrock 不支持,CC v2.1.87+ 默认携带)
|
||||
// - 注入 max_tokens 默认值 81920(CC 可能省略,Bedrock 要求必须提供)
|
||||
// - 注入 anthropic_version(CC 通过 HTTP 头发送,Bedrock 需要放在请求体中)
|
||||
func sanitizeBedrockCCFields(body []byte) []byte {
|
||||
if gjson.GetBytes(body, "service_tier").Exists() {
|
||||
body, _ = sjson.DeleteBytes(body, "service_tier")
|
||||
}
|
||||
if gjson.GetBytes(body, "interface_geo").Exists() {
|
||||
body, _ = sjson.DeleteBytes(body, "interface_geo")
|
||||
}
|
||||
if gjson.GetBytes(body, "context_management").Exists() {
|
||||
body, _ = sjson.DeleteBytes(body, "context_management")
|
||||
}
|
||||
if !gjson.GetBytes(body, "max_tokens").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "max_tokens", defaultCCMaxTokens)
|
||||
}
|
||||
if !gjson.GetBytes(body, "anthropic_version").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "anthropic_version", "bedrock-2023-05-31")
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// sanitizeBedrockCCBetaTokens 清理请求体中的 anthropic_beta 字段,只保留 Bedrock 支持的 beta token
|
||||
// CC 可能在请求体中注入了 Bedrock 不支持的 beta token(如 prompt-caching 等),导致 ValidationException
|
||||
func sanitizeBedrockCCBetaTokens(body []byte, modelID string) []byte {
|
||||
betaField := gjson.GetBytes(body, "anthropic_beta")
|
||||
if !betaField.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
var tokens []string
|
||||
if betaField.IsArray() {
|
||||
for _, t := range betaField.Array() {
|
||||
if t.Type == gjson.String {
|
||||
tokens = append(tokens, t.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
originalTokens := append([]string(nil), tokens...) // 保存原始 tokens 用于日志
|
||||
|
||||
// 复用现有的 Bedrock beta token 过滤逻辑(自动注入 + 白名单过滤 + 转换)
|
||||
// 即使 tokens 为空,也要执行自动注入(根据 body 内容补充必要的 beta token)
|
||||
tokens = autoInjectBedrockBetaTokens(tokens, body, modelID)
|
||||
tokens = filterBedrockBetaTokens(tokens)
|
||||
|
||||
if len(tokens) == 0 {
|
||||
// 所有 token 都被过滤掉,删除 anthropic_beta 字段
|
||||
body, _ = sjson.DeleteBytes(body, "anthropic_beta")
|
||||
logger.LegacyPrintf("service.gateway", "[Bedrock CC Compat] Removed all beta tokens: original=%v", originalTokens)
|
||||
} else {
|
||||
// 更新为过滤后的 token 列表
|
||||
body, _ = sjson.SetBytes(body, "anthropic_beta", tokens)
|
||||
if len(originalTokens) > 0 {
|
||||
logger.LegacyPrintf("service.gateway", "[Bedrock CC Compat] Filtered beta tokens: original=%v final=%v", originalTokens, tokens)
|
||||
} else {
|
||||
logger.LegacyPrintf("service.gateway", "[Bedrock CC Compat] Auto-injected beta tokens: %v", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
@ -216,7 +216,7 @@ func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) {
|
||||
]
|
||||
}`
|
||||
|
||||
betaHeader := "interleaved-thinking-2025-05-14, context-1m-2025-08-07, compact-2026-01-12"
|
||||
betaHeader := "context-1m-2025-08-07, compact-2026-01-12"
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", betaHeader)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -228,10 +228,9 @@ func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) {
|
||||
|
||||
// anthropic_beta 应包含所有 beta tokens
|
||||
betaArr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, betaArr, 3)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", betaArr[0].String())
|
||||
assert.Equal(t, "context-1m-2025-08-07", betaArr[1].String())
|
||||
assert.Equal(t, "compact-2026-01-12", betaArr[2].String())
|
||||
require.Len(t, betaArr, 2)
|
||||
assert.Equal(t, "context-1m-2025-08-07", betaArr[0].String())
|
||||
assert.Equal(t, "compact-2026-01-12", betaArr[1].String())
|
||||
|
||||
// output_format 应被移除,schema 内联到最后一条 user message
|
||||
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||
@ -264,29 +263,29 @@ func TestPrepareBedrockRequestBody_BetaHeader(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("single beta token", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14")
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 1)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||
assert.Equal(t, "context-1m-2025-08-07", arr[0].String())
|
||||
})
|
||||
|
||||
t.Run("multiple beta tokens with spaces", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14 , context-1m-2025-08-07 ")
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07 , compact-2026-01-12 ")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 2)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
|
||||
assert.Equal(t, "context-1m-2025-08-07", arr[0].String())
|
||||
assert.Equal(t, "compact-2026-01-12", arr[1].String())
|
||||
})
|
||||
|
||||
t.Run("json array beta header", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["interleaved-thinking-2025-05-14","context-1m-2025-08-07"]`)
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["context-1m-2025-08-07","compact-2026-01-12"]`)
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 2)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
|
||||
assert.Equal(t, "context-1m-2025-08-07", arr[0].String())
|
||||
assert.Equal(t, "compact-2026-01-12", arr[1].String())
|
||||
})
|
||||
}
|
||||
|
||||
@ -301,15 +300,15 @@ func TestParseAnthropicBetaHeader(t *testing.T) {
|
||||
|
||||
func TestFilterBedrockBetaTokens(t *testing.T) {
|
||||
t.Run("supported tokens pass through", func(t *testing.T) {
|
||||
tokens := []string{"interleaved-thinking-2025-05-14", "context-1m-2025-08-07", "compact-2026-01-12"}
|
||||
tokens := []string{"context-1m-2025-08-07", "compact-2026-01-12", "computer-use-2025-11-24"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
assert.Equal(t, tokens, result)
|
||||
})
|
||||
|
||||
t.Run("unsupported tokens are filtered out", func(t *testing.T) {
|
||||
tokens := []string{"interleaved-thinking-2025-05-14", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"}
|
||||
tokens := []string{"context-1m-2025-08-07", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
|
||||
assert.Equal(t, []string{"context-1m-2025-08-07"}, result)
|
||||
})
|
||||
|
||||
t.Run("advanced-tool-use transforms to tool-search-tool", func(t *testing.T) {
|
||||
@ -361,11 +360,11 @@ func TestPrepareBedrockRequestBody_BetaFiltering(t *testing.T) {
|
||||
|
||||
t.Run("unsupported beta tokens are filtered", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
|
||||
"interleaved-thinking-2025-05-14, output-128k-2025-02-19, files-api-2025-04-14")
|
||||
"compact-2026-01-12, output-128k-2025-02-19, files-api-2025-04-14")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 1)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||
assert.Equal(t, "compact-2026-01-12", arr[0].String())
|
||||
})
|
||||
|
||||
t.Run("advanced-tool-use transformed in full pipeline", func(t *testing.T) {
|
||||
@ -498,22 +497,17 @@ func TestResolveBedrockModelID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAutoInjectBedrockBetaTokens(t *testing.T) {
|
||||
t.Run("inject interleaved-thinking when thinking present", func(t *testing.T) {
|
||||
t.Run("no auto-inject for thinking (interleaved-thinking not supported)", func(t *testing.T) {
|
||||
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
|
||||
// interleaved-thinking-2025-05-14 已从白名单移除,不应自动注入
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
|
||||
t.Run("no duplicate when already present", func(t *testing.T) {
|
||||
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens([]string{"interleaved-thinking-2025-05-14"}, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
count := 0
|
||||
for _, t := range result {
|
||||
if t == "interleaved-thinking-2025-05-14" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, count)
|
||||
result := autoInjectBedrockBetaTokens([]string{"context-1m-2025-08-07"}, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Equal(t, []string{"context-1m-2025-08-07"}, result)
|
||||
})
|
||||
|
||||
t.Run("inject computer-use when computer tool present", func(t *testing.T) {
|
||||
@ -574,7 +568,8 @@ func TestAutoInjectBedrockBetaTokens(t *testing.T) {
|
||||
result := autoInjectBedrockBetaTokens(existing, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "context-1m-2025-08-07")
|
||||
assert.Contains(t, result, "compact-2026-01-12")
|
||||
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
|
||||
// interleaved-thinking 不再自动注入
|
||||
assert.NotContains(t, result, "interleaved-thinking-2025-05-14")
|
||||
})
|
||||
}
|
||||
|
||||
@ -588,27 +583,21 @@ func TestResolveBedrockBetaTokens(t *testing.T) {
|
||||
|
||||
t.Run("unsupported client beta tokens are filtered out", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := ResolveBedrockBetaTokens("interleaved-thinking-2025-05-14,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
|
||||
result := ResolveBedrockBetaTokens("context-1m-2025-08-07,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Equal(t, []string{"context-1m-2025-08-07"}, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) {
|
||||
t.Run("thinking in body auto-injects beta without header", func(t *testing.T) {
|
||||
t.Run("thinking in body does not auto-inject beta (not supported)", func(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
found := false
|
||||
for _, v := range arr {
|
||||
if v.String() == "interleaved-thinking-2025-05-14" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "interleaved-thinking should be auto-injected")
|
||||
// interleaved-thinking 已从白名单移除,不应自动注入
|
||||
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
|
||||
})
|
||||
|
||||
t.Run("header tokens merged with auto-injected tokens", func(t *testing.T) {
|
||||
t.Run("header tokens preserved without auto-injection", func(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07")
|
||||
require.NoError(t, err)
|
||||
@ -618,7 +607,8 @@ func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) {
|
||||
names[i] = v.String()
|
||||
}
|
||||
assert.Contains(t, names, "context-1m-2025-08-07")
|
||||
assert.Contains(t, names, "interleaved-thinking-2025-05-14")
|
||||
// interleaved-thinking 不再自动注入
|
||||
assert.NotContains(t, names, "interleaved-thinking-2025-05-14")
|
||||
})
|
||||
}
|
||||
|
||||
@ -657,3 +647,363 @@ func TestAdjustBedrockModelRegionPrefix(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBedrockOpus47OrNewer(t *testing.T) {
|
||||
tests := []struct {
|
||||
modelID string
|
||||
expect bool
|
||||
}{
|
||||
{"us.anthropic.claude-opus-4-7-v1", true},
|
||||
{"us.anthropic.claude-opus-4-6-v1", false},
|
||||
{"us.anthropic.claude-opus-4-5-20251101-v1:0", false},
|
||||
{"us.anthropic.claude-opus-5-0-v1", true},
|
||||
// Sonnet 4.7 is not Opus → false
|
||||
{"us.anthropic.claude-sonnet-4-7-v1", false},
|
||||
{"us.anthropic.claude-sonnet-4-6", false},
|
||||
// Haiku is not Opus
|
||||
{"us.anthropic.claude-haiku-4-5-20251001-v1:0", false},
|
||||
// Non-Claude models
|
||||
{"amazon.nova-pro-v1", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.modelID, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expect, isBedrockOpus47OrNewer(tt.modelID))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockThinking(t *testing.T) {
|
||||
t.Run("opus 4.7 converts enabled to adaptive", func(t *testing.T) {
|
||||
input := `{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[]}`
|
||||
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1")
|
||||
assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String())
|
||||
assert.False(t, gjson.GetBytes(result, "thinking.budget_tokens").Exists())
|
||||
})
|
||||
|
||||
t.Run("opus 4.7 keeps adaptive unchanged", func(t *testing.T) {
|
||||
input := `{"thinking":{"type":"adaptive"},"messages":[]}`
|
||||
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1")
|
||||
assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String())
|
||||
})
|
||||
|
||||
t.Run("opus 4.6 enabled without budget_tokens gets default", func(t *testing.T) {
|
||||
input := `{"thinking":{"type":"enabled"},"messages":[]}`
|
||||
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String())
|
||||
assert.Equal(t, int64(defaultThinkingBudgetTokens), gjson.GetBytes(result, "thinking.budget_tokens").Int())
|
||||
})
|
||||
|
||||
t.Run("opus 4.6 enabled with budget_tokens unchanged", func(t *testing.T) {
|
||||
input := `{"thinking":{"type":"enabled","budget_tokens":20000},"messages":[]}`
|
||||
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String())
|
||||
assert.Equal(t, int64(20000), gjson.GetBytes(result, "thinking.budget_tokens").Int())
|
||||
})
|
||||
|
||||
t.Run("no thinking field unchanged", func(t *testing.T) {
|
||||
input := `{"messages":[]}`
|
||||
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1")
|
||||
assert.JSONEq(t, input, string(result))
|
||||
})
|
||||
|
||||
t.Run("sonnet 4.6 enabled without budget_tokens gets default", func(t *testing.T) {
|
||||
input := `{"thinking":{"type":"enabled"},"messages":[]}`
|
||||
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-sonnet-4-6")
|
||||
assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String())
|
||||
assert.Equal(t, int64(defaultThinkingBudgetTokens), gjson.GetBytes(result, "thinking.budget_tokens").Int())
|
||||
})
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockToolUseIDs(t *testing.T) {
|
||||
t.Run("clean IDs unchanged", func(t *testing.T) {
|
||||
input := `{"messages":[{"role":"assistant","content":[{"type":"tool_use","id":"toolu_01AbCdEf","name":"bash","input":{}}]}]}`
|
||||
result := sanitizeBedrockToolUseIDs([]byte(input))
|
||||
assert.Equal(t, "toolu_01AbCdEf", gjson.GetBytes(result, "messages.0.content.0.id").String())
|
||||
})
|
||||
|
||||
t.Run("dots in tool_use ID replaced with underscores", func(t *testing.T) {
|
||||
input := `{"messages":[{"role":"assistant","content":[{"type":"tool_use","id":"toolu.01.Ab","name":"bash","input":{}}]}]}`
|
||||
result := sanitizeBedrockToolUseIDs([]byte(input))
|
||||
assert.Equal(t, "toolu_01_Ab", gjson.GetBytes(result, "messages.0.content.0.id").String())
|
||||
})
|
||||
|
||||
t.Run("special chars in tool_use ID sanitized", func(t *testing.T) {
|
||||
input := `{"messages":[{"role":"assistant","content":[{"type":"tool_use","id":"toolu:01@Ab#Cd","name":"bash","input":{}}]}]}`
|
||||
result := sanitizeBedrockToolUseIDs([]byte(input))
|
||||
id := gjson.GetBytes(result, "messages.0.content.0.id").String()
|
||||
assert.Regexp(t, `^[a-zA-Z0-9_-]+$`, id)
|
||||
})
|
||||
|
||||
t.Run("tool_result tool_use_id sanitized", func(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu.01.Ab","content":"ok"}]}]}`
|
||||
result := sanitizeBedrockToolUseIDs([]byte(input))
|
||||
assert.Equal(t, "toolu_01_Ab", gjson.GetBytes(result, "messages.0.content.0.tool_use_id").String())
|
||||
})
|
||||
|
||||
t.Run("mixed clean and dirty IDs", func(t *testing.T) {
|
||||
input := `{"messages":[
|
||||
{"role":"assistant","content":[{"type":"tool_use","id":"clean_id-123","name":"a","input":{}}]},
|
||||
{"role":"user","content":[{"type":"tool_result","tool_use_id":"dirty.id@456","content":"ok"}]},
|
||||
{"role":"assistant","content":[{"type":"tool_use","id":"also.dirty","name":"b","input":{}}]}
|
||||
]}`
|
||||
result := sanitizeBedrockToolUseIDs([]byte(input))
|
||||
assert.Equal(t, "clean_id-123", gjson.GetBytes(result, "messages.0.content.0.id").String())
|
||||
assert.Equal(t, "dirty_id_456", gjson.GetBytes(result, "messages.1.content.0.tool_use_id").String())
|
||||
assert.Equal(t, "also_dirty", gjson.GetBytes(result, "messages.2.content.0.id").String())
|
||||
})
|
||||
|
||||
t.Run("no messages unchanged", func(t *testing.T) {
|
||||
input := `{"system":[{"type":"text","text":"hi"}]}`
|
||||
result := sanitizeBedrockToolUseIDs([]byte(input))
|
||||
assert.JSONEq(t, input, string(result))
|
||||
})
|
||||
|
||||
t.Run("string content skipped", func(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"plain text"}]}`
|
||||
result := sanitizeBedrockToolUseIDs([]byte(input))
|
||||
assert.JSONEq(t, input, string(result))
|
||||
})
|
||||
|
||||
t.Run("empty ID skipped", func(t *testing.T) {
|
||||
input := `{"messages":[{"role":"assistant","content":[{"type":"tool_use","id":"","name":"a","input":{}}]}]}`
|
||||
result := sanitizeBedrockToolUseIDs([]byte(input))
|
||||
assert.Equal(t, "", gjson.GetBytes(result, "messages.0.content.0.id").String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockThinking_EdgeCases(t *testing.T) {
|
||||
t.Run("opus 4.7 enabled without budget_tokens converts to adaptive", func(t *testing.T) {
|
||||
input := `{"thinking":{"type":"enabled"},"messages":[]}`
|
||||
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1")
|
||||
assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String())
|
||||
assert.False(t, gjson.GetBytes(result, "thinking.budget_tokens").Exists())
|
||||
})
|
||||
|
||||
t.Run("thinking type disabled unchanged", func(t *testing.T) {
|
||||
input := `{"thinking":{"type":"disabled"},"messages":[]}`
|
||||
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1")
|
||||
assert.Equal(t, "disabled", gjson.GetBytes(result, "thinking.type").String())
|
||||
})
|
||||
|
||||
t.Run("thinking type empty string unchanged", func(t *testing.T) {
|
||||
input := `{"thinking":{"type":""},"messages":[]}`
|
||||
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1")
|
||||
assert.JSONEq(t, input, string(result))
|
||||
})
|
||||
|
||||
t.Run("thinking is not an object unchanged", func(t *testing.T) {
|
||||
input := `{"thinking":true,"messages":[]}`
|
||||
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1")
|
||||
assert.JSONEq(t, input, string(result))
|
||||
})
|
||||
|
||||
t.Run("opus 4.7 adaptive with budget_tokens preserved", func(t *testing.T) {
|
||||
input := `{"thinking":{"type":"adaptive","budget_tokens":5000},"messages":[]}`
|
||||
result := sanitizeBedrockThinking([]byte(input), "us.anthropic.claude-opus-4-7-v1")
|
||||
assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String())
|
||||
assert.Equal(t, int64(5000), gjson.GetBytes(result, "thinking.budget_tokens").Int())
|
||||
})
|
||||
|
||||
// Forward() passes parsed.Model (standard names like "claude-opus-4-7")
|
||||
t.Run("standard model name opus 4.7 converts enabled to adaptive", func(t *testing.T) {
|
||||
input := `{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[]}`
|
||||
result := sanitizeBedrockThinking([]byte(input), "claude-opus-4-7")
|
||||
assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String())
|
||||
assert.False(t, gjson.GetBytes(result, "thinking.budget_tokens").Exists())
|
||||
})
|
||||
|
||||
t.Run("standard model name opus 4.6 keeps enabled", func(t *testing.T) {
|
||||
input := `{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[]}`
|
||||
result := sanitizeBedrockThinking([]byte(input), "claude-opus-4-6")
|
||||
assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String())
|
||||
assert.Equal(t, int64(10000), gjson.GetBytes(result, "thinking.budget_tokens").Int())
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsBedrockOpus47OrNewer_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
modelID string
|
||||
expect bool
|
||||
}{
|
||||
{"anthropic.claude-opus-4-7-v1", true},
|
||||
{"us.anthropic.claude-opus-4-7-20270101-v1:0", true},
|
||||
{"", false},
|
||||
// Forward() passes parsed.Model (standard names), not Bedrock IDs
|
||||
{"claude-opus-4-7", true},
|
||||
{"claude-opus-4-6", false},
|
||||
{"claude-sonnet-4-7", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.modelID, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expect, isBedrockOpus47OrNewer(tt.modelID))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBodyWithTokens_CCCompat(t *testing.T) {
|
||||
input := `{
|
||||
"model":"claude-opus-4-6",
|
||||
"stream":true,
|
||||
"max_tokens":16384,
|
||||
"thinking":{"type":"enabled"},
|
||||
"messages":[
|
||||
{"role":"assistant","content":[{"type":"tool_use","id":"toolu.01.Ab","name":"bash","input":{}}]},
|
||||
{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu.01.Ab","content":"ok"}]}
|
||||
]
|
||||
}`
|
||||
|
||||
t.Run("ccCompat=false skips thinking and toolUseID sanitization", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBodyWithTokens([]byte(input), "us.anthropic.claude-opus-4-6-v1", nil, false)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String())
|
||||
assert.False(t, gjson.GetBytes(result, "thinking.budget_tokens").Exists())
|
||||
assert.Equal(t, "toolu.01.Ab", gjson.GetBytes(result, "messages.0.content.0.id").String())
|
||||
})
|
||||
|
||||
t.Run("ccCompat=true applies thinking fix and toolUseID sanitization (opus 4.6)", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBodyWithTokens([]byte(input), "us.anthropic.claude-opus-4-6-v1", nil, true)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String())
|
||||
assert.Equal(t, int64(defaultThinkingBudgetTokens), gjson.GetBytes(result, "thinking.budget_tokens").Int())
|
||||
assert.Equal(t, "toolu_01_Ab", gjson.GetBytes(result, "messages.0.content.0.id").String())
|
||||
assert.Equal(t, "toolu_01_Ab", gjson.GetBytes(result, "messages.1.content.0.tool_use_id").String())
|
||||
})
|
||||
|
||||
t.Run("ccCompat=true converts thinking to adaptive for opus 4.7", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBodyWithTokens([]byte(input), "us.anthropic.claude-opus-4-7-v1", nil, true)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "adaptive", gjson.GetBytes(result, "thinking.type").String())
|
||||
assert.False(t, gjson.GetBytes(result, "thinking.budget_tokens").Exists())
|
||||
assert.Equal(t, "toolu_01_Ab", gjson.GetBytes(result, "messages.0.content.0.id").String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCCFields(t *testing.T) {
|
||||
t.Run("removes service_tier and interface_geo", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-opus-4-6","service_tier":"standard","interface_geo":"us","messages":[]}`)
|
||||
result := sanitizeBedrockCCFields(body)
|
||||
assert.False(t, gjson.GetBytes(result, "service_tier").Exists())
|
||||
assert.False(t, gjson.GetBytes(result, "interface_geo").Exists())
|
||||
assert.True(t, gjson.GetBytes(result, "messages").Exists())
|
||||
})
|
||||
|
||||
t.Run("removes context_management", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-opus-4-6","context_management":{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]},"messages":[]}`)
|
||||
result := sanitizeBedrockCCFields(body)
|
||||
assert.False(t, gjson.GetBytes(result, "context_management").Exists())
|
||||
assert.True(t, gjson.GetBytes(result, "messages").Exists())
|
||||
})
|
||||
|
||||
t.Run("injects max_tokens when missing", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-opus-4-6","messages":[]}`)
|
||||
result := sanitizeBedrockCCFields(body)
|
||||
assert.Equal(t, int64(defaultCCMaxTokens), gjson.GetBytes(result, "max_tokens").Int())
|
||||
})
|
||||
|
||||
t.Run("preserves existing max_tokens", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-opus-4-6","max_tokens":4096,"messages":[]}`)
|
||||
result := sanitizeBedrockCCFields(body)
|
||||
assert.Equal(t, int64(4096), gjson.GetBytes(result, "max_tokens").Int())
|
||||
})
|
||||
|
||||
t.Run("injects anthropic_version when missing", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-opus-4-6","messages":[]}`)
|
||||
result := sanitizeBedrockCCFields(body)
|
||||
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
|
||||
})
|
||||
|
||||
t.Run("preserves existing anthropic_version", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-opus-4-6","anthropic_version":"bedrock-2023-05-31","messages":[]}`)
|
||||
result := sanitizeBedrockCCFields(body)
|
||||
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
|
||||
})
|
||||
|
||||
t.Run("no-op when fields already clean", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-opus-4-6","max_tokens":81920,"anthropic_version":"bedrock-2023-05-31","messages":[]}`)
|
||||
result := sanitizeBedrockCCFields(body)
|
||||
assert.Equal(t, int64(defaultCCMaxTokens), gjson.GetBytes(result, "max_tokens").Int())
|
||||
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
|
||||
assert.False(t, gjson.GetBytes(result, "service_tier").Exists())
|
||||
assert.False(t, gjson.GetBytes(result, "interface_geo").Exists())
|
||||
assert.False(t, gjson.GetBytes(result, "context_management").Exists())
|
||||
})
|
||||
|
||||
t.Run("full CC request sanitization", func(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model":"claude-opus-4-6",
|
||||
"service_tier":"standard",
|
||||
"interface_geo":"us",
|
||||
"context_management":{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]},
|
||||
"messages":[{"role":"user","content":"hello"}],
|
||||
"thinking":{"type":"enabled"}
|
||||
}`)
|
||||
result := sanitizeBedrockCCFields(body)
|
||||
assert.False(t, gjson.GetBytes(result, "service_tier").Exists())
|
||||
assert.False(t, gjson.GetBytes(result, "interface_geo").Exists())
|
||||
assert.False(t, gjson.GetBytes(result, "context_management").Exists())
|
||||
assert.Equal(t, int64(defaultCCMaxTokens), gjson.GetBytes(result, "max_tokens").Int())
|
||||
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
|
||||
assert.Equal(t, "enabled", gjson.GetBytes(result, "thinking.type").String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCCBetaTokens(t *testing.T) {
|
||||
t.Run("filters unsupported beta tokens", func(t *testing.T) {
|
||||
input := `{"anthropic_beta":["prompt-caching-2024-07-31","context-1m-2025-08-07","unsupported-feature"],"messages":[]}`
|
||||
result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6")
|
||||
beta := gjson.GetBytes(result, "anthropic_beta")
|
||||
assert.True(t, beta.Exists())
|
||||
assert.True(t, beta.IsArray())
|
||||
tokens := beta.Array()
|
||||
assert.Equal(t, 1, len(tokens))
|
||||
assert.Equal(t, "context-1m-2025-08-07", tokens[0].String())
|
||||
})
|
||||
|
||||
t.Run("removes anthropic_beta if all tokens filtered", func(t *testing.T) {
|
||||
input := `{"anthropic_beta":["prompt-caching-2024-07-31","unsupported-feature"],"messages":[]}`
|
||||
result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6")
|
||||
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
|
||||
})
|
||||
|
||||
t.Run("thinking alone does not auto-inject beta tokens", func(t *testing.T) {
|
||||
input := `{"anthropic_beta":[],"thinking":{"type":"enabled"},"messages":[]}`
|
||||
result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6")
|
||||
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
|
||||
})
|
||||
|
||||
t.Run("auto-injects computer-use beta token", func(t *testing.T) {
|
||||
input := `{"anthropic_beta":[],"tools":[{"type":"computer_20250124","name":"computer"}],"messages":[]}`
|
||||
result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6")
|
||||
beta := gjson.GetBytes(result, "anthropic_beta")
|
||||
assert.True(t, beta.Exists())
|
||||
tokens := beta.Array()
|
||||
assert.Equal(t, 1, len(tokens))
|
||||
assert.Equal(t, "computer-use-2025-11-24", tokens[0].String())
|
||||
})
|
||||
|
||||
t.Run("transforms advanced-tool-use to tool-search-tool", func(t *testing.T) {
|
||||
input := `{"anthropic_beta":["advanced-tool-use-2025-11-20"],"messages":[]}`
|
||||
result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6")
|
||||
beta := gjson.GetBytes(result, "anthropic_beta")
|
||||
tokens := beta.Array()
|
||||
assert.Equal(t, 2, len(tokens)) // tool-search-tool + tool-examples (auto-associated)
|
||||
assert.Contains(t, []string{tokens[0].String(), tokens[1].String()}, "tool-search-tool-2025-10-19")
|
||||
assert.Contains(t, []string{tokens[0].String(), tokens[1].String()}, "tool-examples-2025-10-29")
|
||||
})
|
||||
|
||||
t.Run("no-op when anthropic_beta not present", func(t *testing.T) {
|
||||
input := `{"messages":[]}`
|
||||
result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6")
|
||||
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
|
||||
})
|
||||
|
||||
t.Run("preserves supported beta tokens", func(t *testing.T) {
|
||||
input := `{"anthropic_beta":["computer-use-2025-11-24","context-1m-2025-08-07"],"messages":[]}`
|
||||
result := sanitizeBedrockCCBetaTokens([]byte(input), "claude-opus-4-6")
|
||||
beta := gjson.GetBytes(result, "anthropic_beta")
|
||||
tokens := beta.Array()
|
||||
assert.Equal(t, 2, len(tokens))
|
||||
assert.Contains(t, []string{tokens[0].String(), tokens[1].String()}, "computer-use-2025-11-24")
|
||||
assert.Contains(t, []string{tokens[0].String(), tokens[1].String()}, "context-1m-2025-08-07")
|
||||
})
|
||||
}
|
||||
|
||||
@ -248,6 +248,17 @@ func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// IsBedrockCCCompatEnabled 返回该渠道是否启用了 Bedrock CC 兼容模式。
|
||||
// 一旦启用,该渠道下所有请求都会应用 CC 兼容转换,不区分账号 platform。
|
||||
func (c *Channel) IsBedrockCCCompatEnabled(platform string) bool {
|
||||
if c == nil || c.FeaturesConfig == nil {
|
||||
return false
|
||||
}
|
||||
// 直接检查 bedrock_cc_compat 开关,不再检查 platform 子字段
|
||||
enabled, ok := c.FeaturesConfig[featureKeyBedrockCCCompat].(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
// deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution.
|
||||
func deepCopyFeaturesConfig(src map[string]any) map[string]any {
|
||||
dst := make(map[string]any, len(src))
|
||||
|
||||
73
backend/internal/service/channel_bedrock_cc_test.go
Normal file
73
backend/internal/service/channel_bedrock_cc_test.go
Normal file
@ -0,0 +1,73 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestChannel_IsBedrockCCCompatEnabled_Enabled(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyBedrockCCCompat: true,
|
||||
},
|
||||
}
|
||||
require.True(t, c.IsBedrockCCCompatEnabled("bedrock"))
|
||||
}
|
||||
|
||||
func TestChannel_IsBedrockCCCompatEnabled_AppliesToAllPlatforms(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyBedrockCCCompat: true,
|
||||
},
|
||||
}
|
||||
require.True(t, c.IsBedrockCCCompatEnabled("anthropic"))
|
||||
require.True(t, c.IsBedrockCCCompatEnabled("openai"))
|
||||
require.True(t, c.IsBedrockCCCompatEnabled(""))
|
||||
}
|
||||
|
||||
func TestChannel_IsBedrockCCCompatEnabled_Disabled(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyBedrockCCCompat: false,
|
||||
},
|
||||
}
|
||||
require.False(t, c.IsBedrockCCCompatEnabled("bedrock"))
|
||||
}
|
||||
|
||||
func TestChannel_IsBedrockCCCompatEnabled_NilFeaturesConfig(t *testing.T) {
|
||||
c := &Channel{FeaturesConfig: nil}
|
||||
require.False(t, c.IsBedrockCCCompatEnabled("bedrock"))
|
||||
}
|
||||
|
||||
func TestChannel_IsBedrockCCCompatEnabled_NilChannel(t *testing.T) {
|
||||
var c *Channel
|
||||
require.False(t, c.IsBedrockCCCompatEnabled("bedrock"))
|
||||
}
|
||||
|
||||
func TestChannel_IsBedrockCCCompatEnabled_WrongType(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyBedrockCCCompat: "yes",
|
||||
},
|
||||
}
|
||||
require.False(t, c.IsBedrockCCCompatEnabled("bedrock"))
|
||||
}
|
||||
|
||||
func TestChannel_IsBedrockCCCompatEnabled_OldMapFormat(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
featureKeyBedrockCCCompat: map[string]any{"bedrock": true},
|
||||
},
|
||||
}
|
||||
require.False(t, c.IsBedrockCCCompatEnabled("bedrock"))
|
||||
}
|
||||
|
||||
func TestChannel_IsBedrockCCCompatEnabled_MissingKey(t *testing.T) {
|
||||
c := &Channel{
|
||||
FeaturesConfig: map[string]any{
|
||||
"other_feature": true,
|
||||
},
|
||||
}
|
||||
require.False(t, c.IsBedrockCCCompatEnabled("bedrock"))
|
||||
}
|
||||
@ -281,9 +281,54 @@ func callProvider(ctx context.Context, provider, endpoint, apiKey, model, prompt
|
||||
if err != nil {
|
||||
return "", "", status, err
|
||||
}
|
||||
if provider == MonitorProviderOpenAI && apiMode == MonitorAPIModeResponses {
|
||||
return extractOpenAIResponsesText(respBytes), string(respBytes), status, nil
|
||||
}
|
||||
return gjson.GetBytes(respBytes, adapter.textPath).String(), string(respBytes), status, nil
|
||||
}
|
||||
|
||||
// extractOpenAIResponsesText 聚合 Responses API 的最终 assistant 文本。
|
||||
// Responses 的 output 数组顺序由模型决定:reasoning / tool-call item 可能排在 message 前面,
|
||||
// 因此不能假设文本永远在 output.0.content.0.text。
|
||||
func extractOpenAIResponsesText(respBytes []byte) string {
|
||||
if text := gjson.GetBytes(respBytes, "output_text").String(); strings.TrimSpace(text) != "" {
|
||||
return text
|
||||
}
|
||||
|
||||
var texts []string
|
||||
outputs := gjson.GetBytes(respBytes, "output")
|
||||
if outputs.IsArray() {
|
||||
outputs.ForEach(func(_, output gjson.Result) bool {
|
||||
outputType := output.Get("type").String()
|
||||
if outputType != "" && outputType != "message" {
|
||||
return true
|
||||
}
|
||||
|
||||
content := output.Get("content")
|
||||
if !content.IsArray() {
|
||||
return true
|
||||
}
|
||||
|
||||
content.ForEach(func(_, block gjson.Result) bool {
|
||||
blockType := block.Get("type").String()
|
||||
if blockType != "" && blockType != "output_text" {
|
||||
return true
|
||||
}
|
||||
if text := block.Get("text").String(); strings.TrimSpace(text) != "" {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
if len(texts) > 0 {
|
||||
return strings.Join(texts, "")
|
||||
}
|
||||
return gjson.GetBytes(respBytes, providerOpenAIResponsesAdapter.textPath).String()
|
||||
}
|
||||
|
||||
// mergeHeaders 把用户自定义 headers 合并到 adapter 默认 headers 上。
|
||||
// 用户值覆盖默认;命中黑名单(hop-by-hop / 由 http.Client 自管的)的 key 静默丢弃。
|
||||
func mergeHeaders(base map[string]string, opts *CheckOptions) map[string]string {
|
||||
|
||||
@ -60,10 +60,11 @@ func setupFakeAnthropic(t *testing.T, handler *captureHandler) string {
|
||||
}
|
||||
|
||||
type openAICaptureHandler struct {
|
||||
lastBody map[string]any
|
||||
lastHeaders http.Header
|
||||
lastPath string
|
||||
status int
|
||||
lastBody map[string]any
|
||||
lastHeaders http.Header
|
||||
lastPath string
|
||||
status int
|
||||
responsesLeadingReasoning bool
|
||||
}
|
||||
|
||||
func (h *openAICaptureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
@ -82,10 +83,23 @@ func (h *openAICaptureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
answer := answerFromOpenAIRequest(parsed)
|
||||
if h.lastPath == providerOpenAIResponsesPath {
|
||||
output := []map[string]any{}
|
||||
if h.responsesLeadingReasoning {
|
||||
output = append(output, map[string]any{
|
||||
"type": "reasoning",
|
||||
"summary": []any{},
|
||||
})
|
||||
}
|
||||
output = append(output, map[string]any{
|
||||
"type": "message",
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": []map[string]any{
|
||||
{"type": "output_text", "text": answer},
|
||||
},
|
||||
})
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"output": []map[string]any{{
|
||||
"content": []map[string]any{{"type": "output_text", "text": answer}},
|
||||
}},
|
||||
"output": output,
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -212,6 +226,22 @@ func TestRunCheckForModel_OpenAIResponses_DefaultRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCheckForModel_OpenAIResponses_SkipsLeadingReasoningItem(t *testing.T) {
|
||||
h := &openAICaptureHandler{responsesLeadingReasoning: true}
|
||||
endpoint := setupFakeOpenAI(t, h)
|
||||
|
||||
res := runCheckForModel(context.Background(), MonitorProviderOpenAI, endpoint, "sk-openai", "gpt-5.5", &CheckOptions{
|
||||
APIMode: MonitorAPIModeResponses,
|
||||
})
|
||||
|
||||
if res.Status != MonitorStatusOperational {
|
||||
t.Fatalf("responses request should find text after leading reasoning item, got status=%s message=%q", res.Status, res.Message)
|
||||
}
|
||||
if h.lastPath != providerOpenAIResponsesPath {
|
||||
t.Fatalf("expected responses path %q, got %q", providerOpenAIResponsesPath, h.lastPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCheckForModel_OpenAIResponsesReplaceMissingInstructionsFailsLocally(t *testing.T) {
|
||||
h := &openAICaptureHandler{}
|
||||
endpoint := setupFakeOpenAI(t, h)
|
||||
|
||||
@ -3,13 +3,17 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// ConcurrencyCache 定义并发控制的缓存接口
|
||||
@ -79,18 +83,50 @@ func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error
|
||||
}
|
||||
|
||||
const (
|
||||
// Default extra wait slots beyond concurrency limit
|
||||
// 默认等待队列额外槽位
|
||||
defaultExtraWaitSlots = 20
|
||||
|
||||
defaultAccountLoadBatchCacheTTL = 200 * time.Millisecond
|
||||
accountLoadBatchFetchTimeout = 3 * time.Second
|
||||
maxAccountLoadBatchCacheEntries = 256
|
||||
)
|
||||
|
||||
// ConcurrencyService manages concurrent request limiting for accounts and users
|
||||
// ConcurrencyService 管理账号和用户的并发限制。
|
||||
type ConcurrencyService struct {
|
||||
cache ConcurrencyCache
|
||||
|
||||
accountLoadCacheTTL atomic.Int64
|
||||
accountLoadCacheMu sync.RWMutex
|
||||
accountLoadCache map[string]cachedAccountLoadBatch
|
||||
accountLoadGroup singleflight.Group
|
||||
}
|
||||
|
||||
// NewConcurrencyService creates a new ConcurrencyService
|
||||
type cachedAccountLoadBatch struct {
|
||||
loadMap map[int64]*AccountLoadInfo
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// NewConcurrencyService 创建并发控制服务。
|
||||
func NewConcurrencyService(cache ConcurrencyCache) *ConcurrencyService {
|
||||
return &ConcurrencyService{cache: cache}
|
||||
svc := &ConcurrencyService{
|
||||
cache: cache,
|
||||
accountLoadCache: make(map[string]cachedAccountLoadBatch),
|
||||
}
|
||||
svc.SetAccountLoadBatchCacheTTL(defaultAccountLoadBatchCacheTTL)
|
||||
return svc
|
||||
}
|
||||
|
||||
// SetAccountLoadBatchCacheTTL 设置账号负载批量读取的极短 TTL 缓存;非正数表示禁用缓存。
|
||||
func (s *ConcurrencyService) SetAccountLoadBatchCacheTTL(ttl time.Duration) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.accountLoadCacheTTL.Store(int64(ttl))
|
||||
if ttl <= 0 {
|
||||
s.accountLoadCacheMu.Lock()
|
||||
s.accountLoadCache = make(map[string]cachedAccountLoadBatch)
|
||||
s.accountLoadCacheMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// AcquireResult represents the result of acquiring a concurrency slot
|
||||
@ -284,12 +320,140 @@ func CalculateMaxWait(userConcurrency int) int {
|
||||
return userConcurrency + defaultExtraWaitSlots
|
||||
}
|
||||
|
||||
// GetAccountsLoadBatch returns load info for multiple accounts.
|
||||
// GetAccountsLoadBatch 批量获取账号负载信息。
|
||||
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
return s.getAccountsLoadBatch(ctx, accounts, true)
|
||||
}
|
||||
|
||||
// GetAccountsLoadBatchFresh 绕过极短 TTL 缓存,用于抢槽失败后的实时刷新兜底。
|
||||
func (s *ConcurrencyService) GetAccountsLoadBatchFresh(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
return s.getAccountsLoadBatch(ctx, accounts, false)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyService) getAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency, allowCache bool) (map[int64]*AccountLoadInfo, error) {
|
||||
if len(accounts) == 0 {
|
||||
return map[int64]*AccountLoadInfo{}, nil
|
||||
}
|
||||
if s.cache == nil {
|
||||
return map[int64]*AccountLoadInfo{}, nil
|
||||
}
|
||||
return s.cache.GetAccountsLoadBatch(ctx, accounts)
|
||||
|
||||
ttl := time.Duration(s.accountLoadCacheTTL.Load())
|
||||
if !allowCache || ttl <= 0 {
|
||||
return s.fetchAccountsLoadBatch(ctx, accounts)
|
||||
}
|
||||
|
||||
key := accountLoadBatchCacheKey(accounts)
|
||||
if cached, ok := s.getCachedAccountLoadBatch(key, time.Now()); ok {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
value, err, _ := s.accountLoadGroup.Do(key, func() (any, error) {
|
||||
now := time.Now()
|
||||
if cached, ok := s.getCachedAccountLoadBatch(key, now); ok {
|
||||
return cached, nil
|
||||
}
|
||||
loadMap, fetchErr := s.fetchAccountsLoadBatch(ctx, accounts)
|
||||
if fetchErr != nil {
|
||||
return nil, fetchErr
|
||||
}
|
||||
cached := cloneAccountLoadMap(loadMap)
|
||||
s.storeCachedAccountLoadBatch(key, cached, now.Add(ttl))
|
||||
return cached, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
loadMap, _ := value.(map[int64]*AccountLoadInfo)
|
||||
if loadMap == nil {
|
||||
return map[int64]*AccountLoadInfo{}, nil
|
||||
}
|
||||
return loadMap, nil
|
||||
}
|
||||
|
||||
func (s *ConcurrencyService) fetchAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
if s.cache == nil {
|
||||
return map[int64]*AccountLoadInfo{}, nil
|
||||
}
|
||||
baseCtx := context.Background()
|
||||
if ctx != nil {
|
||||
baseCtx = context.WithoutCancel(ctx)
|
||||
}
|
||||
redisCtx, cancel := context.WithTimeout(baseCtx, accountLoadBatchFetchTimeout)
|
||||
defer cancel()
|
||||
return s.cache.GetAccountsLoadBatch(redisCtx, accounts)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyService) getCachedAccountLoadBatch(key string, now time.Time) (map[int64]*AccountLoadInfo, bool) {
|
||||
s.accountLoadCacheMu.RLock()
|
||||
cached, ok := s.accountLoadCache[key]
|
||||
s.accountLoadCacheMu.RUnlock()
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if !now.Before(cached.expiresAt) {
|
||||
s.accountLoadCacheMu.Lock()
|
||||
if current, exists := s.accountLoadCache[key]; exists && !now.Before(current.expiresAt) {
|
||||
delete(s.accountLoadCache, key)
|
||||
}
|
||||
s.accountLoadCacheMu.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
return cached.loadMap, true
|
||||
}
|
||||
|
||||
func (s *ConcurrencyService) storeCachedAccountLoadBatch(key string, loadMap map[int64]*AccountLoadInfo, expiresAt time.Time) {
|
||||
s.accountLoadCacheMu.Lock()
|
||||
if s.accountLoadCache == nil {
|
||||
s.accountLoadCache = make(map[string]cachedAccountLoadBatch)
|
||||
}
|
||||
if len(s.accountLoadCache) >= maxAccountLoadBatchCacheEntries {
|
||||
now := time.Now()
|
||||
for cacheKey, cached := range s.accountLoadCache {
|
||||
if !now.Before(cached.expiresAt) {
|
||||
delete(s.accountLoadCache, cacheKey)
|
||||
}
|
||||
}
|
||||
for len(s.accountLoadCache) >= maxAccountLoadBatchCacheEntries {
|
||||
for cacheKey := range s.accountLoadCache {
|
||||
delete(s.accountLoadCache, cacheKey)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
s.accountLoadCache[key] = cachedAccountLoadBatch{
|
||||
loadMap: loadMap,
|
||||
expiresAt: expiresAt,
|
||||
}
|
||||
s.accountLoadCacheMu.Unlock()
|
||||
}
|
||||
|
||||
func accountLoadBatchCacheKey(accounts []AccountWithConcurrency) string {
|
||||
hash := sha256.New()
|
||||
var buf [16]byte
|
||||
for _, account := range accounts {
|
||||
binary.LittleEndian.PutUint64(buf[:8], uint64(account.ID))
|
||||
binary.LittleEndian.PutUint64(buf[8:], uint64(int64(account.MaxConcurrency)))
|
||||
_, _ = hash.Write(buf[:])
|
||||
}
|
||||
sum := hash.Sum(nil)
|
||||
return strconv.Itoa(len(accounts)) + ":" + hex.EncodeToString(sum)
|
||||
}
|
||||
|
||||
func cloneAccountLoadMap(loadMap map[int64]*AccountLoadInfo) map[int64]*AccountLoadInfo {
|
||||
if len(loadMap) == 0 {
|
||||
return map[int64]*AccountLoadInfo{}
|
||||
}
|
||||
clone := make(map[int64]*AccountLoadInfo, len(loadMap))
|
||||
for accountID, loadInfo := range loadMap {
|
||||
if loadInfo == nil {
|
||||
clone[accountID] = nil
|
||||
continue
|
||||
}
|
||||
copied := *loadInfo
|
||||
clone[accountID] = &copied
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
// GetUsersLoadBatch returns load info for multiple users.
|
||||
|
||||
@ -7,7 +7,9 @@ import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -32,6 +34,7 @@ type stubConcurrencyCacheForTest struct {
|
||||
// 记录调用
|
||||
releasedAccountIDs []int64
|
||||
releasedRequestIDs []string
|
||||
loadBatchCalls atomic.Int64
|
||||
}
|
||||
|
||||
var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil)
|
||||
@ -82,6 +85,7 @@ func (c *stubConcurrencyCacheForTest) DecrementWaitCount(_ context.Context, _ in
|
||||
return nil
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
c.loadBatchCalls.Add(1)
|
||||
return c.loadBatch, c.loadBatchErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
|
||||
@ -237,6 +241,47 @@ func TestGetAccountsLoadBatch_NilCache(t *testing.T) {
|
||||
require.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestGetAccountsLoadBatch_UsesShortTTLCache(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{
|
||||
loadBatch: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, CurrentConcurrency: 1, LoadRate: 20},
|
||||
},
|
||||
}
|
||||
svc := NewConcurrencyService(cache)
|
||||
svc.SetAccountLoadBatchCacheTTL(time.Second)
|
||||
|
||||
accounts := []AccountWithConcurrency{{ID: 1, MaxConcurrency: 5}}
|
||||
first, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, first[int64(1)].CurrentConcurrency)
|
||||
|
||||
cache.loadBatch[1] = &AccountLoadInfo{AccountID: 1, CurrentConcurrency: 4, LoadRate: 80}
|
||||
second, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, second[int64(1)].CurrentConcurrency)
|
||||
require.Equal(t, int64(1), cache.loadBatchCalls.Load())
|
||||
}
|
||||
|
||||
func TestGetAccountsLoadBatchFresh_BypassesShortTTLCache(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{
|
||||
loadBatch: map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, CurrentConcurrency: 1, LoadRate: 20},
|
||||
},
|
||||
}
|
||||
svc := NewConcurrencyService(cache)
|
||||
svc.SetAccountLoadBatchCacheTTL(time.Second)
|
||||
|
||||
accounts := []AccountWithConcurrency{{ID: 1, MaxConcurrency: 5}}
|
||||
_, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
|
||||
require.NoError(t, err)
|
||||
|
||||
cache.loadBatch[1] = &AccountLoadInfo{AccountID: 1, CurrentConcurrency: 4, LoadRate: 80}
|
||||
fresh, err := svc.GetAccountsLoadBatchFresh(context.Background(), accounts)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 4, fresh[int64(1)].CurrentConcurrency)
|
||||
require.Equal(t, int64(2), cache.loadBatchCalls.Load())
|
||||
}
|
||||
|
||||
func TestIncrementWaitCount_Success(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{waitAllowed: true}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
@ -44,6 +44,10 @@ const (
|
||||
ContentModerationKeywordModeKeywordAndAPI = "keyword_and_api"
|
||||
ContentModerationKeywordModeAPIOnly = "api_only"
|
||||
|
||||
ContentModerationModelFilterAll = "all"
|
||||
ContentModerationModelFilterInclude = "include"
|
||||
ContentModerationModelFilterExclude = "exclude"
|
||||
|
||||
ContentModerationProtocolAnthropicMessages = "anthropic_messages"
|
||||
ContentModerationProtocolOpenAIResponses = "openai_responses"
|
||||
ContentModerationProtocolOpenAIChat = "openai_chat_completions"
|
||||
@ -80,6 +84,8 @@ const (
|
||||
maxContentModerationTestImageDataURLBytes = 12 * 1024 * 1024
|
||||
maxContentModerationBlockedKeywords = 10000
|
||||
maxContentModerationBlockedKeywordRunes = 200
|
||||
maxContentModerationModelFilterModels = 1000
|
||||
maxContentModerationModelFilterRunes = 200
|
||||
|
||||
contentModerationCleanupInterval = 24 * time.Hour
|
||||
contentModerationCleanupTimeout = 30 * time.Minute
|
||||
@ -127,32 +133,33 @@ func ContentModerationCategories() []string {
|
||||
}
|
||||
|
||||
type ContentModerationConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Model string `json:"model"`
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
APIKeys []string `json:"api_keys,omitempty"`
|
||||
TimeoutMS int `json:"timeout_ms"`
|
||||
SampleRate int `json:"sample_rate"`
|
||||
AllGroups bool `json:"all_groups"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
RecordNonHits bool `json:"record_non_hits"`
|
||||
Thresholds map[string]float64 `json:"thresholds"`
|
||||
WorkerCount int `json:"worker_count"`
|
||||
QueueSize int `json:"queue_size"`
|
||||
BlockStatus int `json:"block_status"`
|
||||
BlockMessage string `json:"block_message"`
|
||||
EmailOnHit bool `json:"email_on_hit"`
|
||||
AutoBanEnabled bool `json:"auto_ban_enabled"`
|
||||
BanThreshold int `json:"ban_threshold"`
|
||||
ViolationWindowHours int `json:"violation_window_hours"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
HitRetentionDays int `json:"hit_retention_days"`
|
||||
NonHitRetentionDays int `json:"non_hit_retention_days"`
|
||||
PreHashCheckEnabled bool `json:"pre_hash_check_enabled"`
|
||||
BlockedKeywords []string `json:"blocked_keywords"`
|
||||
KeywordBlockingMode string `json:"keyword_blocking_mode"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Model string `json:"model"`
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
APIKeys []string `json:"api_keys,omitempty"`
|
||||
TimeoutMS int `json:"timeout_ms"`
|
||||
SampleRate int `json:"sample_rate"`
|
||||
AllGroups bool `json:"all_groups"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
RecordNonHits bool `json:"record_non_hits"`
|
||||
Thresholds map[string]float64 `json:"thresholds"`
|
||||
WorkerCount int `json:"worker_count"`
|
||||
QueueSize int `json:"queue_size"`
|
||||
BlockStatus int `json:"block_status"`
|
||||
BlockMessage string `json:"block_message"`
|
||||
EmailOnHit bool `json:"email_on_hit"`
|
||||
AutoBanEnabled bool `json:"auto_ban_enabled"`
|
||||
BanThreshold int `json:"ban_threshold"`
|
||||
ViolationWindowHours int `json:"violation_window_hours"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
HitRetentionDays int `json:"hit_retention_days"`
|
||||
NonHitRetentionDays int `json:"non_hit_retention_days"`
|
||||
PreHashCheckEnabled bool `json:"pre_hash_check_enabled"`
|
||||
BlockedKeywords []string `json:"blocked_keywords"`
|
||||
KeywordBlockingMode string `json:"keyword_blocking_mode"`
|
||||
ModelFilter ContentModerationModelFilter `json:"model_filter"`
|
||||
}
|
||||
|
||||
type ContentModerationConfigView struct {
|
||||
@ -184,6 +191,7 @@ type ContentModerationConfigView struct {
|
||||
PreHashCheckEnabled bool `json:"pre_hash_check_enabled"`
|
||||
BlockedKeywords []string `json:"blocked_keywords"`
|
||||
KeywordBlockingMode string `json:"keyword_blocking_mode"`
|
||||
ModelFilter ContentModerationModelFilter `json:"model_filter"`
|
||||
}
|
||||
|
||||
type ContentModerationAPIKeyStatus struct {
|
||||
@ -227,34 +235,40 @@ type ContentModerationTestAuditResult struct {
|
||||
}
|
||||
|
||||
type UpdateContentModerationConfigInput struct {
|
||||
Enabled *bool `json:"enabled"`
|
||||
Mode *string `json:"mode"`
|
||||
BaseURL *string `json:"base_url"`
|
||||
Model *string `json:"model"`
|
||||
APIKey *string `json:"api_key"`
|
||||
APIKeys *[]string `json:"api_keys"`
|
||||
APIKeysMode string `json:"api_keys_mode"`
|
||||
DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"`
|
||||
ClearAPIKey bool `json:"clear_api_key"`
|
||||
TimeoutMS *int `json:"timeout_ms"`
|
||||
SampleRate *int `json:"sample_rate"`
|
||||
AllGroups *bool `json:"all_groups"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
RecordNonHits *bool `json:"record_non_hits"`
|
||||
WorkerCount *int `json:"worker_count"`
|
||||
QueueSize *int `json:"queue_size"`
|
||||
BlockStatus *int `json:"block_status"`
|
||||
BlockMessage *string `json:"block_message"`
|
||||
EmailOnHit *bool `json:"email_on_hit"`
|
||||
AutoBanEnabled *bool `json:"auto_ban_enabled"`
|
||||
BanThreshold *int `json:"ban_threshold"`
|
||||
ViolationWindowHours *int `json:"violation_window_hours"`
|
||||
RetryCount *int `json:"retry_count"`
|
||||
HitRetentionDays *int `json:"hit_retention_days"`
|
||||
NonHitRetentionDays *int `json:"non_hit_retention_days"`
|
||||
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
|
||||
BlockedKeywords *[]string `json:"blocked_keywords"`
|
||||
KeywordBlockingMode *string `json:"keyword_blocking_mode"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
Mode *string `json:"mode"`
|
||||
BaseURL *string `json:"base_url"`
|
||||
Model *string `json:"model"`
|
||||
APIKey *string `json:"api_key"`
|
||||
APIKeys *[]string `json:"api_keys"`
|
||||
APIKeysMode string `json:"api_keys_mode"`
|
||||
DeleteAPIKeyHashes *[]string `json:"delete_api_key_hashes"`
|
||||
ClearAPIKey bool `json:"clear_api_key"`
|
||||
TimeoutMS *int `json:"timeout_ms"`
|
||||
SampleRate *int `json:"sample_rate"`
|
||||
AllGroups *bool `json:"all_groups"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
RecordNonHits *bool `json:"record_non_hits"`
|
||||
WorkerCount *int `json:"worker_count"`
|
||||
QueueSize *int `json:"queue_size"`
|
||||
BlockStatus *int `json:"block_status"`
|
||||
BlockMessage *string `json:"block_message"`
|
||||
EmailOnHit *bool `json:"email_on_hit"`
|
||||
AutoBanEnabled *bool `json:"auto_ban_enabled"`
|
||||
BanThreshold *int `json:"ban_threshold"`
|
||||
ViolationWindowHours *int `json:"violation_window_hours"`
|
||||
RetryCount *int `json:"retry_count"`
|
||||
HitRetentionDays *int `json:"hit_retention_days"`
|
||||
NonHitRetentionDays *int `json:"non_hit_retention_days"`
|
||||
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
|
||||
BlockedKeywords *[]string `json:"blocked_keywords"`
|
||||
KeywordBlockingMode *string `json:"keyword_blocking_mode"`
|
||||
ModelFilter *ContentModerationModelFilter `json:"model_filter"`
|
||||
}
|
||||
|
||||
type ContentModerationModelFilter struct {
|
||||
Type string `json:"type"`
|
||||
Models []string `json:"models"`
|
||||
}
|
||||
|
||||
type ContentModerationCheckInput struct {
|
||||
@ -581,6 +595,9 @@ func (s *ContentModerationService) UpdateConfig(ctx context.Context, input Updat
|
||||
if input.KeywordBlockingMode != nil {
|
||||
cfg.KeywordBlockingMode = strings.TrimSpace(*input.KeywordBlockingMode)
|
||||
}
|
||||
if input.ModelFilter != nil {
|
||||
cfg.ModelFilter = *input.ModelFilter
|
||||
}
|
||||
if input.AllGroups != nil {
|
||||
cfg.AllGroups = *input.AllGroups
|
||||
}
|
||||
@ -719,7 +736,8 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer
|
||||
"error", err)
|
||||
return allow, nil
|
||||
}
|
||||
inScope := cfg.includesGroup(input.GroupID)
|
||||
inGroupScope := cfg.includesGroup(input.GroupID)
|
||||
inModelScope := cfg.includesModel(input.Model)
|
||||
slog.Info("content_moderation.config_loaded",
|
||||
"user_id", input.UserID,
|
||||
"api_key_id", input.APIKeyID,
|
||||
@ -733,7 +751,10 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer
|
||||
"mode", cfg.Mode,
|
||||
"all_groups", cfg.AllGroups,
|
||||
"configured_group_ids", cfg.GroupIDs,
|
||||
"in_scope", inScope,
|
||||
"in_group_scope", inGroupScope,
|
||||
"model_filter_type", cfg.ModelFilter.Type,
|
||||
"configured_models", cfg.ModelFilter.Models,
|
||||
"in_model_scope", inModelScope,
|
||||
"sample_rate", cfg.SampleRate,
|
||||
"api_key_count", len(cfg.apiKeys()),
|
||||
"pre_hash_check_enabled", cfg.PreHashCheckEnabled,
|
||||
@ -756,7 +777,7 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer
|
||||
"protocol", input.Protocol)
|
||||
return allow, nil
|
||||
}
|
||||
if !inScope {
|
||||
if !inGroupScope {
|
||||
slog.Info("content_moderation.skip_group_out_of_scope",
|
||||
"user_id", input.UserID,
|
||||
"api_key_id", input.APIKeyID,
|
||||
@ -768,6 +789,19 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer
|
||||
"configured_group_ids", cfg.GroupIDs)
|
||||
return allow, nil
|
||||
}
|
||||
if !inModelScope {
|
||||
slog.Info("content_moderation.skip_model_out_of_scope",
|
||||
"user_id", input.UserID,
|
||||
"api_key_id", input.APIKeyID,
|
||||
"group_id", contentModerationLogGroupID(input.GroupID),
|
||||
"group_name", input.GroupName,
|
||||
"endpoint", input.Endpoint,
|
||||
"protocol", input.Protocol,
|
||||
"model", input.Model,
|
||||
"model_filter_type", cfg.ModelFilter.Type,
|
||||
"configured_models", cfg.ModelFilter.Models)
|
||||
return allow, nil
|
||||
}
|
||||
content := ExtractContentModerationInput(input.Protocol, input.Body)
|
||||
if content.IsEmpty() {
|
||||
slog.Info("content_moderation.skip_empty_input",
|
||||
@ -1025,6 +1059,9 @@ func (s *ContentModerationService) worker(id int) {
|
||||
if !cfg.includesGroup(task.input.GroupID) {
|
||||
return
|
||||
}
|
||||
if !cfg.includesModel(task.input.Model) {
|
||||
return
|
||||
}
|
||||
s.asyncActive.Add(1)
|
||||
defer s.asyncActive.Add(-1)
|
||||
queueDelay := int(time.Since(task.enqueuedAt).Milliseconds())
|
||||
@ -1270,6 +1307,9 @@ func (s *ContentModerationService) validateConfig(ctx context.Context, cfg *Cont
|
||||
if cfg.BlockStatus < 400 || cfg.BlockStatus > 599 {
|
||||
return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_BLOCK_STATUS", "拦截 HTTP 状态码必须在 400-599 之间")
|
||||
}
|
||||
if cfg.ModelFilter.Type != ContentModerationModelFilterAll && len(cfg.ModelFilter.Models) == 0 {
|
||||
return infraerrors.BadRequest("INVALID_CONTENT_MODERATION_MODEL_FILTER", "指定或排除模型时至少需要配置 1 个模型")
|
||||
}
|
||||
if !cfg.AllGroups && len(cfg.GroupIDs) > 0 && s.groupRepo != nil {
|
||||
for _, groupID := range cfg.GroupIDs {
|
||||
if _, err := s.groupRepo.GetByIDLite(ctx, groupID); err != nil {
|
||||
@ -1590,6 +1630,10 @@ func defaultContentModerationConfig() *ContentModerationConfig {
|
||||
PreHashCheckEnabled: false,
|
||||
BlockedKeywords: []string{},
|
||||
KeywordBlockingMode: ContentModerationKeywordModeKeywordAndAPI,
|
||||
ModelFilter: ContentModerationModelFilter{
|
||||
Type: ContentModerationModelFilterAll,
|
||||
Models: []string{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -1670,6 +1714,7 @@ func (cfg *ContentModerationConfig) normalize() {
|
||||
cfg.Thresholds = mergeContentModerationThresholds(ContentModerationDefaultThresholds(), cfg.Thresholds)
|
||||
cfg.BlockedKeywords = normalizeBlockedKeywords(cfg.BlockedKeywords)
|
||||
cfg.KeywordBlockingMode = normalizeKeywordBlockingMode(cfg.KeywordBlockingMode)
|
||||
cfg.ModelFilter = normalizeContentModerationModelFilter(cfg.ModelFilter)
|
||||
}
|
||||
|
||||
func (cfg *ContentModerationConfig) includesGroup(groupID *int64) bool {
|
||||
@ -1687,6 +1732,21 @@ func (cfg *ContentModerationConfig) includesGroup(groupID *int64) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (cfg *ContentModerationConfig) includesModel(model string) bool {
|
||||
if cfg == nil {
|
||||
return true
|
||||
}
|
||||
filter := normalizeContentModerationModelFilter(cfg.ModelFilter)
|
||||
switch filter.Type {
|
||||
case ContentModerationModelFilterInclude:
|
||||
return contentModerationModelListContains(filter.Models, model)
|
||||
case ContentModerationModelFilterExclude:
|
||||
return !contentModerationModelListContains(filter.Models, model)
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func contentModerationLogGroupID(groupID *int64) int64 {
|
||||
if groupID == nil {
|
||||
return 0
|
||||
@ -1848,6 +1908,7 @@ func (s *ContentModerationService) configView(cfg *ContentModerationConfig) *Con
|
||||
PreHashCheckEnabled: cfg.PreHashCheckEnabled,
|
||||
BlockedKeywords: append([]string(nil), cfg.BlockedKeywords...),
|
||||
KeywordBlockingMode: cfg.KeywordBlockingMode,
|
||||
ModelFilter: cloneContentModerationModelFilter(cfg.ModelFilter),
|
||||
}
|
||||
}
|
||||
|
||||
@ -2125,6 +2186,73 @@ func normalizeKeywordBlockingMode(mode string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeContentModerationModelFilter(filter ContentModerationModelFilter) ContentModerationModelFilter {
|
||||
out := ContentModerationModelFilter{
|
||||
Type: normalizeContentModerationModelFilterType(filter.Type),
|
||||
Models: normalizeContentModerationModelNames(filter.Models),
|
||||
}
|
||||
if out.Type == ContentModerationModelFilterAll {
|
||||
out.Models = []string{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneContentModerationModelFilter(filter ContentModerationModelFilter) ContentModerationModelFilter {
|
||||
normalized := normalizeContentModerationModelFilter(filter)
|
||||
normalized.Models = append([]string(nil), normalized.Models...)
|
||||
return normalized
|
||||
}
|
||||
|
||||
func normalizeContentModerationModelFilterType(filterType string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(filterType)) {
|
||||
case ContentModerationModelFilterInclude:
|
||||
return ContentModerationModelFilterInclude
|
||||
case ContentModerationModelFilterExclude:
|
||||
return ContentModerationModelFilterExclude
|
||||
case ContentModerationModelFilterAll:
|
||||
return ContentModerationModelFilterAll
|
||||
default:
|
||||
return ContentModerationModelFilterAll
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeContentModerationModelNames(models []string) []string {
|
||||
if len(models) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
out := make([]string, 0, len(models))
|
||||
seen := make(map[string]struct{}, len(models))
|
||||
for _, raw := range models {
|
||||
model := trimRunes(strings.TrimSpace(raw), maxContentModerationModelFilterRunes)
|
||||
if model == "" {
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(model)
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, model)
|
||||
if len(out) >= maxContentModerationModelFilterModels {
|
||||
break
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func contentModerationModelListContains(models []string, model string) bool {
|
||||
model = strings.ToLower(strings.TrimSpace(model))
|
||||
if model == "" {
|
||||
return false
|
||||
}
|
||||
for _, candidate := range models {
|
||||
if strings.ToLower(strings.TrimSpace(candidate)) == model {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matchBlockedKeyword(text string, keywords []string) (string, bool) {
|
||||
if text == "" || len(keywords) == 0 {
|
||||
return "", false
|
||||
|
||||
@ -48,44 +48,44 @@ func collectLastRoleMessage(messages gjson.Result, role string, parts *[]string,
|
||||
if !messages.IsArray() {
|
||||
return
|
||||
}
|
||||
var lastParts []string
|
||||
var lastImages []string
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == role {
|
||||
var candidate []string
|
||||
var candidateImages []string
|
||||
collectContentValue(msg.Get("content"), &candidate, &candidateImages)
|
||||
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
|
||||
lastParts = candidate
|
||||
lastImages = candidateImages
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
*parts = append(*parts, lastParts...)
|
||||
*images = append(*images, lastImages...)
|
||||
array := messages.Array()
|
||||
if len(array) == 0 {
|
||||
return
|
||||
}
|
||||
last := array[len(array)-1]
|
||||
if strings.ToLower(strings.TrimSpace(last.Get("role").String())) != role {
|
||||
return
|
||||
}
|
||||
var candidate []string
|
||||
var candidateImages []string
|
||||
collectContentValue(last.Get("content"), &candidate, &candidateImages)
|
||||
if normalizeContentModerationText(strings.Join(candidate, "\n")) == "" && len(candidateImages) == 0 {
|
||||
return
|
||||
}
|
||||
*parts = append(*parts, candidate...)
|
||||
*images = append(*images, candidateImages...)
|
||||
}
|
||||
|
||||
func collectLastAnthropicUserMessage(messages gjson.Result, parts *[]string, images *[]string) {
|
||||
if !messages.IsArray() {
|
||||
return
|
||||
}
|
||||
var lastParts []string
|
||||
var lastImages []string
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
if strings.ToLower(strings.TrimSpace(msg.Get("role").String())) == "user" {
|
||||
var candidate []string
|
||||
var candidateImages []string
|
||||
collectAnthropicUserContentValue(msg.Get("content"), &candidate, &candidateImages)
|
||||
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
|
||||
lastParts = candidate
|
||||
lastImages = candidateImages
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
*parts = append(*parts, lastParts...)
|
||||
*images = append(*images, lastImages...)
|
||||
array := messages.Array()
|
||||
if len(array) == 0 {
|
||||
return
|
||||
}
|
||||
last := array[len(array)-1]
|
||||
if strings.ToLower(strings.TrimSpace(last.Get("role").String())) != "user" {
|
||||
return
|
||||
}
|
||||
var candidate []string
|
||||
var candidateImages []string
|
||||
collectAnthropicUserContentValue(last.Get("content"), &candidate, &candidateImages)
|
||||
if normalizeContentModerationText(strings.Join(candidate, "\n")) == "" && len(candidateImages) == 0 {
|
||||
return
|
||||
}
|
||||
*parts = append(*parts, candidate...)
|
||||
*images = append(*images, candidateImages...)
|
||||
}
|
||||
|
||||
func collectAnthropicUserContentValue(value gjson.Result, parts *[]string, images *[]string) {
|
||||
@ -128,18 +128,17 @@ func collectLastResponsesInput(input gjson.Result, parts *[]string, images *[]st
|
||||
case input.Type == gjson.String:
|
||||
addModerationText(parts, input.String())
|
||||
case input.IsArray():
|
||||
var last gjson.Result
|
||||
input.ForEach(func(_, item gjson.Result) bool {
|
||||
if isResponsesUserTextItem(item) {
|
||||
last = item
|
||||
}
|
||||
return true
|
||||
})
|
||||
if last.Exists() {
|
||||
collectContentValue(last.Get("content"), parts, images)
|
||||
if last.Get("type").String() == "input_text" || last.Get("text").Exists() {
|
||||
collectContentValue(last, parts, images)
|
||||
}
|
||||
array := input.Array()
|
||||
if len(array) == 0 {
|
||||
return
|
||||
}
|
||||
last := array[len(array)-1]
|
||||
if !isResponsesUserTextItem(last) {
|
||||
return
|
||||
}
|
||||
collectContentValue(last.Get("content"), parts, images)
|
||||
if last.Get("type").String() == "input_text" || last.Get("text").Exists() {
|
||||
collectContentValue(last, parts, images)
|
||||
}
|
||||
case input.IsObject():
|
||||
if isResponsesUserTextItem(input) {
|
||||
@ -176,29 +175,29 @@ func collectLastGeminiContent(contents gjson.Result, parts *[]string, images *[]
|
||||
if !contents.IsArray() {
|
||||
return
|
||||
}
|
||||
var lastParts []string
|
||||
var lastImages []string
|
||||
contents.ForEach(func(_, content gjson.Result) bool {
|
||||
role := strings.ToLower(strings.TrimSpace(content.Get("role").String()))
|
||||
if role == "" || role == "user" {
|
||||
var candidate []string
|
||||
var candidateImages []string
|
||||
if arr := content.Get("parts"); arr.IsArray() {
|
||||
arr.ForEach(func(_, part gjson.Result) bool {
|
||||
addModerationText(&candidate, part.Get("text").String())
|
||||
addGeminiModerationImage(&candidateImages, part)
|
||||
return true
|
||||
})
|
||||
}
|
||||
if normalizeContentModerationText(strings.Join(candidate, "\n")) != "" || len(candidateImages) > 0 {
|
||||
lastParts = candidate
|
||||
lastImages = candidateImages
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
*parts = append(*parts, lastParts...)
|
||||
*images = append(*images, lastImages...)
|
||||
array := contents.Array()
|
||||
if len(array) == 0 {
|
||||
return
|
||||
}
|
||||
last := array[len(array)-1]
|
||||
role := strings.ToLower(strings.TrimSpace(last.Get("role").String()))
|
||||
if role != "" && role != "user" {
|
||||
return
|
||||
}
|
||||
var candidate []string
|
||||
var candidateImages []string
|
||||
if arr := last.Get("parts"); arr.IsArray() {
|
||||
arr.ForEach(func(_, part gjson.Result) bool {
|
||||
addModerationText(&candidate, part.Get("text").String())
|
||||
addGeminiModerationImage(&candidateImages, part)
|
||||
return true
|
||||
})
|
||||
}
|
||||
if normalizeContentModerationText(strings.Join(candidate, "\n")) == "" && len(candidateImages) == 0 {
|
||||
return
|
||||
}
|
||||
*parts = append(*parts, candidate...)
|
||||
*images = append(*images, candidateImages...)
|
||||
}
|
||||
|
||||
func collectContentValue(value gjson.Result, parts *[]string, images *[]string) {
|
||||
|
||||
179
backend/internal/service/content_moderation_input_test.go
Normal file
179
backend/internal/service/content_moderation_input_test.go
Normal file
@ -0,0 +1,179 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 当数组末尾不是用户消息时(典型场景:Agent 工具循环结束于 tool/assistant),
|
||||
// 应直接跳过审计——不再回溯查找历史中的某条用户消息。
|
||||
|
||||
func TestExtractContentModerationInput_AnthropicAgentToolLoopSkipsAudit(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages": [
|
||||
{"role":"user","content":"调用一下天气工具"},
|
||||
{"role":"assistant","content":[{"type":"tool_use","id":"tool_1","name":"weather","input":{}}]},
|
||||
{"role":"user","content":[{"type":"tool_result","tool_use_id":"tool_1","content":"晴 25 度"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body)
|
||||
|
||||
require.Empty(t, input.Text)
|
||||
require.Empty(t, input.Images)
|
||||
}
|
||||
|
||||
func TestExtractContentModerationInput_AnthropicFirstTurnExtractsUser(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages": [
|
||||
{"role":"user","content":"Q1"}
|
||||
]
|
||||
}`)
|
||||
|
||||
input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body)
|
||||
|
||||
require.Equal(t, "Q1", input.Text)
|
||||
}
|
||||
|
||||
func TestExtractContentModerationInput_AnthropicMultiTurnExtractsLatestUser(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages": [
|
||||
{"role":"user","content":"Q1"},
|
||||
{"role":"assistant","content":"A1"},
|
||||
{"role":"user","content":"Q2"}
|
||||
]
|
||||
}`)
|
||||
|
||||
input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body)
|
||||
|
||||
require.Equal(t, "Q2", input.Text)
|
||||
}
|
||||
|
||||
func TestExtractContentModerationInput_AnthropicStreamResendExtractsResend(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages": [
|
||||
{"role":"user","content":"原问题"},
|
||||
{"role":"assistant","content":"部分回答……"},
|
||||
{"role":"user","content":"重发"}
|
||||
]
|
||||
}`)
|
||||
|
||||
input := ExtractContentModerationInput(ContentModerationProtocolAnthropicMessages, body)
|
||||
|
||||
require.Equal(t, "重发", input.Text)
|
||||
}
|
||||
|
||||
func TestExtractContentModerationInput_OpenAIChatAgentToolLoopSkipsAudit(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages": [
|
||||
{"role":"system","content":"sys"},
|
||||
{"role":"user","content":"列出我的订单"},
|
||||
{"role":"assistant","content":null,"tool_calls":[{"id":"call_1","type":"function","function":{"name":"orders","arguments":"{}"}}]},
|
||||
{"role":"tool","tool_call_id":"call_1","content":"[]"}
|
||||
]
|
||||
}`)
|
||||
|
||||
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIChat, body)
|
||||
|
||||
require.Empty(t, input.Text)
|
||||
require.Empty(t, input.Images)
|
||||
}
|
||||
|
||||
func TestExtractContentModerationInput_OpenAIChatMultiTurnExtractsLatestUser(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages": [
|
||||
{"role":"user","content":"Q1"},
|
||||
{"role":"assistant","content":"A1"},
|
||||
{"role":"user","content":"Q2"}
|
||||
]
|
||||
}`)
|
||||
|
||||
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIChat, body)
|
||||
|
||||
require.Equal(t, "Q2", input.Text)
|
||||
}
|
||||
|
||||
func TestExtractContentModerationInput_GeminiAgentToolLoopSkipsAudit(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"contents": [
|
||||
{"role":"user","parts":[{"text":"查询天气"}]},
|
||||
{"role":"model","parts":[{"functionCall":{"name":"weather","args":{}}}]},
|
||||
{"role":"user","parts":[{"functionResponse":{"name":"weather","response":{"temp":25}}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
input := ExtractContentModerationInput(ContentModerationProtocolGemini, body)
|
||||
|
||||
require.Empty(t, input.Text)
|
||||
require.Empty(t, input.Images)
|
||||
}
|
||||
|
||||
func TestExtractContentModerationInput_GeminiFirstTurnExtractsUser(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"contents": [
|
||||
{"role":"user","parts":[{"text":"你好"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
input := ExtractContentModerationInput(ContentModerationProtocolGemini, body)
|
||||
|
||||
require.Equal(t, "你好", input.Text)
|
||||
}
|
||||
|
||||
func TestExtractContentModerationInput_GeminiMultiTurnExtractsLatestUser(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"contents": [
|
||||
{"role":"user","parts":[{"text":"Q1"}]},
|
||||
{"role":"model","parts":[{"text":"A1"}]},
|
||||
{"role":"user","parts":[{"text":"Q2"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
input := ExtractContentModerationInput(ContentModerationProtocolGemini, body)
|
||||
|
||||
require.Equal(t, "Q2", input.Text)
|
||||
}
|
||||
|
||||
func TestExtractContentModerationInput_ResponsesAgentToolLoopSkipsAudit(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"input":[
|
||||
{"type":"message","role":"user","content":[{"type":"input_text","text":"运行测试"}]},
|
||||
{"type":"function_call","call_id":"call_1","name":"run_tests","arguments":"{}"},
|
||||
{"type":"function_call_output","call_id":"call_1","output":"all passed"}
|
||||
]
|
||||
}`)
|
||||
|
||||
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIResponses, body)
|
||||
|
||||
require.Empty(t, input.Text)
|
||||
require.Empty(t, input.Images)
|
||||
}
|
||||
|
||||
func TestExtractContentModerationInput_ResponsesLastUserMessageExtracted(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"input":[
|
||||
{"type":"message","role":"user","content":[{"type":"input_text","text":"first"}]},
|
||||
{"type":"message","role":"assistant","content":[{"type":"output_text","text":"answer"}]},
|
||||
{"type":"message","role":"user","content":[{"type":"input_text","text":"latest"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIResponses, body)
|
||||
|
||||
require.Equal(t, "latest", input.Text)
|
||||
}
|
||||
|
||||
func TestExtractContentModerationInput_ResponsesLastIsAssistantSkipped(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"input":[
|
||||
{"type":"message","role":"user","content":[{"type":"input_text","text":"q1"}]},
|
||||
{"type":"message","role":"assistant","content":[{"type":"output_text","text":"a1"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
input := ExtractContentModerationInput(ContentModerationProtocolOpenAIResponses, body)
|
||||
|
||||
require.Empty(t, input.Text)
|
||||
require.Empty(t, input.Images)
|
||||
}
|
||||
@ -530,6 +530,147 @@ func TestNormalizeKeywordBlockingMode_UnknownFallsBackToDefault(t *testing.T) {
|
||||
require.Equal(t, ContentModerationKeywordModeAPIOnly, normalizeKeywordBlockingMode("api_only"))
|
||||
}
|
||||
|
||||
func TestContentModerationCheck_ModelFilterAllAuditsEveryModel(t *testing.T) {
|
||||
cfg := defaultContentModerationModelFilterTestConfig()
|
||||
cfg.ModelFilter = ContentModerationModelFilter{Type: ContentModerationModelFilterAll}
|
||||
svc, repo := newContentModerationModelFilterTestService(t, cfg)
|
||||
|
||||
for _, model := range []string{"gpt-5.5", "gpt-5.4"} {
|
||||
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
|
||||
Model: model,
|
||||
Protocol: ContentModerationProtocolOpenAIChat,
|
||||
Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, decision.Blocked)
|
||||
require.Equal(t, ContentModerationActionKeywordBlock, decision.Action)
|
||||
}
|
||||
require.Len(t, repo.logs, 2)
|
||||
}
|
||||
|
||||
func TestContentModerationCheck_ModelFilterIncludeOnlyAuditsListedModels(t *testing.T) {
|
||||
cfg := defaultContentModerationModelFilterTestConfig()
|
||||
cfg.ModelFilter = ContentModerationModelFilter{Type: ContentModerationModelFilterInclude, Models: []string{"gpt-5.5"}}
|
||||
svc, repo := newContentModerationModelFilterTestService(t, cfg)
|
||||
|
||||
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
|
||||
Model: "gpt-5.5",
|
||||
Protocol: ContentModerationProtocolOpenAIChat,
|
||||
Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, decision.Blocked)
|
||||
require.Equal(t, ContentModerationActionKeywordBlock, decision.Action)
|
||||
|
||||
decision, err = svc.Check(context.Background(), ContentModerationCheckInput{
|
||||
Model: "gpt-5.4",
|
||||
Protocol: ContentModerationProtocolOpenAIChat,
|
||||
Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, decision.Allowed)
|
||||
require.False(t, decision.Blocked)
|
||||
require.Equal(t, ContentModerationActionAllow, decision.Action)
|
||||
require.Len(t, repo.logs, 1)
|
||||
require.Equal(t, "gpt-5.5", repo.logs[0].Model)
|
||||
}
|
||||
|
||||
func TestContentModerationCheck_ModelFilterExcludeSkipsListedModels(t *testing.T) {
|
||||
cfg := defaultContentModerationModelFilterTestConfig()
|
||||
cfg.ModelFilter = ContentModerationModelFilter{Type: ContentModerationModelFilterExclude, Models: []string{"gpt-5.4"}}
|
||||
svc, repo := newContentModerationModelFilterTestService(t, cfg)
|
||||
|
||||
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
|
||||
Model: "gpt-5.5",
|
||||
Protocol: ContentModerationProtocolOpenAIChat,
|
||||
Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, decision.Blocked)
|
||||
require.Equal(t, ContentModerationActionKeywordBlock, decision.Action)
|
||||
|
||||
decision, err = svc.Check(context.Background(), ContentModerationCheckInput{
|
||||
Model: "gpt-5.4",
|
||||
Protocol: ContentModerationProtocolOpenAIChat,
|
||||
Body: []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, decision.Allowed)
|
||||
require.False(t, decision.Blocked)
|
||||
require.Equal(t, ContentModerationActionAllow, decision.Action)
|
||||
require.Len(t, repo.logs, 1)
|
||||
require.Equal(t, "gpt-5.5", repo.logs[0].Model)
|
||||
}
|
||||
|
||||
func TestContentModerationLoadConfig_LegacyConfigDefaultsModelFilterToAll(t *testing.T) {
|
||||
raw := `{"enabled":true,"mode":"pre_block","base_url":"https://api.openai.com","model":"omni-moderation-latest","blocked_keywords":["secret-token"]}`
|
||||
svc := NewContentModerationService(
|
||||
&contentModerationTestSettingRepo{values: map[string]string{
|
||||
SettingKeyContentModerationConfig: raw,
|
||||
}},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
cfg, err := svc.loadConfig(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ContentModerationModelFilterAll, cfg.ModelFilter.Type)
|
||||
require.Empty(t, cfg.ModelFilter.Models)
|
||||
require.True(t, cfg.includesModel("gpt-5.5"))
|
||||
require.True(t, cfg.includesModel("gpt-5.4"))
|
||||
}
|
||||
|
||||
func TestContentModerationCheck_ModelFilterUsesRequestedModelNotBodyModel(t *testing.T) {
|
||||
cfg := defaultContentModerationModelFilterTestConfig()
|
||||
cfg.ModelFilter = ContentModerationModelFilter{Type: ContentModerationModelFilterInclude, Models: []string{"gpt-5.5"}}
|
||||
svc, repo := newContentModerationModelFilterTestService(t, cfg)
|
||||
|
||||
decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
|
||||
Model: "gpt-5.5",
|
||||
Protocol: ContentModerationProtocolOpenAIChat,
|
||||
Body: []byte(`{"model":"mapped-upstream-model","messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`),
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.True(t, decision.Blocked)
|
||||
require.Equal(t, ContentModerationActionKeywordBlock, decision.Action)
|
||||
require.Len(t, repo.logs, 1)
|
||||
require.Equal(t, "gpt-5.5", repo.logs[0].Model)
|
||||
}
|
||||
|
||||
func defaultContentModerationModelFilterTestConfig() *ContentModerationConfig {
|
||||
cfg := defaultContentModerationConfig()
|
||||
cfg.Enabled = true
|
||||
cfg.Mode = ContentModerationModePreBlock
|
||||
cfg.BlockedKeywords = []string{"secret-token"}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func newContentModerationModelFilterTestService(t *testing.T, cfg *ContentModerationConfig) (*ContentModerationService, *contentModerationTestRepo) {
|
||||
t.Helper()
|
||||
rawCfg, err := json.Marshal(cfg)
|
||||
require.NoError(t, err)
|
||||
repo := &contentModerationTestRepo{}
|
||||
svc := NewContentModerationService(
|
||||
&contentModerationTestSettingRepo{values: map[string]string{
|
||||
SettingKeyRiskControlEnabled: "true",
|
||||
SettingKeyContentModerationConfig: string(rawCfg),
|
||||
}},
|
||||
repo,
|
||||
&contentModerationTestHashCache{},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
return svc, repo
|
||||
}
|
||||
|
||||
func TestContentModerationUpdateConfig_AppendsAndDeletesAPIKeys(t *testing.T) {
|
||||
cfg := defaultContentModerationConfig()
|
||||
cfg.APIKeys = []string{"sk-old-a", "sk-old-b"}
|
||||
|
||||
@ -132,6 +132,9 @@ const (
|
||||
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
|
||||
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
|
||||
|
||||
// API Key IP 访问控制设置
|
||||
SettingKeyAPIKeyACLTrustForwardedIP = "api_key_acl_trust_forwarded_ip" // API Key IP 白/黑名单是否信任转发 IP
|
||||
|
||||
// TOTP 双因素认证设置
|
||||
SettingKeyTotpEnabled = "totp_enabled" // 是否启用 TOTP 2FA 功能
|
||||
|
||||
@ -406,12 +409,15 @@ const (
|
||||
// 用于避免 Cloudflare 对浏览器型 UA 的质询拦截。
|
||||
SettingKeyOpenAICodexUserAgent = "openai_codex_user_agent"
|
||||
|
||||
// Balance Low Notification
|
||||
// 余额不足提醒
|
||||
SettingKeyBalanceLowNotifyEnabled = "balance_low_notify_enabled" // 全局开关
|
||||
SettingKeyBalanceLowNotifyThreshold = "balance_low_notify_threshold" // 默认阈值(USD)
|
||||
SettingKeyBalanceLowNotifyRechargeURL = "balance_low_notify_recharge_url" // 充值页面 URL
|
||||
|
||||
// Account Quota Notification
|
||||
// 订阅到期提醒
|
||||
SettingKeySubscriptionExpiryNotifyEnabled = "subscription_expiry_notify_enabled" // 订阅到期提醒全局开关,默认开启
|
||||
|
||||
// 账号限额通知
|
||||
SettingKeyAccountQuotaNotifyEnabled = "account_quota_notify_enabled" // 全局开关
|
||||
SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表(JSON 数组)
|
||||
|
||||
|
||||
@ -5739,6 +5739,31 @@ func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header,
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyBedrockCCCompat 应用 Bedrock CC 兼容转换(渠道级模型映射后调用)
|
||||
// 清理 Anthropic API 专有字段、注入 Bedrock 必需字段、修复 thinking/tool_use ID
|
||||
func (s *GatewayService) ApplyBedrockCCCompat(ctx context.Context, body []byte, model string, account *Account, groupID *int64) []byte {
|
||||
if !s.isBedrockCCCompatEnabled(ctx, account, groupID) {
|
||||
return body
|
||||
}
|
||||
body = sanitizeBedrockCCFields(body)
|
||||
body = sanitizeBedrockThinking(body, model)
|
||||
body = sanitizeBedrockToolUseIDs(body)
|
||||
body = sanitizeBedrockCCBetaTokens(body, model)
|
||||
return body
|
||||
}
|
||||
|
||||
// isBedrockCCCompatEnabled 检查渠道是否启用了 Bedrock CC 兼容模式
|
||||
func (s *GatewayService) isBedrockCCCompatEnabled(ctx context.Context, account *Account, groupID *int64) bool {
|
||||
if groupID == nil || s.channelService == nil {
|
||||
return false
|
||||
}
|
||||
ch, err := s.channelService.GetChannelForGroup(ctx, *groupID)
|
||||
if err != nil || ch == nil {
|
||||
return false
|
||||
}
|
||||
return ch.IsBedrockCCCompatEnabled(account.Platform)
|
||||
}
|
||||
|
||||
// forwardBedrock 转发请求到 AWS Bedrock
|
||||
func (s *GatewayService) forwardBedrock(
|
||||
ctx context.Context,
|
||||
@ -5771,7 +5796,7 @@ func (s *GatewayService) forwardBedrock(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bedrockBody, err := PrepareBedrockRequestBodyWithTokens(body, mappedModel, betaTokens)
|
||||
bedrockBody, err := PrepareBedrockRequestBodyWithTokens(body, mappedModel, betaTokens, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("prepare bedrock request body: %w", err)
|
||||
}
|
||||
|
||||
@ -126,17 +126,17 @@ func TestResolveBedrockBetaTokensForRequest_FiltersAfterBedrockTransform(t *test
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking 验证:
|
||||
// 管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token,
|
||||
// 但请求体包含 thinking 字段 → 自动注入后应被 block。
|
||||
func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking(t *testing.T) {
|
||||
// TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedComputerUse 验证:
|
||||
// 管理员 block 了 computer-use,客户端不在 header 中带该 token,
|
||||
// 但请求体包含 computer_use 工具 → 自动注入后应被 block。
|
||||
func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedComputerUse(t *testing.T) {
|
||||
settings := &BetaPolicySettings{
|
||||
Rules: []BetaPolicyRule{
|
||||
{
|
||||
BetaToken: "interleaved-thinking-2025-05-14",
|
||||
BetaToken: "computer-use-2025-11-24",
|
||||
Action: BetaPolicyActionBlock,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
ErrorMessage: "thinking is blocked",
|
||||
ErrorMessage: "computer use is blocked",
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -155,18 +155,18 @@ func TestResolveBedrockBetaTokensForRequest_BlocksBodyAutoInjectedThinking(t *te
|
||||
}
|
||||
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
|
||||
|
||||
// header 中不带 beta token,但 body 中有 thinking 字段
|
||||
// header 中不带 beta token,但 body 中有 computer_use 工具
|
||||
_, err = svc.resolveBedrockBetaTokensForRequest(
|
||||
context.Background(),
|
||||
account,
|
||||
"", // 空 header
|
||||
[]byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`),
|
||||
[]byte(`{"tools":[{"type":"computer_20250124","name":"computer"}],"messages":[{"role":"user","content":"hi"}]}`),
|
||||
"us.anthropic.claude-opus-4-6-v1",
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("expected body-injected interleaved-thinking to be blocked")
|
||||
t.Fatal("expected body-injected computer-use to be blocked")
|
||||
}
|
||||
if err.Error() != "thinking is blocked" {
|
||||
if err.Error() != "computer use is blocked" {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
@ -222,10 +222,10 @@ func TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches(t *test
|
||||
settings := &BetaPolicySettings{
|
||||
Rules: []BetaPolicyRule{
|
||||
{
|
||||
BetaToken: "computer-use-2025-11-24",
|
||||
BetaToken: "context-1m-2025-08-07",
|
||||
Action: BetaPolicyActionBlock,
|
||||
Scope: BetaPolicyScopeAll,
|
||||
ErrorMessage: "computer use is blocked",
|
||||
ErrorMessage: "context is blocked",
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -244,12 +244,12 @@ func TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches(t *test
|
||||
}
|
||||
account := &Account{Platform: PlatformAnthropic, Type: AccountTypeBedrock}
|
||||
|
||||
// body 中有 thinking(会注入 interleaved-thinking),但 block 规则只针对 computer-use
|
||||
// body 中有 computer_use 工具(会注入 computer-use token),但 block 规则只针对 context-1m
|
||||
tokens, err := svc.resolveBedrockBetaTokensForRequest(
|
||||
context.Background(),
|
||||
account,
|
||||
"",
|
||||
[]byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`),
|
||||
[]byte(`{"tools":[{"type":"computer_20250124","name":"computer"}],"messages":[{"role":"user","content":"hi"}]}`),
|
||||
"us.anthropic.claude-opus-4-6-v1",
|
||||
)
|
||||
if err != nil {
|
||||
@ -257,11 +257,11 @@ func TestResolveBedrockBetaTokensForRequest_PassesWhenNoBlockRuleMatches(t *test
|
||||
}
|
||||
found := false
|
||||
for _, token := range tokens {
|
||||
if token == "interleaved-thinking-2025-05-14" {
|
||||
if token == "computer-use-2025-11-24" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("expected interleaved-thinking token to be present")
|
||||
t.Fatal("expected computer-use token to be present")
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,169 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
openAIAccountStateUpdateTimeout = 5 * time.Second
|
||||
openAIOAuth429FallbackCooldown = 5 * time.Second
|
||||
openAIStopSchedulingBridgeCooldown = 2 * time.Minute
|
||||
openAIOAuth429StormWindow = 10 * time.Second
|
||||
openAIOAuth429StormThreshold = 20
|
||||
openAIOAuth429StormMaxAccountSwitches = 1
|
||||
)
|
||||
|
||||
func openAIAccountStateContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
base := context.Background()
|
||||
if ctx != nil {
|
||||
base = context.WithoutCancel(ctx)
|
||||
}
|
||||
return context.WithTimeout(base, openAIAccountStateUpdateTimeout)
|
||||
}
|
||||
|
||||
func isOpenAIOAuthAccount(account *Account) bool {
|
||||
return account != nil && account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
func isOpenAIAccount(account *Account) bool {
|
||||
return account != nil && account.Platform == PlatformOpenAI
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleOpenAIAccountUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) bool {
|
||||
stateCtx, cancel := openAIAccountStateContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
if statusCode == http.StatusTooManyRequests {
|
||||
s.markOpenAIOAuth429RateLimited(stateCtx, account, headers, responseBody)
|
||||
}
|
||||
if s == nil || account == nil || s.rateLimitService == nil {
|
||||
return false
|
||||
}
|
||||
shouldDisable := s.rateLimitService.HandleUpstreamError(stateCtx, account, statusCode, headers, responseBody)
|
||||
if shouldDisable {
|
||||
s.BlockAccountScheduling(account, time.Time{}, "upstream_disable")
|
||||
}
|
||||
return shouldDisable
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) markOpenAIOAuth429RateLimited(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
|
||||
if s == nil || !isOpenAIOAuthAccount(account) {
|
||||
return
|
||||
}
|
||||
s.recordOpenAIOAuth429()
|
||||
|
||||
cooldownUntil := time.Now().Add(openAIOAuth429FallbackCooldown)
|
||||
if s.rateLimitService != nil {
|
||||
if resetAt := s.rateLimitService.calculateOpenAI429ResetTime(headers); resetAt != nil && resetAt.After(time.Now()) {
|
||||
cooldownUntil = *resetAt
|
||||
} else if resetUnix := parseOpenAIRateLimitResetTime(responseBody); resetUnix != nil {
|
||||
if resetAt := time.Unix(*resetUnix, 0); resetAt.After(time.Now()) {
|
||||
cooldownUntil = resetAt
|
||||
}
|
||||
} else if cooldown, ok := s.rateLimitService.get429FallbackCooldown(ctx, account); ok && cooldown > 0 {
|
||||
cooldownUntil = time.Now().Add(cooldown)
|
||||
}
|
||||
}
|
||||
s.BlockAccountScheduling(account, cooldownUntil, "429")
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) BlockAccountScheduling(account *Account, until time.Time, reason string) {
|
||||
if s == nil || !isOpenAIAccount(account) {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
blockUntil := until
|
||||
if blockUntil.IsZero() || !blockUntil.After(now) {
|
||||
blockUntil = now.Add(openAIStopSchedulingBridgeCooldown)
|
||||
}
|
||||
|
||||
for {
|
||||
current, loaded := s.openaiAccountRuntimeBlockUntil.Load(account.ID)
|
||||
if !loaded {
|
||||
actual, stored := s.openaiAccountRuntimeBlockUntil.LoadOrStore(account.ID, blockUntil)
|
||||
if !stored {
|
||||
return
|
||||
}
|
||||
current = actual
|
||||
}
|
||||
|
||||
currentUntil, ok := current.(time.Time)
|
||||
if !ok || currentUntil.IsZero() {
|
||||
if s.openaiAccountRuntimeBlockUntil.CompareAndSwap(account.ID, current, blockUntil) {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if currentUntil.After(blockUntil) {
|
||||
return
|
||||
}
|
||||
if s.openaiAccountRuntimeBlockUntil.CompareAndSwap(account.ID, current, blockUntil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) ClearAccountSchedulingBlock(accountID int64) {
|
||||
if s == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
s.openaiAccountRuntimeBlockUntil.Delete(accountID)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) isOpenAIAccountRuntimeBlocked(account *Account) bool {
|
||||
if s == nil || !isOpenAIAccount(account) {
|
||||
return false
|
||||
}
|
||||
value, ok := s.openaiAccountRuntimeBlockUntil.Load(account.ID)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
cooldownUntil, ok := value.(time.Time)
|
||||
if !ok || cooldownUntil.IsZero() {
|
||||
s.openaiAccountRuntimeBlockUntil.Delete(account.ID)
|
||||
return false
|
||||
}
|
||||
if time.Now().Before(cooldownUntil) {
|
||||
return true
|
||||
}
|
||||
s.openaiAccountRuntimeBlockUntil.Delete(account.ID)
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) recordOpenAIOAuth429() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
windowStart := s.openaiOAuth429WindowStartUnixNano.Load()
|
||||
if windowStart == 0 || now.Sub(time.Unix(0, windowStart)) >= openAIOAuth429StormWindow {
|
||||
if s.openaiOAuth429WindowStartUnixNano.CompareAndSwap(windowStart, now.UnixNano()) {
|
||||
s.openaiOAuth429WindowCount.Store(1)
|
||||
return
|
||||
}
|
||||
}
|
||||
s.openaiOAuth429WindowCount.Add(1)
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) isOpenAIOAuth429Storm() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
windowStart := s.openaiOAuth429WindowStartUnixNano.Load()
|
||||
if windowStart == 0 || time.Since(time.Unix(0, windowStart)) >= openAIOAuth429StormWindow {
|
||||
return false
|
||||
}
|
||||
return s.openaiOAuth429WindowCount.Load() >= openAIOAuth429StormThreshold
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) ShouldStopOpenAIOAuth429Failover(account *Account, statusCode int, failedSwitches int) bool {
|
||||
if statusCode != http.StatusTooManyRequests || failedSwitches < openAIOAuth429StormMaxAccountSwitches {
|
||||
return false
|
||||
}
|
||||
if !isOpenAIOAuthAccount(account) {
|
||||
return false
|
||||
}
|
||||
return s.isOpenAIOAuth429Storm()
|
||||
}
|
||||
@ -0,0 +1,101 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpenAI429FastPath_MarksOAuthAccountCoolingDown(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
account := &Account{ID: 42, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||
apiKeyAccount := &Account{ID: 43, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
shouldDisable := svc.handleOpenAIAccountUpstreamError(context.Background(), account, http.StatusTooManyRequests, http.Header{}, nil)
|
||||
apiKeyShouldDisable := svc.handleOpenAIAccountUpstreamError(context.Background(), apiKeyAccount, http.StatusTooManyRequests, http.Header{}, nil)
|
||||
|
||||
require.False(t, shouldDisable)
|
||||
require.False(t, apiKeyShouldDisable)
|
||||
require.True(t, svc.isOpenAIAccountRuntimeBlocked(account))
|
||||
require.False(t, svc.isOpenAIAccountRuntimeBlocked(apiKeyAccount))
|
||||
}
|
||||
|
||||
func TestOpenAIRuntimeBlock_AppliesToOpenAIAPIKeyWhenRateLimitServiceStopsScheduling(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
account := &Account{ID: 44, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
svc.BlockAccountScheduling(account, time.Time{}, "custom_error_code")
|
||||
|
||||
require.True(t, svc.isOpenAIAccountRuntimeBlocked(account))
|
||||
}
|
||||
|
||||
func TestOpenAIRuntimeBlock_DoesNotApplyToOtherPlatforms(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
account := &Account{ID: 45, Platform: PlatformGemini, Type: AccountTypeOAuth}
|
||||
|
||||
svc.BlockAccountScheduling(account, time.Time{}, "custom_error_code")
|
||||
|
||||
require.False(t, svc.isOpenAIAccountRuntimeBlocked(account))
|
||||
}
|
||||
|
||||
func TestOpenAIRuntimeBlocker_IgnoresNonOpenAIFromRateLimitService(t *testing.T) {
|
||||
gateway := &OpenAIGatewayService{}
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
rateLimitService := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
rateLimitService.SetAccountRuntimeBlocker(gateway)
|
||||
account := &Account{ID: 45, Platform: PlatformGemini, Type: AccountTypeOAuth}
|
||||
|
||||
shouldDisable := rateLimitService.HandleUpstreamError(context.Background(), account, http.StatusForbidden, http.Header{}, []byte("forbidden"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.False(t, gateway.isOpenAIAccountRuntimeBlocked(account))
|
||||
}
|
||||
|
||||
func TestOpenAIRuntimeBlock_DoesNotShortenExistingBlock(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
account := &Account{ID: 46, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||
longUntil := time.Now().Add(10 * time.Minute)
|
||||
|
||||
svc.BlockAccountScheduling(account, longUntil, "oauth_401")
|
||||
svc.BlockAccountScheduling(account, time.Time{}, "upstream_disable")
|
||||
|
||||
value, ok := svc.openaiAccountRuntimeBlockUntil.Load(account.ID)
|
||||
require.True(t, ok)
|
||||
actualUntil, ok := value.(time.Time)
|
||||
require.True(t, ok)
|
||||
require.WithinDuration(t, longUntil, actualUntil, time.Second)
|
||||
}
|
||||
|
||||
func TestOpenAIRuntimeBlock_ClearAccountSchedulingBlock(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
account := &Account{ID: 47, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||
|
||||
svc.BlockAccountScheduling(account, time.Now().Add(time.Minute), "429")
|
||||
require.True(t, svc.isOpenAIAccountRuntimeBlocked(account))
|
||||
|
||||
svc.ClearAccountSchedulingBlock(account.ID)
|
||||
require.False(t, svc.isOpenAIAccountRuntimeBlocked(account))
|
||||
}
|
||||
|
||||
func TestShouldStopOpenAIOAuth429Failover_OnlyDuringStorm(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
account := &Account{ID: 42, Platform: PlatformOpenAI, Type: AccountTypeOAuth}
|
||||
apiKeyAccount := &Account{ID: 43, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
|
||||
|
||||
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 1))
|
||||
|
||||
for i := 0; i < openAIOAuth429StormThreshold; i++ {
|
||||
svc.recordOpenAIOAuth429()
|
||||
}
|
||||
|
||||
require.True(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 1))
|
||||
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(apiKeyAccount, http.StatusTooManyRequests, 1))
|
||||
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusInternalServerError, 1))
|
||||
require.False(t, svc.ShouldStopOpenAIOAuth429Failover(account, http.StatusTooManyRequests, 0))
|
||||
}
|
||||
@ -92,6 +92,16 @@ type openAIAccountSchedulerMetrics struct {
|
||||
loadSkewMilliTotal atomic.Int64
|
||||
}
|
||||
|
||||
type openAIAccountLoadPlan struct {
|
||||
allCandidates []openAIAccountCandidateScore
|
||||
candidates []openAIAccountCandidateScore
|
||||
staleSnapshotCompactRetry []openAIAccountCandidateScore
|
||||
selectionOrder []openAIAccountCandidateScore
|
||||
candidateCount int
|
||||
topK int
|
||||
loadSkew float64
|
||||
}
|
||||
|
||||
func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) {
|
||||
if m == nil {
|
||||
return
|
||||
@ -360,7 +370,7 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
|
||||
}
|
||||
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if acquireErr == nil && result.Acquired {
|
||||
if acquireErr == nil && result != nil && result.Acquired {
|
||||
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
@ -586,6 +596,234 @@ func buildOpenAIWeightedSelectionOrder(
|
||||
return order
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) buildOpenAIAccountLoadPlan(
|
||||
req OpenAIAccountScheduleRequest,
|
||||
filtered []*Account,
|
||||
loadMap map[int64]*AccountLoadInfo,
|
||||
) openAIAccountLoadPlan {
|
||||
allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered))
|
||||
for _, account := range filtered {
|
||||
loadInfo := loadMap[account.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: account.ID}
|
||||
}
|
||||
errorRate, ttft, hasTTFT := 0.0, 0.0, false
|
||||
if s.stats != nil {
|
||||
errorRate, ttft, hasTTFT = s.stats.snapshot(account.ID)
|
||||
}
|
||||
allCandidates = append(allCandidates, openAIAccountCandidateScore{
|
||||
account: account,
|
||||
loadInfo: loadInfo,
|
||||
errorRate: errorRate,
|
||||
ttft: ttft,
|
||||
hasTTFT: hasTTFT,
|
||||
})
|
||||
}
|
||||
|
||||
candidates := allCandidates
|
||||
staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
||||
if req.RequireCompact {
|
||||
candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
||||
for _, candidate := range allCandidates {
|
||||
if openAICompactSupportTier(candidate.account) == 0 {
|
||||
staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate)
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, candidate)
|
||||
}
|
||||
}
|
||||
|
||||
plan := openAIAccountLoadPlan{
|
||||
allCandidates: allCandidates,
|
||||
candidates: candidates,
|
||||
staleSnapshotCompactRetry: staleSnapshotCompactRetry,
|
||||
candidateCount: len(candidates),
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
plan.selectionOrder = s.buildOpenAISelectionOrder(req, plan)
|
||||
return plan
|
||||
}
|
||||
|
||||
minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority
|
||||
maxWaiting := 1
|
||||
loadRateSum := 0.0
|
||||
loadRateSumSquares := 0.0
|
||||
minTTFT, maxTTFT := 0.0, 0.0
|
||||
hasTTFTSample := false
|
||||
for _, candidate := range candidates {
|
||||
if candidate.account.Priority < minPriority {
|
||||
minPriority = candidate.account.Priority
|
||||
}
|
||||
if candidate.account.Priority > maxPriority {
|
||||
maxPriority = candidate.account.Priority
|
||||
}
|
||||
if candidate.loadInfo.WaitingCount > maxWaiting {
|
||||
maxWaiting = candidate.loadInfo.WaitingCount
|
||||
}
|
||||
if candidate.hasTTFT && candidate.ttft > 0 {
|
||||
if !hasTTFTSample {
|
||||
minTTFT, maxTTFT = candidate.ttft, candidate.ttft
|
||||
hasTTFTSample = true
|
||||
} else {
|
||||
if candidate.ttft < minTTFT {
|
||||
minTTFT = candidate.ttft
|
||||
}
|
||||
if candidate.ttft > maxTTFT {
|
||||
maxTTFT = candidate.ttft
|
||||
}
|
||||
}
|
||||
}
|
||||
loadRate := float64(candidate.loadInfo.LoadRate)
|
||||
loadRateSum += loadRate
|
||||
loadRateSumSquares += loadRate * loadRate
|
||||
}
|
||||
plan.loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
|
||||
|
||||
weights := s.service.openAIWSSchedulerWeights()
|
||||
for i := range candidates {
|
||||
item := &candidates[i]
|
||||
priorityFactor := 1.0
|
||||
if maxPriority > minPriority {
|
||||
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
|
||||
}
|
||||
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
|
||||
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
|
||||
errorFactor := 1 - clamp01(item.errorRate)
|
||||
ttftFactor := 0.5
|
||||
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
|
||||
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
|
||||
}
|
||||
|
||||
quotaFactor := item.account.GetQuotaRemainingFraction()
|
||||
|
||||
item.score = weights.Priority*priorityFactor +
|
||||
weights.Load*loadFactor +
|
||||
weights.Queue*queueFactor +
|
||||
weights.ErrorRate*errorFactor +
|
||||
weights.TTFT*ttftFactor +
|
||||
weights.Quota*quotaFactor
|
||||
}
|
||||
plan.candidates = candidates
|
||||
|
||||
plan.topK = s.service.openAIWSLBTopK()
|
||||
if plan.topK > len(candidates) {
|
||||
plan.topK = len(candidates)
|
||||
}
|
||||
if plan.topK <= 0 {
|
||||
plan.topK = 1
|
||||
}
|
||||
|
||||
plan.selectionOrder = s.buildOpenAISelectionOrder(req, plan)
|
||||
return plan
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) buildOpenAISelectionOrder(
|
||||
req OpenAIAccountScheduleRequest,
|
||||
plan openAIAccountLoadPlan,
|
||||
) []openAIAccountCandidateScore {
|
||||
buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
|
||||
if len(pool) == 0 || plan.topK <= 0 {
|
||||
return nil
|
||||
}
|
||||
groupTopK := plan.topK
|
||||
if groupTopK > len(pool) {
|
||||
groupTopK = len(pool)
|
||||
}
|
||||
ranked := selectTopKOpenAICandidates(pool, groupTopK)
|
||||
return buildOpenAIWeightedSelectionOrder(ranked, req)
|
||||
}
|
||||
|
||||
if req.RequireCompact {
|
||||
supported := make([]openAIAccountCandidateScore, 0, len(plan.candidates))
|
||||
unknown := make([]openAIAccountCandidateScore, 0, len(plan.candidates))
|
||||
for _, candidate := range plan.candidates {
|
||||
switch openAICompactSupportTier(candidate.account) {
|
||||
case 2:
|
||||
supported = append(supported, candidate)
|
||||
case 1:
|
||||
unknown = append(unknown, candidate)
|
||||
}
|
||||
}
|
||||
selectionOrder := make([]openAIAccountCandidateScore, 0, len(plan.allCandidates))
|
||||
selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...)
|
||||
selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...)
|
||||
if len(plan.staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil {
|
||||
selectionOrder = append(selectionOrder, sortOpenAICompactRetryCandidates(plan.staleSnapshotCompactRetry)...)
|
||||
}
|
||||
return selectionOrder
|
||||
}
|
||||
|
||||
return buildSelectionOrder(plan.candidates)
|
||||
}
|
||||
|
||||
func sortOpenAICompactRetryCandidates(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
|
||||
if len(pool) == 0 {
|
||||
return nil
|
||||
}
|
||||
ordered := append([]openAIAccountCandidateScore(nil), pool...)
|
||||
sort.SliceStable(ordered, func(i, j int) bool {
|
||||
a, b := ordered[i], ordered[j]
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return a.account.Priority < b.account.Priority
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||
}
|
||||
if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount {
|
||||
return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount
|
||||
}
|
||||
switch {
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||
return true
|
||||
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
default:
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
return ordered
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) tryAcquireOpenAISelectionOrder(
|
||||
ctx context.Context,
|
||||
req OpenAIAccountScheduleRequest,
|
||||
selectionOrder []openAIAccountCandidateScore,
|
||||
) (*AccountSelectionResult, bool, error) {
|
||||
compactBlocked := false
|
||||
for i := 0; i < len(selectionOrder); i++ {
|
||||
candidate := selectionOrder[i]
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||
continue
|
||||
}
|
||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||
continue
|
||||
}
|
||||
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
|
||||
compactBlocked = true
|
||||
continue
|
||||
}
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if acquireErr != nil {
|
||||
return nil, compactBlocked, acquireErr
|
||||
}
|
||||
if result != nil && result.Acquired {
|
||||
if req.SessionHash != "" {
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: fresh,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, compactBlocked, nil
|
||||
}
|
||||
}
|
||||
return nil, compactBlocked, nil
|
||||
}
|
||||
|
||||
func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
ctx context.Context,
|
||||
req OpenAIAccountScheduleRequest,
|
||||
@ -616,8 +854,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
if !account.IsSchedulable() || !account.IsOpenAI() {
|
||||
continue
|
||||
}
|
||||
if s.service.isOpenAIAccountRuntimeBlocked(account) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !account.IsPrivacySet() {
|
||||
s.service.BlockAccountScheduling(account, time.Time{}, "privacy_not_set")
|
||||
_ = s.service.accountRepo.SetError(ctx, account.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
@ -645,214 +887,47 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
|
||||
}
|
||||
}
|
||||
|
||||
allCandidates := make([]openAIAccountCandidateScore, 0, len(filtered))
|
||||
for _, account := range filtered {
|
||||
loadInfo := loadMap[account.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: account.ID}
|
||||
}
|
||||
errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
|
||||
allCandidates = append(allCandidates, openAIAccountCandidateScore{
|
||||
account: account,
|
||||
loadInfo: loadInfo,
|
||||
errorRate: errorRate,
|
||||
ttft: ttft,
|
||||
hasTTFT: hasTTFT,
|
||||
})
|
||||
plan := s.buildOpenAIAccountLoadPlan(req, filtered, loadMap)
|
||||
candidateCount := plan.candidateCount
|
||||
topK := plan.topK
|
||||
loadSkew := plan.loadSkew
|
||||
selectionOrder := plan.selectionOrder
|
||||
if req.RequireCompact && len(plan.candidates) == 0 && len(plan.staleSnapshotCompactRetry) == 0 {
|
||||
return nil, 0, 0, 0, ErrNoAvailableCompactAccounts
|
||||
}
|
||||
|
||||
// Compact 模式下把明确不支持 compact 的账号拆出,仅在 schedulerSnapshot 启用
|
||||
// 时作为最后兜底(snapshot 可能已陈旧)。
|
||||
candidates := allCandidates
|
||||
staleSnapshotCompactRetry := make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
||||
if req.RequireCompact {
|
||||
candidates = make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
||||
for _, candidate := range allCandidates {
|
||||
if openAICompactSupportTier(candidate.account) == 0 {
|
||||
staleSnapshotCompactRetry = append(staleSnapshotCompactRetry, candidate)
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, candidate)
|
||||
}
|
||||
if len(candidates) == 0 && len(staleSnapshotCompactRetry) == 0 {
|
||||
return nil, 0, 0, 0, ErrNoAvailableCompactAccounts
|
||||
}
|
||||
}
|
||||
|
||||
candidateCount := len(candidates)
|
||||
loadSkew := 0.0
|
||||
if len(candidates) > 0 {
|
||||
minPriority, maxPriority := candidates[0].account.Priority, candidates[0].account.Priority
|
||||
maxWaiting := 1
|
||||
loadRateSum := 0.0
|
||||
loadRateSumSquares := 0.0
|
||||
minTTFT, maxTTFT := 0.0, 0.0
|
||||
hasTTFTSample := false
|
||||
for _, candidate := range candidates {
|
||||
if candidate.account.Priority < minPriority {
|
||||
minPriority = candidate.account.Priority
|
||||
}
|
||||
if candidate.account.Priority > maxPriority {
|
||||
maxPriority = candidate.account.Priority
|
||||
}
|
||||
if candidate.loadInfo.WaitingCount > maxWaiting {
|
||||
maxWaiting = candidate.loadInfo.WaitingCount
|
||||
}
|
||||
if candidate.hasTTFT && candidate.ttft > 0 {
|
||||
if !hasTTFTSample {
|
||||
minTTFT, maxTTFT = candidate.ttft, candidate.ttft
|
||||
hasTTFTSample = true
|
||||
} else {
|
||||
if candidate.ttft < minTTFT {
|
||||
minTTFT = candidate.ttft
|
||||
}
|
||||
if candidate.ttft > maxTTFT {
|
||||
maxTTFT = candidate.ttft
|
||||
}
|
||||
}
|
||||
}
|
||||
loadRate := float64(candidate.loadInfo.LoadRate)
|
||||
loadRateSum += loadRate
|
||||
loadRateSumSquares += loadRate * loadRate
|
||||
}
|
||||
loadSkew = calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates))
|
||||
|
||||
weights := s.service.openAIWSSchedulerWeights()
|
||||
for i := range candidates {
|
||||
item := &candidates[i]
|
||||
priorityFactor := 1.0
|
||||
if maxPriority > minPriority {
|
||||
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
|
||||
}
|
||||
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
|
||||
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
|
||||
errorFactor := 1 - clamp01(item.errorRate)
|
||||
ttftFactor := 0.5
|
||||
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
|
||||
ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT))
|
||||
}
|
||||
quotaFactor := item.account.GetQuotaRemainingFraction()
|
||||
|
||||
item.score = weights.Priority*priorityFactor +
|
||||
weights.Load*loadFactor +
|
||||
weights.Queue*queueFactor +
|
||||
weights.ErrorRate*errorFactor +
|
||||
weights.TTFT*ttftFactor +
|
||||
weights.Quota*quotaFactor
|
||||
}
|
||||
|
||||
if s.service.openAIWSP2CEnabled() {
|
||||
return s.selectByPowerOfTwo(ctx, req, candidates, loadSkew)
|
||||
}
|
||||
}
|
||||
|
||||
topK := 0
|
||||
if len(candidates) > 0 {
|
||||
topK = s.service.openAIWSLBTopK()
|
||||
if topK > len(candidates) {
|
||||
topK = len(candidates)
|
||||
}
|
||||
if topK <= 0 {
|
||||
topK = 1
|
||||
}
|
||||
}
|
||||
|
||||
buildSelectionOrder := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
|
||||
if len(pool) == 0 || topK <= 0 {
|
||||
return nil
|
||||
}
|
||||
groupTopK := topK
|
||||
if groupTopK > len(pool) {
|
||||
groupTopK = len(pool)
|
||||
}
|
||||
ranked := selectTopKOpenAICandidates(pool, groupTopK)
|
||||
return buildOpenAIWeightedSelectionOrder(ranked, req)
|
||||
}
|
||||
sortCompactRetryCandidates := func(pool []openAIAccountCandidateScore) []openAIAccountCandidateScore {
|
||||
if len(pool) == 0 {
|
||||
return nil
|
||||
}
|
||||
ordered := append([]openAIAccountCandidateScore(nil), pool...)
|
||||
sort.SliceStable(ordered, func(i, j int) bool {
|
||||
a, b := ordered[i], ordered[j]
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return a.account.Priority < b.account.Priority
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||
}
|
||||
if a.loadInfo.WaitingCount != b.loadInfo.WaitingCount {
|
||||
return a.loadInfo.WaitingCount < b.loadInfo.WaitingCount
|
||||
}
|
||||
switch {
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||
return true
|
||||
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
default:
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
return ordered
|
||||
}
|
||||
|
||||
selectionOrder := make([]openAIAccountCandidateScore, 0, len(allCandidates))
|
||||
if req.RequireCompact {
|
||||
supported := make([]openAIAccountCandidateScore, 0, len(candidates))
|
||||
unknown := make([]openAIAccountCandidateScore, 0, len(candidates))
|
||||
for _, candidate := range candidates {
|
||||
switch openAICompactSupportTier(candidate.account) {
|
||||
case 2:
|
||||
supported = append(supported, candidate)
|
||||
case 1:
|
||||
unknown = append(unknown, candidate)
|
||||
}
|
||||
}
|
||||
if len(supported) == 0 && len(unknown) == 0 && s.service.schedulerSnapshot == nil {
|
||||
return nil, candidateCount, topK, loadSkew, ErrNoAvailableCompactAccounts
|
||||
}
|
||||
selectionOrder = append(selectionOrder, buildSelectionOrder(supported)...)
|
||||
selectionOrder = append(selectionOrder, buildSelectionOrder(unknown)...)
|
||||
if len(staleSnapshotCompactRetry) > 0 && s.service.schedulerSnapshot != nil {
|
||||
selectionOrder = append(selectionOrder, sortCompactRetryCandidates(staleSnapshotCompactRetry)...)
|
||||
}
|
||||
} else {
|
||||
selectionOrder = buildSelectionOrder(candidates)
|
||||
if s.service.openAIWSP2CEnabled() && len(plan.candidates) > 0 {
|
||||
return s.selectByPowerOfTwo(ctx, req, plan.candidates, loadSkew)
|
||||
}
|
||||
if len(selectionOrder) == 0 {
|
||||
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(allCandidates) > 0)
|
||||
return nil, candidateCount, topK, loadSkew, noAvailableOpenAISelectionError(req.RequestedModel, req.RequireCompact && len(plan.allCandidates) > 0)
|
||||
}
|
||||
|
||||
compactBlocked := false
|
||||
for i := 0; i < len(selectionOrder); i++ {
|
||||
candidate := selectionOrder[i]
|
||||
fresh := s.service.resolveFreshSchedulableOpenAIAccount(ctx, candidate.account, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||
continue
|
||||
}
|
||||
fresh = s.service.recheckSelectedOpenAIAccountFromDB(ctx, fresh, req.RequestedModel, false)
|
||||
if fresh == nil || !s.isAccountTransportCompatible(fresh, req.RequiredTransport) || !s.isAccountRequestCompatible(ctx, fresh, req) {
|
||||
continue
|
||||
}
|
||||
if req.RequireCompact && openAICompactSupportTier(fresh) == 0 {
|
||||
compactBlocked = true
|
||||
continue
|
||||
}
|
||||
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if acquireErr != nil {
|
||||
return nil, candidateCount, topK, loadSkew, acquireErr
|
||||
}
|
||||
if result != nil && result.Acquired {
|
||||
if req.SessionHash != "" {
|
||||
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, fresh.ID)
|
||||
result, compactBlocked, acquireErr := s.tryAcquireOpenAISelectionOrder(ctx, req, selectionOrder)
|
||||
if acquireErr != nil {
|
||||
return nil, candidateCount, topK, loadSkew, acquireErr
|
||||
}
|
||||
if result != nil {
|
||||
return result, candidateCount, topK, loadSkew, nil
|
||||
}
|
||||
|
||||
if s.service.concurrencyService != nil {
|
||||
if freshLoadMap, loadErr := s.service.concurrencyService.GetAccountsLoadBatchFresh(ctx, loadReq); loadErr == nil {
|
||||
freshPlan := s.buildOpenAIAccountLoadPlan(req, filtered, freshLoadMap)
|
||||
if len(freshPlan.selectionOrder) > 0 {
|
||||
freshResult, freshCompactBlocked, freshAcquireErr := s.tryAcquireOpenAISelectionOrder(ctx, req, freshPlan.selectionOrder)
|
||||
if freshAcquireErr != nil {
|
||||
return nil, candidateCount, topK, loadSkew, freshAcquireErr
|
||||
}
|
||||
if freshResult != nil {
|
||||
return freshResult, freshPlan.candidateCount, freshPlan.topK, freshPlan.loadSkew, nil
|
||||
}
|
||||
compactBlocked = compactBlocked || freshCompactBlocked
|
||||
selectionOrder = freshPlan.selectionOrder
|
||||
candidateCount = freshPlan.candidateCount
|
||||
topK = freshPlan.topK
|
||||
loadSkew = freshPlan.loadSkew
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: fresh,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, candidateCount, topK, loadSkew, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -899,6 +974,9 @@ func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(ctx context.C
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if s != nil && s.service != nil && s.service.isOpenAIAccountRuntimeBlocked(account) {
|
||||
return false
|
||||
}
|
||||
if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) {
|
||||
return false
|
||||
}
|
||||
|
||||
@ -276,9 +276,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
|
||||
@ -206,9 +206,7 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
|
||||
@ -337,9 +337,7 @@ func (s *OpenAIGatewayService) ForwardAsAnthropic(
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
|
||||
@ -187,9 +187,7 @@ func (s *OpenAIGatewayService) forwardResponsesViaRawChatCompletions(
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
}
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
return nil, &UpstreamFailoverError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ResponseBody: respBody,
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai_compat"
|
||||
@ -354,6 +355,9 @@ type OpenAIGatewayService struct {
|
||||
openaiAccountStats *openAIAccountRuntimeStats
|
||||
|
||||
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
|
||||
openaiAccountRuntimeBlockUntil sync.Map // key: int64(accountID), value: time.Time
|
||||
openaiOAuth429WindowStartUnixNano atomic.Int64
|
||||
openaiOAuth429WindowCount atomic.Int64
|
||||
openaiWSRetryMetrics openAIWSRetryMetrics
|
||||
responseHeaderFilter *responseheaders.CompiledHeaderFilter
|
||||
codexSnapshotThrottle *accountWriteThrottle
|
||||
@ -417,6 +421,12 @@ func NewOpenAIGatewayService(
|
||||
responseHeaderFilter: compileResponseHeaderFilter(cfg),
|
||||
codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval),
|
||||
}
|
||||
if rateLimitService != nil {
|
||||
rateLimitService.SetAccountRuntimeBlocker(svc)
|
||||
}
|
||||
if openAITokenProvider != nil {
|
||||
openAITokenProvider.SetAccountRuntimeBlocker(svc)
|
||||
}
|
||||
svc.logOpenAIWSModeBootstrap()
|
||||
return svc
|
||||
}
|
||||
@ -962,7 +972,7 @@ func appendCodexCLIOnlyRejectedRequestFields(fields []zap.Field, c *gin.Context,
|
||||
zap.String("request_path", strings.TrimSpace(req.URL.Path)),
|
||||
zap.String("request_query", strings.TrimSpace(req.URL.RawQuery)),
|
||||
zap.String("request_host", strings.TrimSpace(req.Host)),
|
||||
zap.String("request_client_ip", strings.TrimSpace(c.ClientIP())),
|
||||
zap.String("request_client_ip", strings.TrimSpace(ip.GetClientIP(c))),
|
||||
zap.String("request_remote_addr", strings.TrimSpace(req.RemoteAddr)),
|
||||
zap.String("request_user_agent", strings.TrimSpace(req.Header.Get("User-Agent"))),
|
||||
zap.String("request_content_type", strings.TrimSpace(req.Header.Get("Content-Type"))),
|
||||
@ -1381,13 +1391,18 @@ func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.C
|
||||
return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked)
|
||||
}
|
||||
|
||||
hydrated, err := s.hydrateSelectedAccount(ctx, selected)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. 设置粘性会话绑定
|
||||
// Set sticky session binding
|
||||
if sessionHash != "" {
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL)
|
||||
}
|
||||
|
||||
return s.hydrateSelectedAccount(ctx, selected)
|
||||
return hydrated, nil
|
||||
}
|
||||
|
||||
// tryStickySessionHit 尝试从粘性会话获取账号。
|
||||
@ -1430,6 +1445,10 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
|
||||
if !isOpenAIAccountEligibleForRequest(account, requestedModel, false) {
|
||||
return nil
|
||||
}
|
||||
if s.isOpenAIAccountRuntimeBlocked(account) {
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
return nil
|
||||
}
|
||||
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
|
||||
if account == nil {
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
@ -1575,8 +1594,8 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
||||
return nil, err
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
|
||||
if err == nil && result != nil && result.Acquired {
|
||||
return s.newAcquiredSelectionResult(ctx, account, result.ReleaseFunc)
|
||||
}
|
||||
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
||||
@ -1627,13 +1646,19 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
||||
account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact)
|
||||
if account == nil {
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
} else if s.isOpenAIAccountRuntimeBlocked(account) {
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
} else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) {
|
||||
_ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash)
|
||||
} else {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if err == nil && result != nil && result.Acquired {
|
||||
selection, selectErr := s.newAcquiredSelectionResult(ctx, account, result.ReleaseFunc)
|
||||
if selectErr != nil {
|
||||
return nil, selectErr
|
||||
}
|
||||
_ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL)
|
||||
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
|
||||
return selection, nil
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
@ -1665,6 +1690,9 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
||||
if !acc.IsSchedulable() {
|
||||
continue
|
||||
}
|
||||
if s.isOpenAIAccountRuntimeBlocked(acc) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
|
||||
continue
|
||||
}
|
||||
@ -1687,6 +1715,92 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
||||
})
|
||||
}
|
||||
|
||||
tryAcquireFromLoadMap := func(loadMap map[int64]*AccountLoadInfo) (*AccountSelectionResult, bool, error) {
|
||||
var available []accountWithLoad
|
||||
for _, acc := range candidates {
|
||||
loadInfo := loadMap[acc.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
|
||||
}
|
||||
if loadInfo.LoadRate < 100 {
|
||||
available = append(available, accountWithLoad{
|
||||
account: acc,
|
||||
loadInfo: loadInfo,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(available) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
sort.SliceStable(available, func(i, j int) bool {
|
||||
a, b := available[i], available[j]
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return a.account.Priority < b.account.Priority
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||
}
|
||||
switch {
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||
return true
|
||||
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
default:
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
shuffleWithinSortGroups(available)
|
||||
|
||||
selectionOrder := make([]accountWithLoad, 0, len(available))
|
||||
if requireCompact {
|
||||
appendTier := func(out []accountWithLoad, tier int) []accountWithLoad {
|
||||
for _, item := range available {
|
||||
if openAICompactSupportTier(item.account) == tier {
|
||||
out = append(out, item)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
selectionOrder = appendTier(selectionOrder, 2)
|
||||
selectionOrder = appendTier(selectionOrder, 1)
|
||||
// tier 0 候选作为兜底追加:DB recheck 时若发现 cache tier 0 实际
|
||||
// 已升级为 1/2(探测刚跑完,cache 尚未刷新),仍可正常命中。
|
||||
selectionOrder = appendTier(selectionOrder, 0)
|
||||
} else {
|
||||
selectionOrder = append(selectionOrder, available...)
|
||||
}
|
||||
|
||||
for _, item := range selectionOrder {
|
||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
|
||||
continue
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if err == nil && result != nil && result.Acquired {
|
||||
selection, selectErr := s.newAcquiredSelectionResult(ctx, fresh, result.ReleaseFunc)
|
||||
if selectErr != nil {
|
||||
return nil, true, selectErr
|
||||
}
|
||||
if sessionHash != "" {
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return selection, true, nil
|
||||
}
|
||||
}
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||
if err != nil {
|
||||
ordered := append([]*Account(nil), candidates...)
|
||||
@ -1707,87 +1821,28 @@ func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Contex
|
||||
continue
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if err == nil && result != nil && result.Acquired {
|
||||
selection, selectErr := s.newAcquiredSelectionResult(ctx, fresh, result.ReleaseFunc)
|
||||
if selectErr != nil {
|
||||
return nil, selectErr
|
||||
}
|
||||
if sessionHash != "" {
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil)
|
||||
return selection, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var available []accountWithLoad
|
||||
for _, acc := range candidates {
|
||||
loadInfo := loadMap[acc.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
|
||||
}
|
||||
if loadInfo.LoadRate < 100 {
|
||||
available = append(available, accountWithLoad{
|
||||
account: acc,
|
||||
loadInfo: loadInfo,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(available) > 0 {
|
||||
sort.SliceStable(available, func(i, j int) bool {
|
||||
a, b := available[i], available[j]
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return a.account.Priority < b.account.Priority
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||
}
|
||||
switch {
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||
return true
|
||||
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
default:
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
shuffleWithinSortGroups(available)
|
||||
|
||||
selectionOrder := make([]accountWithLoad, 0, len(available))
|
||||
if requireCompact {
|
||||
appendTier := func(out []accountWithLoad, tier int) []accountWithLoad {
|
||||
for _, item := range available {
|
||||
if openAICompactSupportTier(item.account) == tier {
|
||||
out = append(out, item)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
selectionOrder = appendTier(selectionOrder, 2)
|
||||
selectionOrder = appendTier(selectionOrder, 1)
|
||||
// tier 0 候选作为兜底追加:DB recheck 时若发现 cache tier 0 实际
|
||||
// 已升级为 1/2(探测刚跑完,cache 尚未刷新),仍可正常命中。
|
||||
selectionOrder = appendTier(selectionOrder, 0)
|
||||
} else {
|
||||
selectionOrder = append(selectionOrder, available...)
|
||||
}
|
||||
|
||||
for _, item := range selectionOrder {
|
||||
fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact)
|
||||
if fresh == nil {
|
||||
continue
|
||||
}
|
||||
if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) {
|
||||
continue
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
_ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL)
|
||||
}
|
||||
return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil)
|
||||
if selection, attempted, selectErr := tryAcquireFromLoadMap(loadMap); selectErr != nil {
|
||||
return nil, selectErr
|
||||
} else if selection != nil {
|
||||
return selection, nil
|
||||
} else if attempted {
|
||||
if freshLoadMap, loadErr := s.concurrencyService.GetAccountsLoadBatchFresh(ctx, accountLoads); loadErr == nil {
|
||||
if selection, _, selectErr := tryAcquireFromLoadMap(freshLoadMap); selectErr != nil {
|
||||
return nil, selectErr
|
||||
} else if selection != nil {
|
||||
return selection, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1868,6 +1923,9 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.
|
||||
if !isOpenAIAccountEligibleForRequest(fresh, requestedModel, requireCompact) {
|
||||
return nil
|
||||
}
|
||||
if s.isOpenAIAccountRuntimeBlocked(fresh) {
|
||||
return nil
|
||||
}
|
||||
return fresh
|
||||
}
|
||||
|
||||
@ -1889,6 +1947,9 @@ func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Co
|
||||
if !isOpenAIAccountEligibleForRequest(latest, requestedModel, requireCompact) {
|
||||
return nil
|
||||
}
|
||||
if s.isOpenAIAccountRuntimeBlocked(latest) {
|
||||
return nil
|
||||
}
|
||||
return latest
|
||||
}
|
||||
|
||||
@ -1935,6 +1996,14 @@ func (s *OpenAIGatewayService) newSelectionResult(ctx context.Context, account *
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) newAcquiredSelectionResult(ctx context.Context, account *Account, release func()) (*AccountSelectionResult, error) {
|
||||
selection, err := s.newSelectionResult(ctx, account, true, release, nil)
|
||||
if err != nil && release != nil {
|
||||
release()
|
||||
}
|
||||
return selection, err
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
if s.cfg != nil {
|
||||
return s.cfg.Gateway.Scheduling
|
||||
@ -1996,7 +2065,7 @@ func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode i
|
||||
|
||||
func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
|
||||
// Forward forwards request to OpenAI API
|
||||
@ -3278,9 +3347,7 @@ func (s *OpenAIGatewayService) handleFailoverErrorResponsePassthrough(
|
||||
}
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
|
||||
if s.rateLimitService != nil {
|
||||
_ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
@ -3321,12 +3388,9 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
|
||||
}
|
||||
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
|
||||
if s.rateLimitService != nil {
|
||||
// Passthrough mode preserves the raw upstream error response, but runtime
|
||||
// account state still needs to be updated so sticky routing can stop
|
||||
// reusing a freshly rate-limited account.
|
||||
_ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
// 透传模式保留原始上游错误响应,但运行态账号状态仍需更新,
|
||||
// 避免粘性路由继续复用刚被限流的账号。
|
||||
_ = s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
AccountID: account.ID,
|
||||
@ -4075,10 +4139,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(
|
||||
}
|
||||
|
||||
// Handle upstream error (mark account status)
|
||||
shouldDisable := false
|
||||
if s.rateLimitService != nil {
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
shouldDisable := s.handleOpenAIAccountUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
kind := "http_error"
|
||||
if shouldDisable {
|
||||
kind = "failover"
|
||||
@ -4210,12 +4271,9 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse(
|
||||
}
|
||||
|
||||
// Track rate limits and decide whether to trigger secondary failover.
|
||||
shouldDisable := false
|
||||
if s.rateLimitService != nil {
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(
|
||||
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
|
||||
)
|
||||
}
|
||||
shouldDisable := s.handleOpenAIAccountUpstreamError(
|
||||
c.Request.Context(), account, resp.StatusCode, resp.Header, body,
|
||||
)
|
||||
kind := "http_error"
|
||||
if shouldDisable {
|
||||
kind = "failover"
|
||||
|
||||
@ -131,8 +131,10 @@ func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil))
|
||||
c.Request.RemoteAddr = "172.18.0.1:54321"
|
||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Request.Header.Set("X-Real-IP", "203.0.113.42")
|
||||
c.Request.Header.Set("OpenAI-Beta", "assistants=v2")
|
||||
|
||||
body := []byte(`{"model":"gpt-5.2","stream":false,"prompt_cache_key":"pc-123","access_token":"secret-token","input":[{"type":"text","text":"hello"}]}`)
|
||||
@ -146,6 +148,8 @@ func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) {
|
||||
require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.2"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_query", "trace=1"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_client_ip", "203.0.113.42"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_remote_addr", "172.18.0.1:54321"))
|
||||
require.True(t, logSink.ContainsFieldValue("request_prompt_cache_key_sha256", hashSensitiveValueForLog("pc-123")))
|
||||
require.True(t, logSink.ContainsFieldValue("request_headers", "openai-beta"))
|
||||
require.True(t, logSink.ContainsField("request_body_size"))
|
||||
|
||||
@ -29,6 +29,69 @@ type openAIResponsesImageResult struct {
|
||||
Model string
|
||||
}
|
||||
|
||||
type OpenAIImagesUpstreamError struct {
|
||||
StatusCode int
|
||||
ErrorType string
|
||||
Code string
|
||||
Message string
|
||||
Param string
|
||||
UpstreamRequestID string
|
||||
}
|
||||
|
||||
func (e *OpenAIImagesUpstreamError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
code := strings.TrimSpace(e.Code)
|
||||
if code == "" {
|
||||
code = strings.TrimSpace(e.ErrorType)
|
||||
}
|
||||
message := strings.TrimSpace(e.Message)
|
||||
if code != "" && message != "" {
|
||||
return fmt.Sprintf("openai images upstream error: %s: %s", code, message)
|
||||
}
|
||||
if message != "" {
|
||||
return "openai images upstream error: " + message
|
||||
}
|
||||
if code != "" {
|
||||
return "openai images upstream error: " + code
|
||||
}
|
||||
return "openai images upstream error"
|
||||
}
|
||||
|
||||
func (e *OpenAIImagesUpstreamError) clientStatusCode() int {
|
||||
if e == nil {
|
||||
return http.StatusBadGateway
|
||||
}
|
||||
if e.StatusCode > 0 {
|
||||
return e.StatusCode
|
||||
}
|
||||
return http.StatusBadGateway
|
||||
}
|
||||
|
||||
func (e *OpenAIImagesUpstreamError) clientErrorType() string {
|
||||
if e == nil {
|
||||
return "upstream_error"
|
||||
}
|
||||
if trimmed := strings.TrimSpace(e.ErrorType); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
return "upstream_error"
|
||||
}
|
||||
|
||||
func (e *OpenAIImagesUpstreamError) clientMessage() string {
|
||||
if e == nil {
|
||||
return "Upstream request failed"
|
||||
}
|
||||
if trimmed := strings.TrimSpace(e.Message); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
if trimmed := strings.TrimSpace(e.Code); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
return "Upstream request failed"
|
||||
}
|
||||
|
||||
func openAIResponsesImageResultKey(itemID string, result openAIResponsesImageResult) string {
|
||||
if strings.TrimSpace(result.Result) != "" {
|
||||
return strings.TrimSpace(result.OutputFormat) + "|" + strings.TrimSpace(result.Result)
|
||||
@ -465,6 +528,57 @@ func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageRe
|
||||
return nil, createdAt, usageRaw, openAIResponsesImageResult{}, foundFinal, nil
|
||||
}
|
||||
|
||||
func extractOpenAIImagesUpstreamError(body []byte) *OpenAIImagesUpstreamError {
|
||||
var upstreamErr *OpenAIImagesUpstreamError
|
||||
forEachOpenAISSEDataPayload(string(body), func(payload []byte) {
|
||||
if upstreamErr != nil || !gjson.ValidBytes(payload) {
|
||||
return
|
||||
}
|
||||
upstreamErr = openAIImagesUpstreamErrorFromSSEPayload(payload)
|
||||
})
|
||||
return upstreamErr
|
||||
}
|
||||
|
||||
func openAIImagesUpstreamErrorFromSSEPayload(payload []byte) *OpenAIImagesUpstreamError {
|
||||
if !gjson.ValidBytes(payload) {
|
||||
return nil
|
||||
}
|
||||
switch gjson.GetBytes(payload, "type").String() {
|
||||
case "error":
|
||||
return openAIImagesUpstreamErrorFromGJSON(gjson.GetBytes(payload, "error"), "")
|
||||
case "response.failed":
|
||||
response := gjson.GetBytes(payload, "response")
|
||||
return openAIImagesUpstreamErrorFromGJSON(response.Get("error"), response.Get("id").String())
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func openAIImagesUpstreamErrorFromGJSON(errorObj gjson.Result, upstreamRequestID string) *OpenAIImagesUpstreamError {
|
||||
if !errorObj.Exists() {
|
||||
return nil
|
||||
}
|
||||
code := strings.TrimSpace(errorObj.Get("code").String())
|
||||
errType := strings.TrimSpace(errorObj.Get("type").String())
|
||||
message := strings.TrimSpace(errorObj.Get("message").String())
|
||||
param := strings.TrimSpace(errorObj.Get("param").String())
|
||||
statusCode := http.StatusBadGateway
|
||||
if strings.EqualFold(code, "moderation_blocked") || strings.EqualFold(errType, "image_generation_user_error") {
|
||||
statusCode = http.StatusBadRequest
|
||||
}
|
||||
if message == "" {
|
||||
message = "Upstream request failed"
|
||||
}
|
||||
return &OpenAIImagesUpstreamError{
|
||||
StatusCode: statusCode,
|
||||
ErrorType: errType,
|
||||
Code: code,
|
||||
Message: sanitizeUpstreamErrorMessage(message),
|
||||
Param: param,
|
||||
UpstreamRequestID: strings.TrimSpace(upstreamRequestID),
|
||||
}
|
||||
}
|
||||
|
||||
func buildOpenAIImagesAPIResponse(
|
||||
results []openAIResponsesImageResult,
|
||||
createdAt int64,
|
||||
@ -531,6 +645,41 @@ func buildOpenAIImagesStreamErrorBody(message string) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
func buildOpenAIImagesStreamErrorBodyFromUpstream(err *OpenAIImagesUpstreamError) []byte {
|
||||
if err == nil {
|
||||
return buildOpenAIImagesStreamErrorBody("")
|
||||
}
|
||||
body := buildOpenAIImagesStreamErrorBody(err.clientMessage())
|
||||
body, _ = sjson.SetBytes(body, "error.type", err.clientErrorType())
|
||||
if code := strings.TrimSpace(err.Code); code != "" {
|
||||
body, _ = sjson.SetBytes(body, "error.code", code)
|
||||
}
|
||||
if param := strings.TrimSpace(err.Param); param != "" {
|
||||
body, _ = sjson.SetBytes(body, "error.param", param)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func writeOpenAIImagesUpstreamErrorResponse(c *gin.Context, err *OpenAIImagesUpstreamError) bool {
|
||||
if c == nil || c.Writer == nil || c.Writer.Written() || err == nil {
|
||||
return false
|
||||
}
|
||||
errorObj := gin.H{
|
||||
"type": err.clientErrorType(),
|
||||
"message": err.clientMessage(),
|
||||
}
|
||||
if code := strings.TrimSpace(err.Code); code != "" {
|
||||
errorObj["code"] = code
|
||||
}
|
||||
if param := strings.TrimSpace(err.Param); param != "" {
|
||||
errorObj["param"] = param
|
||||
}
|
||||
c.JSON(err.clientStatusCode(), gin.H{
|
||||
"error": errorObj,
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) writeOpenAIImagesStreamEvent(c *gin.Context, flusher http.Flusher, eventName string, payload []byte) error {
|
||||
if strings.TrimSpace(eventName) != "" {
|
||||
if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil {
|
||||
@ -588,6 +737,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
|
||||
return OpenAIUsage{}, 0, nil, err
|
||||
}
|
||||
if len(results) == 0 {
|
||||
if upstreamErr := extractOpenAIImagesUpstreamError(body); upstreamErr != nil {
|
||||
setOpsUpstreamError(c, upstreamErr.clientStatusCode(), upstreamErr.clientMessage(), "")
|
||||
writeOpenAIImagesUpstreamErrorResponse(c, upstreamErr)
|
||||
return OpenAIUsage{}, 0, nil, upstreamErr
|
||||
}
|
||||
return OpenAIUsage{}, 0, nil, fmt.Errorf("upstream did not return image output")
|
||||
}
|
||||
if strings.TrimSpace(firstMeta.Model) == "" {
|
||||
@ -742,6 +896,16 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
|
||||
imageCount = len(emitted)
|
||||
imageOutputSizes = openAIResponsesImageResultSizes(finalResults)
|
||||
processDataDone = true
|
||||
case "error", "response.failed":
|
||||
if upstreamErr := openAIImagesUpstreamErrorFromSSEPayload(dataBytes); upstreamErr != nil {
|
||||
if !clientDisconnected {
|
||||
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBodyFromUpstream(upstreamErr))
|
||||
}
|
||||
setOpsUpstreamError(c, upstreamErr.clientStatusCode(), upstreamErr.clientMessage(), "")
|
||||
processDataErr = upstreamErr
|
||||
processDataDone = true
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -553,6 +553,59 @@ func TestOpenAIGatewayServiceForwardImages_OAuthPassesNAndReturnsAllImages(t *te
|
||||
require.Equal(t, "draw a cat 3", gjson.Get(rec.Body.String(), "data.2.revised_prompt").String())
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_OAuthNonStreamModerationBlockedReturnsClientError(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw blocked image","response_format":"b64_json"}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = req
|
||||
c.Set("api_key", &APIKey{ID: 42})
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
parsed, err := svc.ParseOpenAIImagesRequest(c, body)
|
||||
require.NoError(t, err)
|
||||
|
||||
svc.httpUpstream = &httpUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"text/event-stream"},
|
||||
"X-Request-Id": []string{"req_img_blocked"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
"data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000020}}\n\n" +
|
||||
"data: {\"type\":\"error\",\"error\":{\"type\":\"image_generation_user_error\",\"code\":\"moderation_blocked\",\"message\":\"Your request was rejected by the safety system. safety_violations=[sexual].\"}}\n\n" +
|
||||
"data: {\"type\":\"response.failed\",\"response\":{\"id\":\"resp_blocked\",\"status\":\"failed\",\"error\":{\"type\":\"image_generation_user_error\",\"code\":\"moderation_blocked\",\"message\":\"Your request was rejected by the safety system. safety_violations=[sexual].\"}}}\n\n",
|
||||
)),
|
||||
},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "openai-oauth",
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "token-123",
|
||||
},
|
||||
}
|
||||
|
||||
result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "")
|
||||
require.Nil(t, result)
|
||||
var upstreamErr *OpenAIImagesUpstreamError
|
||||
require.ErrorAs(t, err, &upstreamErr)
|
||||
require.Equal(t, http.StatusBadRequest, upstreamErr.StatusCode)
|
||||
require.Equal(t, "moderation_blocked", upstreamErr.Code)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
require.Equal(t, "image_generation_user_error", gjson.Get(rec.Body.String(), "error.type").String())
|
||||
require.Equal(t, "moderation_blocked", gjson.Get(rec.Body.String(), "error.code").String())
|
||||
require.Contains(t, gjson.Get(rec.Body.String(), "error.message").String(), "safety system")
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayServiceForwardImages_APIKeyGenerationUsesConfiguredV1BaseURL(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","response_format":"b64_json"}`)
|
||||
|
||||
@ -80,6 +80,7 @@ type OpenAITokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache OpenAITokenCache
|
||||
openAIOAuthService *OpenAIOAuthService
|
||||
runtimeBlocker AccountRuntimeBlocker
|
||||
metrics *openAITokenRuntimeMetricsStore
|
||||
refreshAPI *OAuthRefreshAPI
|
||||
executor OAuthRefreshExecutor
|
||||
@ -111,6 +112,10 @@ func (p *OpenAITokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||
p.refreshPolicy = policy
|
||||
}
|
||||
|
||||
func (p *OpenAITokenProvider) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) {
|
||||
p.runtimeBlocker = blocker
|
||||
}
|
||||
|
||||
func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics {
|
||||
if p == nil {
|
||||
return OpenAITokenRuntimeMetrics{}
|
||||
@ -275,6 +280,9 @@ func (p *OpenAITokenProvider) disableAccountMissingRefreshToken(account *Account
|
||||
if p == nil || p.accountRepo == nil || account == nil {
|
||||
return
|
||||
}
|
||||
if p.runtimeBlocker != nil {
|
||||
p.runtimeBlocker.BlockAccountScheduling(account, time.Time{}, "missing_refresh_token")
|
||||
}
|
||||
bgCtx := context.Background()
|
||||
if err := p.accountRepo.SetError(bgCtx, account.ID, reason); err != nil {
|
||||
slog.Warn("openai_token_provider.set_error_failed",
|
||||
|
||||
@ -952,6 +952,8 @@ func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T)
|
||||
cache.getErr = errors.New("simulated cache miss")
|
||||
|
||||
provider := NewOpenAITokenProvider(repo, cache, nil)
|
||||
blocker := &runtimeBlockRecorder{}
|
||||
provider.SetAccountRuntimeBlocker(blocker)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
@ -960,4 +962,7 @@ func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T)
|
||||
|
||||
require.Equal(t, 1, repo.setErrorCalls, "account should be disabled via SetError exactly once")
|
||||
require.Contains(t, repo.lastErrorMsg, "refresh_token is missing")
|
||||
require.Len(t, blocker.accounts, 1)
|
||||
require.Equal(t, account.ID, blocker.accounts[0].ID)
|
||||
require.Equal(t, "missing_refresh_token", blocker.reasons[0])
|
||||
}
|
||||
|
||||
@ -4091,7 +4091,7 @@ func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Contex
|
||||
if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
|
||||
return
|
||||
}
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody)
|
||||
s.handleOpenAIAccountUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody)
|
||||
}
|
||||
|
||||
func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {
|
||||
|
||||
@ -115,6 +115,10 @@ func (r *paymentOrderLifecycleRedeemRepo) Update(context.Context, *RedeemCode) e
|
||||
panic("unexpected call")
|
||||
}
|
||||
|
||||
func (r *paymentOrderLifecycleRedeemRepo) BatchUpdate(context.Context, []int64, RedeemCodeBatchUpdateFields) (int64, error) {
|
||||
panic("unexpected call")
|
||||
}
|
||||
|
||||
func (r *paymentOrderLifecycleRedeemRepo) Delete(context.Context, int64) error {
|
||||
panic("unexpected call")
|
||||
}
|
||||
|
||||
@ -28,10 +28,16 @@ type RateLimitService struct {
|
||||
openAI403CounterCache OpenAI403CounterCache
|
||||
settingService *SettingService
|
||||
tokenCacheInvalidator TokenCacheInvalidator
|
||||
runtimeBlocker AccountRuntimeBlocker
|
||||
usageCacheMu sync.RWMutex
|
||||
usageCache map[int64]*geminiUsageCacheEntry
|
||||
}
|
||||
|
||||
type AccountRuntimeBlocker interface {
|
||||
BlockAccountScheduling(account *Account, until time.Time, reason string)
|
||||
ClearAccountSchedulingBlock(accountID int64)
|
||||
}
|
||||
|
||||
// SuccessfulTestRecoveryResult 表示测试成功后恢复了哪些运行时状态。
|
||||
type SuccessfulTestRecoveryResult struct {
|
||||
ClearedError bool
|
||||
@ -98,6 +104,24 @@ func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvali
|
||||
s.tokenCacheInvalidator = invalidator
|
||||
}
|
||||
|
||||
func (s *RateLimitService) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) {
|
||||
s.runtimeBlocker = blocker
|
||||
}
|
||||
|
||||
func (s *RateLimitService) notifyAccountSchedulingBlocked(account *Account, until time.Time, reason string) {
|
||||
if s == nil || s.runtimeBlocker == nil || account == nil {
|
||||
return
|
||||
}
|
||||
s.runtimeBlocker.BlockAccountScheduling(account, until, reason)
|
||||
}
|
||||
|
||||
func (s *RateLimitService) notifyAccountSchedulingBlockCleared(accountID int64) {
|
||||
if s == nil || s.runtimeBlocker == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
s.runtimeBlocker.ClearAccountSchedulingBlock(accountID)
|
||||
}
|
||||
|
||||
// ErrorPolicyResult 表示错误策略检查的结果
|
||||
type ErrorPolicyResult int
|
||||
|
||||
@ -240,6 +264,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
cooldownMinutes = 10
|
||||
}
|
||||
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
|
||||
s.notifyAccountSchedulingBlocked(account, until, "oauth_401")
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, msg); err != nil {
|
||||
slog.Warn("oauth_401_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
@ -678,6 +703,7 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
|
||||
|
||||
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
||||
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
|
||||
s.notifyAccountSchedulingBlocked(account, time.Time{}, "auth_error")
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
@ -758,6 +784,7 @@ func (s *RateLimitService) handleOpenAI403(ctx context.Context, account *Account
|
||||
|
||||
until := time.Now().Add(time.Duration(openAI403CooldownMinutesDefault) * time.Minute)
|
||||
reason := fmt.Sprintf("OpenAI 403 temporary cooldown (%d/%d): %s", count, openAI403DisableThreshold, msg)
|
||||
s.notifyAccountSchedulingBlocked(account, until, "openai_403_temp")
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
slog.Warn("openai_403_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
@ -823,6 +850,7 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac
|
||||
// handleCustomErrorCode 处理自定义错误码,停止账号调度
|
||||
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
|
||||
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
|
||||
s.notifyAccountSchedulingBlocked(account, time.Time{}, "custom_error_code")
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil {
|
||||
slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err)
|
||||
return
|
||||
@ -838,6 +866,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
persistOpenAI429PlanType(ctx, s.accountRepo, account, responseBody)
|
||||
s.persistOpenAICodexSnapshot(ctx, account, headers)
|
||||
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
|
||||
s.notifyAccountSchedulingBlocked(account, *resetAt, "429")
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
@ -849,6 +878,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
|
||||
// 2. Anthropic 平台:尝试解析 per-window 头(5h / 7d),选择实际触发的窗口
|
||||
if result := calculateAnthropic429ResetTime(headers); result != nil {
|
||||
s.notifyAccountSchedulingBlocked(account, result.resetAt, "429")
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, result.resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
@ -878,6 +908,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
// 尝试解析 OpenAI 的 usage_limit_reached 错误
|
||||
if resetAt := parseOpenAIRateLimitResetTime(responseBody); resetAt != nil {
|
||||
resetTime := time.Unix(*resetAt, 0)
|
||||
s.notifyAccountSchedulingBlocked(account, resetTime, "429")
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
@ -889,6 +920,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
// 尝试解析 Gemini 格式(用于其他平台)
|
||||
if resetAt := ParseGeminiRateLimitResetTime(responseBody); resetAt != nil {
|
||||
resetTime := time.Unix(*resetAt, 0)
|
||||
s.notifyAccountSchedulingBlocked(account, resetTime, "429")
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
@ -924,6 +956,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
resetAt := time.Unix(ts, 0)
|
||||
|
||||
// 标记限流状态
|
||||
s.notifyAccountSchedulingBlocked(account, resetAt, "429")
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
@ -948,6 +981,7 @@ func (s *RateLimitService) apply429FallbackRateLimit(ctx context.Context, accoun
|
||||
|
||||
resetAt := time.Now().Add(cooldown)
|
||||
slog.Warn("rate_limit_429_fallback_used", "account_id", account.ID, "platform", account.Platform, "reason", reason, "using_default", cooldown.String())
|
||||
s.notifyAccountSchedulingBlocked(account, resetAt, "429_fallback")
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
@ -1291,6 +1325,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
|
||||
}
|
||||
|
||||
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
|
||||
s.notifyAccountSchedulingBlocked(account, until, "529")
|
||||
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
|
||||
slog.Warn("overload_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
@ -1420,6 +1455,7 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
|
||||
}
|
||||
}
|
||||
s.ResetOpenAI403Counter(ctx, accountID)
|
||||
s.notifyAccountSchedulingBlockCleared(accountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1460,6 +1496,9 @@ func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID in
|
||||
}
|
||||
if result.ClearedError || result.ClearedRateLimit {
|
||||
s.ResetOpenAI403Counter(ctx, accountID)
|
||||
if result.ClearedError && !result.ClearedRateLimit {
|
||||
s.notifyAccountSchedulingBlockCleared(accountID)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
@ -1484,6 +1523,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
|
||||
if err := s.accountRepo.ClearModelRateLimits(ctx, accountID); err != nil {
|
||||
slog.Warn("clear_model_rate_limits_on_temp_unsched_reset_failed", "account_id", accountID, "error", err)
|
||||
}
|
||||
s.notifyAccountSchedulingBlockCleared(accountID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1694,6 +1734,7 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
|
||||
reason = strings.TrimSpace(state.ErrorMessage)
|
||||
}
|
||||
|
||||
s.notifyAccountSchedulingBlocked(account, until, "temp_unschedulable")
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
@ -1798,6 +1839,7 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
|
||||
reason = state.ErrorMessage
|
||||
}
|
||||
|
||||
s.notifyAccountSchedulingBlocked(account, until, "stream_timeout_temp_unschedulable")
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
@ -1824,6 +1866,7 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
|
||||
func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, account *Account, model string) bool {
|
||||
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
|
||||
|
||||
s.notifyAccountSchedulingBlocked(account, time.Time{}, "stream_timeout_error")
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err)
|
||||
return false
|
||||
|
||||
@ -6,16 +6,36 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type runtimeBlockRecorder struct {
|
||||
accounts []*Account
|
||||
until []time.Time
|
||||
reasons []string
|
||||
clearedIDs []int64
|
||||
}
|
||||
|
||||
func (r *runtimeBlockRecorder) BlockAccountScheduling(account *Account, until time.Time, reason string) {
|
||||
r.accounts = append(r.accounts, account)
|
||||
r.until = append(r.until, until)
|
||||
r.reasons = append(r.reasons, reason)
|
||||
}
|
||||
|
||||
func (r *runtimeBlockRecorder) ClearAccountSchedulingBlock(accountID int64) {
|
||||
r.clearedIDs = append(r.clearedIDs, accountID)
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
counter := &openAI403CounterCacheStub{counts: []int64{1}}
|
||||
blocker := &runtimeBlockRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetOpenAI403CounterCache(counter)
|
||||
service.SetAccountRuntimeBlocker(blocker)
|
||||
account := &Account{
|
||||
ID: 301,
|
||||
Platform: PlatformOpenAI,
|
||||
@ -35,6 +55,10 @@ func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
require.Contains(t, repo.lastTempReason, "temporary edge rejection")
|
||||
require.Contains(t, repo.lastTempReason, "(1/3)")
|
||||
require.Len(t, blocker.accounts, 1)
|
||||
require.Equal(t, account.ID, blocker.accounts[0].ID)
|
||||
require.Equal(t, "openai_403_temp", blocker.reasons[0])
|
||||
require.True(t, blocker.until[0].After(time.Now()))
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OpenAI403ThresholdDisables(t *testing.T) {
|
||||
|
||||
@ -219,7 +219,9 @@ func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLi
|
||||
},
|
||||
}
|
||||
cache := &tempUnschedCacheRecorder{}
|
||||
blocker := &runtimeBlockRecorder{}
|
||||
svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache)
|
||||
svc.SetAccountRuntimeBlocker(blocker)
|
||||
|
||||
result, err := svc.RecoverAccountAfterSuccessfulTest(context.Background(), 42)
|
||||
require.NoError(t, err)
|
||||
@ -234,6 +236,7 @@ func TestRateLimitService_RecoverAccountAfterSuccessfulTest_ClearsErrorAndRateLi
|
||||
require.Equal(t, 1, repo.clearModelRateLimitCalls)
|
||||
require.Equal(t, 1, repo.clearTempUnschedCalls)
|
||||
require.Equal(t, []int64{42}, cache.deletedIDs)
|
||||
require.Equal(t, []int64{42}, blocker.clearedIDs)
|
||||
}
|
||||
|
||||
func TestRateLimitService_RecoverAccountAfterSuccessfulTest_NoRecoverableStateIsNoop(t *testing.T) {
|
||||
|
||||
@ -54,6 +54,7 @@ type RedeemCodeRepository interface {
|
||||
GetByID(ctx context.Context, id int64) (*RedeemCode, error)
|
||||
GetByCode(ctx context.Context, code string) (*RedeemCode, error)
|
||||
Update(ctx context.Context, code *RedeemCode) error
|
||||
BatchUpdate(ctx context.Context, ids []int64, fields RedeemCodeBatchUpdateFields) (int64, error)
|
||||
Delete(ctx context.Context, id int64) error
|
||||
Use(ctx context.Context, id, userID int64) error
|
||||
|
||||
@ -82,6 +83,54 @@ type RedeemCodeResponse struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type NullableTimeUpdate struct {
|
||||
Set bool
|
||||
Value *time.Time
|
||||
}
|
||||
|
||||
type NullableInt64Update struct {
|
||||
Set bool
|
||||
Value *int64
|
||||
}
|
||||
|
||||
type RedeemCodeBatchUpdateFields struct {
|
||||
Status *string
|
||||
ExpiresAt NullableTimeUpdate
|
||||
Notes *string
|
||||
GroupID NullableInt64Update
|
||||
|
||||
// Core fields are intentionally modeled only so service validation can
|
||||
// reject payloads that try to mutate redemption value semantics in bulk.
|
||||
Type *string
|
||||
Value *float64
|
||||
}
|
||||
|
||||
func (f RedeemCodeBatchUpdateFields) HasChanges() bool {
|
||||
return f.Status != nil ||
|
||||
f.ExpiresAt.Set ||
|
||||
f.Notes != nil ||
|
||||
f.GroupID.Set ||
|
||||
f.Type != nil ||
|
||||
f.Value != nil
|
||||
}
|
||||
|
||||
func (f RedeemCodeBatchUpdateFields) HasCoreFieldChanges() bool {
|
||||
return f.Type != nil || f.Value != nil
|
||||
}
|
||||
|
||||
func (f RedeemCodeBatchUpdateFields) TouchesUsedSensitiveFields() bool {
|
||||
return f.Status != nil || f.ExpiresAt.Set || f.GroupID.Set
|
||||
}
|
||||
|
||||
type RedeemCodeBatchUpdateInput struct {
|
||||
IDs []int64
|
||||
Fields RedeemCodeBatchUpdateFields
|
||||
}
|
||||
|
||||
type RedeemCodeBatchUpdateResult struct {
|
||||
Updated int64 `json:"updated"`
|
||||
}
|
||||
|
||||
// RedeemService 兑换码服务
|
||||
type RedeemService struct {
|
||||
redeemRepo RedeemCodeRepository
|
||||
@ -218,6 +267,61 @@ func (s *RedeemService) CreateCode(ctx context.Context, code *RedeemCode) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RedeemService) BatchUpdate(ctx context.Context, input *RedeemCodeBatchUpdateInput) (*RedeemCodeBatchUpdateResult, error) {
|
||||
if input == nil {
|
||||
return nil, infraerrors.BadRequest("REDEEM_CODE_BATCH_UPDATE_INVALID", "batch update input is required")
|
||||
}
|
||||
if len(input.IDs) == 0 {
|
||||
return nil, infraerrors.BadRequest("REDEEM_CODE_BATCH_UPDATE_IDS_REQUIRED", "ids are required")
|
||||
}
|
||||
if !input.Fields.HasChanges() {
|
||||
return nil, infraerrors.BadRequest("REDEEM_CODE_BATCH_UPDATE_EMPTY", "at least one field must be selected")
|
||||
}
|
||||
if input.Fields.HasCoreFieldChanges() {
|
||||
return nil, infraerrors.BadRequest("REDEEM_CODE_CORE_FIELDS_IMMUTABLE", "type and value cannot be batch updated")
|
||||
}
|
||||
|
||||
ids := make([]int64, 0, len(input.IDs))
|
||||
seen := make(map[int64]struct{}, len(input.IDs))
|
||||
for _, id := range input.IDs {
|
||||
if id <= 0 {
|
||||
return nil, infraerrors.BadRequest("REDEEM_CODE_BATCH_UPDATE_INVALID_ID", "ids must be positive")
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil, infraerrors.BadRequest("REDEEM_CODE_BATCH_UPDATE_IDS_REQUIRED", "ids are required")
|
||||
}
|
||||
|
||||
if input.Fields.Status != nil {
|
||||
switch *input.Fields.Status {
|
||||
case StatusUnused, StatusDisabled:
|
||||
default:
|
||||
return nil, infraerrors.BadRequest("REDEEM_CODE_STATUS_INVALID", "status must be unused or disabled")
|
||||
}
|
||||
}
|
||||
if input.Fields.ExpiresAt.Set && input.Fields.ExpiresAt.Value != nil {
|
||||
expiresAt := input.Fields.ExpiresAt.Value.UTC()
|
||||
if !expiresAt.After(time.Now().UTC()) {
|
||||
return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRES_AT_INVALID", "expires_at must be in the future")
|
||||
}
|
||||
input.Fields.ExpiresAt.Value = &expiresAt
|
||||
}
|
||||
if input.Fields.GroupID.Set && input.Fields.GroupID.Value != nil && *input.Fields.GroupID.Value <= 0 {
|
||||
return nil, infraerrors.BadRequest("REDEEM_CODE_GROUP_ID_INVALID", "group_id must be positive")
|
||||
}
|
||||
|
||||
updated, err := s.redeemRepo.BatchUpdate(ctx, ids, input.Fields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &RedeemCodeBatchUpdateResult{Updated: updated}, nil
|
||||
}
|
||||
|
||||
// checkRedeemRateLimit 检查用户兑换错误次数是否超限
|
||||
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
|
||||
if s.cache == nil {
|
||||
|
||||
75
backend/internal/service/redeem_service_batch_update_test.go
Normal file
75
backend/internal/service/redeem_service_batch_update_test.go
Normal file
@ -0,0 +1,75 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRedeemService_BatchUpdate_PartialFields(t *testing.T) {
|
||||
status := StatusDisabled
|
||||
notes := "maintenance window"
|
||||
expiresAt := time.Now().UTC().Add(24 * time.Hour)
|
||||
repo := &redeemRepoStub{}
|
||||
svc := &RedeemService{redeemRepo: repo}
|
||||
|
||||
result, err := svc.BatchUpdate(context.Background(), &RedeemCodeBatchUpdateInput{
|
||||
IDs: []int64{1, 2, 2},
|
||||
Fields: RedeemCodeBatchUpdateFields{
|
||||
Status: &status,
|
||||
ExpiresAt: NullableTimeUpdate{Set: true, Value: &expiresAt},
|
||||
Notes: ¬es,
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), result.Updated)
|
||||
require.True(t, repo.batchUpdateCalled)
|
||||
require.Equal(t, []int64{1, 2}, repo.batchUpdateIDs)
|
||||
require.Equal(t, &status, repo.batchUpdateFields.Status)
|
||||
require.True(t, repo.batchUpdateFields.ExpiresAt.Set)
|
||||
require.WithinDuration(t, expiresAt, *repo.batchUpdateFields.ExpiresAt.Value, time.Second)
|
||||
require.Equal(t, ¬es, repo.batchUpdateFields.Notes)
|
||||
require.False(t, repo.batchUpdateFields.GroupID.Set)
|
||||
require.Nil(t, repo.batchUpdateFields.Type)
|
||||
require.Nil(t, repo.batchUpdateFields.Value)
|
||||
}
|
||||
|
||||
func TestRedeemService_BatchUpdate_RejectsInvalidID(t *testing.T) {
|
||||
repo := &redeemRepoStub{}
|
||||
svc := &RedeemService{redeemRepo: repo}
|
||||
notes := "bad id"
|
||||
|
||||
result, err := svc.BatchUpdate(context.Background(), &RedeemCodeBatchUpdateInput{
|
||||
IDs: []int64{1, 0},
|
||||
Fields: RedeemCodeBatchUpdateFields{Notes: ¬es},
|
||||
})
|
||||
|
||||
require.Nil(t, result)
|
||||
require.Error(t, err)
|
||||
require.True(t, infraerrors.IsBadRequest(err))
|
||||
require.False(t, repo.batchUpdateCalled)
|
||||
}
|
||||
|
||||
func TestRedeemService_BatchUpdate_RejectsCoreFieldsForUsedCodes(t *testing.T) {
|
||||
repo := &redeemRepoStub{}
|
||||
svc := &RedeemService{redeemRepo: repo}
|
||||
newValue := 100.0
|
||||
|
||||
result, err := svc.BatchUpdate(context.Background(), &RedeemCodeBatchUpdateInput{
|
||||
IDs: []int64{42},
|
||||
Fields: RedeemCodeBatchUpdateFields{
|
||||
Value: &newValue,
|
||||
},
|
||||
})
|
||||
|
||||
require.Nil(t, result)
|
||||
require.Error(t, err)
|
||||
require.True(t, infraerrors.IsBadRequest(err))
|
||||
require.False(t, repo.batchUpdateCalled)
|
||||
}
|
||||
@ -26,12 +26,17 @@ func IsRegistrationEmailSuffixAllowed(email string, whitelist []string) bool {
|
||||
if len(whitelist) == 0 {
|
||||
return true
|
||||
}
|
||||
suffix := RegistrationEmailSuffix(email)
|
||||
if suffix == "" {
|
||||
_, domain, ok := splitEmailForPolicy(email)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
suffix := "@" + domain
|
||||
for _, allowed := range whitelist {
|
||||
if suffix == allowed {
|
||||
allowed = strings.ToLower(strings.TrimSpace(allowed))
|
||||
if strings.HasPrefix(allowed, "@") && suffix == allowed {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(allowed, "*.") && registrationEmailDomainMatchesWildcard(domain, allowed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@ -98,6 +103,14 @@ func normalizeRegistrationEmailSuffix(raw string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(value, "*.") {
|
||||
domain := strings.TrimPrefix(value, "*.")
|
||||
if !isValidRegistrationEmailDomain(domain) {
|
||||
return "", fmt.Errorf("invalid email suffix: %q", raw)
|
||||
}
|
||||
return "*." + domain, nil
|
||||
}
|
||||
|
||||
domain := value
|
||||
if strings.Contains(value, "@") {
|
||||
if !strings.HasPrefix(value, "@") || strings.Count(value, "@") != 1 {
|
||||
@ -106,13 +119,27 @@ func normalizeRegistrationEmailSuffix(raw string) (string, error) {
|
||||
domain = strings.TrimPrefix(value, "@")
|
||||
}
|
||||
|
||||
if domain == "" || strings.Contains(domain, "@") || !registrationEmailDomainPattern.MatchString(domain) {
|
||||
if !isValidRegistrationEmailDomain(domain) {
|
||||
return "", fmt.Errorf("invalid email suffix: %q", raw)
|
||||
}
|
||||
|
||||
return "@" + domain, nil
|
||||
}
|
||||
|
||||
func isValidRegistrationEmailDomain(domain string) bool {
|
||||
return domain != "" &&
|
||||
!strings.Contains(domain, "@") &&
|
||||
registrationEmailDomainPattern.MatchString(domain)
|
||||
}
|
||||
|
||||
func registrationEmailDomainMatchesWildcard(domain string, allowed string) bool {
|
||||
base := strings.TrimPrefix(allowed, "*.")
|
||||
if !isValidRegistrationEmailDomain(base) {
|
||||
return false
|
||||
}
|
||||
return domain == base || strings.HasSuffix(domain, "."+base)
|
||||
}
|
||||
|
||||
func splitEmailForPolicy(raw string) (local string, domain string, ok bool) {
|
||||
email := strings.ToLower(strings.TrimSpace(raw))
|
||||
local, domain, found := strings.Cut(email, "@")
|
||||
|
||||
@ -9,23 +9,36 @@ import (
|
||||
)
|
||||
|
||||
func TestNormalizeRegistrationEmailSuffixWhitelist(t *testing.T) {
|
||||
got, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"example.com", "@EXAMPLE.COM", " @foo.bar "})
|
||||
got, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"example.com", "@EXAMPLE.COM", " @foo.bar ", "*.EDU.CN"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"@example.com", "@foo.bar"}, got)
|
||||
require.Equal(t, []string{"@example.com", "@foo.bar", "*.edu.cn"}, got)
|
||||
}
|
||||
|
||||
func TestNormalizeRegistrationEmailSuffixWhitelist_Invalid(t *testing.T) {
|
||||
_, err := NormalizeRegistrationEmailSuffixWhitelist([]string{"@invalid_domain"})
|
||||
require.Error(t, err)
|
||||
for _, item := range []string{"@invalid_domain", "*.", "*", "*.@", "*.foo"} {
|
||||
t.Run(item, func(t *testing.T) {
|
||||
_, err := NormalizeRegistrationEmailSuffixWhitelist([]string{item})
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRegistrationEmailSuffixWhitelist(t *testing.T) {
|
||||
got := ParseRegistrationEmailSuffixWhitelist(`["example.com","@foo.bar","@invalid_domain"]`)
|
||||
require.Equal(t, []string{"@example.com", "@foo.bar"}, got)
|
||||
got := ParseRegistrationEmailSuffixWhitelist(`["example.com","@foo.bar","*.EDU.CN","@invalid_domain","*.foo"]`)
|
||||
require.Equal(t, []string{"@example.com", "@foo.bar", "*.edu.cn"}, got)
|
||||
}
|
||||
|
||||
func TestIsRegistrationEmailSuffixAllowed(t *testing.T) {
|
||||
require.True(t, IsRegistrationEmailSuffixAllowed("user@example.com", []string{"@example.com"}))
|
||||
require.False(t, IsRegistrationEmailSuffixAllowed("user@sub.example.com", []string{"@example.com"}))
|
||||
require.True(t, IsRegistrationEmailSuffixAllowed("user@qq.com", []string{"@qq.com"}))
|
||||
require.False(t, IsRegistrationEmailSuffixAllowed("user@sub.qq.com", []string{"@qq.com"}))
|
||||
require.True(t, IsRegistrationEmailSuffixAllowed("student@cs.edu.cn", []string{"*.edu.cn"}))
|
||||
require.True(t, IsRegistrationEmailSuffixAllowed("student@edu.cn", []string{"*.edu.cn"}))
|
||||
require.False(t, IsRegistrationEmailSuffixAllowed("student@foo.cn", []string{"*.edu.cn"}))
|
||||
require.True(t, IsRegistrationEmailSuffixAllowed("user@a.com", []string{"@a.com", "*.b.cn"}))
|
||||
require.True(t, IsRegistrationEmailSuffixAllowed("user@school.b.cn", []string{"@a.com", "*.b.cn"}))
|
||||
require.True(t, IsRegistrationEmailSuffixAllowed("user@b.cn", []string{"@a.com", "*.b.cn"}))
|
||||
require.False(t, IsRegistrationEmailSuffixAllowed("user@c.cn", []string{"@a.com", "*.b.cn"}))
|
||||
require.True(t, IsRegistrationEmailSuffixAllowed("user@any.com", []string{}))
|
||||
}
|
||||
|
||||
@ -114,6 +114,31 @@ func TestOpenAISelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedul
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAINewAcquiredSelectionResult_ReleasesSlotWhenHydrationFails(t *testing.T) {
|
||||
cache := &snapshotHydrationCache{
|
||||
accounts: map[int64]*Account{},
|
||||
}
|
||||
schedulerSnapshot := NewSchedulerSnapshotService(cache, nil, stubOpenAIAccountRepo{}, nil, nil)
|
||||
svc := &OpenAIGatewayService{
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
}
|
||||
releaseCalls := 0
|
||||
|
||||
selection, err := svc.newAcquiredSelectionResult(context.Background(), &Account{ID: 1001}, func() {
|
||||
releaseCalls++
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("expected hydration error")
|
||||
}
|
||||
if selection != nil {
|
||||
t.Fatalf("expected nil selection on hydration error")
|
||||
}
|
||||
if releaseCalls != 1 {
|
||||
t.Fatalf("expected release to be called once, got %d", releaseCalls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewaySelectAccountWithLoadAwareness_HydratesSelectedAccountFromSchedulerSnapshot(t *testing.T) {
|
||||
cache := &snapshotHydrationCache{
|
||||
snapshot: []*Account{
|
||||
|
||||
@ -597,6 +597,23 @@ func (s *SettingService) SetProxyRepository(repo ProxyRepository) {
|
||||
s.proxyRepo = repo
|
||||
}
|
||||
|
||||
func (s *SettingService) LoadAPIKeyACLTrustForwardedIPSetting(ctx context.Context) error {
|
||||
if s == nil || s.cfg == nil || s.settingRepo == nil {
|
||||
return nil
|
||||
}
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyAPIKeyACLTrustForwardedIP)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
s.cfg.SetTrustForwardedIPForAPIKeyACL(s.cfg.Security.TrustForwardedIPForAPIKeyACL)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("get api key acl forwarded ip setting: %w", err)
|
||||
}
|
||||
enabled := value == "true"
|
||||
s.cfg.SetTrustForwardedIPForAPIKeyACL(enabled)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllSettings 获取所有系统设置
|
||||
func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) {
|
||||
settings, err := s.settingRepo.GetAll(ctx)
|
||||
@ -633,6 +650,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyLoginAgreementDocuments,
|
||||
SettingKeyTurnstileEnabled,
|
||||
SettingKeyTurnstileSiteKey,
|
||||
SettingKeyAPIKeyACLTrustForwardedIP,
|
||||
SettingKeySiteName,
|
||||
SettingKeySiteLogo,
|
||||
SettingKeySiteSubtitle,
|
||||
@ -1568,6 +1586,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
if settings.TurnstileSecretKey != "" {
|
||||
updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
|
||||
}
|
||||
updates[SettingKeyAPIKeyACLTrustForwardedIP] = strconv.FormatBool(settings.APIKeyACLTrustForwardedIP)
|
||||
|
||||
// LinuxDo Connect OAuth 登录
|
||||
updates[SettingKeyLinuxDoConnectEnabled] = strconv.FormatBool(settings.LinuxDoConnectEnabled)
|
||||
@ -1777,10 +1796,11 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
|
||||
updates[SettingPaymentVisibleMethodWxpayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodWxpayEnabled)
|
||||
updates[openAIAdvancedSchedulerSettingKey] = strconv.FormatBool(settings.OpenAIAdvancedSchedulerEnabled)
|
||||
|
||||
// Balance low notification
|
||||
// 余额、订阅到期与账号限额通知
|
||||
updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled)
|
||||
updates[SettingKeyBalanceLowNotifyThreshold] = strconv.FormatFloat(settings.BalanceLowNotifyThreshold, 'f', 8, 64)
|
||||
updates[SettingKeyBalanceLowNotifyRechargeURL] = settings.BalanceLowNotifyRechargeURL
|
||||
updates[SettingKeySubscriptionExpiryNotifyEnabled] = strconv.FormatBool(settings.SubscriptionExpiryNotifyEnabled)
|
||||
updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled)
|
||||
updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails)
|
||||
|
||||
@ -1867,6 +1887,9 @@ func (s *SettingService) refreshCachedSettings(settings *SystemSettings) {
|
||||
enabled: settings.OpenAIAdvancedSchedulerEnabled,
|
||||
expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(),
|
||||
})
|
||||
if s.cfg != nil {
|
||||
s.cfg.SetTrustForwardedIPForAPIKeyACL(settings.APIKeyACLTrustForwardedIP)
|
||||
}
|
||||
if s.onUpdate != nil {
|
||||
s.onUpdate() // Invalidate cache after settings update
|
||||
}
|
||||
@ -2463,6 +2486,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeyLoginAgreementMode: defaultLoginAgreementMode,
|
||||
SettingKeyLoginAgreementUpdatedAt: defaultLoginAgreementDate,
|
||||
SettingKeyLoginAgreementDocuments: loginAgreementDocumentsJSON,
|
||||
SettingKeyAPIKeyACLTrustForwardedIP: "false",
|
||||
SettingKeySiteName: "Sub2API",
|
||||
SettingKeySiteLogo: "",
|
||||
SettingKeyPurchaseSubscriptionEnabled: "false",
|
||||
@ -2622,6 +2646,12 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
if loginAgreementUpdatedAt == "" {
|
||||
loginAgreementUpdatedAt = defaultLoginAgreementDate
|
||||
}
|
||||
apiKeyACLTrustForwardedIP := false
|
||||
if value, ok := settings[SettingKeyAPIKeyACLTrustForwardedIP]; ok {
|
||||
apiKeyACLTrustForwardedIP = value == "true"
|
||||
} else if s != nil && s.cfg != nil {
|
||||
apiKeyACLTrustForwardedIP = s.cfg.Security.TrustForwardedIPForAPIKeyACL
|
||||
}
|
||||
result := &SystemSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: emailVerifyEnabled,
|
||||
@ -2644,6 +2674,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "",
|
||||
APIKeyACLTrustForwardedIP: apiKeyACLTrustForwardedIP,
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
@ -3131,14 +3162,15 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result.PaymentVisibleMethodWxpayEnabled = settings[SettingPaymentVisibleMethodWxpayEnabled] == "true"
|
||||
result.OpenAIAdvancedSchedulerEnabled = settings[openAIAdvancedSchedulerSettingKey] == "true"
|
||||
|
||||
// Balance low notification
|
||||
// 余额、订阅到期与账号限额通知
|
||||
result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true"
|
||||
if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 {
|
||||
result.BalanceLowNotifyThreshold = v
|
||||
}
|
||||
result.BalanceLowNotifyRechargeURL = settings[SettingKeyBalanceLowNotifyRechargeURL]
|
||||
result.SubscriptionExpiryNotifyEnabled = !isFalseSettingValue(settings[SettingKeySubscriptionExpiryNotifyEnabled])
|
||||
|
||||
// Account quota notification
|
||||
// 账号限额通知
|
||||
result.AccountQuotaNotifyEnabled = settings[SettingKeyAccountQuotaNotifyEnabled] == "true"
|
||||
if raw := strings.TrimSpace(settings[SettingKeyAccountQuotaNotifyEmails]); raw != "" {
|
||||
result.AccountQuotaNotifyEmails = ParseNotifyEmails(raw)
|
||||
|
||||
@ -53,14 +53,14 @@ func TestSettingService_GetPublicSettings_ExposesRegistrationEmailSuffixWhitelis
|
||||
values: map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
SettingKeyRegistrationEmailSuffixWhitelist: `["@EXAMPLE.com"," @foo.bar ","@invalid_domain",""]`,
|
||||
SettingKeyRegistrationEmailSuffixWhitelist: `["@EXAMPLE.com"," @foo.bar ","*.EDU.CN","@invalid_domain",""]`,
|
||||
},
|
||||
}
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
settings, err := svc.GetPublicSettings(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"@example.com", "@foo.bar"}, settings.RegistrationEmailSuffixWhitelist)
|
||||
require.Equal(t, []string{"@example.com", "@foo.bar", "*.edu.cn"}, settings.RegistrationEmailSuffixWhitelist)
|
||||
}
|
||||
|
||||
func TestSettingService_GetPublicSettings_ExposesTablePreferences(t *testing.T) {
|
||||
|
||||
@ -213,10 +213,10 @@ func TestSettingService_UpdateSettings_RegistrationEmailSuffixWhitelist_Normaliz
|
||||
svc := NewSettingService(repo, &config.Config{})
|
||||
|
||||
err := svc.UpdateSettings(context.Background(), &SystemSettings{
|
||||
RegistrationEmailSuffixWhitelist: []string{"example.com", "@EXAMPLE.com", " @foo.bar "},
|
||||
RegistrationEmailSuffixWhitelist: []string{"example.com", "@EXAMPLE.com", " @foo.bar ", "*.EDU.CN"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, `["@example.com","@foo.bar"]`, repo.updates[SettingKeyRegistrationEmailSuffixWhitelist])
|
||||
require.Equal(t, `["@example.com","@foo.bar","*.edu.cn"]`, repo.updates[SettingKeyRegistrationEmailSuffixWhitelist])
|
||||
}
|
||||
|
||||
func TestSettingService_UpdateSettings_RegistrationEmailSuffixWhitelist_Invalid(t *testing.T) {
|
||||
@ -290,6 +290,30 @@ func TestSettingService_UpdateSettings_AntigravityUserAgentVersion(t *testing.T)
|
||||
require.Equal(t, "1.23.2", repo.updates[SettingKeyAntigravityUserAgentVersion])
|
||||
}
|
||||
|
||||
func TestSettingService_UpdateSettings_APIKeyACLTrustForwardedIPRefreshesConfig(t *testing.T) {
|
||||
repo := &settingUpdateRepoStub{}
|
||||
cfg := &config.Config{}
|
||||
svc := NewSettingService(repo, cfg)
|
||||
|
||||
err := svc.UpdateSettings(context.Background(), &SystemSettings{
|
||||
APIKeyACLTrustForwardedIP: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "true", repo.updates[SettingKeyAPIKeyACLTrustForwardedIP])
|
||||
require.True(t, cfg.Security.TrustForwardedIPForAPIKeyACL)
|
||||
require.True(t, cfg.TrustForwardedIPForAPIKeyACL())
|
||||
}
|
||||
|
||||
func TestSettingService_ParseSettings_APIKeyACLTrustForwardedIPFallsBackToConfigWhenMissing(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Security.TrustForwardedIPForAPIKeyACL = true
|
||||
svc := NewSettingService(&settingUpdateRepoStub{}, cfg)
|
||||
|
||||
got := svc.parseSettings(map[string]string{})
|
||||
|
||||
require.True(t, got.APIKeyACLTrustForwardedIP)
|
||||
}
|
||||
|
||||
func TestSettingService_GetAntigravityUserAgentVersion_Precedence(t *testing.T) {
|
||||
t.Run("后台设置优先", func(t *testing.T) {
|
||||
svc := NewSettingService(&settingAntigravityUARepoStub{values: map[string]string{
|
||||
|
||||
@ -38,6 +38,7 @@ type SystemSettings struct {
|
||||
TurnstileSiteKey string
|
||||
TurnstileSecretKey string
|
||||
TurnstileSecretKeyConfigured bool
|
||||
APIKeyACLTrustForwardedIP bool
|
||||
|
||||
// LinuxDo Connect OAuth 登录
|
||||
LinuxDoConnectEnabled bool
|
||||
@ -204,15 +205,18 @@ type SystemSettings struct {
|
||||
PaymentVisibleMethodAlipayEnabled bool
|
||||
PaymentVisibleMethodWxpayEnabled bool
|
||||
|
||||
// OpenAI account scheduling
|
||||
// OpenAI 账号调度
|
||||
OpenAIAdvancedSchedulerEnabled bool
|
||||
|
||||
// Balance low notification
|
||||
// 余额不足提醒
|
||||
BalanceLowNotifyEnabled bool
|
||||
BalanceLowNotifyThreshold float64
|
||||
BalanceLowNotifyRechargeURL string
|
||||
|
||||
// Account quota notification
|
||||
// 订阅到期提醒
|
||||
SubscriptionExpiryNotifyEnabled bool
|
||||
|
||||
// 账号限额通知
|
||||
AccountQuotaNotifyEnabled bool
|
||||
AccountQuotaNotifyEmails []NotifyEmailEntry
|
||||
}
|
||||
|
||||
@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
@ -14,6 +15,7 @@ import (
|
||||
// SubscriptionExpiryService periodically updates expired subscription status.
|
||||
type SubscriptionExpiryService struct {
|
||||
userSubRepo UserSubscriptionRepository
|
||||
settingRepo SettingRepository
|
||||
notificationEmailService *NotificationEmailService
|
||||
interval time.Duration
|
||||
stopCh chan struct{}
|
||||
@ -29,6 +31,10 @@ func NewSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, interv
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SubscriptionExpiryService) SetSettingRepository(settingRepo SettingRepository) {
|
||||
s.settingRepo = settingRepo
|
||||
}
|
||||
|
||||
func (s *SubscriptionExpiryService) SetNotificationEmailService(notificationEmailService *NotificationEmailService) {
|
||||
s.notificationEmailService = notificationEmailService
|
||||
}
|
||||
@ -84,6 +90,9 @@ func (s *SubscriptionExpiryService) sendExpiryReminders(ctx context.Context) {
|
||||
if s == nil || s.userSubRepo == nil || s.notificationEmailService == nil {
|
||||
return
|
||||
}
|
||||
if !s.expiryReminderEnabled(ctx) {
|
||||
return
|
||||
}
|
||||
for page := 1; ; page++ {
|
||||
subs, pag, err := s.userSubRepo.List(ctx, pagination.PaginationParams{Page: page, PageSize: 200}, nil, nil, SubscriptionStatusActive, "", "expires_at", "asc")
|
||||
if err != nil {
|
||||
@ -99,6 +108,21 @@ func (s *SubscriptionExpiryService) sendExpiryReminders(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SubscriptionExpiryService) expiryReminderEnabled(ctx context.Context) bool {
|
||||
if s == nil || s.settingRepo == nil {
|
||||
return true
|
||||
}
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeySubscriptionExpiryNotifyEnabled)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return true
|
||||
}
|
||||
log.Printf("[SubscriptionExpiry] Read expiry reminder switch failed: %v", err)
|
||||
return false
|
||||
}
|
||||
return !isFalseSettingValue(value)
|
||||
}
|
||||
|
||||
func (s *SubscriptionExpiryService) sendExpiryReminderIfDue(ctx context.Context, sub *UserSubscription) {
|
||||
if sub == nil || sub.User == nil || sub.Group == nil || sub.User.Email == "" {
|
||||
return
|
||||
|
||||
164
backend/internal/service/subscription_expiry_service_test.go
Normal file
164
backend/internal/service/subscription_expiry_service_test.go
Normal file
@ -0,0 +1,164 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type subscriptionExpiryRepoStub struct {
|
||||
listCalls int
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) Create(context.Context, *UserSubscription) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) GetByID(context.Context, int64) (*UserSubscription, error) {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) GetByUserIDAndGroupID(context.Context, int64, int64) (*UserSubscription, error) {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) GetActiveByUserIDAndGroupID(context.Context, int64, int64) (*UserSubscription, error) {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) Update(context.Context, *UserSubscription) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) Delete(context.Context, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) ListByUserID(context.Context, int64) ([]UserSubscription, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) ListActiveByUserID(context.Context, int64) ([]UserSubscription, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) List(context.Context, pagination.PaginationParams, *int64, *int64, string, string, string, string) ([]UserSubscription, *pagination.PaginationResult, error) {
|
||||
r.listCalls++
|
||||
return nil, &pagination.PaginationResult{Page: 1, Pages: 1}, nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) ExistsByUserIDAndGroupID(context.Context, int64, int64) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) ExtendExpiry(context.Context, int64, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) UpdateStatus(context.Context, int64, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) UpdateNotes(context.Context, int64, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) ActivateWindows(context.Context, int64, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) ResetDailyUsage(context.Context, int64, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) ResetWeeklyUsage(context.Context, int64, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) ResetMonthlyUsage(context.Context, int64, time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) IncrementUsage(context.Context, int64, float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpiryRepoStub) BatchUpdateExpiredStatus(context.Context) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
type subscriptionExpirySettingRepoStub struct {
|
||||
values map[string]string
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *subscriptionExpirySettingRepoStub) Get(context.Context, string) (*Setting, error) {
|
||||
return nil, ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (r *subscriptionExpirySettingRepoStub) GetValue(_ context.Context, key string) (string, error) {
|
||||
if r.err != nil {
|
||||
return "", r.err
|
||||
}
|
||||
value, ok := r.values[key]
|
||||
if !ok {
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpirySettingRepoStub) Set(context.Context, string, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpirySettingRepoStub) GetMultiple(context.Context, []string) (map[string]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpirySettingRepoStub) SetMultiple(context.Context, map[string]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpirySettingRepoStub) GetAll(context.Context) (map[string]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *subscriptionExpirySettingRepoStub) Delete(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSubscriptionExpiryService_ExpiryReminderEnabledDefaultsToTrue(t *testing.T) {
|
||||
svc := NewSubscriptionExpiryService(nil, time.Minute)
|
||||
svc.SetSettingRepository(&subscriptionExpirySettingRepoStub{values: map[string]string{}})
|
||||
|
||||
require.True(t, svc.expiryReminderEnabled(context.Background()))
|
||||
}
|
||||
|
||||
func TestSubscriptionExpiryService_ExpiryReminderDisabledSkipsSubscriptionScan(t *testing.T) {
|
||||
repo := &subscriptionExpiryRepoStub{}
|
||||
settingRepo := &subscriptionExpirySettingRepoStub{
|
||||
values: map[string]string{SettingKeySubscriptionExpiryNotifyEnabled: "false"},
|
||||
}
|
||||
svc := NewSubscriptionExpiryService(repo, time.Minute)
|
||||
svc.SetSettingRepository(settingRepo)
|
||||
svc.SetNotificationEmailService(NewNotificationEmailService(settingRepo, nil))
|
||||
|
||||
svc.sendExpiryReminders(context.Background())
|
||||
|
||||
require.Zero(t, repo.listCalls)
|
||||
}
|
||||
|
||||
func TestSubscriptionExpiryService_ExpiryReminderSettingReadErrorFailsClosed(t *testing.T) {
|
||||
svc := NewSubscriptionExpiryService(nil, time.Minute)
|
||||
svc.SetSettingRepository(&subscriptionExpirySettingRepoStub{err: errors.New("db down")})
|
||||
|
||||
require.False(t, svc.expiryReminderEnabled(context.Background()))
|
||||
}
|
||||
@ -27,6 +27,7 @@ type TokenRefreshService struct {
|
||||
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
|
||||
tempUnschedCache TempUnschedCache // 用于清除 Redis 中的临时不可调度缓存
|
||||
refreshAPI *OAuthRefreshAPI // 统一刷新 API
|
||||
runtimeBlocker AccountRuntimeBlocker
|
||||
|
||||
// OpenAI privacy: 刷新成功后检查并设置 training opt-out
|
||||
privacyClientFactory PrivacyClientFactory
|
||||
@ -100,6 +101,24 @@ func (s *TokenRefreshService) SetRefreshPolicy(policy BackgroundRefreshPolicy) {
|
||||
s.refreshPolicy = policy
|
||||
}
|
||||
|
||||
func (s *TokenRefreshService) SetAccountRuntimeBlocker(blocker AccountRuntimeBlocker) {
|
||||
s.runtimeBlocker = blocker
|
||||
}
|
||||
|
||||
func (s *TokenRefreshService) notifyAccountSchedulingBlocked(account *Account, until time.Time, reason string) {
|
||||
if s == nil || s.runtimeBlocker == nil || account == nil {
|
||||
return
|
||||
}
|
||||
s.runtimeBlocker.BlockAccountScheduling(account, until, reason)
|
||||
}
|
||||
|
||||
func (s *TokenRefreshService) notifyAccountSchedulingBlockCleared(accountID int64) {
|
||||
if s == nil || s.runtimeBlocker == nil || accountID <= 0 {
|
||||
return
|
||||
}
|
||||
s.runtimeBlocker.ClearAccountSchedulingBlock(accountID)
|
||||
}
|
||||
|
||||
// Start 启动后台刷新服务
|
||||
func (s *TokenRefreshService) Start() {
|
||||
if !s.cfg.Enabled {
|
||||
@ -284,6 +303,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
// 不可重试错误(invalid_grant/invalid_client 等)直接标记 error 状态并返回
|
||||
if isNonRetryableRefreshError(err) {
|
||||
errorMsg := fmt.Sprintf("Token refresh failed (non-retryable): %v", err)
|
||||
s.notifyAccountSchedulingBlocked(account, time.Time{}, "token_refresh_non_retryable")
|
||||
if setErr := s.accountRepo.SetError(ctx, account.ID, errorMsg); setErr != nil {
|
||||
slog.Error("token_refresh.set_error_status_failed",
|
||||
"account_id", account.ID,
|
||||
@ -327,6 +347,7 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
// 设置临时不可调度 10 分钟(不标记 error,保持 status=active 让下个刷新周期能继续尝试)
|
||||
until := time.Now().Add(tokenRefreshTempUnschedDuration)
|
||||
reason := fmt.Sprintf("token refresh retry exhausted: %v", lastErr)
|
||||
s.notifyAccountSchedulingBlocked(account, until, "token_refresh_retry_exhausted")
|
||||
if setErr := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); setErr != nil {
|
||||
slog.Warn("token_refresh.set_temp_unschedulable_failed",
|
||||
"account_id", account.ID,
|
||||
@ -355,6 +376,7 @@ func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *A
|
||||
)
|
||||
} else {
|
||||
slog.Info("token_refresh.cleared_missing_project_id_error", "account_id", account.ID)
|
||||
s.notifyAccountSchedulingBlockCleared(account.ID)
|
||||
}
|
||||
}
|
||||
// 刷新成功后清除临时不可调度状态(处理 OAuth 401 恢复场景)
|
||||
@ -366,6 +388,7 @@ func (s *TokenRefreshService) postRefreshActions(ctx context.Context, account *A
|
||||
)
|
||||
} else {
|
||||
slog.Info("token_refresh.cleared_temp_unschedulable", "account_id", account.ID)
|
||||
s.notifyAccountSchedulingBlockCleared(account.ID)
|
||||
}
|
||||
// 同步清除 Redis 缓存,避免调度器读到过期的临时不可调度状态
|
||||
if s.tempUnschedCache != nil {
|
||||
@ -417,11 +440,12 @@ func isNonRetryableRefreshError(err error) bool {
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
nonRetryable := []string{
|
||||
"invalid_grant", // refresh_token 已失效
|
||||
"invalid_client", // 客户端配置错误
|
||||
"unauthorized_client", // 客户端未授权
|
||||
"access_denied", // 访问被拒绝
|
||||
"missing_project_id", // 缺少 project_id
|
||||
"invalid_grant", // refresh_token 已失效
|
||||
"refresh_token_reused", // OpenAI refresh_token 已被使用,必须重新授权
|
||||
"invalid_client", // 客户端配置错误
|
||||
"unauthorized_client", // 客户端未授权
|
||||
"access_denied", // 访问被拒绝
|
||||
"missing_project_id", // 缺少 project_id
|
||||
"no refresh token available",
|
||||
}
|
||||
for _, needle := range nonRetryable {
|
||||
|
||||
@ -532,6 +532,7 @@ func TestIsNonRetryableRefreshError(t *testing.T) {
|
||||
{name: "network_error", err: errors.New("network timeout"), expected: false},
|
||||
{name: "invalid_grant", err: errors.New("invalid_grant"), expected: true},
|
||||
{name: "invalid_client", err: errors.New("invalid_client"), expected: true},
|
||||
{name: "refresh_token_reused", err: errors.New(`OPENAI_OAUTH_TOKEN_REFRESH_FAILED: token refresh failed: status 401, body: {"error":{"code":"refresh_token_reused"}}`), expected: true},
|
||||
{name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true},
|
||||
{name: "access_denied", err: errors.New("access_denied"), expected: true},
|
||||
{name: "no_refresh_token", err: errors.New("no refresh token available"), expected: true},
|
||||
|
||||
@ -62,6 +62,7 @@ func ProvideTokenRefreshService(
|
||||
privacyClientFactory PrivacyClientFactory,
|
||||
proxyRepo ProxyRepository,
|
||||
refreshAPI *OAuthRefreshAPI,
|
||||
runtimeBlocker AccountRuntimeBlocker,
|
||||
) *TokenRefreshService {
|
||||
svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, schedulerCache, cfg, tempUnschedCache)
|
||||
// 注入 OpenAI privacy opt-out 依赖
|
||||
@ -70,6 +71,7 @@ func ProvideTokenRefreshService(
|
||||
svc.SetRefreshAPI(refreshAPI)
|
||||
// 调用侧显式注入后台刷新策略,避免策略漂移
|
||||
svc.SetRefreshPolicy(DefaultBackgroundRefreshPolicy())
|
||||
svc.SetAccountRuntimeBlocker(runtimeBlocker)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
@ -154,8 +156,9 @@ func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpirySe
|
||||
}
|
||||
|
||||
// ProvideSubscriptionExpiryService creates and starts SubscriptionExpiryService.
|
||||
func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, notificationEmailService *NotificationEmailService) *SubscriptionExpiryService {
|
||||
func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, settingRepo SettingRepository, notificationEmailService *NotificationEmailService) *SubscriptionExpiryService {
|
||||
svc := NewSubscriptionExpiryService(userSubRepo, time.Minute)
|
||||
svc.SetSettingRepository(settingRepo)
|
||||
svc.SetNotificationEmailService(notificationEmailService)
|
||||
svc.Start()
|
||||
return svc
|
||||
@ -185,6 +188,7 @@ func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountReposi
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: startup cleanup stale process slots failed: %v", err)
|
||||
}
|
||||
if cfg != nil {
|
||||
svc.SetAccountLoadBatchCacheTTL(time.Duration(cfg.Gateway.Scheduling.LoadBatchCacheTTLMS) * time.Millisecond)
|
||||
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
|
||||
}
|
||||
return svc
|
||||
@ -400,6 +404,9 @@ func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupReposit
|
||||
svc := NewSettingService(settingRepo, cfg)
|
||||
svc.SetDefaultSubscriptionGroupReader(groupRepo)
|
||||
svc.SetProxyRepository(proxyRepo)
|
||||
if err := svc.LoadAPIKeyACLTrustForwardedIPSetting(context.Background()); err != nil {
|
||||
logger.LegacyPrintf("service.setting", "Warning: load api key acl forwarded ip setting failed: %v", err)
|
||||
}
|
||||
antigravity.SetUserAgentVersionResolver(svc.GetAntigravityUserAgentVersion)
|
||||
return svc
|
||||
}
|
||||
@ -455,6 +462,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewRPMTokenBucketService,
|
||||
NewGatewayService,
|
||||
NewOpenAIGatewayService,
|
||||
wire.Bind(new(AccountRuntimeBlocker), new(*OpenAIGatewayService)),
|
||||
NewOAuthService,
|
||||
NewOpenAIOAuthService,
|
||||
NewGeminiOAuthService,
|
||||
|
||||
@ -0,0 +1,12 @@
|
||||
-- 修复:user_provider_default_grants 表的 provider_type check 约束
|
||||
-- 与 auth_identities / auth_identity_channels / pending_auth_sessions 保持一致,
|
||||
-- 否则启用了 auth_source_default_{github,google,dingtalk}_grant_on_first_bind
|
||||
-- 之后,OAuth 首次绑定流程会因约束违反而失败。
|
||||
-- 参见 migrations 135、136 漏改本表。
|
||||
|
||||
ALTER TABLE user_provider_default_grants
|
||||
DROP CONSTRAINT IF EXISTS user_provider_default_grants_provider_type_check;
|
||||
|
||||
ALTER TABLE user_provider_default_grants
|
||||
ADD CONSTRAINT user_provider_default_grants_provider_type_check
|
||||
CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google', 'dingtalk'));
|
||||
@ -0,0 +1,4 @@
|
||||
-- 订阅到期提醒邮件开关,默认保持历史行为:开启。
|
||||
INSERT INTO settings (key, value, updated_at)
|
||||
VALUES ('subscription_expiry_notify_enabled', 'true', NOW())
|
||||
ON CONFLICT (key) DO NOTHING;
|
||||
@ -405,6 +405,9 @@ gateway:
|
||||
# Enable batch load calculation for scheduling
|
||||
# 启用调度批量负载计算
|
||||
load_batch_enabled: true
|
||||
# Tiny in-process TTL for batch load reads in milliseconds (0 disables)
|
||||
# 调度批量负载读取的进程内短缓存 TTL(毫秒,0 表示禁用)
|
||||
load_batch_cache_ttl_ms: 200
|
||||
# Slot cleanup interval (duration)
|
||||
# 并发槽位清理周期(时间段)
|
||||
slot_cleanup_interval: 30s
|
||||
|
||||
@ -57,5 +57,10 @@
|
||||
"vite-plugin-checker": "^0.9.1",
|
||||
"vitest": "^2.1.9",
|
||||
"vue-tsc": "^2.2.0"
|
||||
},
|
||||
"pnpm": {
|
||||
"overrides": {
|
||||
"js-cookie": "3.0.7"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
15
frontend/pnpm-lock.yaml
generated
15
frontend/pnpm-lock.yaml
generated
@ -4,6 +4,9 @@ settings:
|
||||
autoInstallPeers: true
|
||||
excludeLinksFromLockfile: false
|
||||
|
||||
overrides:
|
||||
js-cookie: 3.0.7
|
||||
|
||||
importers:
|
||||
|
||||
.:
|
||||
@ -2882,9 +2885,9 @@ packages:
|
||||
engines: {node: '>=14'}
|
||||
hasBin: true
|
||||
|
||||
js-cookie@3.0.5:
|
||||
resolution: {integrity: sha512-cEiJEAEoIbWfCZYKWhVwFuvPX1gETRYPw6LlaTKoxD3s2AkXzkCjnp6h0V77ozyqj0jakteJ4YqDJT830+lVGw==}
|
||||
engines: {node: '>=14'}
|
||||
js-cookie@3.0.7:
|
||||
resolution: {integrity: sha512-z/wZZgDrkNV1eA0ULjM/F9/50Ya8fbzgKneSpoPsXSGd0KnpdtHfOZWK+GcwLk+EZbS4F9RBhU+K2RgzuDaItw==}
|
||||
engines: {node: '>=20'}
|
||||
|
||||
js-tokens@4.0.0:
|
||||
resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==}
|
||||
@ -6365,7 +6368,7 @@ snapshots:
|
||||
'@types/js-cookie': 3.0.6
|
||||
dayjs: 1.11.20
|
||||
intersection-observer: 0.12.2
|
||||
js-cookie: 3.0.5
|
||||
js-cookie: 3.0.7
|
||||
lodash: 4.18.1
|
||||
react: 19.2.3
|
||||
react-dom: 19.2.3(react@19.2.3)
|
||||
@ -7656,10 +7659,10 @@ snapshots:
|
||||
config-chain: 1.1.13
|
||||
editorconfig: 1.0.4
|
||||
glob: 10.5.0
|
||||
js-cookie: 3.0.5
|
||||
js-cookie: 3.0.7
|
||||
nopt: 7.2.1
|
||||
|
||||
js-cookie@3.0.5: {}
|
||||
js-cookie@3.0.7: {}
|
||||
|
||||
js-tokens@4.0.0: {}
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ import { apiClient } from '../client'
|
||||
import type {
|
||||
RedeemCode,
|
||||
GenerateRedeemCodesRequest,
|
||||
BatchUpdateRedeemCodeFields,
|
||||
RedeemCodeType,
|
||||
PaginatedResponse
|
||||
} from '@/types'
|
||||
@ -23,7 +24,7 @@ export async function list(
|
||||
pageSize: number = 20,
|
||||
filters?: {
|
||||
type?: RedeemCodeType
|
||||
status?: 'active' | 'used' | 'expired' | 'unused'
|
||||
status?: 'active' | 'used' | 'expired' | 'unused' | 'disabled'
|
||||
search?: string
|
||||
sort_by?: string
|
||||
sort_order?: 'asc' | 'desc'
|
||||
@ -118,6 +119,26 @@ export async function batchDelete(ids: number[]): Promise<{
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Batch update selected redeem code fields
|
||||
* @param ids - Array of redeem code IDs
|
||||
* @param fields - Field collection to update
|
||||
* @returns Updated count
|
||||
*/
|
||||
export async function batchUpdate(
|
||||
ids: number[],
|
||||
fields: BatchUpdateRedeemCodeFields
|
||||
): Promise<{
|
||||
updated: number
|
||||
message: string
|
||||
}> {
|
||||
const { data } = await apiClient.post<{
|
||||
updated: number
|
||||
message: string
|
||||
}>('/admin/redeem-codes/batch-update', { ids, fields })
|
||||
return data
|
||||
}
|
||||
|
||||
/**
|
||||
* Expire redeem code
|
||||
* @param id - Redeem code ID
|
||||
@ -158,7 +179,7 @@ export async function getStats(): Promise<{
|
||||
*/
|
||||
export async function exportCodes(filters?: {
|
||||
type?: RedeemCodeType
|
||||
status?: 'used' | 'expired' | 'unused'
|
||||
status?: 'used' | 'expired' | 'unused' | 'disabled'
|
||||
search?: string
|
||||
sort_by?: string
|
||||
sort_order?: 'asc' | 'desc'
|
||||
@ -176,6 +197,7 @@ export const redeemAPI = {
|
||||
generate,
|
||||
delete: deleteCode,
|
||||
batchDelete,
|
||||
batchUpdate,
|
||||
expire,
|
||||
getStats,
|
||||
exportCodes
|
||||
|
||||
@ -2,6 +2,12 @@ import { apiClient } from '../client'
|
||||
|
||||
export type ModerationMode = 'off' | 'observe' | 'pre_block'
|
||||
export type KeywordBlockingMode = 'keyword_only' | 'keyword_and_api' | 'api_only'
|
||||
export type ContentModerationModelFilterType = 'all' | 'include' | 'exclude'
|
||||
|
||||
export interface ContentModerationModelFilter {
|
||||
type: ContentModerationModelFilterType
|
||||
models: string[]
|
||||
}
|
||||
|
||||
export interface ContentModerationConfig {
|
||||
enabled: boolean
|
||||
@ -32,6 +38,7 @@ export interface ContentModerationConfig {
|
||||
pre_hash_check_enabled: boolean
|
||||
blocked_keywords: string[]
|
||||
keyword_blocking_mode: KeywordBlockingMode
|
||||
model_filter: ContentModerationModelFilter
|
||||
}
|
||||
|
||||
export type ContentModerationAPIKeyStatusValue = 'unknown' | 'ok' | 'error' | 'frozen'
|
||||
@ -105,6 +112,7 @@ export interface UpdateContentModerationConfig {
|
||||
pre_hash_check_enabled?: boolean
|
||||
blocked_keywords?: string[]
|
||||
keyword_blocking_mode?: KeywordBlockingMode
|
||||
model_filter?: ContentModerationModelFilter
|
||||
}
|
||||
|
||||
export interface ContentModerationRuntimeStatus {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user