diff --git a/Dockerfile b/Dockerfile
index aacc7bfb..91964517 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -22,8 +22,8 @@ RUN sed -i 's#https://dl-cdn.alpinelinux.org/alpine#https://mirrors.aliyun.com/a
WORKDIR /app/frontend
-# Install pnpm. Keep this on v9 so Docker builds do not fail on pnpm v10's
-# interactive build-script approval flow.
+# Install pnpm. Pin to a specific v9 patch to dodge pnpm v10's interactive
+# build-script approval flow and keep Docker builds reproducible.
RUN corepack enable && corepack prepare pnpm@9.15.9 --activate
# Install dependencies first (better caching)
diff --git a/README.md b/README.md
index bdb09d15..add1b4eb 100644
--- a/README.md
+++ b/README.md
@@ -62,13 +62,18 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
-
-Thanks to Poixe Ai for sponsoring this project! Poixe AI provides reliable LLM API services. You can leverage the platform's API endpoints to seamlessly build AI-powered products. Additionally, you can become a vendor by providing AI API resources to the platform and earn revenue. Register through the exclusive sub2api referral link and receive a bonus of $5 USD on your first top-up.
+
+Thanks to CTok.ai for sponsoring this project! CTok.ai is dedicated to building a one-stop AI programming tool service platform. We offer professional Claude Code packages and technical community services, with support for Google Gemini and OpenAI Codex. Through carefully designed plans and a professional tech community, we provide developers with reliable service guarantees and continuous technical support, making AI-assisted programming a true productivity tool. Click here to register!
-
-Thanks to CTok.ai for sponsoring this project! CTok.ai is dedicated to building a one-stop AI programming tool service platform. We offer professional Claude Code packages and technical community services, with support for Google Gemini and OpenAI Codex. Through carefully designed plans and a professional tech community, we provide developers with reliable service guarantees and continuous technical support, making AI-assisted programming a true productivity tool. Click here to register!
+
+Thanks to AIGoCode for sponsoring this project! AIGoCode is an all-in-one platform that integrates Claude Code, Codex, and the latest Gemini models, providing you with stable, efficient, and highly cost-effective AI coding services. The platform offers flexible subscription plans, zero risk of account suspension, direct access with no VPN required, and lightning-fast responses. AIGoCode has prepared a special benefit for sub2api users: if you register via this link , you'll receive an extra 10% bonus credit on your first top-up!
+
+
+
+
+Thanks to APIKEY.FUN for sponsoring this project! APIKEY.FUN is one of the core contributors to the sub2api open-source project, dedicated to providing open, stable, and cost-effective AI API access. The platform supports API relay services for Claude, OpenAI, Gemini, and other popular models, with pricing starting from as low as 7% of the original rate. Register via the exclusive link: APIKEY to enjoy a permanent 5% discount on all recharges.
@@ -86,11 +91,6 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for sub2api users: register via this link to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off!
-
-
-Thanks to AIGoCode for sponsoring this project! AIGoCode is an all-in-one platform that integrates Claude Code, Codex, and the latest Gemini models, providing you with stable, efficient, and highly cost-effective AI coding services. The platform offers flexible subscription plans, zero risk of account suspension, direct access with no VPN required, and lightning-fast responses. AIGoCode has prepared a special benefit for sub2api users: if you register via this link , you'll receive an extra 10% bonus credit on your first top-up!
-
-
Huge thanks to BmoPlus for sponsoring this project! BmoPlus is a highly reliable AI account provider built strictly for heavy AI users and developers. They offer rock-solid, ready-to-use accounts and official top-up services for ChatGPT Plus / ChatGPT Pro (Full Warranty) / Claude Pro / Super Grok / Gemini Pro. By registering and ordering through BmoPlus - Premium AI Accounts & Top-ups , users can unlock the mind-blowing rate of 10% of the official GPT subscription price (90% OFF)
@@ -108,6 +108,12 @@ Enterprise-grade high concurrency is also supported, with a dedicated management
Register now via this link to receive $3 in trial credits. User top-ups start as low as 60% off, and referring friends earns both parties rewards — referral bonuses up to $150.
+
+
+Thanks to PPToken.org for sponsoring this project! PPToken.org specializes in GPT model API relay services, supporting Codex, Claude Code, OpenAI-compatible clients, and Gemini CLI integration. Top-ups are 1:1 (¥1 = $1 credit); GPT models start at 0.16x rate multiplier, with overall cost at roughly 2.2% of official pricing and first-token latency around 1 second — ideal for developers seeking low-cost, high-speed access to GPT model capabilities. Technical support: 24/7 real human responses (no bots), @tech in the group chat and get a reply within 10 minutes. Sponsor benefit: the first 200 users who register via the exclusive registration link and enter promo code `SUB2API` can claim free Codex / Claude Code trial credits — no minimum spend, no card required.
+
+
+
## Ecosystem
diff --git a/README_CN.md b/README_CN.md
index e13f86de..67340969 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -61,13 +61,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
-
-感谢 Poixe AI 赞助了本项目!Poixe AI 提供可靠的 AI 模型接口服务,您可以使用平台提供的 LLM API 接口轻松构建 AI 产品,同时也可以成为供应商,为平台提供大模型资源以赚取收益。通过 此链接 专属链接注册,充值额外赠送 $5 美金
+
+感谢 CTok.ai 赞助了本项目!CTok.ai 致力于打造一站式 AI 编程工具服务平台。我们提供 Claude Code 专业套餐及技术社群服务,同时支持 Google Gemini 和 OpenAI Codex。通过精心设计的套餐方案和专业的技术社群,为开发者提供稳定的服务保障和持续的技术支持,让 AI 辅助编程真正成为开发者的生产力工具。点击这里 注册!
-
-感谢 CTok.ai 赞助了本项目!CTok.ai 致力于打造一站式 AI 编程工具服务平台。我们提供 Claude Code 专业套餐及技术社群服务,同时支持 Google Gemini 和 OpenAI Codex。通过精心设计的套餐方案和专业的技术社群,为开发者提供稳定的服务保障和持续的技术支持,让 AI 辅助编程真正成为开发者的生产力工具。点击这里 注册!
+
+感谢 AIGoCode 赞助了本项目!AIGoCode 是一站式集成 Claude Code、Codex 以及最新 Gemini 模型的综合平台,为您提供稳定、高效、高性价比的 AI 编程服务。平台提供灵活的订阅方案,零封号风险,免 VPN 直连,响应极速。AIGoCode 为 sub2api 用户准备了专属福利:通过此链接 注册,首次充值可额外获得 10% 赠送额度!
+
+
+
+
+感谢 APIKEY.FUN 赞助了本项目!APIKEY.FUN 是 sub2api 开源项目的核心贡献者之一,致力于提供开放、稳定、高性价比的 AI API 接入服务。平台支持 Claude、OpenAI、Gemini 等热门模型的 API 中转服务,价格低至官方原价的 7%。通过专属链接 APIKEY 注册,可享受所有充值永久 95 折优惠。
@@ -85,11 +90,6 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定性中转服务,企业级并发、快速开票、7×24 小时专属技术支持。Claude Code / Codex / Gemini 官方通道低至原价 38% / 2% / 9%,充值更享额外折扣!AICodeMirror 为 sub2api 用户提供专属福利:通过此链接 注册,首次充值立享 8 折优惠,企业客户最高可享 75 折!
-
-
-感谢 AIGoCode 赞助了本项目!AIGoCode 是一站式集成 Claude Code、Codex 以及最新 Gemini 模型的综合平台,为您提供稳定、高效、高性价比的 AI 编程服务。平台提供灵活的订阅方案,零封号风险,免 VPN 直连,响应极速。AIGoCode 为 sub2api 用户准备了专属福利:通过此链接 注册,首次充值可额外获得 10% 赠送额度!
-
-
感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过BmoPlus AI成品号专卖/代充 注册下单的用户,可享GPT 官网订阅一折 的震撼价格!
@@ -107,6 +107,11 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
现在通过 此链接 注册即送 $3 试用额度,用户充值低至 6 折,邀请好友双向赠送,邀请奖励可达 $150。
+
+
+感谢 PPToken.org 赞助本项目! PPToken.org 主打 GPT 系列模型 API 中转服务,支持 Codex、Claude Code、OpenAI 兼容客户端及 Gemini CLI 等工具接入。充值 1:1,1 元=1 美元额度;GPT 模型最低 0.16 倍倍率,综合成本约为官方价格的 0.22 折,最快首字 Token 约 1 秒,适合开发者低成本、高响应速度接入 GPT 模型能力。技术支持: 7×24 小时真人响应(不是机器人),群内@技术,10 分钟内有回复 。赞助商福利:前 200 名用户通过 [专属注册链接] 注册,输入优惠码 `SUB2API`,可领取 Codex / Claude Code 免费试用额度,无门槛、不绑卡。
+
+
## 生态项目
diff --git a/README_JA.md b/README_JA.md
index 73331a07..13d710df 100644
--- a/README_JA.md
+++ b/README_JA.md
@@ -61,13 +61,18 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
-
-Poixe AI のご支援に感謝します!Poixe AI は信頼性の高い LLM API サービスを提供しています。プラットフォームの API エンドポイントを活用して、AI 搭載プロダクトをシームレスに構築できます。また、ベンダーとして AI API リソースをプラットフォームに提供し、収益を得ることも可能です。専用の sub2api 紹介リンクから登録すると、初回チャージ時に $5 USD のボーナスがもらえます。
+
+CTok.ai のご支援に感謝します!CTok.ai はワンストップ AI プログラミングツールサービスプラットフォームの構築に取り組んでいます。Claude Code の専用プランと技術コミュニティサービスを提供し、Google Gemini や OpenAI Codex もサポートしています。丁寧に設計されたプランと専門的な技術コミュニティを通じて、開発者に安定したサービス保証と継続的な技術サポートを提供し、AI アシスト プログラミングを真の生産性向上ツールにします。こちら から登録!
-
-CTok.ai のご支援に感謝します!CTok.ai はワンストップ AI プログラミングツールサービスプラットフォームの構築に取り組んでいます。Claude Code の専用プランと技術コミュニティサービスを提供し、Google Gemini や OpenAI Codex もサポートしています。丁寧に設計されたプランと専門的な技術コミュニティを通じて、開発者に安定したサービス保証と継続的な技術サポートを提供し、AI アシスト プログラミングを真の生産性向上ツールにします。こちら から登録!
+
+AIGoCode のご支援に感謝します!AIGoCode は Claude Code、Codex、最新の Gemini モデルを統合したオールインワンプラットフォームで、安定的かつ効率的でコストパフォーマンスに優れた AI コーディングサービスを提供します。柔軟なサブスクリプションプラン、アカウント停止リスクゼロ、VPN 不要の直接アクセス、超高速レスポンスが特長です。AIGoCode は sub2api ユーザー向けに特別特典を用意しています:こちらのリンク から登録すると、初回チャージ時に 10% のボーナスクレジットを追加プレゼント!
+
+
+
+
+APIKEY.FUN のご支援に感謝します!APIKEY.FUN は sub2api オープンソースプロジェクトのコアコントリビューターの一つであり、オープンで安定した、コストパフォーマンスに優れた AI API アクセスサービスの提供に取り組んでいます。プラットフォームは Claude、OpenAI、Gemini など人気モデルの API 中継サービスをサポートし、価格は公式料金のわずか 7% から。専用リンク APIKEY から登録すると、すべてのチャージで永久 5% 割引をご利用いただけます。
@@ -85,11 +90,6 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
AICodeMirror のご支援に感謝します!AICodeMirror は Claude Code / Codex / Gemini CLI の公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時実行、迅速な請求書発行、24時間年中無休の専属テクニカルサポートを備えています。Claude Code / Codex / Gemini の公式チャネルを定価の 38% / 2% / 9% で利用可能、チャージ時にはさらに追加割引!AICodeMirror は sub2api ユーザー向けに特別特典を提供中:こちらのリンク から登録すると、初回チャージが 20% オフ、法人のお客様は最大 25% オフ!
-
-
-AIGoCode のご支援に感謝します!AIGoCode は Claude Code、Codex、最新の Gemini モデルを統合したオールインワンプラットフォームで、安定的かつ効率的でコストパフォーマンスに優れた AI コーディングサービスを提供します。柔軟なサブスクリプションプラン、アカウント停止リスクゼロ、VPN 不要の直接アクセス、超高速レスポンスが特長です。AIGoCode は sub2api ユーザー向けに特別特典を用意しています:こちらのリンク から登録すると、初回チャージ時に 10% のボーナスクレジットを追加プレゼント!
-
-
本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらのBmoPlus AIアカウント専門店/代行チャージ 経由でご登録・ご注文いただいたユーザー様は、GPTを 公式サイト価格の約1割(90% OFF) という驚異的な価格でご利用いただけます!
@@ -107,6 +107,12 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
こちらのリンク から登録すると、$3 のトライアルクレジットがもらえます。チャージは最大40%オフ、友達紹介で双方にボーナス付与 — 紹介報酬は最大 $150。
+
+
+PPToken.org のご支援に感謝します!PPToken.org は GPT シリーズモデルの API 中継サービスを専門としており、Codex、Claude Code、OpenAI 互換クライアント、Gemini CLI などのツール接続をサポートしています。チャージは 1:1(1元=1ドル分のクレジット)、GPT モデルは最低 0.16 倍のレート倍率で、総合コストは公式価格の約 2.2% 、最速ファーストトークンは約1秒 — 開発者が低コスト・高速レスポンスで GPT モデル機能にアクセスするのに最適です。テクニカルサポート:24時間365日リアルな人間が対応(ボットではありません)、グループ内で @技術 すれば 10 分以内に返信。スポンサー特典:先着 200 名のユーザーが専用登録リンク から登録し、プロモコード `SUB2API` を入力すると、Codex / Claude Code の無料トライアルクレジットを獲得できます — 最低利用額なし、カード登録不要。
+
+
+
## エコシステム
diff --git a/assets/partners/logos/apikey-fun.png b/assets/partners/logos/apikey-fun.png
new file mode 100644
index 00000000..45687b25
Binary files /dev/null and b/assets/partners/logos/apikey-fun.png differ
diff --git a/assets/partners/logos/pptoken.png b/assets/partners/logos/pptoken.png
new file mode 100644
index 00000000..c199e6e9
Binary files /dev/null and b/assets/partners/logos/pptoken.png differ
diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION
index 9e7e837e..74799d81 100644
--- a/backend/cmd/server/VERSION
+++ b/backend/cmd/server/VERSION
@@ -1 +1 @@
-0.1.126
+0.1.127
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index caba2fc2..9afc20ec 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -81,7 +81,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
}
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
- authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
+ userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
+ userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
+ userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
+ authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService, userAttributeService)
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
@@ -202,7 +205,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
registry := payment.ProvideRegistry()
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
- settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
+ settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService)
requestEventBus := service.NewRequestEventBus()
opsLogBroadcaster := service.ProvideOpsLogBroadcaster()
opsHandler := admin.NewOpsHandler(opsService, requestEventBus, opsLogBroadcaster)
@@ -217,9 +220,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
usageCleanupRepository := repository.NewUsageCleanupRepository(client, db)
usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService, usageCleanupService)
- userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
- userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
- userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
errorPassthroughRepository := repository.NewErrorPassthroughRepository(client)
errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient)
@@ -231,7 +231,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
- channelHandler := admin.NewChannelHandler(channelService, billingService)
+ channelHandler := admin.NewChannelHandler(channelService, billingService, pricingService)
channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db)
channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)
diff --git a/backend/ent/channelmonitor.go b/backend/ent/channelmonitor.go
index dbb73362..defd06c6 100644
--- a/backend/ent/channelmonitor.go
+++ b/backend/ent/channelmonitor.go
@@ -27,6 +27,8 @@ type ChannelMonitor struct {
Name string `json:"name,omitempty"`
// Provider holds the value of the "provider" field.
Provider channelmonitor.Provider `json:"provider,omitempty"`
+ // OpenAI request protocol: chat_completions or responses; non-OpenAI uses chat_completions
+ APIMode string `json:"api_mode,omitempty"`
// Provider base origin, e.g. https://api.openai.com
Endpoint string `json:"endpoint,omitempty"`
// AES-256-GCM encrypted API key
@@ -112,7 +114,7 @@ func (*ChannelMonitor) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case channelmonitor.FieldID, channelmonitor.FieldIntervalSeconds, channelmonitor.FieldCreatedBy, channelmonitor.FieldTemplateID:
values[i] = new(sql.NullInt64)
- case channelmonitor.FieldName, channelmonitor.FieldProvider, channelmonitor.FieldEndpoint, channelmonitor.FieldAPIKeyEncrypted, channelmonitor.FieldPrimaryModel, channelmonitor.FieldGroupName, channelmonitor.FieldBodyOverrideMode:
+ case channelmonitor.FieldName, channelmonitor.FieldProvider, channelmonitor.FieldAPIMode, channelmonitor.FieldEndpoint, channelmonitor.FieldAPIKeyEncrypted, channelmonitor.FieldPrimaryModel, channelmonitor.FieldGroupName, channelmonitor.FieldBodyOverrideMode:
values[i] = new(sql.NullString)
case channelmonitor.FieldCreatedAt, channelmonitor.FieldUpdatedAt, channelmonitor.FieldLastCheckedAt:
values[i] = new(sql.NullTime)
@@ -161,6 +163,12 @@ func (_m *ChannelMonitor) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Provider = channelmonitor.Provider(value.String)
}
+ case channelmonitor.FieldAPIMode:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field api_mode", values[i])
+ } else if value.Valid {
+ _m.APIMode = value.String
+ }
case channelmonitor.FieldEndpoint:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field endpoint", values[i])
@@ -310,6 +318,9 @@ func (_m *ChannelMonitor) String() string {
builder.WriteString("provider=")
builder.WriteString(fmt.Sprintf("%v", _m.Provider))
builder.WriteString(", ")
+ builder.WriteString("api_mode=")
+ builder.WriteString(_m.APIMode)
+ builder.WriteString(", ")
builder.WriteString("endpoint=")
builder.WriteString(_m.Endpoint)
builder.WriteString(", ")
diff --git a/backend/ent/channelmonitor/channelmonitor.go b/backend/ent/channelmonitor/channelmonitor.go
index e5a6bfe7..0723ad0d 100644
--- a/backend/ent/channelmonitor/channelmonitor.go
+++ b/backend/ent/channelmonitor/channelmonitor.go
@@ -23,6 +23,8 @@ const (
FieldName = "name"
// FieldProvider holds the string denoting the provider field in the database.
FieldProvider = "provider"
+ // FieldAPIMode holds the string denoting the api_mode field in the database.
+ FieldAPIMode = "api_mode"
// FieldEndpoint holds the string denoting the endpoint field in the database.
FieldEndpoint = "endpoint"
// FieldAPIKeyEncrypted holds the string denoting the api_key_encrypted field in the database.
@@ -87,6 +89,7 @@ var Columns = []string{
FieldUpdatedAt,
FieldName,
FieldProvider,
+ FieldAPIMode,
FieldEndpoint,
FieldAPIKeyEncrypted,
FieldPrimaryModel,
@@ -121,6 +124,10 @@ var (
UpdateDefaultUpdatedAt func() time.Time
// NameValidator is a validator for the "name" field. It is called by the builders before save.
NameValidator func(string) error
+ // DefaultAPIMode holds the default value on creation for the "api_mode" field.
+ DefaultAPIMode string
+ // APIModeValidator is a validator for the "api_mode" field. It is called by the builders before save.
+ APIModeValidator func(string) error
// EndpointValidator is a validator for the "endpoint" field. It is called by the builders before save.
EndpointValidator func(string) error
// APIKeyEncryptedValidator is a validator for the "api_key_encrypted" field. It is called by the builders before save.
@@ -197,6 +204,11 @@ func ByProvider(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldProvider, opts...).ToFunc()
}
+// ByAPIMode orders the results by the api_mode field.
+func ByAPIMode(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAPIMode, opts...).ToFunc()
+}
+
// ByEndpoint orders the results by the endpoint field.
func ByEndpoint(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldEndpoint, opts...).ToFunc()
diff --git a/backend/ent/channelmonitor/where.go b/backend/ent/channelmonitor/where.go
index 755d83a3..8bd8f627 100644
--- a/backend/ent/channelmonitor/where.go
+++ b/backend/ent/channelmonitor/where.go
@@ -70,6 +70,11 @@ func Name(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldEQ(FieldName, v))
}
+// APIMode applies equality check predicate on the "api_mode" field. It's identical to APIModeEQ.
+func APIMode(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldAPIMode, v))
+}
+
// Endpoint applies equality check predicate on the "endpoint" field. It's identical to EndpointEQ.
func Endpoint(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldEQ(FieldEndpoint, v))
@@ -285,6 +290,71 @@ func ProviderNotIn(vs ...Provider) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldNotIn(FieldProvider, vs...))
}
+// APIModeEQ applies the EQ predicate on the "api_mode" field.
+func APIModeEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEQ(FieldAPIMode, v))
+}
+
+// APIModeNEQ applies the NEQ predicate on the "api_mode" field.
+func APIModeNEQ(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNEQ(FieldAPIMode, v))
+}
+
+// APIModeIn applies the In predicate on the "api_mode" field.
+func APIModeIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldIn(FieldAPIMode, vs...))
+}
+
+// APIModeNotIn applies the NotIn predicate on the "api_mode" field.
+func APIModeNotIn(vs ...string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldNotIn(FieldAPIMode, vs...))
+}
+
+// APIModeGT applies the GT predicate on the "api_mode" field.
+func APIModeGT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGT(FieldAPIMode, v))
+}
+
+// APIModeGTE applies the GTE predicate on the "api_mode" field.
+func APIModeGTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldGTE(FieldAPIMode, v))
+}
+
+// APIModeLT applies the LT predicate on the "api_mode" field.
+func APIModeLT(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLT(FieldAPIMode, v))
+}
+
+// APIModeLTE applies the LTE predicate on the "api_mode" field.
+func APIModeLTE(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldLTE(FieldAPIMode, v))
+}
+
+// APIModeContains applies the Contains predicate on the "api_mode" field.
+func APIModeContains(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContains(FieldAPIMode, v))
+}
+
+// APIModeHasPrefix applies the HasPrefix predicate on the "api_mode" field.
+func APIModeHasPrefix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldAPIMode, v))
+}
+
+// APIModeHasSuffix applies the HasSuffix predicate on the "api_mode" field.
+func APIModeHasSuffix(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldAPIMode, v))
+}
+
+// APIModeEqualFold applies the EqualFold predicate on the "api_mode" field.
+func APIModeEqualFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldEqualFold(FieldAPIMode, v))
+}
+
+// APIModeContainsFold applies the ContainsFold predicate on the "api_mode" field.
+func APIModeContainsFold(v string) predicate.ChannelMonitor {
+ return predicate.ChannelMonitor(sql.FieldContainsFold(FieldAPIMode, v))
+}
+
// EndpointEQ applies the EQ predicate on the "endpoint" field.
func EndpointEQ(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldEQ(FieldEndpoint, v))
diff --git a/backend/ent/channelmonitor_create.go b/backend/ent/channelmonitor_create.go
index 2f70c300..2593893f 100644
--- a/backend/ent/channelmonitor_create.go
+++ b/backend/ent/channelmonitor_create.go
@@ -65,6 +65,20 @@ func (_c *ChannelMonitorCreate) SetProvider(v channelmonitor.Provider) *ChannelM
return _c
}
+// SetAPIMode sets the "api_mode" field.
+func (_c *ChannelMonitorCreate) SetAPIMode(v string) *ChannelMonitorCreate {
+ _c.mutation.SetAPIMode(v)
+ return _c
+}
+
+// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
+func (_c *ChannelMonitorCreate) SetNillableAPIMode(v *string) *ChannelMonitorCreate {
+ if v != nil {
+ _c.SetAPIMode(*v)
+ }
+ return _c
+}
+
// SetEndpoint sets the "endpoint" field.
func (_c *ChannelMonitorCreate) SetEndpoint(v string) *ChannelMonitorCreate {
_c.mutation.SetEndpoint(v)
@@ -275,6 +289,10 @@ func (_c *ChannelMonitorCreate) defaults() {
v := channelmonitor.DefaultUpdatedAt()
_c.mutation.SetUpdatedAt(v)
}
+ if _, ok := _c.mutation.APIMode(); !ok {
+ v := channelmonitor.DefaultAPIMode
+ _c.mutation.SetAPIMode(v)
+ }
if _, ok := _c.mutation.ExtraModels(); !ok {
v := channelmonitor.DefaultExtraModels
_c.mutation.SetExtraModels(v)
@@ -321,6 +339,14 @@ func (_c *ChannelMonitorCreate) check() error {
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)}
}
}
+ if _, ok := _c.mutation.APIMode(); !ok {
+ return &ValidationError{Name: "api_mode", err: errors.New(`ent: missing required field "ChannelMonitor.api_mode"`)}
+ }
+ if v, ok := _c.mutation.APIMode(); ok {
+ if err := channelmonitor.APIModeValidator(v); err != nil {
+ return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_mode": %w`, err)}
+ }
+ }
if _, ok := _c.mutation.Endpoint(); !ok {
return &ValidationError{Name: "endpoint", err: errors.New(`ent: missing required field "ChannelMonitor.endpoint"`)}
}
@@ -421,6 +447,10 @@ func (_c *ChannelMonitorCreate) createSpec() (*ChannelMonitor, *sqlgraph.CreateS
_spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value)
_node.Provider = value
}
+ if value, ok := _c.mutation.APIMode(); ok {
+ _spec.SetField(channelmonitor.FieldAPIMode, field.TypeString, value)
+ _node.APIMode = value
+ }
if value, ok := _c.mutation.Endpoint(); ok {
_spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value)
_node.Endpoint = value
@@ -606,6 +636,18 @@ func (u *ChannelMonitorUpsert) UpdateProvider() *ChannelMonitorUpsert {
return u
}
+// SetAPIMode sets the "api_mode" field.
+func (u *ChannelMonitorUpsert) SetAPIMode(v string) *ChannelMonitorUpsert {
+ u.Set(channelmonitor.FieldAPIMode, v)
+ return u
+}
+
+// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
+func (u *ChannelMonitorUpsert) UpdateAPIMode() *ChannelMonitorUpsert {
+ u.SetExcluded(channelmonitor.FieldAPIMode)
+ return u
+}
+
// SetEndpoint sets the "endpoint" field.
func (u *ChannelMonitorUpsert) SetEndpoint(v string) *ChannelMonitorUpsert {
u.Set(channelmonitor.FieldEndpoint, v)
@@ -885,6 +927,20 @@ func (u *ChannelMonitorUpsertOne) UpdateProvider() *ChannelMonitorUpsertOne {
})
}
+// SetAPIMode sets the "api_mode" field.
+func (u *ChannelMonitorUpsertOne) SetAPIMode(v string) *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetAPIMode(v)
+ })
+}
+
+// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertOne) UpdateAPIMode() *ChannelMonitorUpsertOne {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateAPIMode()
+ })
+}
+
// SetEndpoint sets the "endpoint" field.
func (u *ChannelMonitorUpsertOne) SetEndpoint(v string) *ChannelMonitorUpsertOne {
return u.Update(func(s *ChannelMonitorUpsert) {
@@ -1362,6 +1418,20 @@ func (u *ChannelMonitorUpsertBulk) UpdateProvider() *ChannelMonitorUpsertBulk {
})
}
+// SetAPIMode sets the "api_mode" field.
+func (u *ChannelMonitorUpsertBulk) SetAPIMode(v string) *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.SetAPIMode(v)
+ })
+}
+
+// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
+func (u *ChannelMonitorUpsertBulk) UpdateAPIMode() *ChannelMonitorUpsertBulk {
+ return u.Update(func(s *ChannelMonitorUpsert) {
+ s.UpdateAPIMode()
+ })
+}
+
// SetEndpoint sets the "endpoint" field.
func (u *ChannelMonitorUpsertBulk) SetEndpoint(v string) *ChannelMonitorUpsertBulk {
return u.Update(func(s *ChannelMonitorUpsert) {
diff --git a/backend/ent/channelmonitor_update.go b/backend/ent/channelmonitor_update.go
index 4bbcd564..2cd5e656 100644
--- a/backend/ent/channelmonitor_update.go
+++ b/backend/ent/channelmonitor_update.go
@@ -66,6 +66,20 @@ func (_u *ChannelMonitorUpdate) SetNillableProvider(v *channelmonitor.Provider)
return _u
}
+// SetAPIMode sets the "api_mode" field.
+func (_u *ChannelMonitorUpdate) SetAPIMode(v string) *ChannelMonitorUpdate {
+ _u.mutation.SetAPIMode(v)
+ return _u
+}
+
+// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
+func (_u *ChannelMonitorUpdate) SetNillableAPIMode(v *string) *ChannelMonitorUpdate {
+ if v != nil {
+ _u.SetAPIMode(*v)
+ }
+ return _u
+}
+
// SetEndpoint sets the "endpoint" field.
func (_u *ChannelMonitorUpdate) SetEndpoint(v string) *ChannelMonitorUpdate {
_u.mutation.SetEndpoint(v)
@@ -418,6 +432,11 @@ func (_u *ChannelMonitorUpdate) check() error {
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)}
}
}
+ if v, ok := _u.mutation.APIMode(); ok {
+ if err := channelmonitor.APIModeValidator(v); err != nil {
+ return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_mode": %w`, err)}
+ }
+ }
if v, ok := _u.mutation.Endpoint(); ok {
if err := channelmonitor.EndpointValidator(v); err != nil {
return &ValidationError{Name: "endpoint", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.endpoint": %w`, err)}
@@ -472,6 +491,9 @@ func (_u *ChannelMonitorUpdate) sqlSave(ctx context.Context) (_node int, err err
if value, ok := _u.mutation.Provider(); ok {
_spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value)
}
+ if value, ok := _u.mutation.APIMode(); ok {
+ _spec.SetField(channelmonitor.FieldAPIMode, field.TypeString, value)
+ }
if value, ok := _u.mutation.Endpoint(); ok {
_spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value)
}
@@ -701,6 +723,20 @@ func (_u *ChannelMonitorUpdateOne) SetNillableProvider(v *channelmonitor.Provide
return _u
}
+// SetAPIMode sets the "api_mode" field.
+func (_u *ChannelMonitorUpdateOne) SetAPIMode(v string) *ChannelMonitorUpdateOne {
+ _u.mutation.SetAPIMode(v)
+ return _u
+}
+
+// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
+func (_u *ChannelMonitorUpdateOne) SetNillableAPIMode(v *string) *ChannelMonitorUpdateOne {
+ if v != nil {
+ _u.SetAPIMode(*v)
+ }
+ return _u
+}
+
// SetEndpoint sets the "endpoint" field.
func (_u *ChannelMonitorUpdateOne) SetEndpoint(v string) *ChannelMonitorUpdateOne {
_u.mutation.SetEndpoint(v)
@@ -1066,6 +1102,11 @@ func (_u *ChannelMonitorUpdateOne) check() error {
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)}
}
}
+ if v, ok := _u.mutation.APIMode(); ok {
+ if err := channelmonitor.APIModeValidator(v); err != nil {
+ return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_mode": %w`, err)}
+ }
+ }
if v, ok := _u.mutation.Endpoint(); ok {
if err := channelmonitor.EndpointValidator(v); err != nil {
return &ValidationError{Name: "endpoint", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.endpoint": %w`, err)}
@@ -1137,6 +1178,9 @@ func (_u *ChannelMonitorUpdateOne) sqlSave(ctx context.Context) (_node *ChannelM
if value, ok := _u.mutation.Provider(); ok {
_spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value)
}
+ if value, ok := _u.mutation.APIMode(); ok {
+ _spec.SetField(channelmonitor.FieldAPIMode, field.TypeString, value)
+ }
if value, ok := _u.mutation.Endpoint(); ok {
_spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value)
}
diff --git a/backend/ent/channelmonitorrequesttemplate.go b/backend/ent/channelmonitorrequesttemplate.go
index b8429a4d..7f417efb 100644
--- a/backend/ent/channelmonitorrequesttemplate.go
+++ b/backend/ent/channelmonitorrequesttemplate.go
@@ -26,6 +26,8 @@ type ChannelMonitorRequestTemplate struct {
Name string `json:"name,omitempty"`
// Provider holds the value of the "provider" field.
Provider channelmonitorrequesttemplate.Provider `json:"provider,omitempty"`
+ // OpenAI request protocol: chat_completions or responses; non-OpenAI uses chat_completions
+ APIMode string `json:"api_mode,omitempty"`
// Description holds the value of the "description" field.
Description string `json:"description,omitempty"`
// ExtraHeaders holds the value of the "extra_headers" field.
@@ -67,7 +69,7 @@ func (*ChannelMonitorRequestTemplate) scanValues(columns []string) ([]any, error
values[i] = new([]byte)
case channelmonitorrequesttemplate.FieldID:
values[i] = new(sql.NullInt64)
- case channelmonitorrequesttemplate.FieldName, channelmonitorrequesttemplate.FieldProvider, channelmonitorrequesttemplate.FieldDescription, channelmonitorrequesttemplate.FieldBodyOverrideMode:
+ case channelmonitorrequesttemplate.FieldName, channelmonitorrequesttemplate.FieldProvider, channelmonitorrequesttemplate.FieldAPIMode, channelmonitorrequesttemplate.FieldDescription, channelmonitorrequesttemplate.FieldBodyOverrideMode:
values[i] = new(sql.NullString)
case channelmonitorrequesttemplate.FieldCreatedAt, channelmonitorrequesttemplate.FieldUpdatedAt:
values[i] = new(sql.NullTime)
@@ -116,6 +118,12 @@ func (_m *ChannelMonitorRequestTemplate) assignValues(columns []string, values [
} else if value.Valid {
_m.Provider = channelmonitorrequesttemplate.Provider(value.String)
}
+ case channelmonitorrequesttemplate.FieldAPIMode:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field api_mode", values[i])
+ } else if value.Valid {
+ _m.APIMode = value.String
+ }
case channelmonitorrequesttemplate.FieldDescription:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field description", values[i])
@@ -197,6 +205,9 @@ func (_m *ChannelMonitorRequestTemplate) String() string {
builder.WriteString("provider=")
builder.WriteString(fmt.Sprintf("%v", _m.Provider))
builder.WriteString(", ")
+ builder.WriteString("api_mode=")
+ builder.WriteString(_m.APIMode)
+ builder.WriteString(", ")
builder.WriteString("description=")
builder.WriteString(_m.Description)
builder.WriteString(", ")
diff --git a/backend/ent/channelmonitorrequesttemplate/channelmonitorrequesttemplate.go b/backend/ent/channelmonitorrequesttemplate/channelmonitorrequesttemplate.go
index 65b8d641..db04aee1 100644
--- a/backend/ent/channelmonitorrequesttemplate/channelmonitorrequesttemplate.go
+++ b/backend/ent/channelmonitorrequesttemplate/channelmonitorrequesttemplate.go
@@ -23,6 +23,8 @@ const (
FieldName = "name"
// FieldProvider holds the string denoting the provider field in the database.
FieldProvider = "provider"
+ // FieldAPIMode holds the string denoting the api_mode field in the database.
+ FieldAPIMode = "api_mode"
// FieldDescription holds the string denoting the description field in the database.
FieldDescription = "description"
// FieldExtraHeaders holds the string denoting the extra_headers field in the database.
@@ -51,6 +53,7 @@ var Columns = []string{
FieldUpdatedAt,
FieldName,
FieldProvider,
+ FieldAPIMode,
FieldDescription,
FieldExtraHeaders,
FieldBodyOverrideMode,
@@ -76,6 +79,10 @@ var (
UpdateDefaultUpdatedAt func() time.Time
// NameValidator is a validator for the "name" field. It is called by the builders before save.
NameValidator func(string) error
+ // DefaultAPIMode holds the default value on creation for the "api_mode" field.
+ DefaultAPIMode string
+ // APIModeValidator is a validator for the "api_mode" field. It is called by the builders before save.
+ APIModeValidator func(string) error
// DefaultDescription holds the default value on creation for the "description" field.
DefaultDescription string
// DescriptionValidator is a validator for the "description" field. It is called by the builders before save.
@@ -140,6 +147,11 @@ func ByProvider(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldProvider, opts...).ToFunc()
}
+// ByAPIMode orders the results by the api_mode field.
+func ByAPIMode(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAPIMode, opts...).ToFunc()
+}
+
// ByDescription orders the results by the description field.
func ByDescription(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDescription, opts...).ToFunc()
diff --git a/backend/ent/channelmonitorrequesttemplate/where.go b/backend/ent/channelmonitorrequesttemplate/where.go
index b95e5df0..9f6d7333 100644
--- a/backend/ent/channelmonitorrequesttemplate/where.go
+++ b/backend/ent/channelmonitorrequesttemplate/where.go
@@ -70,6 +70,11 @@ func Name(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldName, v))
}
+// APIMode applies equality check predicate on the "api_mode" field. It's identical to APIModeEQ.
+func APIMode(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldAPIMode, v))
+}
+
// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
func Description(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldDescription, v))
@@ -245,6 +250,71 @@ func ProviderNotIn(vs ...Provider) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldProvider, vs...))
}
+// APIModeEQ applies the EQ predicate on the "api_mode" field.
+func APIModeEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldAPIMode, v))
+}
+
+// APIModeNEQ applies the NEQ predicate on the "api_mode" field.
+func APIModeNEQ(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldAPIMode, v))
+}
+
+// APIModeIn applies the In predicate on the "api_mode" field.
+func APIModeIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldAPIMode, vs...))
+}
+
+// APIModeNotIn applies the NotIn predicate on the "api_mode" field.
+func APIModeNotIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldAPIMode, vs...))
+}
+
+// APIModeGT applies the GT predicate on the "api_mode" field.
+func APIModeGT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldAPIMode, v))
+}
+
+// APIModeGTE applies the GTE predicate on the "api_mode" field.
+func APIModeGTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldAPIMode, v))
+}
+
+// APIModeLT applies the LT predicate on the "api_mode" field.
+func APIModeLT(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldAPIMode, v))
+}
+
+// APIModeLTE applies the LTE predicate on the "api_mode" field.
+func APIModeLTE(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldAPIMode, v))
+}
+
+// APIModeContains applies the Contains predicate on the "api_mode" field.
+func APIModeContains(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContains(FieldAPIMode, v))
+}
+
+// APIModeHasPrefix applies the HasPrefix predicate on the "api_mode" field.
+func APIModeHasPrefix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasPrefix(FieldAPIMode, v))
+}
+
+// APIModeHasSuffix applies the HasSuffix predicate on the "api_mode" field.
+func APIModeHasSuffix(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldHasSuffix(FieldAPIMode, v))
+}
+
+// APIModeEqualFold applies the EqualFold predicate on the "api_mode" field.
+func APIModeEqualFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldEqualFold(FieldAPIMode, v))
+}
+
+// APIModeContainsFold applies the ContainsFold predicate on the "api_mode" field.
+func APIModeContainsFold(v string) predicate.ChannelMonitorRequestTemplate {
+ return predicate.ChannelMonitorRequestTemplate(sql.FieldContainsFold(FieldAPIMode, v))
+}
+
// DescriptionEQ applies the EQ predicate on the "description" field.
func DescriptionEQ(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldDescription, v))
diff --git a/backend/ent/channelmonitorrequesttemplate_create.go b/backend/ent/channelmonitorrequesttemplate_create.go
index 1ba842cd..45405270 100644
--- a/backend/ent/channelmonitorrequesttemplate_create.go
+++ b/backend/ent/channelmonitorrequesttemplate_create.go
@@ -63,6 +63,20 @@ func (_c *ChannelMonitorRequestTemplateCreate) SetProvider(v channelmonitorreque
return _c
}
+// SetAPIMode sets the "api_mode" field.
+func (_c *ChannelMonitorRequestTemplateCreate) SetAPIMode(v string) *ChannelMonitorRequestTemplateCreate {
+ _c.mutation.SetAPIMode(v)
+ return _c
+}
+
+// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
+func (_c *ChannelMonitorRequestTemplateCreate) SetNillableAPIMode(v *string) *ChannelMonitorRequestTemplateCreate {
+ if v != nil {
+ _c.SetAPIMode(*v)
+ }
+ return _c
+}
+
// SetDescription sets the "description" field.
func (_c *ChannelMonitorRequestTemplateCreate) SetDescription(v string) *ChannelMonitorRequestTemplateCreate {
_c.mutation.SetDescription(v)
@@ -161,6 +175,10 @@ func (_c *ChannelMonitorRequestTemplateCreate) defaults() {
v := channelmonitorrequesttemplate.DefaultUpdatedAt()
_c.mutation.SetUpdatedAt(v)
}
+ if _, ok := _c.mutation.APIMode(); !ok {
+ v := channelmonitorrequesttemplate.DefaultAPIMode
+ _c.mutation.SetAPIMode(v)
+ }
if _, ok := _c.mutation.Description(); !ok {
v := channelmonitorrequesttemplate.DefaultDescription
_c.mutation.SetDescription(v)
@@ -199,6 +217,14 @@ func (_c *ChannelMonitorRequestTemplateCreate) check() error {
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
}
}
+ if _, ok := _c.mutation.APIMode(); !ok {
+ return &ValidationError{Name: "api_mode", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.api_mode"`)}
+ }
+ if v, ok := _c.mutation.APIMode(); ok {
+ if err := channelmonitorrequesttemplate.APIModeValidator(v); err != nil {
+ return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.api_mode": %w`, err)}
+ }
+ }
if v, ok := _c.mutation.Description(); ok {
if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
@@ -258,6 +284,10 @@ func (_c *ChannelMonitorRequestTemplateCreate) createSpec() (*ChannelMonitorRequ
_spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
_node.Provider = value
}
+ if value, ok := _c.mutation.APIMode(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldAPIMode, field.TypeString, value)
+ _node.APIMode = value
+ }
if value, ok := _c.mutation.Description(); ok {
_spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
_node.Description = value
@@ -378,6 +408,18 @@ func (u *ChannelMonitorRequestTemplateUpsert) UpdateProvider() *ChannelMonitorRe
return u
}
+// SetAPIMode sets the "api_mode" field.
+func (u *ChannelMonitorRequestTemplateUpsert) SetAPIMode(v string) *ChannelMonitorRequestTemplateUpsert {
+ u.Set(channelmonitorrequesttemplate.FieldAPIMode, v)
+ return u
+}
+
+// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsert) UpdateAPIMode() *ChannelMonitorRequestTemplateUpsert {
+ u.SetExcluded(channelmonitorrequesttemplate.FieldAPIMode)
+ return u
+}
+
// SetDescription sets the "description" field.
func (u *ChannelMonitorRequestTemplateUpsert) SetDescription(v string) *ChannelMonitorRequestTemplateUpsert {
u.Set(channelmonitorrequesttemplate.FieldDescription, v)
@@ -525,6 +567,20 @@ func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateProvider() *ChannelMonito
})
}
+// SetAPIMode sets the "api_mode" field.
+func (u *ChannelMonitorRequestTemplateUpsertOne) SetAPIMode(v string) *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetAPIMode(v)
+ })
+}
+
+// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateAPIMode() *ChannelMonitorRequestTemplateUpsertOne {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateAPIMode()
+ })
+}
+
// SetDescription sets the "description" field.
func (u *ChannelMonitorRequestTemplateUpsertOne) SetDescription(v string) *ChannelMonitorRequestTemplateUpsertOne {
return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
@@ -848,6 +904,20 @@ func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateProvider() *ChannelMonit
})
}
+// SetAPIMode sets the "api_mode" field.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) SetAPIMode(v string) *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.SetAPIMode(v)
+ })
+}
+
+// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
+func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateAPIMode() *ChannelMonitorRequestTemplateUpsertBulk {
+ return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
+ s.UpdateAPIMode()
+ })
+}
+
// SetDescription sets the "description" field.
func (u *ChannelMonitorRequestTemplateUpsertBulk) SetDescription(v string) *ChannelMonitorRequestTemplateUpsertBulk {
return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
diff --git a/backend/ent/channelmonitorrequesttemplate_update.go b/backend/ent/channelmonitorrequesttemplate_update.go
index 8f55ba04..f27cac1c 100644
--- a/backend/ent/channelmonitorrequesttemplate_update.go
+++ b/backend/ent/channelmonitorrequesttemplate_update.go
@@ -63,6 +63,20 @@ func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableProvider(v *channelmon
return _u
}
+// SetAPIMode sets the "api_mode" field.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetAPIMode(v string) *ChannelMonitorRequestTemplateUpdate {
+ _u.mutation.SetAPIMode(v)
+ return _u
+}
+
+// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableAPIMode(v *string) *ChannelMonitorRequestTemplateUpdate {
+ if v != nil {
+ _u.SetAPIMode(*v)
+ }
+ return _u
+}
+
// SetDescription sets the "description" field.
func (_u *ChannelMonitorRequestTemplateUpdate) SetDescription(v string) *ChannelMonitorRequestTemplateUpdate {
_u.mutation.SetDescription(v)
@@ -204,6 +218,11 @@ func (_u *ChannelMonitorRequestTemplateUpdate) check() error {
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
}
}
+ if v, ok := _u.mutation.APIMode(); ok {
+ if err := channelmonitorrequesttemplate.APIModeValidator(v); err != nil {
+ return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.api_mode": %w`, err)}
+ }
+ }
if v, ok := _u.mutation.Description(); ok {
if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
@@ -238,6 +257,9 @@ func (_u *ChannelMonitorRequestTemplateUpdate) sqlSave(ctx context.Context) (_no
if value, ok := _u.mutation.Provider(); ok {
_spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
}
+ if value, ok := _u.mutation.APIMode(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldAPIMode, field.TypeString, value)
+ }
if value, ok := _u.mutation.Description(); ok {
_spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
}
@@ -355,6 +377,20 @@ func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableProvider(v *channel
return _u
}
+// SetAPIMode sets the "api_mode" field.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetAPIMode(v string) *ChannelMonitorRequestTemplateUpdateOne {
+ _u.mutation.SetAPIMode(v)
+ return _u
+}
+
+// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
+func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableAPIMode(v *string) *ChannelMonitorRequestTemplateUpdateOne {
+ if v != nil {
+ _u.SetAPIMode(*v)
+ }
+ return _u
+}
+
// SetDescription sets the "description" field.
func (_u *ChannelMonitorRequestTemplateUpdateOne) SetDescription(v string) *ChannelMonitorRequestTemplateUpdateOne {
_u.mutation.SetDescription(v)
@@ -509,6 +545,11 @@ func (_u *ChannelMonitorRequestTemplateUpdateOne) check() error {
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
}
}
+ if v, ok := _u.mutation.APIMode(); ok {
+ if err := channelmonitorrequesttemplate.APIModeValidator(v); err != nil {
+ return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.api_mode": %w`, err)}
+ }
+ }
if v, ok := _u.mutation.Description(); ok {
if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
@@ -560,6 +601,9 @@ func (_u *ChannelMonitorRequestTemplateUpdateOne) sqlSave(ctx context.Context) (
if value, ok := _u.mutation.Provider(); ok {
_spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
}
+ if value, ok := _u.mutation.APIMode(); ok {
+ _spec.SetField(channelmonitorrequesttemplate.FieldAPIMode, field.TypeString, value)
+ }
if value, ok := _u.mutation.Description(); ok {
_spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
}
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index 525ff092..b1731a35 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -428,6 +428,7 @@ var (
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "name", Type: field.TypeString, Size: 100},
{Name: "provider", Type: field.TypeEnum, Enums: []string{"openai", "anthropic", "gemini"}},
+ {Name: "api_mode", Type: field.TypeString, Size: 32, Default: "chat_completions"},
{Name: "endpoint", Type: field.TypeString, Size: 500},
{Name: "api_key_encrypted", Type: field.TypeString},
{Name: "primary_model", Type: field.TypeString, Size: 200},
@@ -450,7 +451,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "channel_monitors_channel_monitor_request_templates_request_template",
- Columns: []*schema.Column{ChannelMonitorsColumns[17]},
+ Columns: []*schema.Column{ChannelMonitorsColumns[18]},
RefColumns: []*schema.Column{ChannelMonitorRequestTemplatesColumns[0]},
OnDelete: schema.SetNull,
},
@@ -459,22 +460,27 @@ var (
{
Name: "channelmonitor_enabled_last_checked_at",
Unique: false,
- Columns: []*schema.Column{ChannelMonitorsColumns[10], ChannelMonitorsColumns[12]},
+ Columns: []*schema.Column{ChannelMonitorsColumns[11], ChannelMonitorsColumns[13]},
},
{
Name: "channelmonitor_provider",
Unique: false,
Columns: []*schema.Column{ChannelMonitorsColumns[4]},
},
+ {
+ Name: "channelmonitor_provider_api_mode",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorsColumns[4], ChannelMonitorsColumns[5]},
+ },
{
Name: "channelmonitor_group_name",
Unique: false,
- Columns: []*schema.Column{ChannelMonitorsColumns[9]},
+ Columns: []*schema.Column{ChannelMonitorsColumns[10]},
},
{
Name: "channelmonitor_template_id",
Unique: false,
- Columns: []*schema.Column{ChannelMonitorsColumns[17]},
+ Columns: []*schema.Column{ChannelMonitorsColumns[18]},
},
},
}
@@ -566,6 +572,7 @@ var (
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "name", Type: field.TypeString, Size: 100},
{Name: "provider", Type: field.TypeEnum, Enums: []string{"openai", "anthropic", "gemini"}},
+ {Name: "api_mode", Type: field.TypeString, Size: 32, Default: "chat_completions"},
{Name: "description", Type: field.TypeString, Nullable: true, Size: 500, Default: ""},
{Name: "extra_headers", Type: field.TypeJSON},
{Name: "body_override_mode", Type: field.TypeString, Size: 10, Default: "off"},
@@ -582,6 +589,11 @@ var (
Unique: true,
Columns: []*schema.Column{ChannelMonitorRequestTemplatesColumns[4], ChannelMonitorRequestTemplatesColumns[3]},
},
+ {
+ Name: "channelmonitorrequesttemplate_provider_api_mode",
+ Unique: false,
+ Columns: []*schema.Column{ChannelMonitorRequestTemplatesColumns[4], ChannelMonitorRequestTemplatesColumns[5]},
+ },
},
}
// ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table.
@@ -1120,6 +1132,7 @@ var (
{Name: "used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "notes", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "validity_days", Type: field.TypeInt, Default: 30},
{Name: "group_id", Type: field.TypeInt64, Nullable: true},
{Name: "used_by", Type: field.TypeInt64, Nullable: true},
@@ -1132,13 +1145,13 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "redeem_codes_groups_redeem_codes",
- Columns: []*schema.Column{RedeemCodesColumns[9]},
+ Columns: []*schema.Column{RedeemCodesColumns[10]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "redeem_codes_users_redeem_codes",
- Columns: []*schema.Column{RedeemCodesColumns[10]},
+ Columns: []*schema.Column{RedeemCodesColumns[11]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.SetNull,
},
@@ -1152,12 +1165,17 @@ var (
{
Name: "redeemcode_used_by",
Unique: false,
- Columns: []*schema.Column{RedeemCodesColumns[10]},
+ Columns: []*schema.Column{RedeemCodesColumns[11]},
},
{
Name: "redeemcode_group_id",
Unique: false,
- Columns: []*schema.Column{RedeemCodesColumns[9]},
+ Columns: []*schema.Column{RedeemCodesColumns[10]},
+ },
+ {
+ Name: "redeemcode_expires_at",
+ Unique: false,
+ Columns: []*schema.Column{RedeemCodesColumns[8]},
},
},
}
@@ -1318,6 +1336,10 @@ var (
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
{Name: "image_count", Type: field.TypeInt, Default: 0},
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
+ {Name: "image_input_size", Type: field.TypeString, Nullable: true, Size: 32},
+ {Name: "image_output_size", Type: field.TypeString, Nullable: true, Size: 32},
+ {Name: "image_size_source", Type: field.TypeString, Nullable: true, Size: 16},
+ {Name: "image_size_breakdown", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "api_key_id", Type: field.TypeInt64},
@@ -1334,31 +1356,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[33]},
+ Columns: []*schema.Column{UsageLogsColumns[37]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[34]},
+ Columns: []*schema.Column{UsageLogsColumns[38]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[35]},
+ Columns: []*schema.Column{UsageLogsColumns[39]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[36]},
+ Columns: []*schema.Column{UsageLogsColumns[40]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[37]},
+ Columns: []*schema.Column{UsageLogsColumns[41]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@@ -1367,32 +1389,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[36]},
+ Columns: []*schema.Column{UsageLogsColumns[40]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[33]},
+ Columns: []*schema.Column{UsageLogsColumns[37]},
},
{
Name: "usagelog_account_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[34]},
+ Columns: []*schema.Column{UsageLogsColumns[38]},
},
{
Name: "usagelog_group_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[35]},
+ Columns: []*schema.Column{UsageLogsColumns[39]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[37]},
+ Columns: []*schema.Column{UsageLogsColumns[41]},
},
{
Name: "usagelog_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[32]},
+ Columns: []*schema.Column{UsageLogsColumns[36]},
},
{
Name: "usagelog_model",
@@ -1412,17 +1434,17 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[32]},
+ Columns: []*schema.Column{UsageLogsColumns[40], UsageLogsColumns[36]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[32]},
+ Columns: []*schema.Column{UsageLogsColumns[37], UsageLogsColumns[36]},
},
{
Name: "usagelog_group_id_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[35], UsageLogsColumns[32]},
+ Columns: []*schema.Column{UsageLogsColumns[39], UsageLogsColumns[36]},
},
},
}
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 13f6193d..af0edc68 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -8752,6 +8752,7 @@ type ChannelMonitorMutation struct {
updated_at *time.Time
name *string
provider *channelmonitor.Provider
+ api_mode *string
endpoint *string
api_key_encrypted *string
primary_model *string
@@ -9023,6 +9024,42 @@ func (m *ChannelMonitorMutation) ResetProvider() {
m.provider = nil
}
+// SetAPIMode sets the "api_mode" field.
+func (m *ChannelMonitorMutation) SetAPIMode(s string) {
+ m.api_mode = &s
+}
+
+// APIMode returns the value of the "api_mode" field in the mutation.
+func (m *ChannelMonitorMutation) APIMode() (r string, exists bool) {
+ v := m.api_mode
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAPIMode returns the old "api_mode" field's value of the ChannelMonitor entity.
+// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorMutation) OldAPIMode(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAPIMode is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAPIMode requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAPIMode: %w", err)
+ }
+ return oldValue.APIMode, nil
+}
+
+// ResetAPIMode resets all changes to the "api_mode" field.
+func (m *ChannelMonitorMutation) ResetAPIMode() {
+ m.api_mode = nil
+}
+
// SetEndpoint sets the "endpoint" field.
func (m *ChannelMonitorMutation) SetEndpoint(s string) {
m.endpoint = &s
@@ -9780,7 +9817,7 @@ func (m *ChannelMonitorMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *ChannelMonitorMutation) Fields() []string {
- fields := make([]string, 0, 17)
+ fields := make([]string, 0, 18)
if m.created_at != nil {
fields = append(fields, channelmonitor.FieldCreatedAt)
}
@@ -9793,6 +9830,9 @@ func (m *ChannelMonitorMutation) Fields() []string {
if m.provider != nil {
fields = append(fields, channelmonitor.FieldProvider)
}
+ if m.api_mode != nil {
+ fields = append(fields, channelmonitor.FieldAPIMode)
+ }
if m.endpoint != nil {
fields = append(fields, channelmonitor.FieldEndpoint)
}
@@ -9848,6 +9888,8 @@ func (m *ChannelMonitorMutation) Field(name string) (ent.Value, bool) {
return m.Name()
case channelmonitor.FieldProvider:
return m.Provider()
+ case channelmonitor.FieldAPIMode:
+ return m.APIMode()
case channelmonitor.FieldEndpoint:
return m.Endpoint()
case channelmonitor.FieldAPIKeyEncrypted:
@@ -9891,6 +9933,8 @@ func (m *ChannelMonitorMutation) OldField(ctx context.Context, name string) (ent
return m.OldName(ctx)
case channelmonitor.FieldProvider:
return m.OldProvider(ctx)
+ case channelmonitor.FieldAPIMode:
+ return m.OldAPIMode(ctx)
case channelmonitor.FieldEndpoint:
return m.OldEndpoint(ctx)
case channelmonitor.FieldAPIKeyEncrypted:
@@ -9954,6 +9998,13 @@ func (m *ChannelMonitorMutation) SetField(name string, value ent.Value) error {
}
m.SetProvider(v)
return nil
+ case channelmonitor.FieldAPIMode:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAPIMode(v)
+ return nil
case channelmonitor.FieldEndpoint:
v, ok := value.(string)
if !ok {
@@ -10160,6 +10211,9 @@ func (m *ChannelMonitorMutation) ResetField(name string) error {
case channelmonitor.FieldProvider:
m.ResetProvider()
return nil
+ case channelmonitor.FieldAPIMode:
+ m.ResetAPIMode()
+ return nil
case channelmonitor.FieldEndpoint:
m.ResetEndpoint()
return nil
@@ -12591,6 +12645,7 @@ type ChannelMonitorRequestTemplateMutation struct {
updated_at *time.Time
name *string
provider *channelmonitorrequesttemplate.Provider
+ api_mode *string
description *string
extra_headers *map[string]string
body_override_mode *string
@@ -12846,6 +12901,42 @@ func (m *ChannelMonitorRequestTemplateMutation) ResetProvider() {
m.provider = nil
}
+// SetAPIMode sets the "api_mode" field.
+func (m *ChannelMonitorRequestTemplateMutation) SetAPIMode(s string) {
+ m.api_mode = &s
+}
+
+// APIMode returns the value of the "api_mode" field in the mutation.
+func (m *ChannelMonitorRequestTemplateMutation) APIMode() (r string, exists bool) {
+ v := m.api_mode
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAPIMode returns the old "api_mode" field's value of the ChannelMonitorRequestTemplate entity.
+// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *ChannelMonitorRequestTemplateMutation) OldAPIMode(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAPIMode is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAPIMode requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAPIMode: %w", err)
+ }
+ return oldValue.APIMode, nil
+}
+
+// ResetAPIMode resets all changes to the "api_mode" field.
+func (m *ChannelMonitorRequestTemplateMutation) ResetAPIMode() {
+ m.api_mode = nil
+}
+
// SetDescription sets the "description" field.
func (m *ChannelMonitorRequestTemplateMutation) SetDescription(s string) {
m.description = &s
@@ -13104,7 +13195,7 @@ func (m *ChannelMonitorRequestTemplateMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *ChannelMonitorRequestTemplateMutation) Fields() []string {
- fields := make([]string, 0, 8)
+ fields := make([]string, 0, 9)
if m.created_at != nil {
fields = append(fields, channelmonitorrequesttemplate.FieldCreatedAt)
}
@@ -13117,6 +13208,9 @@ func (m *ChannelMonitorRequestTemplateMutation) Fields() []string {
if m.provider != nil {
fields = append(fields, channelmonitorrequesttemplate.FieldProvider)
}
+ if m.api_mode != nil {
+ fields = append(fields, channelmonitorrequesttemplate.FieldAPIMode)
+ }
if m.description != nil {
fields = append(fields, channelmonitorrequesttemplate.FieldDescription)
}
@@ -13145,6 +13239,8 @@ func (m *ChannelMonitorRequestTemplateMutation) Field(name string) (ent.Value, b
return m.Name()
case channelmonitorrequesttemplate.FieldProvider:
return m.Provider()
+ case channelmonitorrequesttemplate.FieldAPIMode:
+ return m.APIMode()
case channelmonitorrequesttemplate.FieldDescription:
return m.Description()
case channelmonitorrequesttemplate.FieldExtraHeaders:
@@ -13170,6 +13266,8 @@ func (m *ChannelMonitorRequestTemplateMutation) OldField(ctx context.Context, na
return m.OldName(ctx)
case channelmonitorrequesttemplate.FieldProvider:
return m.OldProvider(ctx)
+ case channelmonitorrequesttemplate.FieldAPIMode:
+ return m.OldAPIMode(ctx)
case channelmonitorrequesttemplate.FieldDescription:
return m.OldDescription(ctx)
case channelmonitorrequesttemplate.FieldExtraHeaders:
@@ -13215,6 +13313,13 @@ func (m *ChannelMonitorRequestTemplateMutation) SetField(name string, value ent.
}
m.SetProvider(v)
return nil
+ case channelmonitorrequesttemplate.FieldAPIMode:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAPIMode(v)
+ return nil
case channelmonitorrequesttemplate.FieldDescription:
v, ok := value.(string)
if !ok {
@@ -13319,6 +13424,9 @@ func (m *ChannelMonitorRequestTemplateMutation) ResetField(name string) error {
case channelmonitorrequesttemplate.FieldProvider:
m.ResetProvider()
return nil
+ case channelmonitorrequesttemplate.FieldAPIMode:
+ m.ResetAPIMode()
+ return nil
case channelmonitorrequesttemplate.FieldDescription:
m.ResetDescription()
return nil
@@ -28602,6 +28710,7 @@ type RedeemCodeMutation struct {
used_at *time.Time
notes *string
created_at *time.Time
+ expires_at *time.Time
validity_days *int
addvalidity_days *int
clearedFields map[string]struct{}
@@ -29059,6 +29168,55 @@ func (m *RedeemCodeMutation) ResetCreatedAt() {
m.created_at = nil
}
+// SetExpiresAt sets the "expires_at" field.
+func (m *RedeemCodeMutation) SetExpiresAt(t time.Time) {
+ m.expires_at = &t
+}
+
+// ExpiresAt returns the value of the "expires_at" field in the mutation.
+func (m *RedeemCodeMutation) ExpiresAt() (r time.Time, exists bool) {
+ v := m.expires_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldExpiresAt returns the old "expires_at" field's value of the RedeemCode entity.
+// If the RedeemCode object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *RedeemCodeMutation) OldExpiresAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldExpiresAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
+ }
+ return oldValue.ExpiresAt, nil
+}
+
+// ClearExpiresAt clears the value of the "expires_at" field.
+func (m *RedeemCodeMutation) ClearExpiresAt() {
+ m.expires_at = nil
+ m.clearedFields[redeemcode.FieldExpiresAt] = struct{}{}
+}
+
+// ExpiresAtCleared returns if the "expires_at" field was cleared in this mutation.
+func (m *RedeemCodeMutation) ExpiresAtCleared() bool {
+ _, ok := m.clearedFields[redeemcode.FieldExpiresAt]
+ return ok
+}
+
+// ResetExpiresAt resets all changes to the "expires_at" field.
+func (m *RedeemCodeMutation) ResetExpiresAt() {
+ m.expires_at = nil
+ delete(m.clearedFields, redeemcode.FieldExpiresAt)
+}
+
// SetGroupID sets the "group_id" field.
func (m *RedeemCodeMutation) SetGroupID(i int64) {
m.group = &i
@@ -29265,7 +29423,7 @@ func (m *RedeemCodeMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *RedeemCodeMutation) Fields() []string {
- fields := make([]string, 0, 10)
+ fields := make([]string, 0, 11)
if m.code != nil {
fields = append(fields, redeemcode.FieldCode)
}
@@ -29290,6 +29448,9 @@ func (m *RedeemCodeMutation) Fields() []string {
if m.created_at != nil {
fields = append(fields, redeemcode.FieldCreatedAt)
}
+ if m.expires_at != nil {
+ fields = append(fields, redeemcode.FieldExpiresAt)
+ }
if m.group != nil {
fields = append(fields, redeemcode.FieldGroupID)
}
@@ -29320,6 +29481,8 @@ func (m *RedeemCodeMutation) Field(name string) (ent.Value, bool) {
return m.Notes()
case redeemcode.FieldCreatedAt:
return m.CreatedAt()
+ case redeemcode.FieldExpiresAt:
+ return m.ExpiresAt()
case redeemcode.FieldGroupID:
return m.GroupID()
case redeemcode.FieldValidityDays:
@@ -29349,6 +29512,8 @@ func (m *RedeemCodeMutation) OldField(ctx context.Context, name string) (ent.Val
return m.OldNotes(ctx)
case redeemcode.FieldCreatedAt:
return m.OldCreatedAt(ctx)
+ case redeemcode.FieldExpiresAt:
+ return m.OldExpiresAt(ctx)
case redeemcode.FieldGroupID:
return m.OldGroupID(ctx)
case redeemcode.FieldValidityDays:
@@ -29418,6 +29583,13 @@ func (m *RedeemCodeMutation) SetField(name string, value ent.Value) error {
}
m.SetCreatedAt(v)
return nil
+ case redeemcode.FieldExpiresAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetExpiresAt(v)
+ return nil
case redeemcode.FieldGroupID:
v, ok := value.(int64)
if !ok {
@@ -29498,6 +29670,9 @@ func (m *RedeemCodeMutation) ClearedFields() []string {
if m.FieldCleared(redeemcode.FieldNotes) {
fields = append(fields, redeemcode.FieldNotes)
}
+ if m.FieldCleared(redeemcode.FieldExpiresAt) {
+ fields = append(fields, redeemcode.FieldExpiresAt)
+ }
if m.FieldCleared(redeemcode.FieldGroupID) {
fields = append(fields, redeemcode.FieldGroupID)
}
@@ -29524,6 +29699,9 @@ func (m *RedeemCodeMutation) ClearField(name string) error {
case redeemcode.FieldNotes:
m.ClearNotes()
return nil
+ case redeemcode.FieldExpiresAt:
+ m.ClearExpiresAt()
+ return nil
case redeemcode.FieldGroupID:
m.ClearGroupID()
return nil
@@ -29559,6 +29737,9 @@ func (m *RedeemCodeMutation) ResetField(name string) error {
case redeemcode.FieldCreatedAt:
m.ResetCreatedAt()
return nil
+ case redeemcode.FieldExpiresAt:
+ m.ResetExpiresAt()
+ return nil
case redeemcode.FieldGroupID:
m.ResetGroupID()
return nil
@@ -34260,6 +34441,10 @@ type UsageLogMutation struct {
image_count *int
addimage_count *int
image_size *string
+ image_input_size *string
+ image_output_size *string
+ image_size_source *string
+ image_size_breakdown *map[string]int
cache_ttl_overridden *bool
created_at *time.Time
clearedFields map[string]struct{}
@@ -36202,6 +36387,202 @@ func (m *UsageLogMutation) ResetImageSize() {
delete(m.clearedFields, usagelog.FieldImageSize)
}
+// SetImageInputSize sets the "image_input_size" field.
+func (m *UsageLogMutation) SetImageInputSize(s string) {
+ m.image_input_size = &s
+}
+
+// ImageInputSize returns the value of the "image_input_size" field in the mutation.
+func (m *UsageLogMutation) ImageInputSize() (r string, exists bool) {
+ v := m.image_input_size
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldImageInputSize returns the old "image_input_size" field's value of the UsageLog entity.
+// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageLogMutation) OldImageInputSize(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldImageInputSize is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldImageInputSize requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldImageInputSize: %w", err)
+ }
+ return oldValue.ImageInputSize, nil
+}
+
+// ClearImageInputSize clears the value of the "image_input_size" field.
+func (m *UsageLogMutation) ClearImageInputSize() {
+ m.image_input_size = nil
+ m.clearedFields[usagelog.FieldImageInputSize] = struct{}{}
+}
+
+// ImageInputSizeCleared returns if the "image_input_size" field was cleared in this mutation.
+func (m *UsageLogMutation) ImageInputSizeCleared() bool {
+ _, ok := m.clearedFields[usagelog.FieldImageInputSize]
+ return ok
+}
+
+// ResetImageInputSize resets all changes to the "image_input_size" field.
+func (m *UsageLogMutation) ResetImageInputSize() {
+ m.image_input_size = nil
+ delete(m.clearedFields, usagelog.FieldImageInputSize)
+}
+
+// SetImageOutputSize sets the "image_output_size" field.
+func (m *UsageLogMutation) SetImageOutputSize(s string) {
+ m.image_output_size = &s
+}
+
+// ImageOutputSize returns the value of the "image_output_size" field in the mutation.
+func (m *UsageLogMutation) ImageOutputSize() (r string, exists bool) {
+ v := m.image_output_size
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldImageOutputSize returns the old "image_output_size" field's value of the UsageLog entity.
+// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageLogMutation) OldImageOutputSize(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldImageOutputSize is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldImageOutputSize requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldImageOutputSize: %w", err)
+ }
+ return oldValue.ImageOutputSize, nil
+}
+
+// ClearImageOutputSize clears the value of the "image_output_size" field.
+func (m *UsageLogMutation) ClearImageOutputSize() {
+ m.image_output_size = nil
+ m.clearedFields[usagelog.FieldImageOutputSize] = struct{}{}
+}
+
+// ImageOutputSizeCleared returns if the "image_output_size" field was cleared in this mutation.
+func (m *UsageLogMutation) ImageOutputSizeCleared() bool {
+ _, ok := m.clearedFields[usagelog.FieldImageOutputSize]
+ return ok
+}
+
+// ResetImageOutputSize resets all changes to the "image_output_size" field.
+func (m *UsageLogMutation) ResetImageOutputSize() {
+ m.image_output_size = nil
+ delete(m.clearedFields, usagelog.FieldImageOutputSize)
+}
+
+// SetImageSizeSource sets the "image_size_source" field.
+func (m *UsageLogMutation) SetImageSizeSource(s string) {
+ m.image_size_source = &s
+}
+
+// ImageSizeSource returns the value of the "image_size_source" field in the mutation.
+func (m *UsageLogMutation) ImageSizeSource() (r string, exists bool) {
+ v := m.image_size_source
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldImageSizeSource returns the old "image_size_source" field's value of the UsageLog entity.
+// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageLogMutation) OldImageSizeSource(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldImageSizeSource is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldImageSizeSource requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldImageSizeSource: %w", err)
+ }
+ return oldValue.ImageSizeSource, nil
+}
+
+// ClearImageSizeSource clears the value of the "image_size_source" field.
+func (m *UsageLogMutation) ClearImageSizeSource() {
+ m.image_size_source = nil
+ m.clearedFields[usagelog.FieldImageSizeSource] = struct{}{}
+}
+
+// ImageSizeSourceCleared returns if the "image_size_source" field was cleared in this mutation.
+func (m *UsageLogMutation) ImageSizeSourceCleared() bool {
+ _, ok := m.clearedFields[usagelog.FieldImageSizeSource]
+ return ok
+}
+
+// ResetImageSizeSource resets all changes to the "image_size_source" field.
+func (m *UsageLogMutation) ResetImageSizeSource() {
+ m.image_size_source = nil
+ delete(m.clearedFields, usagelog.FieldImageSizeSource)
+}
+
+// SetImageSizeBreakdown sets the "image_size_breakdown" field.
+func (m *UsageLogMutation) SetImageSizeBreakdown(value map[string]int) {
+ m.image_size_breakdown = &value
+}
+
+// ImageSizeBreakdown returns the value of the "image_size_breakdown" field in the mutation.
+func (m *UsageLogMutation) ImageSizeBreakdown() (r map[string]int, exists bool) {
+ v := m.image_size_breakdown
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldImageSizeBreakdown returns the old "image_size_breakdown" field's value of the UsageLog entity.
+// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageLogMutation) OldImageSizeBreakdown(ctx context.Context) (v map[string]int, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldImageSizeBreakdown is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldImageSizeBreakdown requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldImageSizeBreakdown: %w", err)
+ }
+ return oldValue.ImageSizeBreakdown, nil
+}
+
+// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
+func (m *UsageLogMutation) ClearImageSizeBreakdown() {
+ m.image_size_breakdown = nil
+ m.clearedFields[usagelog.FieldImageSizeBreakdown] = struct{}{}
+}
+
+// ImageSizeBreakdownCleared returns if the "image_size_breakdown" field was cleared in this mutation.
+func (m *UsageLogMutation) ImageSizeBreakdownCleared() bool {
+ _, ok := m.clearedFields[usagelog.FieldImageSizeBreakdown]
+ return ok
+}
+
+// ResetImageSizeBreakdown resets all changes to the "image_size_breakdown" field.
+func (m *UsageLogMutation) ResetImageSizeBreakdown() {
+ m.image_size_breakdown = nil
+ delete(m.clearedFields, usagelog.FieldImageSizeBreakdown)
+}
+
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) {
m.cache_ttl_overridden = &b
@@ -36443,7 +36824,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UsageLogMutation) Fields() []string {
- fields := make([]string, 0, 37)
+ fields := make([]string, 0, 41)
if m.user != nil {
fields = append(fields, usagelog.FieldUserID)
}
@@ -36549,6 +36930,18 @@ func (m *UsageLogMutation) Fields() []string {
if m.image_size != nil {
fields = append(fields, usagelog.FieldImageSize)
}
+ if m.image_input_size != nil {
+ fields = append(fields, usagelog.FieldImageInputSize)
+ }
+ if m.image_output_size != nil {
+ fields = append(fields, usagelog.FieldImageOutputSize)
+ }
+ if m.image_size_source != nil {
+ fields = append(fields, usagelog.FieldImageSizeSource)
+ }
+ if m.image_size_breakdown != nil {
+ fields = append(fields, usagelog.FieldImageSizeBreakdown)
+ }
if m.cache_ttl_overridden != nil {
fields = append(fields, usagelog.FieldCacheTTLOverridden)
}
@@ -36633,6 +37026,14 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.ImageCount()
case usagelog.FieldImageSize:
return m.ImageSize()
+ case usagelog.FieldImageInputSize:
+ return m.ImageInputSize()
+ case usagelog.FieldImageOutputSize:
+ return m.ImageOutputSize()
+ case usagelog.FieldImageSizeSource:
+ return m.ImageSizeSource()
+ case usagelog.FieldImageSizeBreakdown:
+ return m.ImageSizeBreakdown()
case usagelog.FieldCacheTTLOverridden:
return m.CacheTTLOverridden()
case usagelog.FieldCreatedAt:
@@ -36716,6 +37117,14 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldImageCount(ctx)
case usagelog.FieldImageSize:
return m.OldImageSize(ctx)
+ case usagelog.FieldImageInputSize:
+ return m.OldImageInputSize(ctx)
+ case usagelog.FieldImageOutputSize:
+ return m.OldImageOutputSize(ctx)
+ case usagelog.FieldImageSizeSource:
+ return m.OldImageSizeSource(ctx)
+ case usagelog.FieldImageSizeBreakdown:
+ return m.OldImageSizeBreakdown(ctx)
case usagelog.FieldCacheTTLOverridden:
return m.OldCacheTTLOverridden(ctx)
case usagelog.FieldCreatedAt:
@@ -36974,6 +37383,34 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
}
m.SetImageSize(v)
return nil
+ case usagelog.FieldImageInputSize:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetImageInputSize(v)
+ return nil
+ case usagelog.FieldImageOutputSize:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetImageOutputSize(v)
+ return nil
+ case usagelog.FieldImageSizeSource:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetImageSizeSource(v)
+ return nil
+ case usagelog.FieldImageSizeBreakdown:
+ v, ok := value.(map[string]int)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetImageSizeBreakdown(v)
+ return nil
case usagelog.FieldCacheTTLOverridden:
v, ok := value.(bool)
if !ok {
@@ -37291,6 +37728,18 @@ func (m *UsageLogMutation) ClearedFields() []string {
if m.FieldCleared(usagelog.FieldImageSize) {
fields = append(fields, usagelog.FieldImageSize)
}
+ if m.FieldCleared(usagelog.FieldImageInputSize) {
+ fields = append(fields, usagelog.FieldImageInputSize)
+ }
+ if m.FieldCleared(usagelog.FieldImageOutputSize) {
+ fields = append(fields, usagelog.FieldImageOutputSize)
+ }
+ if m.FieldCleared(usagelog.FieldImageSizeSource) {
+ fields = append(fields, usagelog.FieldImageSizeSource)
+ }
+ if m.FieldCleared(usagelog.FieldImageSizeBreakdown) {
+ fields = append(fields, usagelog.FieldImageSizeBreakdown)
+ }
return fields
}
@@ -37347,6 +37796,18 @@ func (m *UsageLogMutation) ClearField(name string) error {
case usagelog.FieldImageSize:
m.ClearImageSize()
return nil
+ case usagelog.FieldImageInputSize:
+ m.ClearImageInputSize()
+ return nil
+ case usagelog.FieldImageOutputSize:
+ m.ClearImageOutputSize()
+ return nil
+ case usagelog.FieldImageSizeSource:
+ m.ClearImageSizeSource()
+ return nil
+ case usagelog.FieldImageSizeBreakdown:
+ m.ClearImageSizeBreakdown()
+ return nil
}
return fmt.Errorf("unknown UsageLog nullable field %s", name)
}
@@ -37460,6 +37921,18 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldImageSize:
m.ResetImageSize()
return nil
+ case usagelog.FieldImageInputSize:
+ m.ResetImageInputSize()
+ return nil
+ case usagelog.FieldImageOutputSize:
+ m.ResetImageOutputSize()
+ return nil
+ case usagelog.FieldImageSizeSource:
+ m.ResetImageSizeSource()
+ return nil
+ case usagelog.FieldImageSizeBreakdown:
+ m.ResetImageSizeBreakdown()
+ return nil
case usagelog.FieldCacheTTLOverridden:
m.ResetCacheTTLOverridden()
return nil
diff --git a/backend/ent/redeemcode.go b/backend/ent/redeemcode.go
index 24cd4231..34b55f6b 100644
--- a/backend/ent/redeemcode.go
+++ b/backend/ent/redeemcode.go
@@ -35,6 +35,8 @@ type RedeemCode struct {
Notes *string `json:"notes,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
+ // ExpiresAt holds the value of the "expires_at" field.
+ ExpiresAt *time.Time `json:"expires_at,omitempty"`
// GroupID holds the value of the "group_id" field.
GroupID *int64 `json:"group_id,omitempty"`
// ValidityDays holds the value of the "validity_days" field.
@@ -89,7 +91,7 @@ func (*RedeemCode) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullInt64)
case redeemcode.FieldCode, redeemcode.FieldType, redeemcode.FieldStatus, redeemcode.FieldNotes:
values[i] = new(sql.NullString)
- case redeemcode.FieldUsedAt, redeemcode.FieldCreatedAt:
+ case redeemcode.FieldUsedAt, redeemcode.FieldCreatedAt, redeemcode.FieldExpiresAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@@ -163,6 +165,13 @@ func (_m *RedeemCode) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.CreatedAt = value.Time
}
+ case redeemcode.FieldExpiresAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field expires_at", values[i])
+ } else if value.Valid {
+ _m.ExpiresAt = new(time.Time)
+ *_m.ExpiresAt = value.Time
+ }
case redeemcode.FieldGroupID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field group_id", values[i])
@@ -252,6 +261,11 @@ func (_m *RedeemCode) String() string {
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
+ if v := _m.ExpiresAt; v != nil {
+ builder.WriteString("expires_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
if v := _m.GroupID; v != nil {
builder.WriteString("group_id=")
builder.WriteString(fmt.Sprintf("%v", *v))
diff --git a/backend/ent/redeemcode/redeemcode.go b/backend/ent/redeemcode/redeemcode.go
index b010476c..c7b30c15 100644
--- a/backend/ent/redeemcode/redeemcode.go
+++ b/backend/ent/redeemcode/redeemcode.go
@@ -30,6 +30,8 @@ const (
FieldNotes = "notes"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
+ // FieldExpiresAt holds the string denoting the expires_at field in the database.
+ FieldExpiresAt = "expires_at"
// FieldGroupID holds the string denoting the group_id field in the database.
FieldGroupID = "group_id"
// FieldValidityDays holds the string denoting the validity_days field in the database.
@@ -67,6 +69,7 @@ var Columns = []string{
FieldUsedAt,
FieldNotes,
FieldCreatedAt,
+ FieldExpiresAt,
FieldGroupID,
FieldValidityDays,
}
@@ -148,6 +151,11 @@ func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
+// ByExpiresAt orders the results by the expires_at field.
+func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
+}
+
// ByGroupID orders the results by the group_id field.
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldGroupID, opts...).ToFunc()
diff --git a/backend/ent/redeemcode/where.go b/backend/ent/redeemcode/where.go
index 1fdedba5..8325b9fc 100644
--- a/backend/ent/redeemcode/where.go
+++ b/backend/ent/redeemcode/where.go
@@ -95,6 +95,11 @@ func CreatedAt(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldEQ(FieldCreatedAt, v))
}
+// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
+func ExpiresAt(v time.Time) predicate.RedeemCode {
+ return predicate.RedeemCode(sql.FieldEQ(FieldExpiresAt, v))
+}
+
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
func GroupID(v int64) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldEQ(FieldGroupID, v))
@@ -535,6 +540,56 @@ func CreatedAtLTE(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldLTE(FieldCreatedAt, v))
}
+// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
+func ExpiresAtEQ(v time.Time) predicate.RedeemCode {
+ return predicate.RedeemCode(sql.FieldEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
+func ExpiresAtNEQ(v time.Time) predicate.RedeemCode {
+ return predicate.RedeemCode(sql.FieldNEQ(FieldExpiresAt, v))
+}
+
+// ExpiresAtIn applies the In predicate on the "expires_at" field.
+func ExpiresAtIn(vs ...time.Time) predicate.RedeemCode {
+ return predicate.RedeemCode(sql.FieldIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
+func ExpiresAtNotIn(vs ...time.Time) predicate.RedeemCode {
+ return predicate.RedeemCode(sql.FieldNotIn(FieldExpiresAt, vs...))
+}
+
+// ExpiresAtGT applies the GT predicate on the "expires_at" field.
+func ExpiresAtGT(v time.Time) predicate.RedeemCode {
+ return predicate.RedeemCode(sql.FieldGT(FieldExpiresAt, v))
+}
+
+// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
+func ExpiresAtGTE(v time.Time) predicate.RedeemCode {
+ return predicate.RedeemCode(sql.FieldGTE(FieldExpiresAt, v))
+}
+
+// ExpiresAtLT applies the LT predicate on the "expires_at" field.
+func ExpiresAtLT(v time.Time) predicate.RedeemCode {
+ return predicate.RedeemCode(sql.FieldLT(FieldExpiresAt, v))
+}
+
+// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
+func ExpiresAtLTE(v time.Time) predicate.RedeemCode {
+ return predicate.RedeemCode(sql.FieldLTE(FieldExpiresAt, v))
+}
+
+// ExpiresAtIsNil applies the IsNil predicate on the "expires_at" field.
+func ExpiresAtIsNil() predicate.RedeemCode {
+ return predicate.RedeemCode(sql.FieldIsNull(FieldExpiresAt))
+}
+
+// ExpiresAtNotNil applies the NotNil predicate on the "expires_at" field.
+func ExpiresAtNotNil() predicate.RedeemCode {
+ return predicate.RedeemCode(sql.FieldNotNull(FieldExpiresAt))
+}
+
// GroupIDEQ applies the EQ predicate on the "group_id" field.
func GroupIDEQ(v int64) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldEQ(FieldGroupID, v))
diff --git a/backend/ent/redeemcode_create.go b/backend/ent/redeemcode_create.go
index efdcee40..1bba027b 100644
--- a/backend/ent/redeemcode_create.go
+++ b/backend/ent/redeemcode_create.go
@@ -128,6 +128,20 @@ func (_c *RedeemCodeCreate) SetNillableCreatedAt(v *time.Time) *RedeemCodeCreate
return _c
}
+// SetExpiresAt sets the "expires_at" field.
+func (_c *RedeemCodeCreate) SetExpiresAt(v time.Time) *RedeemCodeCreate {
+ _c.mutation.SetExpiresAt(v)
+ return _c
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (_c *RedeemCodeCreate) SetNillableExpiresAt(v *time.Time) *RedeemCodeCreate {
+ if v != nil {
+ _c.SetExpiresAt(*v)
+ }
+ return _c
+}
+
// SetGroupID sets the "group_id" field.
func (_c *RedeemCodeCreate) SetGroupID(v int64) *RedeemCodeCreate {
_c.mutation.SetGroupID(v)
@@ -327,6 +341,10 @@ func (_c *RedeemCodeCreate) createSpec() (*RedeemCode, *sqlgraph.CreateSpec) {
_spec.SetField(redeemcode.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
}
+ if value, ok := _c.mutation.ExpiresAt(); ok {
+ _spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value)
+ _node.ExpiresAt = &value
+ }
if value, ok := _c.mutation.ValidityDays(); ok {
_spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value)
_node.ValidityDays = value
@@ -525,6 +543,24 @@ func (u *RedeemCodeUpsert) ClearNotes() *RedeemCodeUpsert {
return u
}
+// SetExpiresAt sets the "expires_at" field.
+func (u *RedeemCodeUpsert) SetExpiresAt(v time.Time) *RedeemCodeUpsert {
+ u.Set(redeemcode.FieldExpiresAt, v)
+ return u
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *RedeemCodeUpsert) UpdateExpiresAt() *RedeemCodeUpsert {
+ u.SetExcluded(redeemcode.FieldExpiresAt)
+ return u
+}
+
+// ClearExpiresAt clears the value of the "expires_at" field.
+func (u *RedeemCodeUpsert) ClearExpiresAt() *RedeemCodeUpsert {
+ u.SetNull(redeemcode.FieldExpiresAt)
+ return u
+}
+
// SetGroupID sets the "group_id" field.
func (u *RedeemCodeUpsert) SetGroupID(v int64) *RedeemCodeUpsert {
u.Set(redeemcode.FieldGroupID, v)
@@ -732,6 +768,27 @@ func (u *RedeemCodeUpsertOne) ClearNotes() *RedeemCodeUpsertOne {
})
}
+// SetExpiresAt sets the "expires_at" field.
+func (u *RedeemCodeUpsertOne) SetExpiresAt(v time.Time) *RedeemCodeUpsertOne {
+ return u.Update(func(s *RedeemCodeUpsert) {
+ s.SetExpiresAt(v)
+ })
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *RedeemCodeUpsertOne) UpdateExpiresAt() *RedeemCodeUpsertOne {
+ return u.Update(func(s *RedeemCodeUpsert) {
+ s.UpdateExpiresAt()
+ })
+}
+
+// ClearExpiresAt clears the value of the "expires_at" field.
+func (u *RedeemCodeUpsertOne) ClearExpiresAt() *RedeemCodeUpsertOne {
+ return u.Update(func(s *RedeemCodeUpsert) {
+ s.ClearExpiresAt()
+ })
+}
+
// SetGroupID sets the "group_id" field.
func (u *RedeemCodeUpsertOne) SetGroupID(v int64) *RedeemCodeUpsertOne {
return u.Update(func(s *RedeemCodeUpsert) {
@@ -1111,6 +1168,27 @@ func (u *RedeemCodeUpsertBulk) ClearNotes() *RedeemCodeUpsertBulk {
})
}
+// SetExpiresAt sets the "expires_at" field.
+func (u *RedeemCodeUpsertBulk) SetExpiresAt(v time.Time) *RedeemCodeUpsertBulk {
+ return u.Update(func(s *RedeemCodeUpsert) {
+ s.SetExpiresAt(v)
+ })
+}
+
+// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
+func (u *RedeemCodeUpsertBulk) UpdateExpiresAt() *RedeemCodeUpsertBulk {
+ return u.Update(func(s *RedeemCodeUpsert) {
+ s.UpdateExpiresAt()
+ })
+}
+
+// ClearExpiresAt clears the value of the "expires_at" field.
+func (u *RedeemCodeUpsertBulk) ClearExpiresAt() *RedeemCodeUpsertBulk {
+ return u.Update(func(s *RedeemCodeUpsert) {
+ s.ClearExpiresAt()
+ })
+}
+
// SetGroupID sets the "group_id" field.
func (u *RedeemCodeUpsertBulk) SetGroupID(v int64) *RedeemCodeUpsertBulk {
return u.Update(func(s *RedeemCodeUpsert) {
diff --git a/backend/ent/redeemcode_update.go b/backend/ent/redeemcode_update.go
index 0f05e06d..1e0ec1e6 100644
--- a/backend/ent/redeemcode_update.go
+++ b/backend/ent/redeemcode_update.go
@@ -153,6 +153,26 @@ func (_u *RedeemCodeUpdate) ClearNotes() *RedeemCodeUpdate {
return _u
}
+// SetExpiresAt sets the "expires_at" field.
+func (_u *RedeemCodeUpdate) SetExpiresAt(v time.Time) *RedeemCodeUpdate {
+ _u.mutation.SetExpiresAt(v)
+ return _u
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (_u *RedeemCodeUpdate) SetNillableExpiresAt(v *time.Time) *RedeemCodeUpdate {
+ if v != nil {
+ _u.SetExpiresAt(*v)
+ }
+ return _u
+}
+
+// ClearExpiresAt clears the value of the "expires_at" field.
+func (_u *RedeemCodeUpdate) ClearExpiresAt() *RedeemCodeUpdate {
+ _u.mutation.ClearExpiresAt()
+ return _u
+}
+
// SetGroupID sets the "group_id" field.
func (_u *RedeemCodeUpdate) SetGroupID(v int64) *RedeemCodeUpdate {
_u.mutation.SetGroupID(v)
@@ -321,6 +341,12 @@ func (_u *RedeemCodeUpdate) sqlSave(ctx context.Context) (_node int, err error)
if _u.mutation.NotesCleared() {
_spec.ClearField(redeemcode.FieldNotes, field.TypeString)
}
+ if value, ok := _u.mutation.ExpiresAt(); ok {
+ _spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value)
+ }
+ if _u.mutation.ExpiresAtCleared() {
+ _spec.ClearField(redeemcode.FieldExpiresAt, field.TypeTime)
+ }
if value, ok := _u.mutation.ValidityDays(); ok {
_spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value)
}
@@ -528,6 +554,26 @@ func (_u *RedeemCodeUpdateOne) ClearNotes() *RedeemCodeUpdateOne {
return _u
}
+// SetExpiresAt sets the "expires_at" field.
+func (_u *RedeemCodeUpdateOne) SetExpiresAt(v time.Time) *RedeemCodeUpdateOne {
+ _u.mutation.SetExpiresAt(v)
+ return _u
+}
+
+// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
+func (_u *RedeemCodeUpdateOne) SetNillableExpiresAt(v *time.Time) *RedeemCodeUpdateOne {
+ if v != nil {
+ _u.SetExpiresAt(*v)
+ }
+ return _u
+}
+
+// ClearExpiresAt clears the value of the "expires_at" field.
+func (_u *RedeemCodeUpdateOne) ClearExpiresAt() *RedeemCodeUpdateOne {
+ _u.mutation.ClearExpiresAt()
+ return _u
+}
+
// SetGroupID sets the "group_id" field.
func (_u *RedeemCodeUpdateOne) SetGroupID(v int64) *RedeemCodeUpdateOne {
_u.mutation.SetGroupID(v)
@@ -726,6 +772,12 @@ func (_u *RedeemCodeUpdateOne) sqlSave(ctx context.Context) (_node *RedeemCode,
if _u.mutation.NotesCleared() {
_spec.ClearField(redeemcode.FieldNotes, field.TypeString)
}
+ if value, ok := _u.mutation.ExpiresAt(); ok {
+ _spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value)
+ }
+ if _u.mutation.ExpiresAtCleared() {
+ _spec.ClearField(redeemcode.FieldExpiresAt, field.TypeTime)
+ }
if value, ok := _u.mutation.ValidityDays(); ok {
_spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value)
}
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index a282d9ba..6d541e2f 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -464,8 +464,14 @@ func init() {
return nil
}
}()
+ // channelmonitorDescAPIMode is the schema descriptor for api_mode field.
+ channelmonitorDescAPIMode := channelmonitorFields[2].Descriptor()
+ // channelmonitor.DefaultAPIMode holds the default value on creation for the api_mode field.
+ channelmonitor.DefaultAPIMode = channelmonitorDescAPIMode.Default.(string)
+ // channelmonitor.APIModeValidator is a validator for the "api_mode" field. It is called by the builders before save.
+ channelmonitor.APIModeValidator = channelmonitorDescAPIMode.Validators[0].(func(string) error)
// channelmonitorDescEndpoint is the schema descriptor for endpoint field.
- channelmonitorDescEndpoint := channelmonitorFields[2].Descriptor()
+ channelmonitorDescEndpoint := channelmonitorFields[3].Descriptor()
// channelmonitor.EndpointValidator is a validator for the "endpoint" field. It is called by the builders before save.
channelmonitor.EndpointValidator = func() func(string) error {
validators := channelmonitorDescEndpoint.Validators
@@ -483,11 +489,11 @@ func init() {
}
}()
// channelmonitorDescAPIKeyEncrypted is the schema descriptor for api_key_encrypted field.
- channelmonitorDescAPIKeyEncrypted := channelmonitorFields[3].Descriptor()
+ channelmonitorDescAPIKeyEncrypted := channelmonitorFields[4].Descriptor()
// channelmonitor.APIKeyEncryptedValidator is a validator for the "api_key_encrypted" field. It is called by the builders before save.
channelmonitor.APIKeyEncryptedValidator = channelmonitorDescAPIKeyEncrypted.Validators[0].(func(string) error)
// channelmonitorDescPrimaryModel is the schema descriptor for primary_model field.
- channelmonitorDescPrimaryModel := channelmonitorFields[4].Descriptor()
+ channelmonitorDescPrimaryModel := channelmonitorFields[5].Descriptor()
// channelmonitor.PrimaryModelValidator is a validator for the "primary_model" field. It is called by the builders before save.
channelmonitor.PrimaryModelValidator = func() func(string) error {
validators := channelmonitorDescPrimaryModel.Validators
@@ -505,29 +511,29 @@ func init() {
}
}()
// channelmonitorDescExtraModels is the schema descriptor for extra_models field.
- channelmonitorDescExtraModels := channelmonitorFields[5].Descriptor()
+ channelmonitorDescExtraModels := channelmonitorFields[6].Descriptor()
// channelmonitor.DefaultExtraModels holds the default value on creation for the extra_models field.
channelmonitor.DefaultExtraModels = channelmonitorDescExtraModels.Default.([]string)
// channelmonitorDescGroupName is the schema descriptor for group_name field.
- channelmonitorDescGroupName := channelmonitorFields[6].Descriptor()
+ channelmonitorDescGroupName := channelmonitorFields[7].Descriptor()
// channelmonitor.DefaultGroupName holds the default value on creation for the group_name field.
channelmonitor.DefaultGroupName = channelmonitorDescGroupName.Default.(string)
// channelmonitor.GroupNameValidator is a validator for the "group_name" field. It is called by the builders before save.
channelmonitor.GroupNameValidator = channelmonitorDescGroupName.Validators[0].(func(string) error)
// channelmonitorDescEnabled is the schema descriptor for enabled field.
- channelmonitorDescEnabled := channelmonitorFields[7].Descriptor()
+ channelmonitorDescEnabled := channelmonitorFields[8].Descriptor()
// channelmonitor.DefaultEnabled holds the default value on creation for the enabled field.
channelmonitor.DefaultEnabled = channelmonitorDescEnabled.Default.(bool)
// channelmonitorDescIntervalSeconds is the schema descriptor for interval_seconds field.
- channelmonitorDescIntervalSeconds := channelmonitorFields[8].Descriptor()
+ channelmonitorDescIntervalSeconds := channelmonitorFields[9].Descriptor()
// channelmonitor.IntervalSecondsValidator is a validator for the "interval_seconds" field. It is called by the builders before save.
channelmonitor.IntervalSecondsValidator = channelmonitorDescIntervalSeconds.Validators[0].(func(int) error)
// channelmonitorDescExtraHeaders is the schema descriptor for extra_headers field.
- channelmonitorDescExtraHeaders := channelmonitorFields[12].Descriptor()
+ channelmonitorDescExtraHeaders := channelmonitorFields[13].Descriptor()
// channelmonitor.DefaultExtraHeaders holds the default value on creation for the extra_headers field.
channelmonitor.DefaultExtraHeaders = channelmonitorDescExtraHeaders.Default.(map[string]string)
// channelmonitorDescBodyOverrideMode is the schema descriptor for body_override_mode field.
- channelmonitorDescBodyOverrideMode := channelmonitorFields[13].Descriptor()
+ channelmonitorDescBodyOverrideMode := channelmonitorFields[14].Descriptor()
// channelmonitor.DefaultBodyOverrideMode holds the default value on creation for the body_override_mode field.
channelmonitor.DefaultBodyOverrideMode = channelmonitorDescBodyOverrideMode.Default.(string)
// channelmonitor.BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
@@ -661,18 +667,24 @@ func init() {
return nil
}
}()
+ // channelmonitorrequesttemplateDescAPIMode is the schema descriptor for api_mode field.
+ channelmonitorrequesttemplateDescAPIMode := channelmonitorrequesttemplateFields[2].Descriptor()
+ // channelmonitorrequesttemplate.DefaultAPIMode holds the default value on creation for the api_mode field.
+ channelmonitorrequesttemplate.DefaultAPIMode = channelmonitorrequesttemplateDescAPIMode.Default.(string)
+ // channelmonitorrequesttemplate.APIModeValidator is a validator for the "api_mode" field. It is called by the builders before save.
+ channelmonitorrequesttemplate.APIModeValidator = channelmonitorrequesttemplateDescAPIMode.Validators[0].(func(string) error)
// channelmonitorrequesttemplateDescDescription is the schema descriptor for description field.
- channelmonitorrequesttemplateDescDescription := channelmonitorrequesttemplateFields[2].Descriptor()
+ channelmonitorrequesttemplateDescDescription := channelmonitorrequesttemplateFields[3].Descriptor()
// channelmonitorrequesttemplate.DefaultDescription holds the default value on creation for the description field.
channelmonitorrequesttemplate.DefaultDescription = channelmonitorrequesttemplateDescDescription.Default.(string)
// channelmonitorrequesttemplate.DescriptionValidator is a validator for the "description" field. It is called by the builders before save.
channelmonitorrequesttemplate.DescriptionValidator = channelmonitorrequesttemplateDescDescription.Validators[0].(func(string) error)
// channelmonitorrequesttemplateDescExtraHeaders is the schema descriptor for extra_headers field.
- channelmonitorrequesttemplateDescExtraHeaders := channelmonitorrequesttemplateFields[3].Descriptor()
+ channelmonitorrequesttemplateDescExtraHeaders := channelmonitorrequesttemplateFields[4].Descriptor()
// channelmonitorrequesttemplate.DefaultExtraHeaders holds the default value on creation for the extra_headers field.
channelmonitorrequesttemplate.DefaultExtraHeaders = channelmonitorrequesttemplateDescExtraHeaders.Default.(map[string]string)
// channelmonitorrequesttemplateDescBodyOverrideMode is the schema descriptor for body_override_mode field.
- channelmonitorrequesttemplateDescBodyOverrideMode := channelmonitorrequesttemplateFields[4].Descriptor()
+ channelmonitorrequesttemplateDescBodyOverrideMode := channelmonitorrequesttemplateFields[5].Descriptor()
// channelmonitorrequesttemplate.DefaultBodyOverrideMode holds the default value on creation for the body_override_mode field.
channelmonitorrequesttemplate.DefaultBodyOverrideMode = channelmonitorrequesttemplateDescBodyOverrideMode.Default.(string)
// channelmonitorrequesttemplate.BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
@@ -1386,7 +1398,7 @@ func init() {
// redeemcode.DefaultCreatedAt holds the default value on creation for the created_at field.
redeemcode.DefaultCreatedAt = redeemcodeDescCreatedAt.Default.(func() time.Time)
// redeemcodeDescValidityDays is the schema descriptor for validity_days field.
- redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor()
+ redeemcodeDescValidityDays := redeemcodeFields[10].Descriptor()
// redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field.
redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int)
securitysecretMixin := schema.SecuritySecret{}.Mixin()
@@ -1722,12 +1734,24 @@ func init() {
usagelogDescImageSize := usagelogFields[34].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
+ // usagelogDescImageInputSize is the schema descriptor for image_input_size field.
+ usagelogDescImageInputSize := usagelogFields[35].Descriptor()
+ // usagelog.ImageInputSizeValidator is a validator for the "image_input_size" field. It is called by the builders before save.
+ usagelog.ImageInputSizeValidator = usagelogDescImageInputSize.Validators[0].(func(string) error)
+ // usagelogDescImageOutputSize is the schema descriptor for image_output_size field.
+ usagelogDescImageOutputSize := usagelogFields[36].Descriptor()
+ // usagelog.ImageOutputSizeValidator is a validator for the "image_output_size" field. It is called by the builders before save.
+ usagelog.ImageOutputSizeValidator = usagelogDescImageOutputSize.Validators[0].(func(string) error)
+ // usagelogDescImageSizeSource is the schema descriptor for image_size_source field.
+ usagelogDescImageSizeSource := usagelogFields[37].Descriptor()
+ // usagelog.ImageSizeSourceValidator is a validator for the "image_size_source" field. It is called by the builders before save.
+ usagelog.ImageSizeSourceValidator = usagelogDescImageSizeSource.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
- usagelogDescCacheTTLOverridden := usagelogFields[35].Descriptor()
+ usagelogDescCacheTTLOverridden := usagelogFields[39].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
- usagelogDescCreatedAt := usagelogFields[36].Descriptor()
+ usagelogDescCreatedAt := usagelogFields[40].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()
diff --git a/backend/ent/schema/auth_identity.go b/backend/ent/schema/auth_identity.go
index 5f864080..3deeadcd 100644
--- a/backend/ent/schema/auth_identity.go
+++ b/backend/ent/schema/auth_identity.go
@@ -15,12 +15,13 @@ import (
)
var authProviderTypes = map[string]struct{}{
- "email": {},
- "github": {},
- "google": {},
- "linuxdo": {},
- "oidc": {},
- "wechat": {},
+ "email": {},
+ "github": {},
+ "google": {},
+ "linuxdo": {},
+ "oidc": {},
+ "wechat": {},
+ "dingtalk": {},
}
func validateAuthProviderType(value string) error {
diff --git a/backend/ent/schema/auth_identity_schema_test.go b/backend/ent/schema/auth_identity_schema_test.go
index d3e24050..af272790 100644
--- a/backend/ent/schema/auth_identity_schema_test.go
+++ b/backend/ent/schema/auth_identity_schema_test.go
@@ -83,7 +83,7 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) {
require.Equal(t, 1, signupSource.Validators)
validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source")
- for _, value := range []string{"email", "linuxdo", "wechat", "oidc", "github", "google"} {
+ for _, value := range []string{"email", "linuxdo", "wechat", "oidc", "github", "google", "dingtalk"} {
require.NoError(t, validator(value))
}
require.Error(t, validator("unknown"))
diff --git a/backend/ent/schema/channel_monitor.go b/backend/ent/schema/channel_monitor.go
index 355ade4b..431ab9c8 100644
--- a/backend/ent/schema/channel_monitor.go
+++ b/backend/ent/schema/channel_monitor.go
@@ -36,6 +36,10 @@ func (ChannelMonitor) Fields() []ent.Field {
MaxLen(100),
field.Enum("provider").
Values("openai", "anthropic", "gemini"),
+ field.String("api_mode").
+ Default("chat_completions").
+ MaxLen(32).
+ Comment("OpenAI request protocol: chat_completions or responses; non-OpenAI uses chat_completions"),
field.String("endpoint").
NotEmpty().
MaxLen(500).
@@ -104,6 +108,7 @@ func (ChannelMonitor) Indexes() []ent.Index {
return []ent.Index{
index.Fields("enabled", "last_checked_at"),
index.Fields("provider"),
+ index.Fields("provider", "api_mode"),
index.Fields("group_name"),
index.Fields("template_id"),
}
diff --git a/backend/ent/schema/channel_monitor_request_template.go b/backend/ent/schema/channel_monitor_request_template.go
index 59df2f29..0e0ce3a0 100644
--- a/backend/ent/schema/channel_monitor_request_template.go
+++ b/backend/ent/schema/channel_monitor_request_template.go
@@ -40,6 +40,10 @@ func (ChannelMonitorRequestTemplate) Fields() []ent.Field {
MaxLen(100),
field.Enum("provider").
Values("openai", "anthropic", "gemini"),
+ field.String("api_mode").
+ Default("chat_completions").
+ MaxLen(32).
+ Comment("OpenAI request protocol: chat_completions or responses; non-OpenAI uses chat_completions"),
field.String("description").
Optional().
Default("").
@@ -76,5 +80,6 @@ func (ChannelMonitorRequestTemplate) Indexes() []ent.Index {
return []ent.Index{
// 同一 provider 内 name 唯一:允许 Anthropic + OpenAI 重名 "伪装官方客户端"。
index.Fields("provider", "name").Unique(),
+ index.Fields("provider", "api_mode"),
}
}
diff --git a/backend/ent/schema/redeem_code.go b/backend/ent/schema/redeem_code.go
index 6fb86148..fdaf0808 100644
--- a/backend/ent/schema/redeem_code.go
+++ b/backend/ent/schema/redeem_code.go
@@ -63,6 +63,10 @@ func (RedeemCode) Fields() []ent.Field {
Immutable().
Default(time.Now).
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
+ field.Time("expires_at").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Int64("group_id").
Optional().
Nillable(),
@@ -90,5 +94,6 @@ func (RedeemCode) Indexes() []ent.Index {
index.Fields("status"),
index.Fields("used_by"),
index.Fields("group_id"),
+ index.Fields("expires_at"),
}
}
diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go
index bd3ebfcc..db9e5178 100644
--- a/backend/ent/schema/usage_log.go
+++ b/backend/ent/schema/usage_log.go
@@ -134,6 +134,21 @@ func (UsageLog) Fields() []ent.Field {
MaxLen(10).
Optional().
Nillable(),
+ field.String("image_input_size").
+ MaxLen(32).
+ Optional().
+ Nillable(),
+ field.String("image_output_size").
+ MaxLen(32).
+ Optional().
+ Nillable(),
+ field.String("image_size_source").
+ MaxLen(16).
+ Optional().
+ Nillable(),
+ field.JSON("image_size_breakdown", map[string]int{}).
+ Optional().
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
field.Bool("cache_ttl_overridden").
Default(false),
diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go
index 08bab83a..c6e04273 100644
--- a/backend/ent/schema/user.go
+++ b/backend/ent/schema/user.go
@@ -77,10 +77,10 @@ func (User) Fields() []ent.Field {
field.String("signup_source").
Validate(func(value string) error {
switch value {
- case "email", "linuxdo", "wechat", "oidc", "github", "google":
+ case "email", "linuxdo", "wechat", "oidc", "github", "google", "dingtalk":
return nil
default:
- return fmt.Errorf("must be one of email, linuxdo, wechat, oidc, github, google")
+ return fmt.Errorf("must be one of email, linuxdo, wechat, oidc, github, google, dingtalk")
}
}).
Default("email"),
diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go
index a8e0cc6c..283fe828 100644
--- a/backend/ent/usagelog.go
+++ b/backend/ent/usagelog.go
@@ -3,6 +3,7 @@
package ent
import (
+ "encoding/json"
"fmt"
"strings"
"time"
@@ -92,6 +93,14 @@ type UsageLog struct {
ImageCount int `json:"image_count,omitempty"`
// ImageSize holds the value of the "image_size" field.
ImageSize *string `json:"image_size,omitempty"`
+ // ImageInputSize holds the value of the "image_input_size" field.
+ ImageInputSize *string `json:"image_input_size,omitempty"`
+ // ImageOutputSize holds the value of the "image_output_size" field.
+ ImageOutputSize *string `json:"image_output_size,omitempty"`
+ // ImageSizeSource holds the value of the "image_size_source" field.
+ ImageSizeSource *string `json:"image_size_source,omitempty"`
+ // ImageSizeBreakdown holds the value of the "image_size_breakdown" field.
+ ImageSizeBreakdown map[string]int `json:"image_size_breakdown,omitempty"`
// CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
// CreatedAt holds the value of the "created_at" field.
@@ -179,13 +188,15 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
+ case usagelog.FieldImageSizeBreakdown:
+ values[i] = new([]byte)
case usagelog.FieldStream, usagelog.FieldCacheTTLOverridden:
values[i] = new(sql.NullBool)
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
- case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
+ case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldImageInputSize, usagelog.FieldImageOutputSize, usagelog.FieldImageSizeSource:
values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime)
@@ -434,6 +445,35 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.ImageSize = new(string)
*_m.ImageSize = value.String
}
+ case usagelog.FieldImageInputSize:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field image_input_size", values[i])
+ } else if value.Valid {
+ _m.ImageInputSize = new(string)
+ *_m.ImageInputSize = value.String
+ }
+ case usagelog.FieldImageOutputSize:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field image_output_size", values[i])
+ } else if value.Valid {
+ _m.ImageOutputSize = new(string)
+ *_m.ImageOutputSize = value.String
+ }
+ case usagelog.FieldImageSizeSource:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field image_size_source", values[i])
+ } else if value.Valid {
+ _m.ImageSizeSource = new(string)
+ *_m.ImageSizeSource = value.String
+ }
+ case usagelog.FieldImageSizeBreakdown:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field image_size_breakdown", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.ImageSizeBreakdown); err != nil {
+ return fmt.Errorf("unmarshal field image_size_breakdown: %w", err)
+ }
+ }
case usagelog.FieldCacheTTLOverridden:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
@@ -640,6 +680,24 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
+ if v := _m.ImageInputSize; v != nil {
+ builder.WriteString("image_input_size=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ if v := _m.ImageOutputSize; v != nil {
+ builder.WriteString("image_output_size=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ if v := _m.ImageSizeSource; v != nil {
+ builder.WriteString("image_size_source=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ builder.WriteString("image_size_breakdown=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ImageSizeBreakdown))
+ builder.WriteString(", ")
builder.WriteString("cache_ttl_overridden=")
builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
builder.WriteString(", ")
diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go
index a7438e60..297e0b41 100644
--- a/backend/ent/usagelog/usagelog.go
+++ b/backend/ent/usagelog/usagelog.go
@@ -84,6 +84,14 @@ const (
FieldImageCount = "image_count"
// FieldImageSize holds the string denoting the image_size field in the database.
FieldImageSize = "image_size"
+ // FieldImageInputSize holds the string denoting the image_input_size field in the database.
+ FieldImageInputSize = "image_input_size"
+ // FieldImageOutputSize holds the string denoting the image_output_size field in the database.
+ FieldImageOutputSize = "image_output_size"
+ // FieldImageSizeSource holds the string denoting the image_size_source field in the database.
+ FieldImageSizeSource = "image_size_source"
+ // FieldImageSizeBreakdown holds the string denoting the image_size_breakdown field in the database.
+ FieldImageSizeBreakdown = "image_size_breakdown"
// FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
FieldCacheTTLOverridden = "cache_ttl_overridden"
// FieldCreatedAt holds the string denoting the created_at field in the database.
@@ -175,6 +183,10 @@ var Columns = []string{
FieldIPAddress,
FieldImageCount,
FieldImageSize,
+ FieldImageInputSize,
+ FieldImageOutputSize,
+ FieldImageSizeSource,
+ FieldImageSizeBreakdown,
FieldCacheTTLOverridden,
FieldCreatedAt,
}
@@ -242,6 +254,12 @@ var (
DefaultImageCount int
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
ImageSizeValidator func(string) error
+ // ImageInputSizeValidator is a validator for the "image_input_size" field. It is called by the builders before save.
+ ImageInputSizeValidator func(string) error
+ // ImageOutputSizeValidator is a validator for the "image_output_size" field. It is called by the builders before save.
+ ImageOutputSizeValidator func(string) error
+ // ImageSizeSourceValidator is a validator for the "image_size_source" field. It is called by the builders before save.
+ ImageSizeSourceValidator func(string) error
// DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
DefaultCacheTTLOverridden bool
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
@@ -431,6 +449,21 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageSize, opts...).ToFunc()
}
+// ByImageInputSize orders the results by the image_input_size field.
+func ByImageInputSize(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldImageInputSize, opts...).ToFunc()
+}
+
+// ByImageOutputSize orders the results by the image_output_size field.
+func ByImageOutputSize(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldImageOutputSize, opts...).ToFunc()
+}
+
+// ByImageSizeSource orders the results by the image_size_source field.
+func ByImageSizeSource(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldImageSizeSource, opts...).ToFunc()
+}
+
// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()
diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go
index b8439a03..2987f179 100644
--- a/backend/ent/usagelog/where.go
+++ b/backend/ent/usagelog/where.go
@@ -230,6 +230,21 @@ func ImageSize(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v))
}
+// ImageInputSize applies equality check predicate on the "image_input_size" field. It's identical to ImageInputSizeEQ.
+func ImageInputSize(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldImageInputSize, v))
+}
+
+// ImageOutputSize applies equality check predicate on the "image_output_size" field. It's identical to ImageOutputSizeEQ.
+func ImageOutputSize(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldImageOutputSize, v))
+}
+
+// ImageSizeSource applies equality check predicate on the "image_size_source" field. It's identical to ImageSizeSourceEQ.
+func ImageSizeSource(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldImageSizeSource, v))
+}
+
// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
func CacheTTLOverridden(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
@@ -1900,6 +1915,241 @@ func ImageSizeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v))
}
+// ImageInputSizeEQ applies the EQ predicate on the "image_input_size" field.
+func ImageInputSizeEQ(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldImageInputSize, v))
+}
+
+// ImageInputSizeNEQ applies the NEQ predicate on the "image_input_size" field.
+func ImageInputSizeNEQ(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNEQ(FieldImageInputSize, v))
+}
+
+// ImageInputSizeIn applies the In predicate on the "image_input_size" field.
+func ImageInputSizeIn(vs ...string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIn(FieldImageInputSize, vs...))
+}
+
+// ImageInputSizeNotIn applies the NotIn predicate on the "image_input_size" field.
+func ImageInputSizeNotIn(vs ...string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotIn(FieldImageInputSize, vs...))
+}
+
+// ImageInputSizeGT applies the GT predicate on the "image_input_size" field.
+func ImageInputSizeGT(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGT(FieldImageInputSize, v))
+}
+
+// ImageInputSizeGTE applies the GTE predicate on the "image_input_size" field.
+func ImageInputSizeGTE(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGTE(FieldImageInputSize, v))
+}
+
+// ImageInputSizeLT applies the LT predicate on the "image_input_size" field.
+func ImageInputSizeLT(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLT(FieldImageInputSize, v))
+}
+
+// ImageInputSizeLTE applies the LTE predicate on the "image_input_size" field.
+func ImageInputSizeLTE(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLTE(FieldImageInputSize, v))
+}
+
+// ImageInputSizeContains applies the Contains predicate on the "image_input_size" field.
+func ImageInputSizeContains(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldContains(FieldImageInputSize, v))
+}
+
+// ImageInputSizeHasPrefix applies the HasPrefix predicate on the "image_input_size" field.
+func ImageInputSizeHasPrefix(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldHasPrefix(FieldImageInputSize, v))
+}
+
+// ImageInputSizeHasSuffix applies the HasSuffix predicate on the "image_input_size" field.
+func ImageInputSizeHasSuffix(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldHasSuffix(FieldImageInputSize, v))
+}
+
+// ImageInputSizeIsNil applies the IsNil predicate on the "image_input_size" field.
+func ImageInputSizeIsNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIsNull(FieldImageInputSize))
+}
+
+// ImageInputSizeNotNil applies the NotNil predicate on the "image_input_size" field.
+func ImageInputSizeNotNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotNull(FieldImageInputSize))
+}
+
+// ImageInputSizeEqualFold applies the EqualFold predicate on the "image_input_size" field.
+func ImageInputSizeEqualFold(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEqualFold(FieldImageInputSize, v))
+}
+
+// ImageInputSizeContainsFold applies the ContainsFold predicate on the "image_input_size" field.
+func ImageInputSizeContainsFold(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldContainsFold(FieldImageInputSize, v))
+}
+
+// ImageOutputSizeEQ applies the EQ predicate on the "image_output_size" field.
+func ImageOutputSizeEQ(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldImageOutputSize, v))
+}
+
+// ImageOutputSizeNEQ applies the NEQ predicate on the "image_output_size" field.
+func ImageOutputSizeNEQ(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNEQ(FieldImageOutputSize, v))
+}
+
+// ImageOutputSizeIn applies the In predicate on the "image_output_size" field.
+func ImageOutputSizeIn(vs ...string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIn(FieldImageOutputSize, vs...))
+}
+
+// ImageOutputSizeNotIn applies the NotIn predicate on the "image_output_size" field.
+func ImageOutputSizeNotIn(vs ...string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotIn(FieldImageOutputSize, vs...))
+}
+
+// ImageOutputSizeGT applies the GT predicate on the "image_output_size" field.
+func ImageOutputSizeGT(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGT(FieldImageOutputSize, v))
+}
+
+// ImageOutputSizeGTE applies the GTE predicate on the "image_output_size" field.
+func ImageOutputSizeGTE(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGTE(FieldImageOutputSize, v))
+}
+
+// ImageOutputSizeLT applies the LT predicate on the "image_output_size" field.
+func ImageOutputSizeLT(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLT(FieldImageOutputSize, v))
+}
+
+// ImageOutputSizeLTE applies the LTE predicate on the "image_output_size" field.
+func ImageOutputSizeLTE(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLTE(FieldImageOutputSize, v))
+}
+
+// ImageOutputSizeContains applies the Contains predicate on the "image_output_size" field.
+func ImageOutputSizeContains(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldContains(FieldImageOutputSize, v))
+}
+
+// ImageOutputSizeHasPrefix applies the HasPrefix predicate on the "image_output_size" field.
+func ImageOutputSizeHasPrefix(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldHasPrefix(FieldImageOutputSize, v))
+}
+
+// ImageOutputSizeHasSuffix applies the HasSuffix predicate on the "image_output_size" field.
+func ImageOutputSizeHasSuffix(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldHasSuffix(FieldImageOutputSize, v))
+}
+
+// ImageOutputSizeIsNil applies the IsNil predicate on the "image_output_size" field.
+func ImageOutputSizeIsNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIsNull(FieldImageOutputSize))
+}
+
+// ImageOutputSizeNotNil applies the NotNil predicate on the "image_output_size" field.
+func ImageOutputSizeNotNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotNull(FieldImageOutputSize))
+}
+
+// ImageOutputSizeEqualFold applies the EqualFold predicate on the "image_output_size" field.
+func ImageOutputSizeEqualFold(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEqualFold(FieldImageOutputSize, v))
+}
+
+// ImageOutputSizeContainsFold applies the ContainsFold predicate on the "image_output_size" field.
+func ImageOutputSizeContainsFold(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldContainsFold(FieldImageOutputSize, v))
+}
+
+// ImageSizeSourceEQ applies the EQ predicate on the "image_size_source" field.
+func ImageSizeSourceEQ(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldImageSizeSource, v))
+}
+
+// ImageSizeSourceNEQ applies the NEQ predicate on the "image_size_source" field.
+func ImageSizeSourceNEQ(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNEQ(FieldImageSizeSource, v))
+}
+
+// ImageSizeSourceIn applies the In predicate on the "image_size_source" field.
+func ImageSizeSourceIn(vs ...string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIn(FieldImageSizeSource, vs...))
+}
+
+// ImageSizeSourceNotIn applies the NotIn predicate on the "image_size_source" field.
+func ImageSizeSourceNotIn(vs ...string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotIn(FieldImageSizeSource, vs...))
+}
+
+// ImageSizeSourceGT applies the GT predicate on the "image_size_source" field.
+func ImageSizeSourceGT(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGT(FieldImageSizeSource, v))
+}
+
+// ImageSizeSourceGTE applies the GTE predicate on the "image_size_source" field.
+func ImageSizeSourceGTE(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGTE(FieldImageSizeSource, v))
+}
+
+// ImageSizeSourceLT applies the LT predicate on the "image_size_source" field.
+func ImageSizeSourceLT(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLT(FieldImageSizeSource, v))
+}
+
+// ImageSizeSourceLTE applies the LTE predicate on the "image_size_source" field.
+func ImageSizeSourceLTE(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLTE(FieldImageSizeSource, v))
+}
+
+// ImageSizeSourceContains applies the Contains predicate on the "image_size_source" field.
+func ImageSizeSourceContains(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldContains(FieldImageSizeSource, v))
+}
+
+// ImageSizeSourceHasPrefix applies the HasPrefix predicate on the "image_size_source" field.
+func ImageSizeSourceHasPrefix(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldHasPrefix(FieldImageSizeSource, v))
+}
+
+// ImageSizeSourceHasSuffix applies the HasSuffix predicate on the "image_size_source" field.
+func ImageSizeSourceHasSuffix(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldHasSuffix(FieldImageSizeSource, v))
+}
+
+// ImageSizeSourceIsNil applies the IsNil predicate on the "image_size_source" field.
+func ImageSizeSourceIsNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIsNull(FieldImageSizeSource))
+}
+
+// ImageSizeSourceNotNil applies the NotNil predicate on the "image_size_source" field.
+func ImageSizeSourceNotNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotNull(FieldImageSizeSource))
+}
+
+// ImageSizeSourceEqualFold applies the EqualFold predicate on the "image_size_source" field.
+func ImageSizeSourceEqualFold(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEqualFold(FieldImageSizeSource, v))
+}
+
+// ImageSizeSourceContainsFold applies the ContainsFold predicate on the "image_size_source" field.
+func ImageSizeSourceContainsFold(v string) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldContainsFold(FieldImageSizeSource, v))
+}
+
+// ImageSizeBreakdownIsNil applies the IsNil predicate on the "image_size_breakdown" field.
+func ImageSizeBreakdownIsNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIsNull(FieldImageSizeBreakdown))
+}
+
+// ImageSizeBreakdownNotNil applies the NotNil predicate on the "image_size_breakdown" field.
+func ImageSizeBreakdownNotNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotNull(FieldImageSizeBreakdown))
+}
+
// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go
index fded364e..17e800f9 100644
--- a/backend/ent/usagelog_create.go
+++ b/backend/ent/usagelog_create.go
@@ -477,6 +477,54 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate {
return _c
}
+// SetImageInputSize sets the "image_input_size" field.
+func (_c *UsageLogCreate) SetImageInputSize(v string) *UsageLogCreate {
+ _c.mutation.SetImageInputSize(v)
+ return _c
+}
+
+// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil.
+func (_c *UsageLogCreate) SetNillableImageInputSize(v *string) *UsageLogCreate {
+ if v != nil {
+ _c.SetImageInputSize(*v)
+ }
+ return _c
+}
+
+// SetImageOutputSize sets the "image_output_size" field.
+func (_c *UsageLogCreate) SetImageOutputSize(v string) *UsageLogCreate {
+ _c.mutation.SetImageOutputSize(v)
+ return _c
+}
+
+// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil.
+func (_c *UsageLogCreate) SetNillableImageOutputSize(v *string) *UsageLogCreate {
+ if v != nil {
+ _c.SetImageOutputSize(*v)
+ }
+ return _c
+}
+
+// SetImageSizeSource sets the "image_size_source" field.
+func (_c *UsageLogCreate) SetImageSizeSource(v string) *UsageLogCreate {
+ _c.mutation.SetImageSizeSource(v)
+ return _c
+}
+
+// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil.
+func (_c *UsageLogCreate) SetNillableImageSizeSource(v *string) *UsageLogCreate {
+ if v != nil {
+ _c.SetImageSizeSource(*v)
+ }
+ return _c
+}
+
+// SetImageSizeBreakdown sets the "image_size_breakdown" field.
+func (_c *UsageLogCreate) SetImageSizeBreakdown(v map[string]int) *UsageLogCreate {
+ _c.mutation.SetImageSizeBreakdown(v)
+ return _c
+}
+
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
_c.mutation.SetCacheTTLOverridden(v)
@@ -754,6 +802,21 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
+ if v, ok := _c.mutation.ImageInputSize(); ok {
+ if err := usagelog.ImageInputSizeValidator(v); err != nil {
+ return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)}
+ }
+ }
+ if v, ok := _c.mutation.ImageOutputSize(); ok {
+ if err := usagelog.ImageOutputSizeValidator(v); err != nil {
+ return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)}
+ }
+ }
+ if v, ok := _c.mutation.ImageSizeSource(); ok {
+ if err := usagelog.ImageSizeSourceValidator(v); err != nil {
+ return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)}
+ }
+ }
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
}
@@ -916,6 +979,22 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
_node.ImageSize = &value
}
+ if value, ok := _c.mutation.ImageInputSize(); ok {
+ _spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value)
+ _node.ImageInputSize = &value
+ }
+ if value, ok := _c.mutation.ImageOutputSize(); ok {
+ _spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value)
+ _node.ImageOutputSize = &value
+ }
+ if value, ok := _c.mutation.ImageSizeSource(); ok {
+ _spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value)
+ _node.ImageSizeSource = &value
+ }
+ if value, ok := _c.mutation.ImageSizeBreakdown(); ok {
+ _spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value)
+ _node.ImageSizeBreakdown = value
+ }
if value, ok := _c.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
_node.CacheTTLOverridden = value
@@ -1679,6 +1758,78 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert {
return u
}
+// SetImageInputSize sets the "image_input_size" field.
+func (u *UsageLogUpsert) SetImageInputSize(v string) *UsageLogUpsert {
+ u.Set(usagelog.FieldImageInputSize, v)
+ return u
+}
+
+// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create.
+func (u *UsageLogUpsert) UpdateImageInputSize() *UsageLogUpsert {
+ u.SetExcluded(usagelog.FieldImageInputSize)
+ return u
+}
+
+// ClearImageInputSize clears the value of the "image_input_size" field.
+func (u *UsageLogUpsert) ClearImageInputSize() *UsageLogUpsert {
+ u.SetNull(usagelog.FieldImageInputSize)
+ return u
+}
+
+// SetImageOutputSize sets the "image_output_size" field.
+func (u *UsageLogUpsert) SetImageOutputSize(v string) *UsageLogUpsert {
+ u.Set(usagelog.FieldImageOutputSize, v)
+ return u
+}
+
+// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create.
+func (u *UsageLogUpsert) UpdateImageOutputSize() *UsageLogUpsert {
+ u.SetExcluded(usagelog.FieldImageOutputSize)
+ return u
+}
+
+// ClearImageOutputSize clears the value of the "image_output_size" field.
+func (u *UsageLogUpsert) ClearImageOutputSize() *UsageLogUpsert {
+ u.SetNull(usagelog.FieldImageOutputSize)
+ return u
+}
+
+// SetImageSizeSource sets the "image_size_source" field.
+func (u *UsageLogUpsert) SetImageSizeSource(v string) *UsageLogUpsert {
+ u.Set(usagelog.FieldImageSizeSource, v)
+ return u
+}
+
+// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create.
+func (u *UsageLogUpsert) UpdateImageSizeSource() *UsageLogUpsert {
+ u.SetExcluded(usagelog.FieldImageSizeSource)
+ return u
+}
+
+// ClearImageSizeSource clears the value of the "image_size_source" field.
+func (u *UsageLogUpsert) ClearImageSizeSource() *UsageLogUpsert {
+ u.SetNull(usagelog.FieldImageSizeSource)
+ return u
+}
+
+// SetImageSizeBreakdown sets the "image_size_breakdown" field.
+func (u *UsageLogUpsert) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsert {
+ u.Set(usagelog.FieldImageSizeBreakdown, v)
+ return u
+}
+
+// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create.
+func (u *UsageLogUpsert) UpdateImageSizeBreakdown() *UsageLogUpsert {
+ u.SetExcluded(usagelog.FieldImageSizeBreakdown)
+ return u
+}
+
+// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
+func (u *UsageLogUpsert) ClearImageSizeBreakdown() *UsageLogUpsert {
+ u.SetNull(usagelog.FieldImageSizeBreakdown)
+ return u
+}
+
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
u.Set(usagelog.FieldCacheTTLOverridden, v)
@@ -2457,6 +2608,90 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne {
})
}
+// SetImageInputSize sets the "image_input_size" field.
+func (u *UsageLogUpsertOne) SetImageInputSize(v string) *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetImageInputSize(v)
+ })
+}
+
+// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create.
+func (u *UsageLogUpsertOne) UpdateImageInputSize() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateImageInputSize()
+ })
+}
+
+// ClearImageInputSize clears the value of the "image_input_size" field.
+func (u *UsageLogUpsertOne) ClearImageInputSize() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearImageInputSize()
+ })
+}
+
+// SetImageOutputSize sets the "image_output_size" field.
+func (u *UsageLogUpsertOne) SetImageOutputSize(v string) *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetImageOutputSize(v)
+ })
+}
+
+// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create.
+func (u *UsageLogUpsertOne) UpdateImageOutputSize() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateImageOutputSize()
+ })
+}
+
+// ClearImageOutputSize clears the value of the "image_output_size" field.
+func (u *UsageLogUpsertOne) ClearImageOutputSize() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearImageOutputSize()
+ })
+}
+
+// SetImageSizeSource sets the "image_size_source" field.
+func (u *UsageLogUpsertOne) SetImageSizeSource(v string) *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetImageSizeSource(v)
+ })
+}
+
+// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create.
+func (u *UsageLogUpsertOne) UpdateImageSizeSource() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateImageSizeSource()
+ })
+}
+
+// ClearImageSizeSource clears the value of the "image_size_source" field.
+func (u *UsageLogUpsertOne) ClearImageSizeSource() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearImageSizeSource()
+ })
+}
+
+// SetImageSizeBreakdown sets the "image_size_breakdown" field.
+func (u *UsageLogUpsertOne) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetImageSizeBreakdown(v)
+ })
+}
+
+// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create.
+func (u *UsageLogUpsertOne) UpdateImageSizeBreakdown() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateImageSizeBreakdown()
+ })
+}
+
+// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
+func (u *UsageLogUpsertOne) ClearImageSizeBreakdown() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearImageSizeBreakdown()
+ })
+}
+
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
@@ -3403,6 +3638,90 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk {
})
}
+// SetImageInputSize sets the "image_input_size" field.
+func (u *UsageLogUpsertBulk) SetImageInputSize(v string) *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetImageInputSize(v)
+ })
+}
+
+// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create.
+func (u *UsageLogUpsertBulk) UpdateImageInputSize() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateImageInputSize()
+ })
+}
+
+// ClearImageInputSize clears the value of the "image_input_size" field.
+func (u *UsageLogUpsertBulk) ClearImageInputSize() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearImageInputSize()
+ })
+}
+
+// SetImageOutputSize sets the "image_output_size" field.
+func (u *UsageLogUpsertBulk) SetImageOutputSize(v string) *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetImageOutputSize(v)
+ })
+}
+
+// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create.
+func (u *UsageLogUpsertBulk) UpdateImageOutputSize() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateImageOutputSize()
+ })
+}
+
+// ClearImageOutputSize clears the value of the "image_output_size" field.
+func (u *UsageLogUpsertBulk) ClearImageOutputSize() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearImageOutputSize()
+ })
+}
+
+// SetImageSizeSource sets the "image_size_source" field.
+func (u *UsageLogUpsertBulk) SetImageSizeSource(v string) *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetImageSizeSource(v)
+ })
+}
+
+// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create.
+func (u *UsageLogUpsertBulk) UpdateImageSizeSource() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateImageSizeSource()
+ })
+}
+
+// ClearImageSizeSource clears the value of the "image_size_source" field.
+func (u *UsageLogUpsertBulk) ClearImageSizeSource() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearImageSizeSource()
+ })
+}
+
+// SetImageSizeBreakdown sets the "image_size_breakdown" field.
+func (u *UsageLogUpsertBulk) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetImageSizeBreakdown(v)
+ })
+}
+
+// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create.
+func (u *UsageLogUpsertBulk) UpdateImageSizeBreakdown() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateImageSizeBreakdown()
+ })
+}
+
+// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
+func (u *UsageLogUpsertBulk) ClearImageSizeBreakdown() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearImageSizeBreakdown()
+ })
+}
+
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go
index bb5ac86c..e8fa003c 100644
--- a/backend/ent/usagelog_update.go
+++ b/backend/ent/usagelog_update.go
@@ -739,6 +739,78 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate {
return _u
}
+// SetImageInputSize sets the "image_input_size" field.
+func (_u *UsageLogUpdate) SetImageInputSize(v string) *UsageLogUpdate {
+ _u.mutation.SetImageInputSize(v)
+ return _u
+}
+
+// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil.
+func (_u *UsageLogUpdate) SetNillableImageInputSize(v *string) *UsageLogUpdate {
+ if v != nil {
+ _u.SetImageInputSize(*v)
+ }
+ return _u
+}
+
+// ClearImageInputSize clears the value of the "image_input_size" field.
+func (_u *UsageLogUpdate) ClearImageInputSize() *UsageLogUpdate {
+ _u.mutation.ClearImageInputSize()
+ return _u
+}
+
+// SetImageOutputSize sets the "image_output_size" field.
+func (_u *UsageLogUpdate) SetImageOutputSize(v string) *UsageLogUpdate {
+ _u.mutation.SetImageOutputSize(v)
+ return _u
+}
+
+// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil.
+func (_u *UsageLogUpdate) SetNillableImageOutputSize(v *string) *UsageLogUpdate {
+ if v != nil {
+ _u.SetImageOutputSize(*v)
+ }
+ return _u
+}
+
+// ClearImageOutputSize clears the value of the "image_output_size" field.
+func (_u *UsageLogUpdate) ClearImageOutputSize() *UsageLogUpdate {
+ _u.mutation.ClearImageOutputSize()
+ return _u
+}
+
+// SetImageSizeSource sets the "image_size_source" field.
+func (_u *UsageLogUpdate) SetImageSizeSource(v string) *UsageLogUpdate {
+ _u.mutation.SetImageSizeSource(v)
+ return _u
+}
+
+// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil.
+func (_u *UsageLogUpdate) SetNillableImageSizeSource(v *string) *UsageLogUpdate {
+ if v != nil {
+ _u.SetImageSizeSource(*v)
+ }
+ return _u
+}
+
+// ClearImageSizeSource clears the value of the "image_size_source" field.
+func (_u *UsageLogUpdate) ClearImageSizeSource() *UsageLogUpdate {
+ _u.mutation.ClearImageSizeSource()
+ return _u
+}
+
+// SetImageSizeBreakdown sets the "image_size_breakdown" field.
+func (_u *UsageLogUpdate) SetImageSizeBreakdown(v map[string]int) *UsageLogUpdate {
+ _u.mutation.SetImageSizeBreakdown(v)
+ return _u
+}
+
+// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
+func (_u *UsageLogUpdate) ClearImageSizeBreakdown() *UsageLogUpdate {
+ _u.mutation.ClearImageSizeBreakdown()
+ return _u
+}
+
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
_u.mutation.SetCacheTTLOverridden(v)
@@ -892,6 +964,21 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
+ if v, ok := _u.mutation.ImageInputSize(); ok {
+ if err := usagelog.ImageInputSizeValidator(v); err != nil {
+ return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ImageOutputSize(); ok {
+ if err := usagelog.ImageOutputSizeValidator(v); err != nil {
+ return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ImageSizeSource(); ok {
+ if err := usagelog.ImageSizeSourceValidator(v); err != nil {
+ return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)}
+ }
+ }
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
}
@@ -1099,6 +1186,30 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
}
+ if value, ok := _u.mutation.ImageInputSize(); ok {
+ _spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value)
+ }
+ if _u.mutation.ImageInputSizeCleared() {
+ _spec.ClearField(usagelog.FieldImageInputSize, field.TypeString)
+ }
+ if value, ok := _u.mutation.ImageOutputSize(); ok {
+ _spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value)
+ }
+ if _u.mutation.ImageOutputSizeCleared() {
+ _spec.ClearField(usagelog.FieldImageOutputSize, field.TypeString)
+ }
+ if value, ok := _u.mutation.ImageSizeSource(); ok {
+ _spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value)
+ }
+ if _u.mutation.ImageSizeSourceCleared() {
+ _spec.ClearField(usagelog.FieldImageSizeSource, field.TypeString)
+ }
+ if value, ok := _u.mutation.ImageSizeBreakdown(); ok {
+ _spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value)
+ }
+ if _u.mutation.ImageSizeBreakdownCleared() {
+ _spec.ClearField(usagelog.FieldImageSizeBreakdown, field.TypeJSON)
+ }
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}
@@ -1974,6 +2085,78 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne {
return _u
}
+// SetImageInputSize sets the "image_input_size" field.
+func (_u *UsageLogUpdateOne) SetImageInputSize(v string) *UsageLogUpdateOne {
+ _u.mutation.SetImageInputSize(v)
+ return _u
+}
+
+// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil.
+func (_u *UsageLogUpdateOne) SetNillableImageInputSize(v *string) *UsageLogUpdateOne {
+ if v != nil {
+ _u.SetImageInputSize(*v)
+ }
+ return _u
+}
+
+// ClearImageInputSize clears the value of the "image_input_size" field.
+func (_u *UsageLogUpdateOne) ClearImageInputSize() *UsageLogUpdateOne {
+ _u.mutation.ClearImageInputSize()
+ return _u
+}
+
+// SetImageOutputSize sets the "image_output_size" field.
+func (_u *UsageLogUpdateOne) SetImageOutputSize(v string) *UsageLogUpdateOne {
+ _u.mutation.SetImageOutputSize(v)
+ return _u
+}
+
+// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil.
+func (_u *UsageLogUpdateOne) SetNillableImageOutputSize(v *string) *UsageLogUpdateOne {
+ if v != nil {
+ _u.SetImageOutputSize(*v)
+ }
+ return _u
+}
+
+// ClearImageOutputSize clears the value of the "image_output_size" field.
+func (_u *UsageLogUpdateOne) ClearImageOutputSize() *UsageLogUpdateOne {
+ _u.mutation.ClearImageOutputSize()
+ return _u
+}
+
+// SetImageSizeSource sets the "image_size_source" field.
+func (_u *UsageLogUpdateOne) SetImageSizeSource(v string) *UsageLogUpdateOne {
+ _u.mutation.SetImageSizeSource(v)
+ return _u
+}
+
+// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil.
+func (_u *UsageLogUpdateOne) SetNillableImageSizeSource(v *string) *UsageLogUpdateOne {
+ if v != nil {
+ _u.SetImageSizeSource(*v)
+ }
+ return _u
+}
+
+// ClearImageSizeSource clears the value of the "image_size_source" field.
+func (_u *UsageLogUpdateOne) ClearImageSizeSource() *UsageLogUpdateOne {
+ _u.mutation.ClearImageSizeSource()
+ return _u
+}
+
+// SetImageSizeBreakdown sets the "image_size_breakdown" field.
+func (_u *UsageLogUpdateOne) SetImageSizeBreakdown(v map[string]int) *UsageLogUpdateOne {
+ _u.mutation.SetImageSizeBreakdown(v)
+ return _u
+}
+
+// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
+func (_u *UsageLogUpdateOne) ClearImageSizeBreakdown() *UsageLogUpdateOne {
+ _u.mutation.ClearImageSizeBreakdown()
+ return _u
+}
+
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
_u.mutation.SetCacheTTLOverridden(v)
@@ -2140,6 +2323,21 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
+ if v, ok := _u.mutation.ImageInputSize(); ok {
+ if err := usagelog.ImageInputSizeValidator(v); err != nil {
+ return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ImageOutputSize(); ok {
+ if err := usagelog.ImageOutputSizeValidator(v); err != nil {
+ return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)}
+ }
+ }
+ if v, ok := _u.mutation.ImageSizeSource(); ok {
+ if err := usagelog.ImageSizeSourceValidator(v); err != nil {
+ return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)}
+ }
+ }
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
}
@@ -2364,6 +2562,30 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
}
+ if value, ok := _u.mutation.ImageInputSize(); ok {
+ _spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value)
+ }
+ if _u.mutation.ImageInputSizeCleared() {
+ _spec.ClearField(usagelog.FieldImageInputSize, field.TypeString)
+ }
+ if value, ok := _u.mutation.ImageOutputSize(); ok {
+ _spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value)
+ }
+ if _u.mutation.ImageOutputSizeCleared() {
+ _spec.ClearField(usagelog.FieldImageOutputSize, field.TypeString)
+ }
+ if value, ok := _u.mutation.ImageSizeSource(); ok {
+ _spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value)
+ }
+ if _u.mutation.ImageSizeSourceCleared() {
+ _spec.ClearField(usagelog.FieldImageSizeSource, field.TypeString)
+ }
+ if value, ok := _u.mutation.ImageSizeBreakdown(); ok {
+ _spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value)
+ }
+ if _u.mutation.ImageSizeBreakdownCleared() {
+ _spec.ClearField(usagelog.FieldImageSizeBreakdown, field.TypeJSON)
+ }
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}
diff --git a/backend/go.sum b/backend/go.sum
index 130d8eb4..3e8f0f04 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -222,6 +222,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
+github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@@ -255,6 +257,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
+github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
+github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -284,6 +288,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
+github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
+github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
@@ -316,6 +322,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
+github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
+github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index d3f3a4ad..3ee806d3 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -79,6 +79,7 @@ type Config struct {
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
WeChat WeChatConnectConfig `mapstructure:"wechat_connect"`
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
+ DingTalk DingTalkConnectConfig `mapstructure:"dingtalk_connect"`
GitHubOAuth EmailOAuthProviderConfig `mapstructure:"github_oauth"`
GoogleOAuth EmailOAuthProviderConfig `mapstructure:"google_oauth"`
Default DefaultConfig `mapstructure:"default"`
@@ -250,6 +251,47 @@ type OIDCConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
+type DingTalkConnectConfig struct {
+ Enabled bool `mapstructure:"enabled"`
+ ClientID string `mapstructure:"client_id"`
+ ClientSecret string `mapstructure:"client_secret"`
+ AuthorizeURL string `mapstructure:"authorize_url"`
+ TokenURL string `mapstructure:"token_url"`
+ UserInfoURL string `mapstructure:"userinfo_url"`
+ Scopes string `mapstructure:"scopes"`
+ RedirectURL string `mapstructure:"redirect_url"`
+ FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
+
+ // 平台底座 + 业务行为
+ DingTalkAppKind string `mapstructure:"dingtalk_app_kind"` // 仅 "internal_app"(V4 fail-closed)
+ AppType string `mapstructure:"app_type"` // "public" (default) | "internal"
+
+ // Corp 限定(none | internal_only)
+ CorpRestrictionPolicy string `mapstructure:"corp_restriction_policy"`
+ InternalCorpID string `mapstructure:"internal_corp_id"`
+ BypassRegistration bool `mapstructure:"bypass_registration"`
+ SyncCorpEmail bool `mapstructure:"sync_corp_email"`
+ SyncDisplayName bool `mapstructure:"sync_display_name"`
+ SyncDept bool `mapstructure:"sync_dept"`
+ SyncCorpEmailAttrKey string `mapstructure:"sync_corp_email_attr_key"`
+ SyncDisplayNameAttrKey string `mapstructure:"sync_display_name_attr_key"`
+ SyncDeptAttrKey string `mapstructure:"sync_dept_attr_key"`
+ SyncCorpEmailAttrName string `mapstructure:"sync_corp_email_attr_name"`
+ SyncDisplayNameAttrName string `mapstructure:"sync_display_name_attr_name"`
+ SyncDeptAttrName string `mapstructure:"sync_dept_attr_name"`
+
+ // 邮箱 + Username
+ RequireEmail bool `mapstructure:"require_email"`
+ UsernameOverwritePolicy string `mapstructure:"username_overwrite_policy"`
+
+ // Attribute(私有版扩展点;开源版仅声明)
+ UsernameAttributeKey string `mapstructure:"username_attribute_key"`
+ EnableAttributeMatching bool `mapstructure:"enable_attribute_matching"`
+ EnableAttributeSync bool `mapstructure:"enable_attribute_sync"`
+ AttributeSyncFields []string `mapstructure:"attribute_sync_fields"`
+ AttributeSyncOverwritePolicy string `mapstructure:"attribute_sync_overwrite_policy"`
+}
+
type EmailOAuthProviderConfig struct {
Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"`
@@ -1639,6 +1681,19 @@ func setDefaults() {
viper.SetDefault("oidc_connect.userinfo_id_path", "")
viper.SetDefault("oidc_connect.userinfo_username_path", "")
+ // DingTalk Connect OAuth 登录
+ viper.SetDefault("dingtalk_connect.enabled", false)
+ viper.SetDefault("dingtalk_connect.authorize_url", "https://login.dingtalk.com/oauth2/auth")
+ viper.SetDefault("dingtalk_connect.token_url", "https://api.dingtalk.com/v1.0/oauth2/userAccessToken")
+ viper.SetDefault("dingtalk_connect.userinfo_url", "https://api.dingtalk.com/v1.0/contact/users/me")
+ viper.SetDefault("dingtalk_connect.scopes", "openid")
+ viper.SetDefault("dingtalk_connect.frontend_redirect_url", "/auth/dingtalk/callback")
+ viper.SetDefault("dingtalk_connect.dingtalk_app_kind", "internal_app")
+ viper.SetDefault("dingtalk_connect.app_type", "public")
+ viper.SetDefault("dingtalk_connect.corp_restriction_policy", "none")
+ viper.SetDefault("dingtalk_connect.require_email", true)
+ viper.SetDefault("dingtalk_connect.username_overwrite_policy", "if_empty")
+
// Database
viper.SetDefault("database.host", "localhost")
viper.SetDefault("database.port", 5432)
@@ -2767,6 +2822,9 @@ func (c *Config) Validate() error {
if c.Concurrency.PingInterval < 5 || c.Concurrency.PingInterval > 30 {
return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds")
}
+ if err := ValidateDingTalkConfig(c.DingTalk); err != nil {
+ return fmt.Errorf("dingtalk_connect: %w", err)
+ }
return nil
}
diff --git a/backend/internal/config/validate_dingtalk.go b/backend/internal/config/validate_dingtalk.go
new file mode 100644
index 00000000..15734eb5
--- /dev/null
+++ b/backend/internal/config/validate_dingtalk.go
@@ -0,0 +1,30 @@
+// Package config 包含钉钉连接配置的校验逻辑。
+//
+// internal_only 模式安全模型(方案 A):
+// 不再要求 admin 填写 InternalCorpID 做二次 corpID 比对。
+// 安全边界由钉钉"企业内部应用"类型本身保证——只有应用所属企业的员工才能完成 OAuth,
+// 因此 ValidateDingTalkConfig 只要求 app_type=internal(V1),不再要求 InternalCorpID 非空(原 V3 已删除)。
+// InternalCorpID 字段保留,admin 可选填;若填写,checkDingTalkCorpAllowed 不会使用它做约束。
+package config
+
+import "errors"
+
+var (
+ ErrDingTalkV1AppTypeMismatch = errors.New("dingtalk: internal_only requires app_type=internal")
+ ErrDingTalkV4InvalidAppKind = errors.New("dingtalk: dingtalk_app_kind must be internal_app")
+)
+
+func ValidateDingTalkConfig(cfg DingTalkConnectConfig) error {
+ if !cfg.Enabled {
+ return nil
+ }
+ if cfg.DingTalkAppKind != "internal_app" {
+ return ErrDingTalkV4InvalidAppKind
+ }
+ if cfg.CorpRestrictionPolicy == "internal_only" {
+ if cfg.AppType != "internal" {
+ return ErrDingTalkV1AppTypeMismatch
+ }
+ }
+ return nil
+}
diff --git a/backend/internal/config/validate_dingtalk_test.go b/backend/internal/config/validate_dingtalk_test.go
new file mode 100644
index 00000000..f121b97d
--- /dev/null
+++ b/backend/internal/config/validate_dingtalk_test.go
@@ -0,0 +1,53 @@
+package config
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestValidateDingTalkConfig_Disabled_Skip(t *testing.T) {
+ require.NoError(t, ValidateDingTalkConfig(DingTalkConnectConfig{Enabled: false}))
+}
+
+func TestValidateDingTalkConfig_V4_DingTalkAppKind(t *testing.T) {
+ err := ValidateDingTalkConfig(DingTalkConnectConfig{
+ Enabled: true,
+ DingTalkAppKind: "third_party_enterprise_app",
+ CorpRestrictionPolicy: "none",
+ })
+ require.ErrorIs(t, err, ErrDingTalkV4InvalidAppKind)
+}
+
+func TestValidateDingTalkConfig_V1_InternalOnlyRequiresInternalAppType(t *testing.T) {
+ err := ValidateDingTalkConfig(DingTalkConnectConfig{
+ Enabled: true,
+ DingTalkAppKind: "internal_app",
+ AppType: "public",
+ CorpRestrictionPolicy: "internal_only",
+ InternalCorpID: "dingABC",
+ })
+ require.ErrorIs(t, err, ErrDingTalkV1AppTypeMismatch)
+}
+
+// TestValidateDingTalkConfig_V3_InternalOnlyAllowsEmptyCorpID 验证方案 A:
+// internal_only 策略下,InternalCorpID="" 应通过校验(企业隔离由钉钉 AppType=internal 保证)。
+func TestValidateDingTalkConfig_V3_InternalOnlyAllowsEmptyCorpID(t *testing.T) {
+ err := ValidateDingTalkConfig(DingTalkConnectConfig{
+ Enabled: true,
+ DingTalkAppKind: "internal_app",
+ AppType: "internal",
+ CorpRestrictionPolicy: "internal_only",
+ InternalCorpID: "",
+ })
+ require.NoError(t, err)
+}
+
+func TestValidateDingTalkConfig_HappyPath_None(t *testing.T) {
+ require.NoError(t, ValidateDingTalkConfig(DingTalkConnectConfig{
+ Enabled: true,
+ DingTalkAppKind: "internal_app",
+ AppType: "public",
+ CorpRestrictionPolicy: "none",
+ }))
+}
diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go
index 533c899d..8bfe1af9 100644
--- a/backend/internal/handler/admin/account_data.go
+++ b/backend/internal/handler/admin/account_data.go
@@ -44,6 +44,9 @@ type DataProxy struct {
Status string `json:"status"`
}
+// DataAccount 是管理员显式备份导出使用的账号结构,故意不走 dto.Account 的脱敏路径,
+// Credentials 原文返回。这是"管理员备份"这一显式行为的一部分;如未来需要导出脱敏版本,
+// 应新增独立结构而非修改这里。
type DataAccount struct {
Name string `json:"name"`
Notes *string `json:"notes,omitempty"`
diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go
index eade363e..267206e6 100644
--- a/backend/internal/handler/admin/account_handler.go
+++ b/backend/internal/handler/admin/account_handler.go
@@ -1656,7 +1656,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
}
// GetUsage handles getting account usage information
-// GET /api/v1/admin/accounts/:id/usage?source=passive|active
+// GET /api/v1/admin/accounts/:id/usage?source=passive|active&force=true
func (h *AccountHandler) GetUsage(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
@@ -1665,12 +1665,13 @@ func (h *AccountHandler) GetUsage(c *gin.Context) {
}
source := c.DefaultQuery("source", "active")
+ force := c.Query("force") == "true"
var usage *service.UsageInfo
if source == "passive" {
usage, err = h.accountUsageService.GetPassiveUsage(c.Request.Context(), accountID)
} else {
- usage, err = h.accountUsageService.GetUsage(c.Request.Context(), accountID)
+ usage, err = h.accountUsageService.GetUsage(c.Request.Context(), accountID, force)
}
if err != nil {
response.ErrorFrom(c, err)
@@ -2022,6 +2023,48 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
response.Success(c, models)
}
+// SyncUpstreamModels handles syncing live supported models from an account's upstream.
+// POST /api/v1/admin/accounts/:id/models/sync-upstream
+func (h *AccountHandler) SyncUpstreamModels(c *gin.Context) {
+ accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid account ID")
+ return
+ }
+
+ account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
+ if err != nil {
+ response.NotFound(c, "Account not found")
+ return
+ }
+
+ if h.accountTestService == nil {
+ response.InternalError(c, "Account test service is not configured")
+ return
+ }
+
+ models, err := h.accountTestService.FetchUpstreamSupportedModels(c.Request.Context(), account)
+ if err != nil {
+ var syncErr *service.UpstreamModelSyncError
+ if errors.As(err, &syncErr) {
+ switch syncErr.Kind {
+ case service.UpstreamModelSyncErrorConfiguration, service.UpstreamModelSyncErrorUnsupported:
+ response.BadRequest(c, syncErr.SafeMessage())
+ default:
+ slog.Warn("sync_upstream_models_failed", "account_id", accountID, "kind", syncErr.Kind)
+ response.Error(c, http.StatusBadGateway, syncErr.SafeMessage())
+ }
+ return
+ }
+
+ slog.Warn("sync_upstream_models_failed", "account_id", accountID)
+ response.Error(c, http.StatusBadGateway, "Failed to sync upstream models from upstream")
+ return
+ }
+
+ response.Success(c, gin.H{"models": models})
+}
+
// SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account
// POST /api/v1/admin/accounts/:id/set-privacy
func (h *AccountHandler) SetPrivacy(c *gin.Context) {
diff --git a/backend/internal/handler/admin/account_handler_available_models_test.go b/backend/internal/handler/admin/account_handler_available_models_test.go
index c5f1e2d8..7b27264b 100644
--- a/backend/internal/handler/admin/account_handler_available_models_test.go
+++ b/backend/internal/handler/admin/account_handler_available_models_test.go
@@ -3,10 +3,14 @@ package admin
import (
"context"
"encoding/json"
+ "io"
"net/http"
"net/http/httptest"
+ "strings"
"testing"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
@@ -33,6 +37,40 @@ func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
return router
}
+type syncUpstreamHTTPUpstream struct {
+ resp *http.Response
+ err error
+}
+
+func (u *syncUpstreamHTTPUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
+ if u.err != nil {
+ return nil, u.err
+ }
+ return u.resp, nil
+}
+
+func (u *syncUpstreamHTTPUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
+ return u.Do(req, proxyURL, accountID, accountConcurrency)
+}
+
+func setupSyncUpstreamModelsRouter(adminSvc service.AdminService, upstream service.HTTPUpstream) *gin.Engine {
+ gin.SetMode(gin.TestMode)
+ router := gin.New()
+ accountTestSvc := service.NewAccountTestService(
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ upstream,
+ &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
+ nil,
+ )
+ handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, accountTestSvc, nil, nil, nil, nil, nil)
+ router.POST("/api/v1/admin/accounts/:id/models/sync-upstream", handler.SyncUpstreamModels)
+ return router
+}
+
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
@@ -103,3 +141,58 @@ func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefau
require.NotEmpty(t, resp.Data)
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
}
+
+func TestAccountHandlerSyncUpstreamModels_ConfigErrorReturnsBadRequest(t *testing.T) {
+ svc := &availableModelsAdminService{
+ stubAdminService: newStubAdminService(),
+ account: service.Account{
+ ID: 44,
+ Name: "openai-apikey-missing-key",
+ Platform: service.PlatformOpenAI,
+ Type: service.AccountTypeAPIKey,
+ Status: service.StatusActive,
+ Credentials: map[string]any{
+ "base_url": "https://openai.example.com/v1",
+ },
+ },
+ }
+ router := setupSyncUpstreamModelsRouter(svc, &syncUpstreamHTTPUpstream{})
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/44/models/sync-upstream", nil)
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.Contains(t, rec.Body.String(), "No OpenAI API key is available")
+}
+
+func TestAccountHandlerSyncUpstreamModels_UpstreamErrorDoesNotExposeBody(t *testing.T) {
+ svc := &availableModelsAdminService{
+ stubAdminService: newStubAdminService(),
+ account: service.Account{
+ ID: 45,
+ Name: "openai-apikey-upstream-error",
+ Platform: service.PlatformOpenAI,
+ Type: service.AccountTypeAPIKey,
+ Status: service.StatusActive,
+ Credentials: map[string]any{
+ "api_key": "openai-key",
+ "base_url": "https://openai.example.com/v1",
+ },
+ },
+ }
+ upstream := &syncUpstreamHTTPUpstream{resp: &http.Response{
+ StatusCode: http.StatusBadGateway,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"error":"SECRET_TOKEN should not be exposed"}`)),
+ }}
+ router := setupSyncUpstreamModelsRouter(svc, upstream)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/45/models/sync-upstream", nil)
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadGateway, rec.Code)
+ require.Contains(t, rec.Body.String(), "Upstream model list request failed with HTTP 502")
+ require.NotContains(t, rec.Body.String(), "SECRET_TOKEN")
+}
diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go
index 950e6e72..bf547346 100644
--- a/backend/internal/handler/admin/channel_handler.go
+++ b/backend/internal/handler/admin/channel_handler.go
@@ -17,11 +17,12 @@ import (
type ChannelHandler struct {
channelService *service.ChannelService
billingService *service.BillingService
+ pricingService *service.PricingService
}
// NewChannelHandler creates a new admin channel handler
-func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService) *ChannelHandler {
- return &ChannelHandler{channelService: channelService, billingService: billingService}
+func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService, pricingService *service.PricingService) *ChannelHandler {
+ return &ChannelHandler{channelService: channelService, billingService: billingService, pricingService: pricingService}
}
// --- Request / Response types ---
@@ -500,3 +501,34 @@ func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) {
"image_output_price": pricing.ImageOutputPricePerToken,
})
}
+
+// platformToLiteLLMProvider maps a channel platform name to the corresponding
+// LiteLLM provider string used as the key in the pricing catalog.
+var platformToLiteLLMProvider = map[string]string{
+ service.PlatformAnthropic: "anthropic",
+ service.PlatformOpenAI: "openai",
+ service.PlatformGemini: "google",
+ service.PlatformAntigravity: "anthropic",
+}
+
+// SyncPricingModels 返回 LiteLLM 定价目录中指定平台的最新模型列表
+// GET /api/v1/admin/channels/pricing/sync-models?platform=anthropic
+func (h *ChannelHandler) SyncPricingModels(c *gin.Context) {
+ platform := strings.ToLower(strings.TrimSpace(c.Query("platform")))
+ if platform == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("MISSING_PARAMETER", "platform parameter is required").
+ WithMetadata(map[string]string{"param": "platform"}))
+ return
+ }
+
+ provider, ok := platformToLiteLLMProvider[platform]
+ if !ok {
+ response.ErrorFrom(c, infraerrors.BadRequest("UNSUPPORTED_PLATFORM",
+ fmt.Sprintf("unsupported platform: %s", platform)).
+ WithMetadata(map[string]string{"param": "platform"}))
+ return
+ }
+
+ models := h.pricingService.ListModelNamesByProvider(provider)
+ response.Success(c, gin.H{"models": models})
+}
diff --git a/backend/internal/handler/admin/channel_handler_test.go b/backend/internal/handler/admin/channel_handler_test.go
index 12cd4bdd..d05a1a6a 100644
--- a/backend/internal/handler/admin/channel_handler_test.go
+++ b/backend/internal/handler/admin/channel_handler_test.go
@@ -3,10 +3,14 @@
package admin
import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -416,3 +420,58 @@ func TestPricingRequestToService_NilPriceFields(t *testing.T) {
require.Nil(t, r.ImageOutputPrice)
require.Nil(t, r.PerRequestPrice)
}
+
+// ---------------------------------------------------------------------------
+// 3. SyncPricingModels handler
+// ---------------------------------------------------------------------------
+
+func setupSyncPricingModelsRouter(pricingSvc *service.PricingService) *gin.Engine {
+ gin.SetMode(gin.TestMode)
+ router := gin.New()
+ h := &ChannelHandler{pricingService: pricingSvc}
+ router.GET("/channels/pricing/sync-models", h.SyncPricingModels)
+ return router
+}
+
+func TestSyncPricingModels_MissingPlatform(t *testing.T) {
+ svc := service.NewPricingService(nil, nil)
+ router := setupSyncPricingModelsRouter(svc)
+
+ req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models", nil)
+ w := httptest.NewRecorder()
+ router.ServeHTTP(w, req)
+
+ require.Equal(t, http.StatusBadRequest, w.Code)
+}
+
+func TestSyncPricingModels_UnsupportedPlatform(t *testing.T) {
+ svc := service.NewPricingService(nil, nil)
+ router := setupSyncPricingModelsRouter(svc)
+
+ req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models?platform=unknown", nil)
+ w := httptest.NewRecorder()
+ router.ServeHTTP(w, req)
+
+ require.Equal(t, http.StatusBadRequest, w.Code)
+}
+
+func TestSyncPricingModels_ValidPlatform_EmptyService(t *testing.T) {
+ svc := service.NewPricingService(nil, nil)
+ router := setupSyncPricingModelsRouter(svc)
+
+ for _, platform := range []string{"anthropic", "openai", "gemini", "antigravity"} {
+ req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models?platform="+platform, nil)
+ w := httptest.NewRecorder()
+ router.ServeHTTP(w, req)
+
+ require.Equal(t, http.StatusOK, w.Code, "platform=%s", platform)
+
+ var body struct {
+ Data struct {
+ Models []string `json:"models"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
+ require.NotNil(t, body.Data.Models, "models must not be null for platform=%s", platform)
+ }
+}
diff --git a/backend/internal/handler/admin/channel_monitor_handler.go b/backend/internal/handler/admin/channel_monitor_handler.go
index e92c81fe..5560a513 100644
--- a/backend/internal/handler/admin/channel_monitor_handler.go
+++ b/backend/internal/handler/admin/channel_monitor_handler.go
@@ -38,6 +38,7 @@ func NewChannelMonitorHandler(monitorService *service.ChannelMonitorService) *Ch
type channelMonitorCreateRequest struct {
Name string `json:"name" binding:"required,max=100"`
Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"`
+ APIMode string `json:"api_mode" binding:"omitempty,oneof=chat_completions responses"`
Endpoint string `json:"endpoint" binding:"required,max=500"`
APIKey string `json:"api_key" binding:"required,max=2000"`
PrimaryModel string `json:"primary_model" binding:"required,max=200"`
@@ -54,6 +55,7 @@ type channelMonitorCreateRequest struct {
type channelMonitorUpdateRequest struct {
Name *string `json:"name" binding:"omitempty,max=100"`
Provider *string `json:"provider" binding:"omitempty,oneof=openai anthropic gemini"`
+ APIMode *string `json:"api_mode" binding:"omitempty,oneof=chat_completions responses"`
Endpoint *string `json:"endpoint" binding:"omitempty,max=500"`
APIKey *string `json:"api_key" binding:"omitempty,max=2000"`
PrimaryModel *string `json:"primary_model" binding:"omitempty,max=200"`
@@ -72,6 +74,7 @@ type channelMonitorResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
+ APIMode string `json:"api_mode"`
Endpoint string `json:"endpoint"`
APIKeyMasked string `json:"api_key_masked"`
APIKeyDecryptFailed bool `json:"api_key_decrypt_failed"`
@@ -138,6 +141,7 @@ func channelMonitorToResponse(m *service.ChannelMonitor) *channelMonitorResponse
ID: m.ID,
Name: m.Name,
Provider: m.Provider,
+ APIMode: m.APIMode,
Endpoint: m.Endpoint,
APIKeyMasked: maskAPIKey(m.APIKey),
APIKeyDecryptFailed: m.APIKeyDecryptFailed,
@@ -303,6 +307,7 @@ func (h *ChannelMonitorHandler) Create(c *gin.Context) {
m, err := h.monitorService.Create(c.Request.Context(), service.ChannelMonitorCreateParams{
Name: req.Name,
Provider: req.Provider,
+ APIMode: req.APIMode,
Endpoint: req.Endpoint,
APIKey: req.APIKey,
PrimaryModel: req.PrimaryModel,
@@ -338,6 +343,7 @@ func (h *ChannelMonitorHandler) Update(c *gin.Context) {
m, err := h.monitorService.Update(c.Request.Context(), id, service.ChannelMonitorUpdateParams{
Name: req.Name,
Provider: req.Provider,
+ APIMode: req.APIMode,
Endpoint: req.Endpoint,
APIKey: req.APIKey,
PrimaryModel: req.PrimaryModel,
diff --git a/backend/internal/handler/admin/channel_monitor_template_handler.go b/backend/internal/handler/admin/channel_monitor_template_handler.go
index bebe0929..c842f465 100644
--- a/backend/internal/handler/admin/channel_monitor_template_handler.go
+++ b/backend/internal/handler/admin/channel_monitor_template_handler.go
@@ -27,6 +27,7 @@ func NewChannelMonitorRequestTemplateHandler(templateService *service.ChannelMon
type channelMonitorTemplateCreateRequest struct {
Name string `json:"name" binding:"required,max=100"`
Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"`
+ APIMode string `json:"api_mode" binding:"omitempty,oneof=chat_completions responses"`
Description string `json:"description" binding:"max=500"`
ExtraHeaders map[string]string `json:"extra_headers"`
BodyOverrideMode string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
@@ -35,6 +36,7 @@ type channelMonitorTemplateCreateRequest struct {
type channelMonitorTemplateUpdateRequest struct {
Name *string `json:"name" binding:"omitempty,max=100"`
+ APIMode *string `json:"api_mode" binding:"omitempty,oneof=chat_completions responses"`
Description *string `json:"description" binding:"omitempty,max=500"`
ExtraHeaders *map[string]string `json:"extra_headers"`
BodyOverrideMode *string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
@@ -45,6 +47,7 @@ type channelMonitorTemplateResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
+ APIMode string `json:"api_mode"`
Description string `json:"description"`
ExtraHeaders map[string]string `json:"extra_headers"`
BodyOverrideMode string `json:"body_override_mode"`
@@ -67,6 +70,7 @@ func (h *ChannelMonitorRequestTemplateHandler) toResponse(c *gin.Context, t *ser
ID: t.ID,
Name: t.Name,
Provider: t.Provider,
+ APIMode: t.APIMode,
Description: t.Description,
ExtraHeaders: headers,
BodyOverrideMode: t.BodyOverrideMode,
@@ -93,6 +97,7 @@ func parseTemplateID(c *gin.Context) (int64, bool) {
func (h *ChannelMonitorRequestTemplateHandler) List(c *gin.Context) {
items, err := h.templateService.List(c.Request.Context(), service.ChannelMonitorRequestTemplateListParams{
Provider: strings.TrimSpace(c.Query("provider")),
+ APIMode: strings.TrimSpace(c.Query("api_mode")),
})
if err != nil {
response.ErrorFrom(c, err)
@@ -129,6 +134,7 @@ func (h *ChannelMonitorRequestTemplateHandler) Create(c *gin.Context) {
t, err := h.templateService.Create(c.Request.Context(), service.ChannelMonitorRequestTemplateCreateParams{
Name: req.Name,
Provider: req.Provider,
+ APIMode: req.APIMode,
Description: req.Description,
ExtraHeaders: req.ExtraHeaders,
BodyOverrideMode: req.BodyOverrideMode,
@@ -154,6 +160,7 @@ func (h *ChannelMonitorRequestTemplateHandler) Update(c *gin.Context) {
}
t, err := h.templateService.Update(c.Request.Context(), id, service.ChannelMonitorRequestTemplateUpdateParams{
Name: req.Name,
+ APIMode: req.APIMode,
Description: req.Description,
ExtraHeaders: req.ExtraHeaders,
BodyOverrideMode: req.BodyOverrideMode,
@@ -209,6 +216,7 @@ type associatedMonitorBriefResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
+ APIMode string `json:"api_mode"`
Enabled bool `json:"enabled"`
}
@@ -227,7 +235,7 @@ func (h *ChannelMonitorRequestTemplateHandler) AssociatedMonitors(c *gin.Context
out := make([]associatedMonitorBriefResponse, 0, len(items))
for _, m := range items {
out = append(out, associatedMonitorBriefResponse{
- ID: m.ID, Name: m.Name, Provider: m.Provider, Enabled: m.Enabled,
+ ID: m.ID, Name: m.Name, Provider: m.Provider, APIMode: m.APIMode, Enabled: m.Enabled,
})
}
response.Success(c, gin.H{"items": out})
diff --git a/backend/internal/handler/admin/content_moderation_handler.go b/backend/internal/handler/admin/content_moderation_handler.go
index 4266f5d8..6f0f2aab 100644
--- a/backend/internal/handler/admin/content_moderation_handler.go
+++ b/backend/internal/handler/admin/content_moderation_handler.go
@@ -46,6 +46,8 @@ type contentModerationConfigRequest struct {
HitRetentionDays *int `json:"hit_retention_days"`
NonHitRetentionDays *int `json:"non_hit_retention_days"`
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
+ BlockedKeywords *[]string `json:"blocked_keywords"`
+ KeywordBlockingMode *string `json:"keyword_blocking_mode"`
}
type contentModerationAPIKeyTestRequest struct {
@@ -103,6 +105,8 @@ func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) {
HitRetentionDays: req.HitRetentionDays,
NonHitRetentionDays: req.NonHitRetentionDays,
PreHashCheckEnabled: req.PreHashCheckEnabled,
+ BlockedKeywords: req.BlockedKeywords,
+ KeywordBlockingMode: req.KeywordBlockingMode,
})
if err != nil {
response.ErrorFrom(c, err)
diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go
index 460f6357..e9fbb630 100644
--- a/backend/internal/handler/admin/dashboard_handler.go
+++ b/backend/internal/handler/admin/dashboard_handler.go
@@ -546,9 +546,14 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return
}
+ // cacheKey 必须包含当日日期,否则跨午夜后 30s 内会复用昨天的 "today_*" 结果。
keyRaw, _ := json.Marshal(struct {
+ V int `json:"v"`
+ Day string `json:"day"`
UserIDs []int64 `json:"user_ids"`
}{
+ V: 2, // bump 当响应结构变化(如加入 by_platform 时)
+ Day: timezone.Today().Format("2006-01-02"),
UserIDs: userIDs,
})
cacheKey := string(keyRaw)
diff --git a/backend/internal/handler/admin/ops_handler.go b/backend/internal/handler/admin/ops_handler.go
index 0eaac506..7e05fcbd 100644
--- a/backend/internal/handler/admin/ops_handler.go
+++ b/backend/internal/handler/admin/ops_handler.go
@@ -1,9 +1,7 @@
package admin
import (
- "errors"
"fmt"
- "io"
"net/http"
"strconv"
"strings"
@@ -386,79 +384,6 @@ func (h *OpsHandler) ListRequestErrorUpstreamErrors(c *gin.Context) {
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
-// RetryRequestErrorClient retries the client request based on stored request body.
-// POST /api/v1/admin/ops/request-errors/:id/retry-client
-func (h *OpsHandler) RetryRequestErrorClient(c *gin.Context) {
- if h.opsService == nil {
- response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
- return
- }
- if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
- response.ErrorFrom(c, err)
- return
- }
-
- subject, ok := middleware.GetAuthSubjectFromContext(c)
- if !ok || subject.UserID <= 0 {
- response.Error(c, http.StatusUnauthorized, "Unauthorized")
- return
- }
-
- idStr := strings.TrimSpace(c.Param("id"))
- id, err := strconv.ParseInt(idStr, 10, 64)
- if err != nil || id <= 0 {
- response.BadRequest(c, "Invalid error id")
- return
- }
-
- result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeClient, nil)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, result)
-}
-
-// RetryRequestErrorUpstreamEvent retries a specific upstream attempt using captured upstream_request_body.
-// POST /api/v1/admin/ops/request-errors/:id/upstream-errors/:idx/retry
-func (h *OpsHandler) RetryRequestErrorUpstreamEvent(c *gin.Context) {
- if h.opsService == nil {
- response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
- return
- }
- if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
- response.ErrorFrom(c, err)
- return
- }
-
- subject, ok := middleware.GetAuthSubjectFromContext(c)
- if !ok || subject.UserID <= 0 {
- response.Error(c, http.StatusUnauthorized, "Unauthorized")
- return
- }
-
- idStr := strings.TrimSpace(c.Param("id"))
- id, err := strconv.ParseInt(idStr, 10, 64)
- if err != nil || id <= 0 {
- response.BadRequest(c, "Invalid error id")
- return
- }
-
- idxStr := strings.TrimSpace(c.Param("idx"))
- idx, err := strconv.Atoi(idxStr)
- if err != nil || idx < 0 {
- response.BadRequest(c, "Invalid upstream idx")
- return
- }
-
- result, err := h.opsService.RetryUpstreamEvent(c.Request.Context(), subject.UserID, id, idx)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, result)
-}
-
// ResolveRequestError toggles resolved status.
// PUT /api/v1/admin/ops/request-errors/:id/resolve
func (h *OpsHandler) ResolveRequestError(c *gin.Context) {
@@ -566,39 +491,6 @@ func (h *OpsHandler) GetUpstreamError(c *gin.Context) {
h.GetErrorLogByID(c)
}
-// RetryUpstreamError retries upstream error using the original account_id.
-// POST /api/v1/admin/ops/upstream-errors/:id/retry
-func (h *OpsHandler) RetryUpstreamError(c *gin.Context) {
- if h.opsService == nil {
- response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
- return
- }
- if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
- response.ErrorFrom(c, err)
- return
- }
-
- subject, ok := middleware.GetAuthSubjectFromContext(c)
- if !ok || subject.UserID <= 0 {
- response.Error(c, http.StatusUnauthorized, "Unauthorized")
- return
- }
-
- idStr := strings.TrimSpace(c.Param("id"))
- id, err := strconv.ParseInt(idStr, 10, 64)
- if err != nil || id <= 0 {
- response.BadRequest(c, "Invalid error id")
- return
- }
-
- result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeUpstream, nil)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, result)
-}
-
// ResolveUpstreamError toggles resolved status.
// PUT /api/v1/admin/ops/upstream-errors/:id/resolve
func (h *OpsHandler) ResolveUpstreamError(c *gin.Context) {
@@ -708,106 +600,10 @@ func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
response.Paginated(c, out.Items, out.Total, out.Page, out.PageSize)
}
-type opsRetryRequest struct {
- Mode string `json:"mode"`
- PinnedAccountID *int64 `json:"pinned_account_id"`
- Force bool `json:"force"`
-}
-
type opsResolveRequest struct {
Resolved bool `json:"resolved"`
}
-// RetryErrorRequest retries a failed request using stored request_body.
-// POST /api/v1/admin/ops/errors/:id/retry
-func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
- if h.opsService == nil {
- response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
- return
- }
- if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
- response.ErrorFrom(c, err)
- return
- }
-
- subject, ok := middleware.GetAuthSubjectFromContext(c)
- if !ok || subject.UserID <= 0 {
- response.Error(c, http.StatusUnauthorized, "Unauthorized")
- return
- }
-
- idStr := strings.TrimSpace(c.Param("id"))
- id, err := strconv.ParseInt(idStr, 10, 64)
- if err != nil || id <= 0 {
- response.BadRequest(c, "Invalid error id")
- return
- }
-
- req := opsRetryRequest{Mode: service.OpsRetryModeClient}
- if err := c.ShouldBindJSON(&req); err != nil && !errors.Is(err, io.EOF) {
- response.BadRequest(c, "Invalid request: "+err.Error())
- return
- }
- if strings.TrimSpace(req.Mode) == "" {
- req.Mode = service.OpsRetryModeClient
- }
-
- // Force flag is currently a UI-level acknowledgement. Server may still enforce safety constraints.
- _ = req.Force
-
- // Legacy endpoint safety: only allow retrying the client request here.
- // Upstream retries must go through the split endpoints.
- if strings.EqualFold(strings.TrimSpace(req.Mode), service.OpsRetryModeUpstream) {
- response.BadRequest(c, "upstream retry is not supported on this endpoint")
- return
- }
-
- result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, req.Mode, req.PinnedAccountID)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
-
- response.Success(c, result)
-}
-
-// ListRetryAttempts lists retry attempts for an error log.
-// GET /api/v1/admin/ops/errors/:id/retries
-func (h *OpsHandler) ListRetryAttempts(c *gin.Context) {
- if h.opsService == nil {
- response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
- return
- }
- if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
- response.ErrorFrom(c, err)
- return
- }
-
- idStr := strings.TrimSpace(c.Param("id"))
- id, err := strconv.ParseInt(idStr, 10, 64)
- if err != nil || id <= 0 {
- response.BadRequest(c, "Invalid error id")
- return
- }
-
- limit := 50
- if v := strings.TrimSpace(c.Query("limit")); v != "" {
- n, err := strconv.Atoi(v)
- if err != nil || n <= 0 {
- response.BadRequest(c, "Invalid limit")
- return
- }
- limit = n
- }
-
- items, err := h.opsService.ListRetryAttemptsByErrorID(c.Request.Context(), id, limit)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, items)
-}
-
// UpdateErrorResolution allows manual resolve/unresolve.
// PUT /api/v1/admin/ops/errors/:id/resolve
func (h *OpsHandler) UpdateErrorResolution(c *gin.Context) {
@@ -839,7 +635,7 @@ func (h *OpsHandler) UpdateErrorResolution(c *gin.Context) {
return
}
uid := subject.UserID
- if err := h.opsService.UpdateErrorResolution(c.Request.Context(), id, req.Resolved, &uid, nil); err != nil {
+ if err := h.opsService.UpdateErrorResolution(c.Request.Context(), id, req.Resolved, &uid); err != nil {
response.ErrorFrom(c, err)
return
}
diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go
index 24365f3d..7b4300b1 100644
--- a/backend/internal/handler/admin/redeem_handler.go
+++ b/backend/internal/handler/admin/redeem_handler.go
@@ -8,6 +8,7 @@ import (
"fmt"
"strconv"
"strings"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
@@ -33,23 +34,51 @@ func NewRedeemHandler(adminService service.AdminService, redeemService *service.
// GenerateRedeemCodesRequest represents generate redeem codes request
type GenerateRedeemCodesRequest struct {
- Count int `json:"count" binding:"required,min=1,max=100"`
- Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
- Value float64 `json:"value"`
- GroupID *int64 `json:"group_id"` // 订阅类型必填
- ValidityDays int `json:"validity_days"` // 订阅类型使用,正数增加/负数退款扣减
+ Count int `json:"count" binding:"required,min=1,max=100"`
+ Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
+ Value float64 `json:"value"`
+ GroupID *int64 `json:"group_id"` // 订阅类型必填
+ ValidityDays int `json:"validity_days"` // 订阅类型使用,正数增加/负数退款扣减
+ ExpiresAt *time.Time `json:"expires_at"`
+ ExpiresInDays *int `json:"expires_in_days" binding:"omitempty,min=1,max=3650"`
}
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance)。
type CreateAndRedeemCodeRequest struct {
- Code string `json:"code" binding:"required,min=3,max=128"`
- Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容)
- Value float64 `json:"value" binding:"required"`
- UserID int64 `json:"user_id" binding:"required,gt=0"`
- GroupID *int64 `json:"group_id"` // subscription 类型必填
- ValidityDays int `json:"validity_days"` // subscription 类型:正数增加,负数退款扣减
- Notes string `json:"notes"`
+ Code string `json:"code" binding:"required,min=3,max=128"`
+ Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容)
+ Value float64 `json:"value" binding:"required"`
+ UserID int64 `json:"user_id" binding:"required,gt=0"`
+ GroupID *int64 `json:"group_id"` // subscription 类型必填
+ ValidityDays int `json:"validity_days"` // subscription 类型:正数增加,负数退款扣减
+ Notes string `json:"notes"`
+ ExpiresAt *time.Time `json:"expires_at"`
+ ExpiresInDays *int `json:"expires_in_days" binding:"omitempty,min=1,max=3650"`
+}
+
+func resolveRedeemCodeExpiresAt(expiresAt *time.Time, expiresInDays *int) (*time.Time, error) {
+ if expiresAt != nil && expiresInDays != nil {
+ return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRY_CONFLICT", "expires_at and expires_in_days cannot both be set")
+ }
+
+ now := time.Now().UTC()
+ if expiresInDays != nil {
+ if *expiresInDays <= 0 {
+ return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRES_IN_DAYS_INVALID", "expires_in_days must be greater than zero")
+ }
+ expires := now.AddDate(0, 0, *expiresInDays)
+ return &expires, nil
+ }
+ if expiresAt == nil {
+ return nil, nil
+ }
+
+ expires := expiresAt.UTC()
+ if !expires.After(now) {
+ return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRES_AT_INVALID", "expires_at must be in the future")
+ }
+ return &expires, nil
}
// List handles listing all redeem codes with pagination
@@ -107,6 +136,12 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
return
}
+ expiresAt, err := resolveRedeemCodeExpiresAt(req.ExpiresAt, req.ExpiresInDays)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{
Count: req.Count,
@@ -114,6 +149,7 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
Value: req.Value,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
+ ExpiresAt: expiresAt,
})
if execErr != nil {
return nil, execErr
@@ -158,6 +194,12 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
}
}
+ expiresAt, err := resolveRedeemCodeExpiresAt(req.ExpiresAt, req.ExpiresInDays)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
existing, err := h.redeemService.GetByCode(ctx, req.Code)
if err == nil {
@@ -175,6 +217,7 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
Notes: req.Notes,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
+ ExpiresAt: expiresAt,
})
if createErr != nil {
// Unique code race: if code now exists, use idempotent semantics by used_by.
@@ -199,6 +242,9 @@ func (h *RedeemHandler) resolveCreateAndRedeemExisting(ctx context.Context, exis
}
// If previous run created the code but crashed before redeem, redeem it now.
+ if existing.IsExpired() {
+ return nil, service.ErrRedeemCodeExpired
+ }
if existing.CanUse() {
redeemed, err := h.redeemService.Redeem(ctx, userID, existing.Code)
if err == nil {
@@ -321,7 +367,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
writer := csv.NewWriter(&buf)
// Write header
- if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "created_at"}); err != nil {
+ if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "expires_at", "created_at"}); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
@@ -340,6 +386,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
if code.UsedAt != nil {
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
}
+ expiresAt := ""
+ if code.ExpiresAt != nil {
+ expiresAt = code.ExpiresAt.Format("2006-01-02 15:04:05")
+ }
if err := writer.Write([]string{
fmt.Sprintf("%d", code.ID),
code.Code,
@@ -349,6 +399,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
usedBy,
usedByEmail,
usedAt,
+ expiresAt,
code.CreatedAt.Format("2006-01-02 15:04:05"),
}); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
diff --git a/backend/internal/handler/admin/redeem_handler_test.go b/backend/internal/handler/admin/redeem_handler_test.go
index f1f7778f..d6972460 100644
--- a/backend/internal/handler/admin/redeem_handler_test.go
+++ b/backend/internal/handler/admin/redeem_handler_test.go
@@ -6,6 +6,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -139,3 +140,33 @@ func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
assert.NotEqual(t, http.StatusBadRequest, code,
"balance type should not require group_id or validity_days")
}
+
+func TestResolveRedeemCodeExpiresAt_FromDays(t *testing.T) {
+ days := 3
+ expiresAt, err := resolveRedeemCodeExpiresAt(nil, &days)
+ require.NoError(t, err)
+ require.NotNil(t, expiresAt)
+ require.WithinDuration(t, time.Now().UTC().AddDate(0, 0, days), *expiresAt, 2*time.Second)
+}
+
+func TestResolveRedeemCodeExpiresAt_RejectsPastAbsoluteTime(t *testing.T) {
+ past := time.Now().UTC().Add(-time.Minute)
+ expiresAt, err := resolveRedeemCodeExpiresAt(&past, nil)
+ require.Error(t, err)
+ require.Nil(t, expiresAt)
+}
+
+func TestResolveRedeemCodeExpiresAt_RejectsNonPositiveDays(t *testing.T) {
+ days := 0
+ expiresAt, err := resolveRedeemCodeExpiresAt(nil, &days)
+ require.Error(t, err)
+ require.Nil(t, expiresAt)
+}
+
+func TestResolveRedeemCodeExpiresAt_RejectsConflictingInputs(t *testing.T) {
+ future := time.Now().UTC().Add(time.Hour)
+ days := 3
+ expiresAt, err := resolveRedeemCodeExpiresAt(&future, &days)
+ require.Error(t, err)
+ require.Nil(t, expiresAt)
+}
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index 0ea664d8..9907d441 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -1,9 +1,11 @@
package admin
import (
+ "context"
"crypto/rand"
"encoding/hex"
"encoding/json"
+ "errors"
"fmt"
"log/slog"
"net/http"
@@ -60,10 +62,11 @@ type SettingHandler struct {
opsService *service.OpsService
paymentConfigService *service.PaymentConfigService
paymentService *service.PaymentService
+ userAttributeService *service.UserAttributeService
}
// NewSettingHandler 创建系统设置处理器
-func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService) *SettingHandler {
+func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService, userAttributeService *service.UserAttributeService) *SettingHandler {
return &SettingHandler{
settingService: settingService,
emailService: emailService,
@@ -71,6 +74,7 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser
opsService: opsService,
paymentConfigService: paymentConfigService,
paymentService: paymentService,
+ userAttributeService: userAttributeService,
}
}
@@ -135,6 +139,22 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
+ DingTalkConnectEnabled: settings.DingTalkConnectEnabled,
+ DingTalkConnectClientID: settings.DingTalkConnectClientID,
+ DingTalkConnectClientSecretConfigured: settings.DingTalkConnectClientSecretConfigured,
+ DingTalkConnectRedirectURL: settings.DingTalkConnectRedirectURL,
+ DingTalkConnectCorpRestrictionPolicy: settings.DingTalkConnectCorpRestrictionPolicy,
+ DingTalkConnectInternalCorpID: settings.DingTalkConnectInternalCorpID,
+ DingTalkConnectBypassRegistration: settings.DingTalkConnectBypassRegistration,
+ DingTalkConnectSyncCorpEmail: settings.DingTalkConnectSyncCorpEmail,
+ DingTalkConnectSyncDisplayName: settings.DingTalkConnectSyncDisplayName,
+ DingTalkConnectSyncDept: settings.DingTalkConnectSyncDept,
+ DingTalkConnectSyncCorpEmailAttrKey: settings.DingTalkConnectSyncCorpEmailAttrKey,
+ DingTalkConnectSyncDisplayNameAttrKey: settings.DingTalkConnectSyncDisplayNameAttrKey,
+ DingTalkConnectSyncDeptAttrKey: settings.DingTalkConnectSyncDeptAttrKey,
+ DingTalkConnectSyncCorpEmailAttrName: settings.DingTalkConnectSyncCorpEmailAttrName,
+ DingTalkConnectSyncDisplayNameAttrName: settings.DingTalkConnectSyncDisplayNameAttrName,
+ DingTalkConnectSyncDeptAttrName: settings.DingTalkConnectSyncDeptAttrName,
WeChatConnectEnabled: settings.WeChatConnectEnabled,
WeChatConnectAppID: settings.WeChatConnectAppID,
WeChatConnectAppSecretConfigured: settings.WeChatConnectAppSecretConfigured,
@@ -258,6 +278,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
+ PaymentAlipayForceQRCode: paymentCfg.AlipayForceQRCode,
ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
@@ -376,6 +397,24 @@ type UpdateSettingsRequest struct {
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
+ // DingTalk Connect OAuth 登录
+ DingTalkConnectEnabled bool `json:"dingtalk_connect_enabled"`
+ DingTalkConnectClientID string `json:"dingtalk_connect_client_id"`
+ DingTalkConnectClientSecret string `json:"dingtalk_connect_client_secret"`
+ DingTalkConnectRedirectURL string `json:"dingtalk_connect_redirect_url"`
+ DingTalkConnectCorpRestrictionPolicy string `json:"dingtalk_connect_corp_restriction_policy"`
+ DingTalkConnectInternalCorpID string `json:"dingtalk_connect_internal_corp_id"`
+ DingTalkConnectBypassRegistration bool `json:"dingtalk_connect_bypass_registration"`
+ DingTalkConnectSyncCorpEmail bool `json:"dingtalk_connect_sync_corp_email"`
+ DingTalkConnectSyncDisplayName bool `json:"dingtalk_connect_sync_display_name"`
+ DingTalkConnectSyncDept bool `json:"dingtalk_connect_sync_dept"`
+ DingTalkConnectSyncCorpEmailAttrKey string `json:"dingtalk_connect_sync_corp_email_attr_key"`
+ DingTalkConnectSyncDisplayNameAttrKey string `json:"dingtalk_connect_sync_display_name_attr_key"`
+ DingTalkConnectSyncDeptAttrKey string `json:"dingtalk_connect_sync_dept_attr_key"`
+ DingTalkConnectSyncCorpEmailAttrName string `json:"dingtalk_connect_sync_corp_email_attr_name"`
+ DingTalkConnectSyncDisplayNameAttrName string `json:"dingtalk_connect_sync_display_name_attr_name"`
+ DingTalkConnectSyncDeptAttrName string `json:"dingtalk_connect_sync_dept_attr_name"`
+
// WeChat Connect OAuth 登录
WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
WeChatConnectAppID string `json:"wechat_connect_app_id"`
@@ -446,45 +485,50 @@ type UpdateSettingsRequest struct {
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
// 默认配置
- DefaultConcurrency int `json:"default_concurrency"`
- DefaultBalance float64 `json:"default_balance"`
- AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
- AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
- AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
- AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
- DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
- DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
- AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
- AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
- AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
- AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
- AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
- AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
- AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
- AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
- AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
- AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
- AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
- AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
- AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
- AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
- AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
- AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
- AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
- AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
- AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
- AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
- AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"`
- AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"`
- AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"`
- AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"`
- AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"`
- AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"`
- AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"`
- AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"`
- AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"`
- AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"`
- ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
+ DefaultConcurrency int `json:"default_concurrency"`
+ DefaultBalance float64 `json:"default_balance"`
+ AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
+ AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
+ AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
+ AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
+ DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
+ DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
+ AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
+ AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
+ AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
+ AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
+ AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
+ AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
+ AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
+ AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
+ AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
+ AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
+ AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
+ AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
+ AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
+ AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
+ AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
+ AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
+ AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
+ AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
+ AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
+ AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
+ AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"`
+ AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"`
+ AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"`
+ AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"`
+ AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"`
+ AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"`
+ AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"`
+ AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"`
+ AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"`
+ AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"`
+ AuthSourceDefaultDingTalkBalance *float64 `json:"auth_source_default_dingtalk_balance"`
+ AuthSourceDefaultDingTalkConcurrency *int `json:"auth_source_default_dingtalk_concurrency"`
+ AuthSourceDefaultDingTalkSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_dingtalk_subscriptions"`
+ AuthSourceDefaultDingTalkGrantOnSignup *bool `json:"auth_source_default_dingtalk_grant_on_signup"`
+ AuthSourceDefaultDingTalkGrantOnFirstBind *bool `json:"auth_source_default_dingtalk_grant_on_first_bind"`
+ ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -560,6 +604,9 @@ type UpdateSettingsRequest struct {
PaymentCancelRateLimitUnit *string `json:"payment_cancel_rate_limit_unit"`
PaymentCancelRateLimitMode *string `json:"payment_cancel_rate_limit_window_mode"`
+ // Force Alipay mobile clients to use QR code payment instead of mobile redirect
+ PaymentAlipayForceQRCode *bool `json:"payment_alipay_force_qrcode"`
+
// Channel Monitor feature switch
ChannelMonitorEnabled *bool `json:"channel_monitor_enabled"`
ChannelMonitorDefaultIntervalSeconds *int `json:"channel_monitor_default_interval_seconds"`
@@ -661,6 +708,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions)
req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions)
req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions)
+ req.AuthSourceDefaultDingTalkSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultDingTalkSubscriptions)
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
@@ -777,6 +825,100 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
+ // DingTalk Connect 参数验证
+ // 防御性:任何写入路径上把已废弃的 corp_restriction_policy=whitelist 入参 coerce 为 none,
+ // 避免任何直连 admin API 的客户端把死值写回 DB(前端 UI 已无此选项)。
+ req.DingTalkConnectCorpRestrictionPolicy = service.CoerceDingTalkCorpPolicyForWrite(req.DingTalkConnectCorpRestrictionPolicy)
+
+ if req.DingTalkConnectEnabled {
+ req.DingTalkConnectClientID = strings.TrimSpace(req.DingTalkConnectClientID)
+ req.DingTalkConnectClientSecret = strings.TrimSpace(req.DingTalkConnectClientSecret)
+ req.DingTalkConnectRedirectURL = strings.TrimSpace(req.DingTalkConnectRedirectURL)
+ req.DingTalkConnectCorpRestrictionPolicy = strings.TrimSpace(req.DingTalkConnectCorpRestrictionPolicy)
+ req.DingTalkConnectInternalCorpID = strings.TrimSpace(req.DingTalkConnectInternalCorpID)
+
+ if req.DingTalkConnectClientID == "" {
+ response.BadRequest(c, "DingTalk Client ID is required when enabled")
+ return
+ }
+ if req.DingTalkConnectRedirectURL == "" {
+ response.BadRequest(c, "DingTalk Redirect URL is required when enabled")
+ return
+ }
+ if err := config.ValidateAbsoluteHTTPURL(req.DingTalkConnectRedirectURL); err != nil {
+ response.BadRequest(c, "DingTalk Redirect URL must be an absolute http(s) URL")
+ return
+ }
+
+ // 如果未提供 client_secret,则保留现有值(如有)。
+ if req.DingTalkConnectClientSecret == "" {
+ if previousSettings.DingTalkConnectClientSecret == "" {
+ response.BadRequest(c, "DingTalk Client Secret is required when enabled")
+ return
+ }
+ req.DingTalkConnectClientSecret = previousSettings.DingTalkConnectClientSecret
+ }
+
+ // Corp 策略校验(V1/V4 fail-closed)
+ dingTalkCfg := config.DingTalkConnectConfig{
+ Enabled: true,
+ DingTalkAppKind: "internal_app", // 硬编码:settings 层仅支持 internal_app
+ AppType: "internal", // 对于 internal_only 策略的默认值
+ CorpRestrictionPolicy: req.DingTalkConnectCorpRestrictionPolicy,
+ InternalCorpID: req.DingTalkConnectInternalCorpID,
+ }
+ // 若未填 corp_restriction_policy,保留已有配置
+ if dingTalkCfg.CorpRestrictionPolicy == "" {
+ dingTalkCfg.CorpRestrictionPolicy = previousSettings.DingTalkConnectCorpRestrictionPolicy
+ }
+ // 对于 internal_only 策略,app_type 必须为 internal(V1 校验)
+ if dingTalkCfg.CorpRestrictionPolicy == "internal_only" {
+ dingTalkCfg.AppType = "internal"
+ } else {
+ dingTalkCfg.AppType = "public"
+ }
+ if err := config.ValidateDingTalkConfig(dingTalkCfg); err != nil {
+ response.ErrorWithDetails(c, http.StatusBadRequest, err.Error(), mapDingTalkValidateError(err), nil)
+ return
+ }
+
+ // bypass_registration 仅在 internal_only 模式下有意义;其它策略下强制为 false,
+ // 防止 admin 在切换 policy 时把 bypass 残留在 DB 中(前端 UI 也已隐藏该开关)。
+ if dingTalkCfg.CorpRestrictionPolicy != "internal_only" {
+ req.DingTalkConnectBypassRegistration = false
+ // 身份同步三开关同理:仅 internal_only 模式下有意义,其它策略强制 false。
+ req.DingTalkConnectSyncCorpEmail = false
+ req.DingTalkConnectSyncDisplayName = false
+ req.DingTalkConnectSyncDept = false
+ }
+ // 身份同步目标 attr key:trimSpace + 空值 fallback 到默认值
+ req.DingTalkConnectSyncCorpEmailAttrKey = strings.TrimSpace(req.DingTalkConnectSyncCorpEmailAttrKey)
+ if req.DingTalkConnectSyncCorpEmailAttrKey == "" {
+ req.DingTalkConnectSyncCorpEmailAttrKey = "dingtalk_email"
+ }
+ req.DingTalkConnectSyncDisplayNameAttrKey = strings.TrimSpace(req.DingTalkConnectSyncDisplayNameAttrKey)
+ if req.DingTalkConnectSyncDisplayNameAttrKey == "" {
+ req.DingTalkConnectSyncDisplayNameAttrKey = "dingtalk_name"
+ }
+ req.DingTalkConnectSyncDeptAttrKey = strings.TrimSpace(req.DingTalkConnectSyncDeptAttrKey)
+ if req.DingTalkConnectSyncDeptAttrKey == "" {
+ req.DingTalkConnectSyncDeptAttrKey = "dingtalk_department"
+ }
+ // 身份同步目标 attr 显示名称:trim + 空值 fallback 到默认中文名
+ req.DingTalkConnectSyncCorpEmailAttrName = strings.TrimSpace(req.DingTalkConnectSyncCorpEmailAttrName)
+ if req.DingTalkConnectSyncCorpEmailAttrName == "" {
+ req.DingTalkConnectSyncCorpEmailAttrName = "钉钉企业邮箱"
+ }
+ req.DingTalkConnectSyncDisplayNameAttrName = strings.TrimSpace(req.DingTalkConnectSyncDisplayNameAttrName)
+ if req.DingTalkConnectSyncDisplayNameAttrName == "" {
+ req.DingTalkConnectSyncDisplayNameAttrName = "钉钉姓名"
+ }
+ req.DingTalkConnectSyncDeptAttrName = strings.TrimSpace(req.DingTalkConnectSyncDeptAttrName)
+ if req.DingTalkConnectSyncDeptAttrName == "" {
+ req.DingTalkConnectSyncDeptAttrName = "钉钉部门"
+ }
+ }
+
if req.WeChatConnectEnabled {
req.WeChatConnectAppID = strings.TrimSpace(req.WeChatConnectAppID)
req.WeChatConnectAppSecret = strings.TrimSpace(req.WeChatConnectAppSecret)
@@ -1272,113 +1414,129 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
settings := &service.SystemSettings{
- RegistrationEnabled: req.RegistrationEnabled,
- EmailVerifyEnabled: req.EmailVerifyEnabled,
- RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
- PromoCodeEnabled: req.PromoCodeEnabled,
- PasswordResetEnabled: req.PasswordResetEnabled,
- FrontendURL: req.FrontendURL,
- InvitationCodeEnabled: req.InvitationCodeEnabled,
- TotpEnabled: req.TotpEnabled,
- LoginAgreementEnabled: req.LoginAgreementEnabled,
- LoginAgreementMode: loginAgreementMode,
- LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
- LoginAgreementDocuments: loginAgreementDocuments,
- SMTPHost: req.SMTPHost,
- SMTPPort: req.SMTPPort,
- SMTPUsername: req.SMTPUsername,
- SMTPPassword: req.SMTPPassword,
- SMTPFrom: req.SMTPFrom,
- SMTPFromName: req.SMTPFromName,
- SMTPUseTLS: req.SMTPUseTLS,
- TurnstileEnabled: req.TurnstileEnabled,
- TurnstileSiteKey: req.TurnstileSiteKey,
- TurnstileSecretKey: req.TurnstileSecretKey,
- LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
- LinuxDoConnectClientID: req.LinuxDoConnectClientID,
- LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
- LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
- WeChatConnectEnabled: req.WeChatConnectEnabled,
- WeChatConnectAppID: req.WeChatConnectAppID,
- WeChatConnectAppSecret: req.WeChatConnectAppSecret,
- WeChatConnectOpenAppID: req.WeChatConnectOpenAppID,
- WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret,
- WeChatConnectMPAppID: req.WeChatConnectMPAppID,
- WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret,
- WeChatConnectMobileAppID: req.WeChatConnectMobileAppID,
- WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret,
- WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled,
- WeChatConnectMPEnabled: req.WeChatConnectMPEnabled,
- WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled,
- WeChatConnectMode: req.WeChatConnectMode,
- WeChatConnectScopes: req.WeChatConnectScopes,
- WeChatConnectRedirectURL: req.WeChatConnectRedirectURL,
- WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL,
- OIDCConnectEnabled: req.OIDCConnectEnabled,
- OIDCConnectProviderName: req.OIDCConnectProviderName,
- OIDCConnectClientID: req.OIDCConnectClientID,
- OIDCConnectClientSecret: req.OIDCConnectClientSecret,
- OIDCConnectIssuerURL: req.OIDCConnectIssuerURL,
- OIDCConnectDiscoveryURL: req.OIDCConnectDiscoveryURL,
- OIDCConnectAuthorizeURL: req.OIDCConnectAuthorizeURL,
- OIDCConnectTokenURL: req.OIDCConnectTokenURL,
- OIDCConnectUserInfoURL: req.OIDCConnectUserInfoURL,
- OIDCConnectJWKSURL: req.OIDCConnectJWKSURL,
- OIDCConnectScopes: req.OIDCConnectScopes,
- OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
- OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
- OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
- OIDCConnectUsePKCE: oidcUsePKCE,
- OIDCConnectValidateIDToken: oidcValidateIDToken,
- OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
- OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
- OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,
- OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
- OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
- OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
- GitHubOAuthEnabled: req.GitHubOAuthEnabled,
- GitHubOAuthClientID: req.GitHubOAuthClientID,
- GitHubOAuthClientSecret: req.GitHubOAuthClientSecret,
- GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL,
- GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL,
- GoogleOAuthEnabled: req.GoogleOAuthEnabled,
- GoogleOAuthClientID: req.GoogleOAuthClientID,
- GoogleOAuthClientSecret: req.GoogleOAuthClientSecret,
- GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL,
- GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL,
- SiteName: req.SiteName,
- SiteLogo: req.SiteLogo,
- SiteSubtitle: req.SiteSubtitle,
- APIBaseURL: req.APIBaseURL,
- ContactInfo: req.ContactInfo,
- DocURL: req.DocURL,
- HomeContent: req.HomeContent,
- HideCcsImportButton: req.HideCcsImportButton,
- PurchaseSubscriptionEnabled: purchaseEnabled,
- PurchaseSubscriptionURL: purchaseURL,
- TableDefaultPageSize: req.TableDefaultPageSize,
- TablePageSizeOptions: req.TablePageSizeOptions,
- CustomMenuItems: customMenuJSON,
- CustomEndpoints: customEndpointsJSON,
- DefaultConcurrency: req.DefaultConcurrency,
- DefaultBalance: req.DefaultBalance,
- AffiliateRebateRate: affiliateRebateRate,
- AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
- AffiliateRebateDurationDays: affiliateRebateDurationDays,
- AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
- DefaultUserRPMLimit: req.DefaultUserRPMLimit,
- DefaultSubscriptions: defaultSubscriptions,
- EnableModelFallback: req.EnableModelFallback,
- FallbackModelAnthropic: req.FallbackModelAnthropic,
- FallbackModelOpenAI: req.FallbackModelOpenAI,
- FallbackModelGemini: req.FallbackModelGemini,
- FallbackModelAntigravity: req.FallbackModelAntigravity,
- EnableIdentityPatch: req.EnableIdentityPatch,
- IdentityPatchPrompt: req.IdentityPatchPrompt,
- MinClaudeCodeVersion: req.MinClaudeCodeVersion,
- MaxClaudeCodeVersion: req.MaxClaudeCodeVersion,
- AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
- BackendModeEnabled: req.BackendModeEnabled,
+ RegistrationEnabled: req.RegistrationEnabled,
+ EmailVerifyEnabled: req.EmailVerifyEnabled,
+ RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
+ PromoCodeEnabled: req.PromoCodeEnabled,
+ PasswordResetEnabled: req.PasswordResetEnabled,
+ FrontendURL: req.FrontendURL,
+ InvitationCodeEnabled: req.InvitationCodeEnabled,
+ TotpEnabled: req.TotpEnabled,
+ LoginAgreementEnabled: req.LoginAgreementEnabled,
+ LoginAgreementMode: loginAgreementMode,
+ LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
+ LoginAgreementDocuments: loginAgreementDocuments,
+ SMTPHost: req.SMTPHost,
+ SMTPPort: req.SMTPPort,
+ SMTPUsername: req.SMTPUsername,
+ SMTPPassword: req.SMTPPassword,
+ SMTPFrom: req.SMTPFrom,
+ SMTPFromName: req.SMTPFromName,
+ SMTPUseTLS: req.SMTPUseTLS,
+ TurnstileEnabled: req.TurnstileEnabled,
+ TurnstileSiteKey: req.TurnstileSiteKey,
+ TurnstileSecretKey: req.TurnstileSecretKey,
+ LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
+ LinuxDoConnectClientID: req.LinuxDoConnectClientID,
+ LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
+ LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
+ DingTalkConnectEnabled: req.DingTalkConnectEnabled,
+ DingTalkConnectClientID: req.DingTalkConnectClientID,
+ DingTalkConnectClientSecret: req.DingTalkConnectClientSecret,
+ DingTalkConnectRedirectURL: req.DingTalkConnectRedirectURL,
+ DingTalkConnectCorpRestrictionPolicy: req.DingTalkConnectCorpRestrictionPolicy,
+ DingTalkConnectInternalCorpID: req.DingTalkConnectInternalCorpID,
+ DingTalkConnectBypassRegistration: req.DingTalkConnectBypassRegistration,
+ DingTalkConnectSyncCorpEmail: req.DingTalkConnectSyncCorpEmail,
+ DingTalkConnectSyncDisplayName: req.DingTalkConnectSyncDisplayName,
+ DingTalkConnectSyncDept: req.DingTalkConnectSyncDept,
+ DingTalkConnectSyncCorpEmailAttrKey: req.DingTalkConnectSyncCorpEmailAttrKey,
+ DingTalkConnectSyncDisplayNameAttrKey: req.DingTalkConnectSyncDisplayNameAttrKey,
+ DingTalkConnectSyncDeptAttrKey: req.DingTalkConnectSyncDeptAttrKey,
+ DingTalkConnectSyncCorpEmailAttrName: req.DingTalkConnectSyncCorpEmailAttrName,
+ DingTalkConnectSyncDisplayNameAttrName: req.DingTalkConnectSyncDisplayNameAttrName,
+ DingTalkConnectSyncDeptAttrName: req.DingTalkConnectSyncDeptAttrName,
+ WeChatConnectEnabled: req.WeChatConnectEnabled,
+ WeChatConnectAppID: req.WeChatConnectAppID,
+ WeChatConnectAppSecret: req.WeChatConnectAppSecret,
+ WeChatConnectOpenAppID: req.WeChatConnectOpenAppID,
+ WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret,
+ WeChatConnectMPAppID: req.WeChatConnectMPAppID,
+ WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret,
+ WeChatConnectMobileAppID: req.WeChatConnectMobileAppID,
+ WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret,
+ WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled,
+ WeChatConnectMPEnabled: req.WeChatConnectMPEnabled,
+ WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled,
+ WeChatConnectMode: req.WeChatConnectMode,
+ WeChatConnectScopes: req.WeChatConnectScopes,
+ WeChatConnectRedirectURL: req.WeChatConnectRedirectURL,
+ WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL,
+ OIDCConnectEnabled: req.OIDCConnectEnabled,
+ OIDCConnectProviderName: req.OIDCConnectProviderName,
+ OIDCConnectClientID: req.OIDCConnectClientID,
+ OIDCConnectClientSecret: req.OIDCConnectClientSecret,
+ OIDCConnectIssuerURL: req.OIDCConnectIssuerURL,
+ OIDCConnectDiscoveryURL: req.OIDCConnectDiscoveryURL,
+ OIDCConnectAuthorizeURL: req.OIDCConnectAuthorizeURL,
+ OIDCConnectTokenURL: req.OIDCConnectTokenURL,
+ OIDCConnectUserInfoURL: req.OIDCConnectUserInfoURL,
+ OIDCConnectJWKSURL: req.OIDCConnectJWKSURL,
+ OIDCConnectScopes: req.OIDCConnectScopes,
+ OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
+ OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
+ OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
+ OIDCConnectUsePKCE: oidcUsePKCE,
+ OIDCConnectValidateIDToken: oidcValidateIDToken,
+ OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
+ OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
+ OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,
+ OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
+ OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
+ OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
+ GitHubOAuthEnabled: req.GitHubOAuthEnabled,
+ GitHubOAuthClientID: req.GitHubOAuthClientID,
+ GitHubOAuthClientSecret: req.GitHubOAuthClientSecret,
+ GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL,
+ GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL,
+ GoogleOAuthEnabled: req.GoogleOAuthEnabled,
+ GoogleOAuthClientID: req.GoogleOAuthClientID,
+ GoogleOAuthClientSecret: req.GoogleOAuthClientSecret,
+ GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL,
+ GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL,
+ SiteName: req.SiteName,
+ SiteLogo: req.SiteLogo,
+ SiteSubtitle: req.SiteSubtitle,
+ APIBaseURL: req.APIBaseURL,
+ ContactInfo: req.ContactInfo,
+ DocURL: req.DocURL,
+ HomeContent: req.HomeContent,
+ HideCcsImportButton: req.HideCcsImportButton,
+ PurchaseSubscriptionEnabled: purchaseEnabled,
+ PurchaseSubscriptionURL: purchaseURL,
+ TableDefaultPageSize: req.TableDefaultPageSize,
+ TablePageSizeOptions: req.TablePageSizeOptions,
+ CustomMenuItems: customMenuJSON,
+ CustomEndpoints: customEndpointsJSON,
+ DefaultConcurrency: req.DefaultConcurrency,
+ DefaultBalance: req.DefaultBalance,
+ AffiliateRebateRate: affiliateRebateRate,
+ AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
+ AffiliateRebateDurationDays: affiliateRebateDurationDays,
+ AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
+ DefaultUserRPMLimit: req.DefaultUserRPMLimit,
+ DefaultSubscriptions: defaultSubscriptions,
+ EnableModelFallback: req.EnableModelFallback,
+ FallbackModelAnthropic: req.FallbackModelAnthropic,
+ FallbackModelOpenAI: req.FallbackModelOpenAI,
+ FallbackModelGemini: req.FallbackModelGemini,
+ FallbackModelAntigravity: req.FallbackModelAntigravity,
+ EnableIdentityPatch: req.EnableIdentityPatch,
+ IdentityPatchPrompt: req.IdentityPatchPrompt,
+ MinClaudeCodeVersion: req.MinClaudeCodeVersion,
+ MaxClaudeCodeVersion: req.MaxClaudeCodeVersion,
+ AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
+ BackendModeEnabled: req.BackendModeEnabled,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
@@ -1574,6 +1732,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind),
},
+ DingTalk: service.ProviderDefaultGrantSettings{
+ Balance: float64ValueOrDefault(req.AuthSourceDefaultDingTalkBalance, previousAuthSourceDefaults.DingTalk.Balance),
+ Concurrency: intValueOrDefault(req.AuthSourceDefaultDingTalkConcurrency, previousAuthSourceDefaults.DingTalk.Concurrency),
+ Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultDingTalkSubscriptions, previousAuthSourceDefaults.DingTalk.Subscriptions),
+ GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnSignup, previousAuthSourceDefaults.DingTalk.GrantOnSignup),
+ GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnFirstBind, previousAuthSourceDefaults.DingTalk.GrantOnFirstBind),
+ },
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
}
if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil {
@@ -1613,6 +1778,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CancelRateLimitWindow: req.PaymentCancelRateLimitWindow,
CancelRateLimitUnit: req.PaymentCancelRateLimitUnit,
CancelRateLimitMode: req.PaymentCancelRateLimitMode,
+ AlipayForceQRCode: req.PaymentAlipayForceQRCode,
}
if err := h.paymentConfigService.UpdatePaymentConfig(c.Request.Context(), paymentReq); err != nil {
response.ErrorFrom(c, err)
@@ -1632,6 +1798,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
+ h.ensureDingTalkSyncAttributes(c.Request.Context(), updatedSettings)
updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
@@ -1682,6 +1849,22 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
+ DingTalkConnectEnabled: updatedSettings.DingTalkConnectEnabled,
+ DingTalkConnectClientID: updatedSettings.DingTalkConnectClientID,
+ DingTalkConnectClientSecretConfigured: updatedSettings.DingTalkConnectClientSecretConfigured,
+ DingTalkConnectRedirectURL: updatedSettings.DingTalkConnectRedirectURL,
+ DingTalkConnectCorpRestrictionPolicy: updatedSettings.DingTalkConnectCorpRestrictionPolicy,
+ DingTalkConnectInternalCorpID: updatedSettings.DingTalkConnectInternalCorpID,
+ DingTalkConnectBypassRegistration: updatedSettings.DingTalkConnectBypassRegistration,
+ DingTalkConnectSyncCorpEmail: updatedSettings.DingTalkConnectSyncCorpEmail,
+ DingTalkConnectSyncDisplayName: updatedSettings.DingTalkConnectSyncDisplayName,
+ DingTalkConnectSyncDept: updatedSettings.DingTalkConnectSyncDept,
+ DingTalkConnectSyncCorpEmailAttrKey: updatedSettings.DingTalkConnectSyncCorpEmailAttrKey,
+ DingTalkConnectSyncDisplayNameAttrKey: updatedSettings.DingTalkConnectSyncDisplayNameAttrKey,
+ DingTalkConnectSyncDeptAttrKey: updatedSettings.DingTalkConnectSyncDeptAttrKey,
+ DingTalkConnectSyncCorpEmailAttrName: updatedSettings.DingTalkConnectSyncCorpEmailAttrName,
+ DingTalkConnectSyncDisplayNameAttrName: updatedSettings.DingTalkConnectSyncDisplayNameAttrName,
+ DingTalkConnectSyncDeptAttrName: updatedSettings.DingTalkConnectSyncDeptAttrName,
WeChatConnectEnabled: updatedSettings.WeChatConnectEnabled,
WeChatConnectAppID: updatedSettings.WeChatConnectAppID,
WeChatConnectAppSecretConfigured: updatedSettings.WeChatConnectAppSecretConfigured,
@@ -1803,6 +1986,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
+ PaymentAlipayForceQRCode: updatedPaymentCfg.AlipayForceQRCode,
ChannelMonitorEnabled: updatedSettings.ChannelMonitorEnabled,
ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds,
@@ -1822,6 +2006,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
// hasPaymentFields returns true if any payment-related field was explicitly provided.
+// mapDingTalkValidateError maps ValidateDingTalkConfig errors to machine-readable reason codes.
+func mapDingTalkValidateError(err error) string {
+ switch {
+ case errors.Is(err, config.ErrDingTalkV1AppTypeMismatch):
+ return "dingtalk_apptype_mismatch"
+ case errors.Is(err, config.ErrDingTalkV4InvalidAppKind):
+ return "dingtalk_app_kind_invalid"
+ default:
+ return "dingtalk_corp_config_invalid"
+ }
+}
+
func hasPaymentFields(req UpdateSettingsRequest) bool {
return req.PaymentEnabled != nil || req.PaymentMinAmount != nil ||
req.PaymentMaxAmount != nil || req.PaymentDailyLimit != nil ||
@@ -1832,7 +2028,8 @@ func hasPaymentFields(req UpdateSettingsRequest) bool {
req.PaymentProductNameSuffix != nil || req.PaymentHelpImageURL != nil ||
req.PaymentHelpText != nil || req.PaymentCancelRateLimitEnabled != nil ||
req.PaymentCancelRateLimitMax != nil || req.PaymentCancelRateLimitWindow != nil ||
- req.PaymentCancelRateLimitUnit != nil || req.PaymentCancelRateLimitMode != nil
+ req.PaymentCancelRateLimitUnit != nil || req.PaymentCancelRateLimitMode != nil ||
+ req.PaymentAlipayForceQRCode != nil
}
func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) {
@@ -1935,6 +2132,45 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
changed = append(changed, "linuxdo_connect_redirect_url")
}
+ if before.DingTalkConnectEnabled != after.DingTalkConnectEnabled {
+ changed = append(changed, "dingtalk_connect_enabled")
+ }
+ if before.DingTalkConnectClientID != after.DingTalkConnectClientID {
+ changed = append(changed, "dingtalk_connect_client_id")
+ }
+ if req.DingTalkConnectClientSecret != "" {
+ changed = append(changed, "dingtalk_connect_client_secret")
+ }
+ if before.DingTalkConnectRedirectURL != after.DingTalkConnectRedirectURL {
+ changed = append(changed, "dingtalk_connect_redirect_url")
+ }
+ if before.DingTalkConnectCorpRestrictionPolicy != after.DingTalkConnectCorpRestrictionPolicy {
+ changed = append(changed, "dingtalk_connect_corp_restriction_policy")
+ }
+ if before.DingTalkConnectInternalCorpID != after.DingTalkConnectInternalCorpID {
+ changed = append(changed, "dingtalk_connect_internal_corp_id")
+ }
+ if before.DingTalkConnectBypassRegistration != after.DingTalkConnectBypassRegistration {
+ changed = append(changed, "dingtalk_connect_bypass_registration")
+ }
+ if before.DingTalkConnectSyncCorpEmail != after.DingTalkConnectSyncCorpEmail {
+ changed = append(changed, "dingtalk_connect_sync_corp_email")
+ }
+ if before.DingTalkConnectSyncDisplayName != after.DingTalkConnectSyncDisplayName {
+ changed = append(changed, "dingtalk_connect_sync_display_name")
+ }
+ if before.DingTalkConnectSyncDept != after.DingTalkConnectSyncDept {
+ changed = append(changed, "dingtalk_connect_sync_dept")
+ }
+ if before.DingTalkConnectSyncCorpEmailAttrKey != after.DingTalkConnectSyncCorpEmailAttrKey {
+ changed = append(changed, "dingtalk_connect_sync_corp_email_attr_key")
+ }
+ if before.DingTalkConnectSyncDisplayNameAttrKey != after.DingTalkConnectSyncDisplayNameAttrKey {
+ changed = append(changed, "dingtalk_connect_sync_display_name_attr_key")
+ }
+ if before.DingTalkConnectSyncDeptAttrKey != after.DingTalkConnectSyncDeptAttrKey {
+ changed = append(changed, "dingtalk_connect_sync_dept_attr_key")
+ }
if before.WeChatConnectEnabled != after.WeChatConnectEnabled {
changed = append(changed, "wechat_connect_enabled")
}
@@ -2246,6 +2482,7 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource
{name: "wechat", before: before.WeChat, after: after.WeChat},
{name: "github", before: before.GitHub, after: after.GitHub},
{name: "google", before: before.Google, after: after.Google},
+ {name: "dingtalk", before: before.DingTalk, after: after.DingTalk},
}
for _, field := range fields {
if field.before.Balance != field.after.Balance {
@@ -2350,6 +2587,11 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults
data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions
data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup
data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind
+ data["auth_source_default_dingtalk_balance"] = authSourceDefaults.DingTalk.Balance
+ data["auth_source_default_dingtalk_concurrency"] = authSourceDefaults.DingTalk.Concurrency
+ data["auth_source_default_dingtalk_subscriptions"] = authSourceDefaults.DingTalk.Subscriptions
+ data["auth_source_default_dingtalk_grant_on_signup"] = authSourceDefaults.DingTalk.GrantOnSignup
+ data["auth_source_default_dingtalk_grant_on_first_bind"] = authSourceDefaults.DingTalk.GrantOnFirstBind
data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance
data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency
data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions
@@ -3044,3 +3286,56 @@ func (h *SettingHandler) TestWebSearchEmulation(c *gin.Context) {
}
response.Success(c, result)
}
+
+// ensureDingTalkSyncAttributes 在保存 settings 后,按 admin 配置的 (attr key, attr name)
+// 兜底 upsert 对应 user attribute definition:不存在则创建;存在但 name 不同则更新 name
+// (type/options/required 不变)。仅 internal_only + 对应 sync 开关开启时执行。
+// 失败仅记录日志,不阻塞 settings 保存。
+func (h *SettingHandler) ensureDingTalkSyncAttributes(ctx context.Context, settings *service.SystemSettings) {
+ if h.userAttributeService == nil || settings == nil {
+ return
+ }
+ if settings.DingTalkConnectCorpRestrictionPolicy != "internal_only" {
+ return
+ }
+ if settings.DingTalkConnectSyncDisplayName {
+ h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncDisplayNameAttrKey, settings.DingTalkConnectSyncDisplayNameAttrName, "钉钉 internal_only 登录时同步的钉钉姓名", service.AttributeTypeText)
+ }
+ if settings.DingTalkConnectSyncCorpEmail {
+ h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncCorpEmailAttrKey, settings.DingTalkConnectSyncCorpEmailAttrName, "钉钉 internal_only 登录时同步的企业邮箱", service.AttributeTypeEmail)
+ }
+ if settings.DingTalkConnectSyncDept {
+ h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncDeptAttrKey, settings.DingTalkConnectSyncDeptAttrName, "钉钉 internal_only 登录时同步的完整部门路径(如:公司/研发部)", service.AttributeTypeText)
+ }
+}
+
+func (h *SettingHandler) ensureUserAttributeDefinition(ctx context.Context, key, name, description string, attrType service.UserAttributeType) {
+ key = strings.TrimSpace(key)
+ if key == "" {
+ return
+ }
+ existing, err := h.userAttributeService.GetDefinitionByKey(ctx, key)
+ if err == nil && existing != nil {
+ if strings.TrimSpace(name) != "" && existing.Name != name {
+ if _, err := h.userAttributeService.UpdateDefinition(ctx, existing.ID, service.UpdateAttributeDefinitionInput{
+ Name: &name,
+ }); err != nil {
+ slog.Warn("dingtalk: update user attribute definition name failed", "key", key, "err", err.Error())
+ return
+ }
+ slog.Info("dingtalk: updated user attribute definition name", "key", key, "name", name)
+ }
+ return
+ }
+ if _, err := h.userAttributeService.CreateDefinition(ctx, service.CreateAttributeDefinitionInput{
+ Key: key,
+ Name: name,
+ Description: description,
+ Type: attrType,
+ Enabled: true,
+ }); err != nil {
+ slog.Warn("dingtalk: ensure user attribute definition failed", "key", key, "err", err.Error())
+ return
+ }
+ slog.Info("dingtalk: created user attribute definition", "key", key, "name", name, "type", attrType)
+}
diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
index 085fd2ca..f953f767 100644
--- a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
+++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go
@@ -137,7 +137,7 @@ func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) {
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
- handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
@@ -174,7 +174,7 @@ func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *tes
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
- handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"registration_enabled": true,
@@ -214,7 +214,7 @@ func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedS
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
- handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
@@ -264,7 +264,7 @@ func TestSettingHandler_UpdateSettings_PreservesLegacyBlankPaymentVisibleMethodS
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
- handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": false,
@@ -309,7 +309,7 @@ func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFla
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
- handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
@@ -388,7 +388,7 @@ func TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaul
ClockSkewSeconds: 120,
},
})
- handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
@@ -417,7 +417,7 @@ func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
- handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
@@ -450,7 +450,7 @@ func TestSettingHandler_UpdateSettings_DoesNotPersistPartialSystemSettingsWhenAu
err: errors.New("write auth source defaults failed"),
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
- handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"registration_enabled": true,
diff --git a/backend/internal/handler/admin/setting_handler_dingtalk_test.go b/backend/internal/handler/admin/setting_handler_dingtalk_test.go
new file mode 100644
index 00000000..a3d944cc
--- /dev/null
+++ b/backend/internal/handler/admin/setting_handler_dingtalk_test.go
@@ -0,0 +1,319 @@
+package admin
+
+import (
+ "bytes"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+// dingtalkSettingsRepoStub 复用 settingHandlerRepoStub(已在 setting_handler_auth_source_defaults_test.go 定义)
+
+func newDingTalkSettingsHandler() (*SettingHandler, *settingHandlerRepoStub) {
+ repo := &settingHandlerRepoStub{values: map[string]string{}}
+ svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
+ handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
+ return handler, repo
+}
+
+// baseValidDingTalkBody 返回一个可以通过所有校验的最小合法 body。
+func baseValidDingTalkBody() map[string]any {
+ return map[string]any{
+ "dingtalk_connect_enabled": true,
+ "dingtalk_connect_client_id": "test-client-id",
+ "dingtalk_connect_client_secret": "test-client-secret",
+ "dingtalk_connect_redirect_url": "https://example.com/auth/dingtalk/callback",
+ "dingtalk_connect_corp_restriction_policy": "none",
+ }
+}
+
+// TestSettingsPUT_DingTalk_V3_InternalOnlyAllowsEmptyCorpID 验证方案 A:
+// internal_only + internal_corp_id="" 应通过校验(→ 200),不再是 400。
+func TestSettingsPUT_DingTalk_V3_InternalOnlyAllowsEmptyCorpID(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ handler, _ := newDingTalkSettingsHandler()
+
+ body := baseValidDingTalkBody()
+ body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
+ body["dingtalk_connect_internal_corp_id"] = "" // 空值现在合法
+
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+// TestSettingsPUT_DingTalk_HappyPath_None 验证 none policy → 200
+func TestSettingsPUT_DingTalk_HappyPath_None(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ handler, _ := newDingTalkSettingsHandler()
+
+ body := baseValidDingTalkBody()
+ body["dingtalk_connect_corp_restriction_policy"] = "none"
+
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, data["dingtalk_connect_enabled"])
+}
+
+// TestSettingsPUT_DingTalk_HappyPath_InternalOnly_WithCorpID 验证 internal_only + corp_id → 200
+func TestSettingsPUT_DingTalk_HappyPath_InternalOnly_WithCorpID(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ handler, _ := newDingTalkSettingsHandler()
+
+ body := baseValidDingTalkBody()
+ body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
+ body["dingtalk_connect_internal_corp_id"] = "ding-corp-123"
+
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+// TestSettingsPUT_DingTalk_BypassRegistration_RoundTrip 验证 bypass_registration 字段 save+load。
+// 必须用 policy=internal_only:bypass 仅在该 policy 下生效,其它 policy 写入层会 coerce 为 false。
+func TestSettingsPUT_DingTalk_BypassRegistration_RoundTrip(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ handler, _ := newDingTalkSettingsHandler()
+
+ body := baseValidDingTalkBody()
+ body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
+ body["dingtalk_connect_bypass_registration"] = true
+
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, data["dingtalk_connect_bypass_registration"])
+}
+
+// TestSettingsPUT_DingTalk_Disabled_SkipsValidation 验证 disabled 时跳过 corp 校验 → 200。
+// 用 enabled=true 时必然触发"Client ID is required when enabled"的空 client_id 作为
+// 哨兵——只要 enabled=false 仍能 200 就证明跳过了。
+func TestSettingsPUT_DingTalk_Disabled_SkipsValidation(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ handler, _ := newDingTalkSettingsHandler()
+
+ body := map[string]any{
+ "dingtalk_connect_enabled": false,
+ "dingtalk_connect_client_id": "", // 这种空值在 enabled=true 时会被 400 拒绝
+ "dingtalk_connect_corp_restriction_policy": "internal_only",
+ }
+
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+// TestSettingsPUT_DingTalk_SyncFlags_InternalOnly_RoundTrip 验证三个 sync 开关在 internal_only 下可正常 save+load。
+func TestSettingsPUT_DingTalk_SyncFlags_InternalOnly_RoundTrip(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ handler, _ := newDingTalkSettingsHandler()
+
+ body := baseValidDingTalkBody()
+ body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
+ body["dingtalk_connect_sync_corp_email"] = true
+ body["dingtalk_connect_sync_display_name"] = true
+ body["dingtalk_connect_sync_dept"] = true
+
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, data["dingtalk_connect_sync_corp_email"], "sync_corp_email should be true for internal_only")
+ require.Equal(t, true, data["dingtalk_connect_sync_display_name"], "sync_display_name should be true for internal_only")
+ require.Equal(t, true, data["dingtalk_connect_sync_dept"], "sync_dept should be true for internal_only")
+}
+
+// TestSettingsPUT_DingTalk_SyncFlags_PolicyNone_CoercedToFalse 验证 policy=none 时三个 sync 开关被 coerce 为 false。
+func TestSettingsPUT_DingTalk_SyncFlags_PolicyNone_CoercedToFalse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ handler, _ := newDingTalkSettingsHandler()
+
+ body := baseValidDingTalkBody()
+ body["dingtalk_connect_corp_restriction_policy"] = "none"
+ body["dingtalk_connect_sync_corp_email"] = true
+ body["dingtalk_connect_sync_display_name"] = true
+ body["dingtalk_connect_sync_dept"] = true
+
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, false, data["dingtalk_connect_sync_corp_email"], "sync_corp_email must be coerced to false when policy=none")
+ require.Equal(t, false, data["dingtalk_connect_sync_display_name"], "sync_display_name must be coerced to false when policy=none")
+ require.Equal(t, false, data["dingtalk_connect_sync_dept"], "sync_dept must be coerced to false when policy=none")
+}
+
+// TestSettingsPUT_DingTalk_StaleWhitelist_CoercedToNone 验证升级兼容:
+// admin 直接把 corp_restriction_policy=whitelist 提交(前端 UI 已无此选项,但 API 仍可命中)
+// 不应导致 400 失败,应该被静默 coerce 为 none 后通过校验。
+func TestSettingsPUT_DingTalk_StaleWhitelist_CoercedToNone(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ handler, repo := newDingTalkSettingsHandler()
+
+ body := baseValidDingTalkBody()
+ body["dingtalk_connect_corp_restriction_policy"] = "whitelist"
+
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "none", repo.values[service.SettingKeyDingTalkConnectCorpRestrictionPolicy],
+ "stale whitelist 应在写入路径被 coerce 为 none")
+}
+
+// TestSettingsPUT_DingTalk_SyncAttrKey_RoundTrip 验证 3 个 attr key 字段 save+load + 空值 fallback 到默认值。
+func TestSettingsPUT_DingTalk_SyncAttrKey_RoundTrip(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ t.Run("custom_attr_keys_saved", func(t *testing.T) {
+ handler, repo := newDingTalkSettingsHandler()
+
+ body := baseValidDingTalkBody()
+ body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
+ body["dingtalk_connect_sync_corp_email"] = true
+ body["dingtalk_connect_sync_display_name"] = true
+ body["dingtalk_connect_sync_dept"] = true
+ body["dingtalk_connect_sync_corp_email_attr_key"] = "my_email_attr"
+ body["dingtalk_connect_sync_display_name_attr_key"] = "my_name_attr"
+ body["dingtalk_connect_sync_dept_attr_key"] = "my_dept_attr"
+
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ // 验证写入 DB 的 key
+ require.Equal(t, "my_email_attr", repo.values[service.SettingKeyDingTalkConnectSyncCorpEmailAttrKey])
+ require.Equal(t, "my_name_attr", repo.values[service.SettingKeyDingTalkConnectSyncDisplayNameAttrKey])
+ require.Equal(t, "my_dept_attr", repo.values[service.SettingKeyDingTalkConnectSyncDeptAttrKey])
+
+ // 验证响应中的 attr key
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ data, ok := resp.Data.(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "my_email_attr", data["dingtalk_connect_sync_corp_email_attr_key"])
+ require.Equal(t, "my_name_attr", data["dingtalk_connect_sync_display_name_attr_key"])
+ require.Equal(t, "my_dept_attr", data["dingtalk_connect_sync_dept_attr_key"])
+ })
+
+ t.Run("empty_attr_keys_fallback_to_defaults", func(t *testing.T) {
+ handler, repo := newDingTalkSettingsHandler()
+
+ body := baseValidDingTalkBody()
+ body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
+ // 不传 attr key → 写入层 fallback 到默认值
+ body["dingtalk_connect_sync_corp_email_attr_key"] = ""
+ body["dingtalk_connect_sync_display_name_attr_key"] = ""
+ body["dingtalk_connect_sync_dept_attr_key"] = ""
+
+ rawBody, err := json.Marshal(body)
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ handler.UpdateSettings(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ // 空值应 fallback 到默认值并持久化
+ require.Equal(t, "dingtalk_email", repo.values[service.SettingKeyDingTalkConnectSyncCorpEmailAttrKey])
+ require.Equal(t, "dingtalk_name", repo.values[service.SettingKeyDingTalkConnectSyncDisplayNameAttrKey])
+ require.Equal(t, "dingtalk_department", repo.values[service.SettingKeyDingTalkConnectSyncDeptAttrKey])
+ })
+}
diff --git a/backend/internal/handler/auth_dingtalk_client.go b/backend/internal/handler/auth_dingtalk_client.go
new file mode 100644
index 00000000..2db07d05
--- /dev/null
+++ b/backend/internal/handler/auth_dingtalk_client.go
@@ -0,0 +1,398 @@
+package handler
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log/slog"
+ "net/http"
+ "net/url"
+ "strings"
+ "sync"
+ "time"
+)
+
+// dingTalkClientConfig 是 DingTalkClient 需要的最小配置子集
+type dingTalkClientConfig struct {
+ ClientID string
+ ClientSecret string
+ TokenURL string
+ UserInfoURL string
+}
+
+type DingTalkClient struct {
+ cfg dingTalkClientConfig
+ appToken string
+ appTokenExp time.Time // 钉钉 7200s,留 200s 余量 → 7000s
+ mu sync.Mutex
+ httpClient *http.Client
+ // TODO(multi-instance): Redis 集中缓存 appToken
+}
+
+type DingTalkUserTokenResp struct {
+ AccessToken string `json:"accessToken"`
+ RefreshToken string `json:"refreshToken"`
+ ExpireIn int64 `json:"expireIn"`
+ CorpID string `json:"corpId"`
+}
+
+func (c *DingTalkClient) ExchangeCodeForUserToken(ctx context.Context, code string) (*DingTalkUserTokenResp, error) {
+ body := map[string]string{
+ "clientId": c.cfg.ClientID,
+ "clientSecret": c.cfg.ClientSecret,
+ "code": code,
+ "grantType": "authorization_code",
+ }
+ payload, _ := json.Marshal(body)
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.TokenURL, bytes.NewReader(payload))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = resp.Body.Close() }()
+ raw, _ := io.ReadAll(resp.Body)
+ if resp.StatusCode != http.StatusOK {
+ return nil, parseDingTalkErr(raw, resp.StatusCode)
+ }
+ var out DingTalkUserTokenResp
+ if err := json.Unmarshal(raw, &out); err != nil {
+ return nil, err
+ }
+ if strings.TrimSpace(out.AccessToken) == "" {
+ return nil, parseDingTalkErr(raw, resp.StatusCode)
+ }
+ return &out, nil
+}
+
+type DingTalkAPIError struct {
+ Code string
+ Message string
+ HTTP int
+}
+
+func (e *DingTalkAPIError) Error() string {
+ return fmt.Sprintf("dingtalk api error code=%s msg=%s http=%d", e.Code, e.Message, e.HTTP)
+}
+
+func parseDingTalkErr(raw []byte, status int) error {
+ var v struct {
+ Code string `json:"code"`
+ Message string `json:"message"`
+ ErrCode int `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+ }
+ _ = json.Unmarshal(raw, &v)
+ code := v.Code
+ if code == "" && v.ErrCode != 0 {
+ code = fmt.Sprintf("%d", v.ErrCode)
+ }
+ msg := v.Message
+ if msg == "" {
+ msg = v.ErrMsg
+ }
+ return &DingTalkAPIError{Code: code, Message: msg, HTTP: status}
+}
+
+// GetUnionIdByUserToken 调用 /v1.0/contact/users/me 返回 unionId 与用户自设昵称 nick。
+// nick 来自钉钉新版 OIDC 接口(用户在 App 个人资料填的昵称),与旧版 user/get.nickname 不同源。
+func (c *DingTalkClient) GetUnionIdByUserToken(ctx context.Context, userToken string) (unionID string, nick string, err error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.cfg.UserInfoURL, nil)
+ if err != nil {
+ return "", "", err
+ }
+ req.Header.Set("x-acs-dingtalk-access-token", userToken)
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return "", "", err
+ }
+ defer func() { _ = resp.Body.Close() }()
+ raw, _ := io.ReadAll(resp.Body)
+ if resp.StatusCode != http.StatusOK {
+ return "", "", parseDingTalkErr(raw, resp.StatusCode)
+ }
+ var v struct {
+ UnionID string `json:"unionId"`
+ Nick string `json:"nick"`
+ }
+ if err := json.Unmarshal(raw, &v); err != nil {
+ return "", "", err
+ }
+ if strings.TrimSpace(v.UnionID) == "" {
+ return "", "", parseDingTalkErr(raw, resp.StatusCode)
+ }
+ return v.UnionID, v.Nick, nil
+}
+
+type DingTalkStaffInfo struct {
+ UserID string
+ Name string // 企业内真实姓名(钉钉企业管理后台配置)
+ Nickname string // 钉钉个人昵称(用户自己设置)
+ Email string
+ DeptIDs []int64
+ // CorpID 不来自 staff 接口,来自 userToken;不在此 struct
+}
+
+// dingTalkOAPIBase 推导钉钉旧版 OAPI base URL(host: api.dingtalk.com → oapi.dingtalk.com)。
+// getbyunionid 与 topapi/v2/user/get 仅在旧版 OAPI 提供,不在 v1.0 OpenAPI。
+func (c *DingTalkClient) dingTalkOAPIBase() string {
+ u, err := url.Parse(c.cfg.UserInfoURL)
+ if err != nil || u.Scheme == "" || u.Host == "" {
+ return "https://oapi.dingtalk.com"
+ }
+ host := u.Host
+ if strings.HasPrefix(host, "api.") {
+ host = "oapi." + strings.TrimPrefix(host, "api.")
+ }
+ return u.Scheme + "://" + host
+}
+
+func (c *DingTalkClient) GetAppToken(ctx context.Context) (string, error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.appToken != "" && time.Now().Before(c.appTokenExp) {
+ return c.appToken, nil
+ }
+ body := map[string]string{"appKey": c.cfg.ClientID, "appSecret": c.cfg.ClientSecret}
+ payload, _ := json.Marshal(body)
+ // 钉钉新版 v1.0 企业内部应用 access_token: POST /v1.0/oauth2/accessToken
+ // 此 token 也可作为旧版 OAPI 的 access_token 使用(钉钉文档已说明)
+ appTokenURL := strings.Replace(c.cfg.TokenURL, "/oauth2/userAccessToken", "/oauth2/accessToken", 1)
+ if !strings.Contains(appTokenURL, "accessToken") && !strings.Contains(appTokenURL, "gettoken") {
+ appTokenURL = c.cfg.TokenURL // fallback for test stub
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, appTokenURL, bytes.NewReader(payload))
+ if err != nil {
+ return "", err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return "", err
+ }
+ defer func() { _ = resp.Body.Close() }()
+ raw, _ := io.ReadAll(resp.Body)
+ if resp.StatusCode != http.StatusOK {
+ return "", parseDingTalkErr(raw, resp.StatusCode)
+ }
+ var v struct {
+ AccessToken string `json:"accessToken"`
+ ExpireIn int64 `json:"expireIn"`
+ }
+ if err := json.Unmarshal(raw, &v); err != nil {
+ return "", err
+ }
+ if v.AccessToken == "" {
+ return "", parseDingTalkErr(raw, resp.StatusCode)
+ }
+ c.appToken = v.AccessToken
+ ttl := v.ExpireIn
+ if ttl > 200 {
+ ttl -= 200
+ }
+ c.appTokenExp = time.Now().Add(time.Duration(ttl) * time.Second)
+ return c.appToken, nil
+}
+
+func (c *DingTalkClient) GetUserIdByUnionId(ctx context.Context, unionID string) (string, error) {
+ appToken, err := c.GetAppToken(ctx)
+ if err != nil {
+ return "", err
+ }
+ body := map[string]string{"unionid": unionID}
+ payload, _ := json.Marshal(body)
+ // 钉钉旧版 OAPI: POST https://oapi.dingtalk.com/topapi/user/getbyunionid?access_token=XXX
+ // access_token 通过 query string 传递(不是 header)
+ var targetURL string
+ if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") {
+ targetURL = c.dingTalkOAPIBase() + "/topapi/user/getbyunionid?access_token=" + url.QueryEscape(appToken)
+ } else {
+ targetURL = c.cfg.UserInfoURL // fallback for test stub
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload))
+ if err != nil {
+ return "", err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return "", err
+ }
+ defer func() { _ = resp.Body.Close() }()
+ raw, _ := io.ReadAll(resp.Body)
+ if resp.StatusCode != http.StatusOK {
+ return "", parseDingTalkErr(raw, resp.StatusCode)
+ }
+ var v struct {
+ Result struct {
+ UserID string `json:"userid"`
+ } `json:"result"`
+ ErrCode int `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+ }
+ if err := json.Unmarshal(raw, &v); err != nil {
+ return "", err
+ }
+ if v.ErrCode != 0 {
+ return "", parseDingTalkErr(raw, resp.StatusCode)
+ }
+ if strings.TrimSpace(v.Result.UserID) == "" {
+ return "", parseDingTalkErr(raw, resp.StatusCode)
+ }
+ return v.Result.UserID, nil
+}
+
+// DingTalkDeptInfo 部门信息(topapi/v2/department/get 返回子集)
+type DingTalkDeptInfo struct {
+ DeptID int64
+ Name string
+ ParentID int64
+}
+
+// GetDeptInfo 查询单个部门信息(用于递归拼部门路径)。
+// 调用钉钉旧版 OAPI: POST /topapi/v2/department/get?access_token=XXX
+func (c *DingTalkClient) GetDeptInfo(ctx context.Context, deptID int64) (*DingTalkDeptInfo, error) {
+ appToken, err := c.GetAppToken(ctx)
+ if err != nil {
+ return nil, err
+ }
+ body := map[string]any{"dept_id": deptID, "language": "zh_CN"}
+ payload, _ := json.Marshal(body)
+ var targetURL string
+ if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") {
+ targetURL = c.dingTalkOAPIBase() + "/topapi/v2/department/get?access_token=" + url.QueryEscape(appToken)
+ } else {
+ targetURL = c.cfg.UserInfoURL // test stub fallback
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = resp.Body.Close() }()
+ raw, _ := io.ReadAll(resp.Body)
+ if resp.StatusCode != http.StatusOK {
+ return nil, parseDingTalkErr(raw, resp.StatusCode)
+ }
+ var v struct {
+ Result struct {
+ DeptID int64 `json:"dept_id"`
+ Name string `json:"name"`
+ ParentID int64 `json:"parent_id"`
+ } `json:"result"`
+ ErrCode int `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+ }
+ if err := json.Unmarshal(raw, &v); err != nil {
+ return nil, err
+ }
+ if v.ErrCode != 0 {
+ return nil, parseDingTalkErr(raw, resp.StatusCode)
+ }
+ return &DingTalkDeptInfo{
+ DeptID: v.Result.DeptID,
+ Name: v.Result.Name,
+ ParentID: v.Result.ParentID,
+ }, nil
+}
+
+func (c *DingTalkClient) GetStaffInfoByUserId(ctx context.Context, userID string) (*DingTalkStaffInfo, error) {
+ appToken, err := c.GetAppToken(ctx)
+ if err != nil {
+ return nil, err
+ }
+ body := map[string]string{"userid": userID}
+ payload, _ := json.Marshal(body)
+ // 钉钉旧版 OAPI: POST https://oapi.dingtalk.com/topapi/v2/user/get?access_token=XXX
+ var targetURL string
+ if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") {
+ targetURL = c.dingTalkOAPIBase() + "/topapi/v2/user/get?access_token=" + url.QueryEscape(appToken)
+ } else {
+ targetURL = c.cfg.UserInfoURL // fallback for test stub
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = resp.Body.Close() }()
+ raw, _ := io.ReadAll(resp.Body)
+ if resp.StatusCode != http.StatusOK {
+ return nil, parseDingTalkErr(raw, resp.StatusCode)
+ }
+ var v struct {
+ Result struct {
+ UserID string `json:"userid"`
+ Name string `json:"name"`
+ Nickname string `json:"nickname"`
+ Email string `json:"email"`
+ OrgEmail string `json:"org_email"`
+ Extension string `json:"extension"`
+ DeptID []int64 `json:"dept_id_list"`
+ } `json:"result"`
+ ErrCode int `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+ }
+ if err := json.Unmarshal(raw, &v); err != nil {
+ return nil, err
+ }
+ if v.ErrCode != 0 {
+ return nil, parseDingTalkErr(raw, resp.StatusCode)
+ }
+ if strings.TrimSpace(v.Result.UserID) == "" {
+ return nil, parseDingTalkErr(raw, resp.StatusCode)
+ }
+ // 邮箱三级 fallback:org_email > email > extension["企业邮箱"](钉钉自定义扩展字段,JSON string)
+ email := strings.TrimSpace(v.Result.OrgEmail)
+ emailSource := "org_email"
+ if email == "" {
+ email = strings.TrimSpace(v.Result.Email)
+ emailSource = "email"
+ }
+ extensionParsed := false
+ if email == "" && strings.TrimSpace(v.Result.Extension) != "" {
+ var ext map[string]string
+ if err := json.Unmarshal([]byte(v.Result.Extension), &ext); err == nil {
+ extensionParsed = true
+ if v, ok := ext["企业邮箱"]; ok {
+ email = strings.TrimSpace(v)
+ emailSource = "extension.企业邮箱"
+ }
+ }
+ }
+ if email == "" {
+ emailSource = "none"
+ }
+ slog.Info("dingtalk staff fetched",
+ "userid", v.Result.UserID,
+ "name_present", v.Result.Name != "",
+ "nickname_present", v.Result.Nickname != "",
+ "name_eq_nickname", v.Result.Name != "" && v.Result.Name == v.Result.Nickname,
+ "email_present", v.Result.Email != "",
+ "org_email_present", v.Result.OrgEmail != "",
+ "extension_present", v.Result.Extension != "",
+ "extension_parsed", extensionParsed,
+ "email_source", emailSource,
+ "dept_count", len(v.Result.DeptID),
+ )
+ return &DingTalkStaffInfo{
+ UserID: v.Result.UserID,
+ Name: v.Result.Name,
+ Nickname: v.Result.Nickname,
+ Email: email,
+ DeptIDs: v.Result.DeptID,
+ }, nil
+}
diff --git a/backend/internal/handler/auth_dingtalk_client_test.go b/backend/internal/handler/auth_dingtalk_client_test.go
new file mode 100644
index 00000000..aa2e2fdd
--- /dev/null
+++ b/backend/internal/handler/auth_dingtalk_client_test.go
@@ -0,0 +1,143 @@
+package handler
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestDingTalkClient_ExchangeCodeForUserToken_Success(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "POST", r.Method)
+ require.Equal(t, "/v1.0/oauth2/userAccessToken", r.URL.Path)
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"accessToken":"USER_TOKEN_X","expireIn":7200,"refreshToken":"R","corpId":"dingABC"}`))
+ }))
+ defer server.Close()
+
+ cli := &DingTalkClient{
+ cfg: dingTalkClientConfig{
+ ClientID: "k", ClientSecret: "s",
+ TokenURL: server.URL + "/v1.0/oauth2/userAccessToken",
+ },
+ httpClient: server.Client(),
+ }
+ resp, err := cli.ExchangeCodeForUserToken(context.Background(), "AUTH_CODE")
+ require.NoError(t, err)
+ require.Equal(t, "USER_TOKEN_X", resp.AccessToken)
+ require.Equal(t, "dingABC", resp.CorpID)
+}
+
+func TestDingTalkClient_GetUnionIdByUserToken_Success(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Equal(t, "USER_TOKEN_X", r.Header.Get("x-acs-dingtalk-access-token"))
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"nick":"张三","unionId":"UID_AAA","openId":"OPEN","avatarUrl":"http://x"}`))
+ }))
+ defer server.Close()
+
+ cli := &DingTalkClient{
+ cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/v1.0/contact/users/me"},
+ httpClient: server.Client(),
+ }
+ unionID, nick, err := cli.GetUnionIdByUserToken(context.Background(), "USER_TOKEN_X")
+ require.NoError(t, err)
+ require.Equal(t, "UID_AAA", unionID)
+ require.Equal(t, "张三", nick)
+}
+
+func TestDingTalkClient_GetAppToken_Cached(t *testing.T) {
+ callCount := 0
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ callCount++
+ _, _ = w.Write([]byte(`{"accessToken":"APP_TKN","expireIn":7200}`))
+ }))
+ defer server.Close()
+
+ cli := &DingTalkClient{
+ cfg: dingTalkClientConfig{ClientID: "k", ClientSecret: "s", TokenURL: server.URL + "/gettoken"},
+ httpClient: server.Client(),
+ }
+ t1, err := cli.GetAppToken(context.Background())
+ require.NoError(t, err)
+ t2, err := cli.GetAppToken(context.Background())
+ require.NoError(t, err)
+ require.Equal(t, t1, t2)
+ require.Equal(t, 1, callCount, "second call should hit cache")
+}
+
+func TestDingTalkClient_GetUserIdByUnionId_60011(t *testing.T) {
+ appTokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = w.Write([]byte(`{"accessToken":"APP_TKN","expireIn":7200}`))
+ }))
+ defer appTokenServer.Close()
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"errcode":60011,"errmsg":"not in directory"}`))
+ }))
+ defer server.Close()
+
+ cli := &DingTalkClient{
+ cfg: dingTalkClientConfig{TokenURL: appTokenServer.URL + "/gettoken"},
+ httpClient: server.Client(),
+ }
+ cli.appToken = "APP_TKN"
+ cli.appTokenExp = time.Now().Add(time.Hour)
+ cli.cfg.UserInfoURL = server.URL + "/v1.0/contact/users/byUnionId"
+
+ _, err := cli.GetUserIdByUnionId(context.Background(), "UID_AAA")
+ require.Error(t, err)
+ apiErr, ok := err.(*DingTalkAPIError)
+ require.True(t, ok)
+ require.Equal(t, "60011", apiErr.Code)
+}
+
+// TestDingTalkClient_GetDeptInfo_Success 验证 GetDeptInfo 正常情况返回部门信息。
+func TestDingTalkClient_GetDeptInfo_Success(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"errcode":0,"errmsg":"ok","result":{"dept_id":42,"name":"AI数据","parent_id":1}}`))
+ }))
+ defer server.Close()
+
+ cli := &DingTalkClient{
+ cfg: dingTalkClientConfig{
+ UserInfoURL: server.URL + "/stub", // 不含 /contact/users/me,走 test stub 路径
+ },
+ httpClient: server.Client(),
+ }
+ cli.appToken = "APP_TKN"
+ cli.appTokenExp = time.Now().Add(time.Hour)
+
+ info, err := cli.GetDeptInfo(context.Background(), 42)
+ require.NoError(t, err)
+ require.Equal(t, int64(42), info.DeptID)
+ require.Equal(t, "AI数据", info.Name)
+ require.Equal(t, int64(1), info.ParentID)
+}
+
+// TestDingTalkClient_GetDeptInfo_ErrCode60003 验证 errcode=60003(部门不存在)时返回错误。
+func TestDingTalkClient_GetDeptInfo_ErrCode60003(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"dept not found"}`))
+ }))
+ defer server.Close()
+
+ cli := &DingTalkClient{
+ cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"},
+ httpClient: server.Client(),
+ }
+ cli.appToken = "APP_TKN"
+ cli.appTokenExp = time.Now().Add(time.Hour)
+
+ _, err := cli.GetDeptInfo(context.Background(), 999)
+ require.Error(t, err)
+ apiErr, ok := err.(*DingTalkAPIError)
+ require.True(t, ok)
+ require.Equal(t, "60003", apiErr.Code)
+}
diff --git a/backend/internal/handler/auth_dingtalk_oauth.go b/backend/internal/handler/auth_dingtalk_oauth.go
new file mode 100644
index 00000000..a5b27dc6
--- /dev/null
+++ b/backend/internal/handler/auth_dingtalk_oauth.go
@@ -0,0 +1,1066 @@
+package handler
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ dbuser "github.com/Wei-Shaw/sub2api/ent/user"
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// dingTalkUpstreamRedirect 在 4 步链上游调用失败时记录详细错误日志并跳错误页。
+// 把钉钉 errcode/errmsg 写进 backend log + URL fragment,避免被泛 "internal error" 吞掉。
+func dingTalkUpstreamRedirect(c *gin.Context, frontendCallback, step string, err error) {
+ var apiErr *DingTalkAPIError
+ dtCode := ""
+ dtMsg := ""
+ dtHTTP := 0
+ if errors.As(err, &apiErr) {
+ dtCode = apiErr.Code
+ dtMsg = apiErr.Message
+ dtHTTP = apiErr.HTTP
+ }
+ slog.Error("dingtalk upstream call failed",
+ "step", step,
+ "dingtalk_code", dtCode,
+ "dingtalk_msg", dtMsg,
+ "http_status", dtHTTP,
+ "error", err.Error(),
+ )
+ msg := dtMsg
+ if strings.TrimSpace(msg) == "" {
+ msg = infraerrors.Message(err)
+ }
+ if strings.TrimSpace(dtCode) != "" {
+ msg = "dingtalk[" + dtCode + "] " + msg
+ }
+ redirectOAuthError(c, frontendCallback, mapDingTalkErrorCode(err), msg, "")
+}
+
+// ─── 常量 ──────────────────────────────────────────────────────────────────
+
+const (
+ dingTalkOAuthCookiePath = "/api/v1/auth/oauth/dingtalk"
+ dingTalkOAuthStateCookieName = "dingtalk_oauth_state"
+ dingTalkOAuthRedirectCookie = "dingtalk_oauth_redirect"
+ dingTalkOAuthIntentCookieName = "dingtalk_oauth_intent"
+ dingTalkOAuthBindUserCookieName = "dingtalk_oauth_bind_user"
+ dingTalkOAuthCookieMaxAgeSec = 600 // 10 分钟
+ dingTalkOAuthDefaultRedirectTo = "/dashboard"
+ dingTalkOAuthDefaultFrontendCB = "/auth/dingtalk/callback"
+
+ dingTalkLevelThreeEnabled = true
+)
+
+// ─── Config helper ─────────────────────────────────────────────────────────
+
+// getDingTalkOAuthConfig 返回 DingTalk OAuth 最终生效配置。
+// 优先从 settingSvc(settings 表)读取,回退到 h.cfg.DingTalk。
+func (h *AuthHandler) getDingTalkOAuthConfig(ctx context.Context) (config.DingTalkConnectConfig, error) {
+ if h != nil && h.settingSvc != nil {
+ return h.settingSvc.GetDingTalkConnectOAuthConfig(ctx)
+ }
+ if h == nil || h.cfg == nil {
+ return config.DingTalkConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
+ }
+ if !h.cfg.DingTalk.Enabled {
+ return config.DingTalkConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "dingtalk oauth login is disabled")
+ }
+ return h.cfg.DingTalk, nil
+}
+
+// ─── Cookie helpers(使用 dingtalk path)─────────────────────────────────
+
+func setDingTalkCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: value,
+ Path: dingTalkOAuthCookiePath,
+ MaxAge: maxAgeSec,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+func clearDingTalkCookie(c *gin.Context, name string, secure bool) {
+ http.SetCookie(c.Writer, &http.Cookie{
+ Name: name,
+ Value: "",
+ Path: dingTalkOAuthCookiePath,
+ MaxAge: -1,
+ HttpOnly: true,
+ Secure: secure,
+ SameSite: http.SameSiteLaxMode,
+ })
+}
+
+// ─── DingTalkOAuthStart ────────────────────────────────────────────────────
+
+// DingTalkOAuthStart 启动 DingTalk Connect OAuth 登录流程。
+// GET /api/v1/auth/oauth/dingtalk/start?redirect=/dashboard&intent=login
+func (h *AuthHandler) DingTalkOAuthStart(c *gin.Context) {
+ cfg, err := h.getDingTalkOAuthConfig(c.Request.Context())
+ if err != nil {
+ frontendCB := dingTalkOAuthDefaultFrontendCB
+ redirectOAuthError(c, frontendCB, "dingtalk_not_enabled", "", "")
+ return
+ }
+
+ state, err := oauth.GenerateState()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err))
+ return
+ }
+
+ redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect"))
+ if redirectTo == "" {
+ redirectTo = dingTalkOAuthDefaultRedirectTo
+ }
+
+ browserSessionKey, err := generateOAuthPendingBrowserSession()
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err))
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ setDingTalkCookie(c, dingTalkOAuthStateCookieName, encodeCookieValue(state), dingTalkOAuthCookieMaxAgeSec, secureCookie)
+ setDingTalkCookie(c, dingTalkOAuthRedirectCookie, encodeCookieValue(redirectTo), dingTalkOAuthCookieMaxAgeSec, secureCookie)
+
+ intent := normalizeOAuthIntent(c.Query("intent"))
+ setDingTalkCookie(c, dingTalkOAuthIntentCookieName, encodeCookieValue(intent), dingTalkOAuthCookieMaxAgeSec, secureCookie)
+
+ setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+
+ if intent == oauthIntentBindCurrentUser {
+ bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ setDingTalkCookie(c, dingTalkOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), dingTalkOAuthCookieMaxAgeSec, secureCookie)
+ } else {
+ clearDingTalkCookie(c, dingTalkOAuthBindUserCookieName, secureCookie)
+ }
+
+ authURL, err := buildDingTalkAuthorizeURL(cfg, state)
+ if err != nil {
+ response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build dingtalk authorization url").WithCause(err))
+ return
+ }
+
+ c.Redirect(http.StatusFound, authURL)
+}
+
+// ─── buildDingTalkAuthorizeURL ─────────────────────────────────────────────
+
+// ─── findDingTalkCompatEmailUser ───────────────────────────────────────────
+
+// findDingTalkCompatEmailUser 通过真实邮箱查找可与 DingTalk 账号兼容绑定的现有用户。
+func (h *AuthHandler) findDingTalkCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) {
+ if !dingTalkLevelThreeEnabled {
+ return nil, nil
+ }
+
+ client := h.entClient()
+ if client == nil {
+ return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")
+ }
+
+ email = strings.TrimSpace(strings.ToLower(email))
+ if email == "" ||
+ strings.HasSuffix(email, service.DingTalkConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
+ return nil, nil
+ }
+
+ userEntities, err := client.User.Query().
+ Where(userNormalizedEmailPredicate(email)).
+ Order(dbent.Asc(dbuser.FieldID)).
+ All(ctx)
+ if err != nil {
+ return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err)
+ }
+ switch len(userEntities) {
+ case 0:
+ return nil, nil
+ case 1:
+ return userEntities[0], nil
+ default:
+ return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users")
+ }
+}
+
+// ─── createDingTalkOAuthChoicePendingSession ───────────────────────────────
+
+// createDingTalkOAuthChoicePendingSession 创建 DingTalk OAuth 三方注册/绑定的 choice pending session。
+// signupBlocked=true 时关闭"创建新账户"出口;若同时没有 compat email 匹配的已有账户,
+// 直接把 step 切到 bind_login_required,避免前端展示一个没有实际可点选项的 choice 界面。
+func (h *AuthHandler) createDingTalkOAuthChoicePendingSession(
+ c *gin.Context,
+ identity service.PendingAuthIdentityKey,
+ suggestedEmail string,
+ resolvedEmail string,
+ redirectTo string,
+ browserSessionKey string,
+ upstreamClaims map[string]any,
+ compatEmail string,
+ compatEmailUser *dbent.User,
+ forceEmailOnSignup bool,
+ signupBlocked bool,
+) error {
+ suggestionEmail := strings.TrimSpace(suggestedEmail)
+ canonicalEmail := strings.TrimSpace(resolvedEmail)
+ if suggestionEmail == "" {
+ suggestionEmail = canonicalEmail
+ }
+
+ completionResponse := map[string]any{
+ "step": oauthPendingChoiceStep,
+ "adoption_required": true,
+ "redirect": strings.TrimSpace(redirectTo),
+ "email": suggestionEmail,
+ "resolved_email": canonicalEmail,
+ "existing_account_email": "",
+ "existing_account_bindable": false,
+ "create_account_allowed": !signupBlocked,
+ "force_email_on_signup": forceEmailOnSignup,
+ "choice_reason": "third_party_signup",
+ }
+ if strings.TrimSpace(compatEmail) != "" {
+ completionResponse["compat_email"] = strings.TrimSpace(compatEmail)
+ }
+ resolvedChoiceEmail := suggestionEmail
+ if compatEmailUser != nil {
+ completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email)
+ completionResponse["existing_account_bindable"] = true
+ completionResponse["choice_reason"] = "compat_email_match"
+ resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email)
+ }
+ if forceEmailOnSignup && compatEmailUser == nil {
+ completionResponse["choice_reason"] = "force_email_on_signup"
+ }
+ // 注册被拦:无论是否匹配到 compat email user,都跳过 choice,直接进 bind_login。
+ // "开放注册" 关闭 且 "钉钉企业模式豁免" 也关闭时,唯一合法出口是绑定已有账户,
+ // 不应该让用户看到"创建新账户"按钮;compat user 命中只是让 bind_login 的邮箱字段预填得更准。
+ if signupBlocked {
+ completionResponse["step"] = "bind_login_required"
+ completionResponse["existing_account_bindable"] = true
+ completionResponse["choice_reason"] = "signup_blocked_redirect_to_bind"
+ }
+
+ var targetUserID *int64
+ if compatEmailUser != nil && compatEmailUser.ID > 0 {
+ targetUserID = &compatEmailUser.ID
+ }
+
+ return h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin,
+ Identity: identity,
+ TargetUserID: targetUserID,
+ ResolvedEmail: resolvedChoiceEmail,
+ RedirectTo: redirectTo,
+ BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: completionResponse,
+ })
+}
+
+// ─── DingTalkOAuthCallback ─────────────────────────────────────────────────
+
+// DingTalkOAuthCallback 处理钉钉授权回调。
+// GET /api/v1/auth/oauth/dingtalk/callback?code=...&state=...
+func (h *AuthHandler) DingTalkOAuthCallback(c *gin.Context) {
+ cfg, cfgErr := h.getDingTalkOAuthConfig(c.Request.Context())
+ if cfgErr != nil {
+ response.ErrorFrom(c, cfgErr)
+ return
+ }
+
+ frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL)
+ if frontendCallback == "" {
+ frontendCallback = dingTalkOAuthDefaultFrontendCB
+ }
+
+ if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" {
+ redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description"))
+ return
+ }
+
+ code := strings.TrimSpace(c.Query("code"))
+ state := strings.TrimSpace(c.Query("state"))
+ if code == "" || state == "" {
+ redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "")
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ defer func() {
+ clearDingTalkCookie(c, dingTalkOAuthStateCookieName, secureCookie)
+ clearDingTalkCookie(c, dingTalkOAuthRedirectCookie, secureCookie)
+ clearDingTalkCookie(c, dingTalkOAuthIntentCookieName, secureCookie)
+ }()
+
+ expectedState, err := readCookieDecoded(c, dingTalkOAuthStateCookieName)
+ if err != nil || state != expectedState {
+ redirectOAuthError(c, frontendCallback, "csrf", "state mismatch", "")
+ return
+ }
+ redirectTo, _ := readCookieDecoded(c, dingTalkOAuthRedirectCookie)
+ intent, _ := readCookieDecoded(c, dingTalkOAuthIntentCookieName)
+ intent = normalizeOAuthIntent(intent)
+ browserSessionKey, _ := readOAuthPendingBrowserCookie(c)
+ if strings.TrimSpace(browserSessionKey) == "" {
+ redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing browser session cookie", "")
+ return
+ }
+ forceEmailOnSignup := h.isForceEmailOnThirdPartySignup(c.Request.Context())
+
+ // ─── 4 步链(Step 1 + Step 2 必须;Step 3/4 按需 + 跨组织降级)───
+ client := h.dingTalkClient(cfg)
+ userToken, err := client.ExchangeCodeForUserToken(c.Request.Context(), code)
+ if err != nil {
+ dingTalkUpstreamRedirect(c, frontendCallback, "exchange_code", err)
+ return
+ }
+
+ // D: corp 校验提前到 Step 1 之后、Step 2 之前,减少不必要的上游调用
+ corpID := strings.TrimSpace(userToken.CorpID)
+ if !checkDingTalkCorpAllowed(cfg, corpID) {
+ // 不在 URL 中透传 corpID,避免内部企业标识泄露给前端
+ redirectOAuthError(c, frontendCallback, "corp_rejected", "", "")
+ return
+ }
+
+ // Step 2: 必须 — UnionID 是全局唯一,作为 subject + 合成邮箱种子;nick 是用户在 App 自设的昵称
+ unionID, oauthNick, err := client.GetUnionIdByUserToken(c.Request.Context(), userToken.AccessToken)
+ if err != nil {
+ dingTalkUpstreamRedirect(c, frontendCallback, "get_union_id", err)
+ return
+ }
+
+ identityKey := service.PendingAuthIdentityKey{ProviderType: "dingtalk", ProviderKey: "dingtalk", ProviderSubject: unionID}
+
+ // Step 3/4 调用策略由 policy 决定,与 require_email 解耦。
+ // policy=internal_only → 必须成功(hard fail),因为 AppType=internal 已保证用户在应用企业。
+ // policy=none / "" → 尝试,失败降级(公网场景跨组织用户属正常预期)。
+ // require_email 只影响 Step 3/4 结果后的邮箱处理路径,不影响是否调用。
+ var staff *DingTalkStaffInfo
+ switch cfg.CorpRestrictionPolicy {
+ case "internal_only":
+ // AppType=internal 已保证用户在应用企业,Step 3/4 必须成功。
+ // 失败 = 钉钉 OAPI 故障或应用配置错误,应 hard fail。
+ upstreamUserID, errStep3 := client.GetUserIdByUnionId(c.Request.Context(), unionID)
+ if errStep3 != nil {
+ dingTalkUpstreamRedirect(c, frontendCallback, "get_user_id", errStep3)
+ return
+ }
+ staffInfo, errStep4 := client.GetStaffInfoByUserId(c.Request.Context(), upstreamUserID)
+ if errStep4 != nil {
+ dingTalkUpstreamRedirect(c, frontendCallback, "get_staff_info", errStep4)
+ return
+ }
+ staff = staffInfo
+
+ default: // "none" or ""
+ // 公网登录,跨组织用户 Step 3/4 可能失败(设计预期),尝试调用,失败降级。
+ // 即使 require_email=false 也尝试拿 name(用于 upstreamClaims.username),失败就空着。
+ upstreamUserID, errStep3 := client.GetUserIdByUnionId(c.Request.Context(), unionID)
+ if errStep3 != nil {
+ slog.Debug("dingtalk step3 fallback (none/cross-org)",
+ "corp_id", corpID, "union_id", unionID, "err", errStep3.Error())
+ staff = &DingTalkStaffInfo{}
+ break
+ }
+ staffInfo, errStep4 := client.GetStaffInfoByUserId(c.Request.Context(), upstreamUserID)
+ if errStep4 != nil {
+ slog.Debug("dingtalk step4 fallback (none/cross-org)",
+ "corp_id", corpID, "union_id", unionID, "err", errStep4.Error())
+ staff = &DingTalkStaffInfo{}
+ break
+ }
+ staff = staffInfo
+ }
+
+ // nick 来自 OIDC /contact/users/me,优先作为钉钉昵称(user/get.nickname 多数为空)。
+ if staff != nil && strings.TrimSpace(oauthNick) != "" {
+ staff.Nickname = strings.TrimSpace(oauthNick)
+ }
+
+ upstreamClaims := buildDingTalkUpstreamClaims(staff, unionID, corpID)
+
+ // ─── S1 主动绑定分支(PR-3 才走到这里)───
+ if intent == oauthIntentBindCurrentUser {
+ targetUserID, err := h.readOAuthBindUserIDFromCookie(c, dingTalkOAuthBindUserCookieName)
+ if err != nil {
+ redirectOAuthError(c, frontendCallback, "invalid_state", "invalid bind user cookie", "")
+ return
+ }
+ // policy=none 跨组织用户绑定时 staff.Email="",用合成邮箱占位(用于 audit log,不用于注册)
+ bindResolvedEmail := staff.Email
+ if bindResolvedEmail == "" {
+ bindResolvedEmail = buildDingTalkSyntheticEmail(unionID)
+ }
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentBindCurrentUser, Identity: identityKey,
+ TargetUserID: &targetUserID, ResolvedEmail: bindResolvedEmail,
+ RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{"redirect": redirectTo},
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ clearDingTalkCookie(c, dingTalkOAuthBindUserCookieName, secureCookie)
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ // ─── Level 1:auth_identities hit ───
+ if existing, _ := h.findOAuthIdentityUser(c.Request.Context(), identityKey); existing != nil {
+ // 身份同步:已登录用户,直接同步(user_id 已知)。
+ // 异步执行避免上游钉钉接口(GetStaffInfoByUserId / 部门递归)阻塞登录跳转。
+ runDingTalkSyncAsync(c.Request.Context(), func(ctx context.Context) {
+ h.syncDingTalkIdentity(ctx, cfg, client, existing.ID, staff, false)
+ })
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin, Identity: identityKey, TargetUserID: &existing.ID,
+ ResolvedEmail: existing.Email, RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{"redirect": redirectTo},
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ signupBlocked := h.isDingTalkSignupBlocked(c.Request.Context(), cfg)
+
+ // ─── 非命中:require_email=false 走 synthetic email 直接登录 ───
+ if !cfg.RequireEmail {
+ if signupBlocked {
+ // 注册被拦 + 无邮箱可输:唯一出路是绑定已有账户
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin, Identity: identityKey, TargetUserID: nil,
+ ResolvedEmail: "", RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: dingTalkBindLoginCompletionResponse(redirectTo),
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+ syntheticEmail := buildDingTalkSyntheticEmail(unionID)
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin, Identity: identityKey, TargetUserID: nil,
+ ResolvedEmail: syntheticEmail, RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: map[string]any{"redirect": redirectTo, "synthetic_email": syntheticEmail},
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ // ─── require_email=true 且 staff.Email 空 → 补邮箱(默认)或直接 bind_login(注册被拦时) ───
+ if staff.Email == "" {
+ completionResponse := map[string]any{
+ "step": "email_completion",
+ "requires_email_completion": true,
+ "redirect": redirectTo,
+ }
+ if signupBlocked {
+ // 注册被全局关闭且未豁免:跳过补邮箱页,直接进 bind_login 让用户输入已有账户
+ completionResponse = dingTalkBindLoginCompletionResponse(redirectTo)
+ }
+ if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{
+ Intent: oauthIntentLogin, Identity: identityKey, TargetUserID: nil,
+ ResolvedEmail: "", RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey,
+ UpstreamIdentityClaims: upstreamClaims,
+ CompletionResponse: completionResponse,
+ }); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+ return
+ }
+
+ // ─── L3/L4 有邮箱:统一 choice pending session ───
+ var compatEmailUser *dbent.User
+ if dingTalkLevelThreeEnabled && staff.Email != "" {
+ compatEmailUser, _ = h.findDingTalkCompatEmailUser(c.Request.Context(), staff.Email)
+ }
+ if err := h.createDingTalkOAuthChoicePendingSession(
+ c, identityKey, staff.Email, staff.Email,
+ redirectTo, browserSessionKey, upstreamClaims,
+ staff.Email, compatEmailUser, forceEmailOnSignup,
+ signupBlocked,
+ ); err != nil {
+ redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err))
+ return
+ }
+ redirectToFrontendCallback(c, frontendCallback)
+}
+
+func buildDingTalkSyntheticEmail(userID string) string {
+ return "dingtalk-" + strings.ToLower(strings.TrimSpace(userID)) + service.DingTalkConnectSyntheticEmailDomain
+}
+
+// isDingTalkSignupBlocked 当注册总开关关闭且未开启钉钉企业模式豁免
+// (policy=internal_only + dingtalk_connect_bypass_registration=true)时返回 true。
+// 镜像 service.AuthService.canBypassRegistrationDisabledForOAuth 用于 OAuth callback
+// 早期路由决策:注册被拦 → 跳过补邮箱页直接进 bind_login,避免用户填完表单才报错。
+func (h *AuthHandler) isDingTalkSignupBlocked(ctx context.Context, cfg config.DingTalkConnectConfig) bool {
+ if h.settingSvc == nil {
+ return false
+ }
+ if h.settingSvc.IsRegistrationEnabled(ctx) {
+ return false
+ }
+ if cfg.BypassRegistration && cfg.CorpRestrictionPolicy == "internal_only" {
+ return false
+ }
+ return true
+}
+
+func dingTalkBindLoginCompletionResponse(redirectTo string) map[string]any {
+ return map[string]any{
+ "step": "bind_login_required",
+ "existing_account_bindable": true,
+ "create_account_allowed": false,
+ "redirect": redirectTo,
+ }
+}
+
+func buildDingTalkUpstreamClaims(staff *DingTalkStaffInfo, unionID, corpID string) map[string]any {
+ primaryDeptID := int64(0)
+ if len(staff.DeptIDs) > 0 {
+ primaryDeptID = staff.DeptIDs[0]
+ }
+ return map[string]any{
+ "email": staff.Email,
+ "username": staff.Name,
+ "nickname": staff.Nickname,
+ "subject": unionID, // 与 identityKey.ProviderSubject 保持一致(全局唯一 unionID)
+ "corp_user_id": staff.UserID, // 企业 userid(跨组织时为空),保留作独立字段用于 audit
+ "union_id": unionID,
+ "corp_id": corpID,
+ "primary_dept_id": primaryDeptID, // 首个部门 ID,用于 internal_only 同步路径
+ }
+}
+
+func checkDingTalkCorpAllowed(cfg config.DingTalkConnectConfig, corpID string) bool {
+ switch cfg.CorpRestrictionPolicy {
+ case "internal_only":
+ // 方案 A:完全跳过 corpID 字段校验,由 step 3 `GetUserIdByUnionId` 做真实判定。
+ // 原因:钉钉 /v1.0/oauth2/userAccessToken 在部分授权场景(扫码登录、非企业工作台入口)
+ // 不会返回 corpId 字段。而 step 3 用本企业 appToken 查 unionId→userId 映射,
+ // 跨企业用户会被钉钉拒绝(错误码 60011/60121),mapDingTalkErrorCode 已将其映射回 "corp_rejected"。
+ // AppType=internal 已由 ValidateDingTalkConfig 强制保证应用属性。
+ return true
+ case "none", "":
+ return true
+ default:
+ return false
+ }
+}
+
+// decideDingTalkStep34Strategy 根据 policy 和 Step 3/4 运行时错误决定处理方式。
+// 返回 (proceed bool, fatal bool):
+// - proceed=true:继续处理(step 成功或降级)
+// - fatal=true:应 hard fail(upstream_error)
+//
+// 此 helper 从主链中提取,便于 unit test 独立验证策略决策逻辑。
+func decideDingTalkStep34Strategy(policy string, stepErr error) (shouldFallback bool, isFatal bool) {
+ if stepErr == nil {
+ return false, false // 成功,不需要降级
+ }
+ switch policy {
+ case "internal_only":
+ return false, true // hard fail:同企业 Step 3/4 必须成功
+ case "none", "":
+ return true, false // 降级:公网场景跨组织用户失败属正常预期
+ default:
+ return false, true // 未知 policy,视为 hard fail
+ }
+}
+
+// mapDingTalkErrorCode 把 DingTalkAPIError 映射到 redirectOAuthError 用的字符串 code
+func mapDingTalkErrorCode(err error) string {
+ var apiErr *DingTalkAPIError
+ if !errors.As(err, &apiErr) {
+ return "upstream_error"
+ }
+ switch apiErr.Code {
+ case "60011", "60121":
+ return "corp_rejected"
+ case "40014", "50015", "88":
+ return "upstream_error"
+ default:
+ return "upstream_error"
+ }
+}
+
+// dingTalkClient 构造或返回缓存的 client 实例(h-level 单例)。
+// 若 cfg 关键字段(ClientID/ClientSecret/TokenURL/UserInfoURL)与已缓存实例不一致,
+// 则丢弃旧实例(含 appToken 缓存)并重建,避免管理员改配置后旧凭据持续生效。
+func (h *AuthHandler) dingTalkClient(cfg config.DingTalkConnectConfig) *DingTalkClient {
+ h.dingTalkClientMu.Lock()
+ defer h.dingTalkClientMu.Unlock()
+ newCfg := dingTalkClientConfig{
+ ClientID: cfg.ClientID,
+ ClientSecret: cfg.ClientSecret,
+ TokenURL: cfg.TokenURL,
+ UserInfoURL: cfg.UserInfoURL,
+ }
+ if h.dingTalkClientInstance == nil || h.dingTalkClientInstance.cfg != newCfg {
+ h.dingTalkClientInstance = &DingTalkClient{
+ cfg: newCfg,
+ // 与 wechat OAuth client 对齐,避免上游网络抖动时请求悬挂。
+ httpClient: &http.Client{Timeout: 10 * time.Second},
+ }
+ }
+ return h.dingTalkClientInstance
+}
+
+// ─── buildDingTalkAuthorizeURL ─────────────────────────────────────────────
+
+// buildDingTalkAuthorizeURL 根据配置和 state 构建钉钉 OAuth 授权 URL。
+func buildDingTalkAuthorizeURL(cfg config.DingTalkConnectConfig, state string) (string, error) {
+ base := strings.TrimSpace(cfg.AuthorizeURL)
+ if base == "" {
+ return "", infraerrors.InternalServer("DINGTALK_AUTHORIZE_URL_EMPTY", "dingtalk authorize_url not configured")
+ }
+ redirectURI := strings.TrimSpace(cfg.RedirectURL)
+ if redirectURI == "" {
+ return "", infraerrors.InternalServer("DINGTALK_REDIRECT_URL_EMPTY", "dingtalk redirect_url not configured")
+ }
+
+ u, err := url.Parse(base)
+ if err != nil {
+ return "", infraerrors.InternalServer("DINGTALK_AUTHORIZE_URL_PARSE_FAILED", "failed to parse dingtalk authorize_url").WithCause(err)
+ }
+
+ scopes := strings.TrimSpace(cfg.Scopes)
+ if scopes == "" {
+ scopes = "openid"
+ }
+
+ q := u.Query()
+ q.Set("client_id", cfg.ClientID)
+ q.Set("redirect_uri", redirectURI)
+ q.Set("response_type", "code")
+ q.Set("scope", scopes)
+ q.Set("state", state)
+ q.Set("prompt", "consent")
+ u.RawQuery = q.Encode()
+
+ return u.String(), nil
+}
+
+// ─── Complete Registration ─────────────────────────────────────────────────
+
+type completeDingTalkOAuthRequest struct {
+ InvitationCode string `json:"invitation_code" binding:"required"`
+ AffCode string `json:"aff_code,omitempty"`
+ AdoptDisplayName *bool `json:"adopt_display_name,omitempty"`
+ AdoptAvatar *bool `json:"adopt_avatar,omitempty"`
+}
+
+// CompleteDingTalkOAuthRegistration completes a pending OAuth registration by validating
+// the invitation code and creating the user account.
+// POST /api/v1/auth/oauth/dingtalk/complete-registration
+func (h *AuthHandler) CompleteDingTalkOAuthRegistration(c *gin.Context) {
+ var req completeDingTalkOAuthRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()})
+ return
+ }
+
+ secureCookie := isRequestHTTPS(c)
+ sessionToken, err := readOAuthPendingSessionCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound)
+ return
+ }
+ browserSessionKey, err := readOAuthPendingBrowserCookie(c)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch)
+ return
+ }
+ pendingSvc, err := h.pendingIdentityService()
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey)
+ if err != nil {
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ } else if handled {
+ c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession))
+ return
+ } else {
+ session = updatedSession
+ }
+ if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ email := strings.TrimSpace(session.ResolvedEmail)
+ username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username")
+ // E: username 空时退到 email local part(跨组织用户没拿到 staff.Name 也能注册)
+ if username == "" {
+ if at := strings.Index(email, "@"); at > 0 {
+ username = email[:at]
+ } else {
+ username = email
+ }
+ }
+ if email == "" || username == "" {
+ response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid"))
+ return
+ }
+
+ client := h.entClient()
+ if client == nil {
+ response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready"))
+ return
+ }
+ if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil {
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+ decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{
+ AdoptDisplayName: req.AdoptDisplayName,
+ AdoptAvatar: req.AdoptAvatar,
+ })
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "dingtalk")
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil {
+ respondPendingOAuthBindingApplyError(c, err)
+ return
+ }
+ // 新用户注册完成后执行身份同步(user_id 现在已知)。
+ // 异步执行避免阻塞 token 响应。
+ if completionCfg, cfgErr := h.getDingTalkOAuthConfig(c.Request.Context()); cfgErr == nil {
+ dtClient := h.dingTalkClient(completionCfg)
+ claims := session.UpstreamIdentityClaims
+ runDingTalkSyncAsync(c.Request.Context(), func(ctx context.Context) {
+ h.syncDingTalkIdentityFromClaims(ctx, completionCfg, dtClient, user.ID, claims, true)
+ })
+ }
+ h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ clearOAuthPendingSessionCookie(c, secureCookie)
+ clearOAuthPendingBrowserCookie(c, secureCookie)
+
+ c.JSON(http.StatusOK, gin.H{
+ "access_token": tokenPair.AccessToken,
+ "refresh_token": tokenPair.RefreshToken,
+ "expires_in": tokenPair.ExpiresIn,
+ "token_type": "Bearer",
+ })
+}
+
+// CreateDingTalkOAuthAccount creates a new user account from a pending DingTalk OAuth session.
+// POST /api/v1/auth/oauth/dingtalk/create-account
+func (h *AuthHandler) CreateDingTalkOAuthAccount(c *gin.Context) {
+ h.createPendingOAuthAccount(c, "dingtalk")
+}
+
+// BindDingTalkOAuthLogin 处理已有账户绑定钉钉 OAuth 登录。
+// POST /api/v1/auth/oauth/dingtalk/bind-login
+func (h *AuthHandler) BindDingTalkOAuthLogin(c *gin.Context) {
+ h.bindPendingOAuthLogin(c, "dingtalk")
+}
+
+// ─── DingTalk 身份同步 ─────────────────────────────────────────────────────
+
+// runDingTalkSyncAsync 在后台 goroutine 执行钉钉身份同步,避免阻塞登录响应。
+// 与请求 ctx 解耦(handler 返回后会被取消),但保留其 values(trace/request id)。
+// 固定 30s 超时上限,防止 goroutine 因上游卡顿无限挂起。
+func runDingTalkSyncAsync(parent context.Context, fn func(ctx context.Context)) {
+ base := context.WithoutCancel(parent)
+ go func() {
+ defer func() {
+ if r := recover(); r != nil {
+ slog.Error("dingtalk sync: panic recovered", "panic", r)
+ }
+ }()
+ ctx, cancel := context.WithTimeout(base, 30*time.Second)
+ defer cancel()
+ fn(ctx)
+ }()
+}
+
+// syncDingTalkIdentity 在 internal_only 模式下,按三个 sync 开关把钉钉身份信息
+// 同步到用户属性表(以及 users.username)。
+// 任何错误仅记日志,不中断登录流程(最终一致性)。
+func (h *AuthHandler) syncDingTalkIdentity(ctx context.Context, cfg config.DingTalkConnectConfig, client *DingTalkClient, userID int64, staff *DingTalkStaffInfo, syncUsername bool) {
+ slog.Info("dingtalk sync: entry",
+ "user_id", userID,
+ "policy", cfg.CorpRestrictionPolicy,
+ "sync_corp_email", cfg.SyncCorpEmail,
+ "sync_display_name", cfg.SyncDisplayName,
+ "sync_dept", cfg.SyncDept,
+ "sync_username", syncUsername,
+ "attr_key_email", cfg.SyncCorpEmailAttrKey,
+ "attr_key_name", cfg.SyncDisplayNameAttrKey,
+ "attr_key_dept", cfg.SyncDeptAttrKey,
+ "staff_nil", staff == nil,
+ )
+ if cfg.CorpRestrictionPolicy != "internal_only" || staff == nil {
+ slog.Info("dingtalk sync: skip, not internal_only or staff nil")
+ return
+ }
+ slog.Info("dingtalk sync: staff snapshot",
+ "name", staff.Name, "email", staff.Email, "dept_ids", staff.DeptIDs,
+ )
+ if !cfg.SyncCorpEmail && !cfg.SyncDisplayName && !cfg.SyncDept {
+ slog.Info("dingtalk sync: skip, all flags disabled")
+ return
+ }
+ if h.userAttributeService == nil {
+ slog.Warn("dingtalk sync: userAttributeService not available, skipping")
+ return
+ }
+
+ // 仅首次注册时覆盖 users.username(避免每次登录覆盖用户后续手动改过的名字)。
+ // dingtalk_name 属性下面单独每次写入企业 name,不受此条件影响。
+ if syncUsername && cfg.SyncDisplayName {
+ username := strings.TrimSpace(staff.Nickname)
+ source := "nickname"
+ if username == "" {
+ username = strings.TrimSpace(staff.Name)
+ source = "name(fallback)"
+ }
+ if username != "" && h.userService != nil {
+ if _, err := h.userService.UpdateProfile(ctx, userID, service.UpdateProfileRequest{Username: &username}); err != nil {
+ slog.Warn("dingtalk sync: failed to update username", "user_id", userID, "err", err)
+ } else {
+ slog.Info("dingtalk sync: username updated (register)", "user_id", userID, "username", username, "source", source)
+ }
+ }
+ }
+
+ // 属性同步(目标 attr key 从 cfg 读取,默认值由 GetDingTalkConnectOAuthConfig 保证非空)
+ type syncField struct {
+ key string
+ value string
+ }
+ var fields []syncField
+
+ if cfg.SyncDisplayName && strings.TrimSpace(staff.Name) != "" {
+ fields = append(fields, syncField{cfg.SyncDisplayNameAttrKey, strings.TrimSpace(staff.Name)})
+ }
+ if cfg.SyncCorpEmail && strings.TrimSpace(staff.Email) != "" {
+ fields = append(fields, syncField{cfg.SyncCorpEmailAttrKey, strings.TrimSpace(staff.Email)})
+ }
+ if cfg.SyncDept && len(staff.DeptIDs) > 0 {
+ // 跳过根部门 ID=1,找第一个真实子部门;都是根则保留 1(最终写入空字符串覆盖旧值)。
+ primaryDeptID := int64(0)
+ for _, id := range staff.DeptIDs {
+ if id > 1 {
+ primaryDeptID = id
+ break
+ }
+ }
+ if primaryDeptID == 0 {
+ primaryDeptID = staff.DeptIDs[0]
+ }
+ slog.Info("dingtalk sync: pick primary dept", "user_id", userID, "all_dept_ids", staff.DeptIDs, "primary", primaryDeptID)
+ path, err := h.resolveDingTalkDeptPath(ctx, client, primaryDeptID)
+ if err != nil {
+ slog.Warn("dingtalk sync: failed to resolve dept path", "user_id", userID, "dept_id", primaryDeptID, "err", err)
+ } else {
+ // path="" 表示公司直属(仅在根部门下),仍写入空串覆盖旧值。
+ fields = append(fields, syncField{cfg.SyncDeptAttrKey, path})
+ }
+ }
+
+ if len(fields) == 0 {
+ return
+ }
+
+ // 逐 key 查 definition 并 upsert
+ for _, f := range fields {
+ if err := h.setUserAttributeByKey(ctx, userID, f.key, f.value); err != nil {
+ slog.Warn("dingtalk sync: failed to set attribute", "user_id", userID, "key", f.key, "err", err)
+ }
+ }
+}
+
+// syncDingTalkIdentityFromClaims 从 upstreamClaims 恢复 DingTalkStaffInfo 并调用 syncDingTalkIdentity。
+// 用于 pending session 完成阶段(complete-registration / create-account / bind-login)。
+// syncUsername=true 表示首次注册场景,需要把 nickname 写入 users.username。
+func (h *AuthHandler) syncDingTalkIdentityFromClaims(ctx context.Context, cfg config.DingTalkConnectConfig, client *DingTalkClient, userID int64, claims map[string]any, syncUsername bool) {
+ staff := dingTalkStaffFromClaims(claims)
+ h.syncDingTalkIdentity(ctx, cfg, client, userID, staff, syncUsername)
+}
+
+// maybeSyncDingTalkAfterRegistration 在通用 OAuth 注册路径完成后调用。
+// 同步 4 个字段:users.username(首次) + dingtalk_name/email/department(每次)。
+func (h *AuthHandler) maybeSyncDingTalkAfterRegistration(ctx context.Context, session *dbent.PendingAuthSession, userID int64) {
+ h.dispatchDingTalkPendingSync(ctx, session, userID, true)
+}
+
+// maybeSyncDingTalkAfterLogin 在通用 OAuth 登录/绑定路径完成后调用。
+// 仅刷新 3 个属性(dingtalk_name/email/department),不动 users.username。
+func (h *AuthHandler) maybeSyncDingTalkAfterLogin(ctx context.Context, session *dbent.PendingAuthSession, userID int64) {
+ h.dispatchDingTalkPendingSync(ctx, session, userID, false)
+}
+
+func (h *AuthHandler) dispatchDingTalkPendingSync(ctx context.Context, session *dbent.PendingAuthSession, userID int64, syncUsername bool) {
+ if session == nil || userID <= 0 {
+ return
+ }
+ if !strings.EqualFold(strings.TrimSpace(session.ProviderType), "dingtalk") {
+ return
+ }
+ cfg, err := h.getDingTalkOAuthConfig(ctx)
+ if err != nil {
+ slog.Debug("dingtalk sync: skip post-login sync, config unavailable", "user_id", userID, "err", err.Error())
+ return
+ }
+ client := h.dingTalkClient(cfg)
+ claims := session.UpstreamIdentityClaims
+ // 异步执行避免阻塞 token 响应。
+ runDingTalkSyncAsync(ctx, func(asyncCtx context.Context) {
+ h.syncDingTalkIdentityFromClaims(asyncCtx, cfg, client, userID, claims, syncUsername)
+ })
+}
+
+// dingTalkStaffFromClaims 从 upstreamClaims 重建最小 DingTalkStaffInfo。
+func dingTalkStaffFromClaims(claims map[string]any) *DingTalkStaffInfo {
+ if claims == nil {
+ return &DingTalkStaffInfo{}
+ }
+ staff := &DingTalkStaffInfo{}
+ if v, ok := claims["username"].(string); ok {
+ staff.Name = v
+ }
+ if v, ok := claims["nickname"].(string); ok {
+ staff.Nickname = v
+ }
+ if v, ok := claims["email"].(string); ok {
+ staff.Email = v
+ }
+ if v, ok := claims["corp_user_id"].(string); ok {
+ staff.UserID = v
+ }
+ // primary_dept_id 存为 int64 或 float64(JSON round-trip)
+ switch v := claims["primary_dept_id"].(type) {
+ case int64:
+ if v > 0 {
+ staff.DeptIDs = []int64{v}
+ }
+ case float64:
+ if id := int64(v); id > 0 {
+ staff.DeptIDs = []int64{id}
+ }
+ }
+ return staff
+}
+
+// setUserAttributeByKey 按 attribute key 查找 definition,再 upsert 用户属性值。
+// definition 不存在时记 warn 日志跳过(admin 在 settings 保存时已按需 upsert
+// 对应 def;缺失意味着 admin 改了 attr key 但未保存 settings,或 def 被手工删除)。
+func (h *AuthHandler) setUserAttributeByKey(ctx context.Context, userID int64, key, value string) error {
+ def, err := h.userAttributeService.GetDefinitionByKey(ctx, key)
+ if err != nil {
+ slog.Warn("dingtalk sync: attribute definition not found, skipping", "key", key, "err", err.Error())
+ return nil
+ }
+ if err := h.userAttributeService.UpdateUserAttributes(ctx, userID, []service.UpdateUserAttributeInput{
+ {AttributeID: def.ID, Value: value},
+ }); err != nil {
+ return err
+ }
+ slog.Info("dingtalk sync: attribute upserted", "user_id", userID, "key", key, "attr_id", def.ID)
+ return nil
+}
+
+// resolveDingTalkDeptPath 从叶部门递归向上拼 "公司/部门/子部门" 路径字符串。
+// 遇 dept_id=1(根)或 parent_id=0 停止。加 visited set 防循环,最多 50 层。
+func (h *AuthHandler) resolveDingTalkDeptPath(ctx context.Context, client *DingTalkClient, deptID int64) (string, error) {
+ slog.Info("dingtalk sync: resolve dept path start", "dept_id", deptID)
+ const maxDepth = 50
+ visited := make(map[int64]bool, maxDepth)
+ var parts []string
+
+ current := deptID
+ for i := 0; i < maxDepth; i++ {
+ if current < 1 || visited[current] {
+ break
+ }
+ visited[current] = true
+
+ info, err := client.GetDeptInfo(ctx, current)
+ if err != nil {
+ return "", fmt.Errorf("get dept info %d: %w", current, err)
+ }
+ if strings.TrimSpace(info.Name) != "" {
+ parts = append([]string{strings.TrimSpace(info.Name)}, parts...)
+ }
+ // 钉钉根部门 dept_id=1,ParentID 通常为 0;遇到 0 / self 终止避免循环。
+ if info.ParentID < 1 || info.ParentID == current {
+ break
+ }
+ current = info.ParentID
+ }
+
+ // 去除根组织名(parts[0] 始终是企业全称),仅保留部门层级。
+ // 例:["公司","A","B"] → "A/B";["公司"] → ""(公司直属)。
+ if len(parts) > 0 {
+ parts = parts[1:]
+ }
+
+ return strings.Join(parts, "/"), nil
+}
diff --git a/backend/internal/handler/auth_dingtalk_oauth_test.go b/backend/internal/handler/auth_dingtalk_oauth_test.go
new file mode 100644
index 00000000..1f60e6b6
--- /dev/null
+++ b/backend/internal/handler/auth_dingtalk_oauth_test.go
@@ -0,0 +1,391 @@
+package handler
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// TestDingTalkOAuthStart_Disabled は sentinel テスト。
+// TODO(task-1.10): newTestAuthHandlerWithDingTalk helper が追加されたら t.Skip を外す。
+func TestDingTalkOAuthStart_Disabled(t *testing.T) {
+ t.Skip("helper newTestAuthHandlerWithDingTalk added in Task 1.10; sentinel only")
+}
+
+// TestBuildDingTalkSyntheticEmail_UsesUnionID 验证合成邮箱种子使用 unionID。
+func TestBuildDingTalkSyntheticEmail_UsesUnionID(t *testing.T) {
+ unionID := "union_AbCdEf123"
+ email := buildDingTalkSyntheticEmail(unionID)
+
+ want := "dingtalk-union_abcdef123@dingtalk-connect.invalid"
+ require.Equal(t, want, email)
+
+ // 确保结果都是小写(邮箱大小写不敏感,统一小写)
+ require.True(t, strings.ToLower(email) == email, "synthetic email should be all lowercase")
+
+ // 确保前缀正确
+ require.True(t, strings.HasPrefix(email, "dingtalk-"), "should have dingtalk- prefix")
+
+ // 确保后缀是合成邮箱域名
+ require.True(t, strings.HasSuffix(email, "@dingtalk-connect.invalid"), "should have reserved domain suffix")
+}
+
+// TestBuildDingTalkSyntheticEmail_TrimsSpace 验证 unionID 空白被修剪。
+func TestBuildDingTalkSyntheticEmail_TrimsSpace(t *testing.T) {
+ email := buildDingTalkSyntheticEmail(" UID_XYZ ")
+ require.Equal(t, "dingtalk-uid_xyz@dingtalk-connect.invalid", email)
+}
+
+// TestBuildDingTalkUpstreamClaims_EmptyStaff 验证 staff 为空 struct(跨组织降级路径)时:
+// - subject 等于 unionID(与 identityKey.ProviderSubject 一致)
+// - corp_user_id 为空字符串(跨组织时拿不到企业 userid)
+// - email/username 为空字符串
+// B/C: Step 3/4 失败降级时 staff = &DingTalkStaffInfo{},claims 不应有 nil。
+func TestBuildDingTalkUpstreamClaims_EmptyStaff(t *testing.T) {
+ staff := &DingTalkStaffInfo{}
+ claims := buildDingTalkUpstreamClaims(staff, "UNION_AAA", "CORP_X")
+
+ require.Equal(t, "", claims["email"])
+ require.Equal(t, "", claims["username"])
+ // 重构后 subject = unionID(与 identityKey.ProviderSubject 保持一致)
+ require.Equal(t, "UNION_AAA", claims["subject"])
+ require.Equal(t, "", claims["corp_user_id"]) // 企业 userid 跨组织时为空
+ require.Equal(t, "UNION_AAA", claims["union_id"])
+ require.Equal(t, "CORP_X", claims["corp_id"])
+}
+
+// TestCheckDingTalkCorpAllowed_CrossOrgPolicy 验证 policy=none 时允许任意 corp。
+// D: corp 校验提前后逻辑不变。
+func TestCheckDingTalkCorpAllowed_CrossOrgPolicy(t *testing.T) {
+ cfg := config.DingTalkConnectConfig{CorpRestrictionPolicy: "none"}
+
+ assert.True(t, checkDingTalkCorpAllowed(cfg, "dingABC"), "policy=none should allow any corp")
+ assert.True(t, checkDingTalkCorpAllowed(cfg, ""), "policy=none should allow empty corp")
+ assert.True(t, checkDingTalkCorpAllowed(cfg, "foreign_corp"), "policy=none should allow foreign corp")
+}
+
+// TestCheckDingTalkCorpAllowed_InternalOnly 验证 policy=internal_only 时的 corp 校验语义(方案 A 修订)。
+// 钉钉 userAccessToken 在部分授权场景(扫码登录、非企业工作台入口)不返回 corpId 字段,
+// 因此 checkDingTalkCorpAllowed 完全不校验 corpID,由 step 3 GetUserIdByUnionId 做真实判定
+// (跨企业用户会被钉钉错误码 60011/60121 拒绝,mapDingTalkErrorCode 映射回 corp_rejected)。
+func TestCheckDingTalkCorpAllowed_InternalOnly(t *testing.T) {
+ cfgWithCorpID := config.DingTalkConnectConfig{
+ CorpRestrictionPolicy: "internal_only",
+ InternalCorpID: "dingInternal",
+ }
+ assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, "dingInternal"), "internal_only: matching corpID allowed")
+ assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, "foreign_corp"), "internal_only: corpID 字段不再用于决策,step 3 兜底")
+ assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, ""), "internal_only: 空 corpID 也通过(钉钉部分授权场景不返回 corpId)")
+
+ cfgNoCorpID := config.DingTalkConnectConfig{
+ CorpRestrictionPolicy: "internal_only",
+ InternalCorpID: "",
+ }
+ assert.True(t, checkDingTalkCorpAllowed(cfgNoCorpID, "dingAnyNonEmpty"), "internal_only + no InternalCorpID: 非空 corpID 通过")
+ assert.True(t, checkDingTalkCorpAllowed(cfgNoCorpID, ""), "internal_only + no InternalCorpID: 空 corpID 也通过")
+}
+
+// TestDecideDingTalkStep34Strategy_PolicyNone 验证 policy=none 时
+// Step 3/4 失败应降级(shouldFallback=true, isFatal=false)。
+func TestDecideDingTalkStep34Strategy_PolicyNone(t *testing.T) {
+ step3Err := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403}
+
+ shouldFallback, isFatal := decideDingTalkStep34Strategy("none", step3Err)
+
+ require.True(t, shouldFallback, "policy=none: step3 failure should trigger fallback")
+ require.False(t, isFatal, "policy=none: step3 failure should NOT be fatal")
+}
+
+// TestDecideDingTalkStep34Strategy_PolicyNoneEmpty 验证 policy="" 时行为与 "none" 相同。
+func TestDecideDingTalkStep34Strategy_PolicyNoneEmpty(t *testing.T) {
+ stepErr := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403}
+
+ shouldFallback, isFatal := decideDingTalkStep34Strategy("", stepErr)
+
+ require.True(t, shouldFallback, "policy='': step failure should trigger fallback")
+ require.False(t, isFatal, "policy='': step failure should NOT be fatal")
+}
+
+// TestDecideDingTalkStep34Strategy_PolicyInternalOnly 验证 policy=internal_only 时
+// Step 3/4 失败应 hard fail(isFatal=true)。
+func TestDecideDingTalkStep34Strategy_PolicyInternalOnly(t *testing.T) {
+ step3Err := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403}
+
+ shouldFallback, isFatal := decideDingTalkStep34Strategy("internal_only", step3Err)
+
+ require.False(t, shouldFallback, "policy=internal_only: should NOT fallback on step3 error")
+ require.True(t, isFatal, "policy=internal_only: step3 failure should be fatal")
+}
+
+// TestDecideDingTalkStep34Strategy_NoError 验证 stepErr=nil 时两个返回值均为 false。
+func TestDecideDingTalkStep34Strategy_NoError(t *testing.T) {
+ for _, policy := range []string{"none", "internal_only", ""} {
+ shouldFallback, isFatal := decideDingTalkStep34Strategy(policy, nil)
+ require.False(t, shouldFallback, "no error should not trigger fallback (policy=%q)", policy)
+ require.False(t, isFatal, "no error should not be fatal (policy=%q)", policy)
+ }
+}
+
+// TestCompleteDingTalkRegistration_UsernameFromEmailLocalPart 验证 username 为空时
+// 退到 email local part(@ 之前的部分)。
+// E: CompleteDingTalkOAuthRegistration username fallback。
+func TestCompleteDingTalkRegistration_UsernameFromEmailLocalPart(t *testing.T) {
+ tests := []struct {
+ name string
+ email string
+ username string
+ wantUser string
+ wantValid bool
+ }{
+ {
+ name: "username empty, normal email → local part",
+ email: "dingtalk-uid123@dingtalk-connect.invalid",
+ username: "",
+ wantUser: "dingtalk-uid123",
+ wantValid: true,
+ },
+ {
+ name: "username already set → keep original",
+ email: "user@example.com",
+ username: "张三",
+ wantUser: "张三",
+ wantValid: true,
+ },
+ {
+ name: "username empty, no @ in email → use whole email",
+ email: "noemail",
+ username: "",
+ wantUser: "noemail",
+ wantValid: true,
+ },
+ {
+ name: "both empty → invalid",
+ email: "",
+ username: "",
+ wantUser: "",
+ wantValid: false,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ username := tc.username
+ email := tc.email
+
+ // 模拟 CompleteDingTalkOAuthRegistration 中的 fallback 逻辑
+ if username == "" {
+ if at := strings.Index(email, "@"); at > 0 {
+ username = email[:at]
+ } else {
+ username = email
+ }
+ }
+
+ isValid := email != "" && username != ""
+ require.Equal(t, tc.wantUser, username, fmt.Sprintf("username for email=%q", tc.email))
+ require.Equal(t, tc.wantValid, isValid, "validity check")
+ })
+ }
+}
+
+// TestBuildDingTalkUpstreamClaims_SubjectEqualsUnionID 验证重构后 subject = unionID
+// 而非 staff.UserID,与 identityKey.ProviderSubject 保持一致。
+// §4.2: buildDingTalkUpstreamClaims subject 字段修正。
+func TestBuildDingTalkUpstreamClaims_SubjectEqualsUnionID(t *testing.T) {
+ staff := &DingTalkStaffInfo{UserID: "user123", Name: "张三", Email: "zhangsan@corp.com"}
+ claims := buildDingTalkUpstreamClaims(staff, "union456", "dingcorp789")
+
+ // 重构后 subject = unionID(全局唯一,与 identityKey.ProviderSubject 一致)
+ require.Equal(t, "union456", claims["subject"], "subject should equal unionID after refactor")
+ // 企业 userid 保留为独立字段,供 audit/debug 使用
+ require.Equal(t, "user123", claims["corp_user_id"], "corp_user_id should be staff.UserID")
+ // union_id 字段与 subject 相同(冗余保留,便于读取)
+ require.Equal(t, "union456", claims["union_id"])
+ require.Equal(t, "dingcorp789", claims["corp_id"])
+ require.Equal(t, "张三", claims["username"])
+ require.Equal(t, "zhangsan@corp.com", claims["email"])
+}
+
+// TestBuildDingTalkUpstreamClaims_CrossOrgEmptyCorpUserID 验证跨组织降级时
+// corp_user_id 为空字符串(跨组织拿不到企业 userid),subject 仍为 unionID。
+func TestBuildDingTalkUpstreamClaims_CrossOrgEmptyCorpUserID(t *testing.T) {
+ // 跨组织降级路径:staff = &DingTalkStaffInfo{}(所有字段为零值)
+ staff := &DingTalkStaffInfo{}
+ claims := buildDingTalkUpstreamClaims(staff, "union_cross_org", "foreign_corp")
+
+ require.Equal(t, "union_cross_org", claims["subject"], "subject should still be unionID for cross-org users")
+ require.Equal(t, "", claims["corp_user_id"], "corp_user_id should be empty for cross-org fallback")
+ require.Equal(t, "", claims["email"])
+ require.Equal(t, "", claims["username"])
+}
+
+// TestBuildDingTalkUpstreamClaims_PrimaryDeptIDInClaims 验证首个 dept_id 被存入 claims。
+func TestBuildDingTalkUpstreamClaims_PrimaryDeptIDInClaims(t *testing.T) {
+ staff := &DingTalkStaffInfo{UserID: "u1", Name: "张三", Email: "a@b.com", DeptIDs: []int64{42, 99}}
+ claims := buildDingTalkUpstreamClaims(staff, "uid1", "corpX")
+
+ // 只取首个 dept_id
+ require.Equal(t, int64(42), claims["primary_dept_id"], "primary_dept_id should be the first dept_id")
+}
+
+// TestBuildDingTalkUpstreamClaims_NoDeptIDs 验证无部门时 primary_dept_id=0。
+func TestBuildDingTalkUpstreamClaims_NoDeptIDs(t *testing.T) {
+ staff := &DingTalkStaffInfo{UserID: "u2", Name: "李四"}
+ claims := buildDingTalkUpstreamClaims(staff, "uid2", "corpY")
+
+ require.Equal(t, int64(0), claims["primary_dept_id"], "primary_dept_id should be 0 when no depts")
+}
+
+// TestDingTalkStaffFromClaims_RoundTrip 验证 dingTalkStaffFromClaims 能从 claims 恢复 staff 信息。
+func TestDingTalkStaffFromClaims_RoundTrip(t *testing.T) {
+ staff := &DingTalkStaffInfo{UserID: "u3", Name: "王五", Email: "ww@corp.com", DeptIDs: []int64{55}}
+ claims := buildDingTalkUpstreamClaims(staff, "uid3", "corpZ")
+
+ recovered := dingTalkStaffFromClaims(claims)
+ require.Equal(t, "王五", recovered.Name)
+ require.Equal(t, "ww@corp.com", recovered.Email)
+ require.Equal(t, "u3", recovered.UserID)
+ require.Equal(t, []int64{55}, recovered.DeptIDs)
+}
+
+// TestResolveDingTalkDeptPath_SingleLevel 验证单层部门(parent_id=1)返回部门名。
+func TestResolveDingTalkDeptPath_SingleLevel(t *testing.T) {
+ handler := &AuthHandler{}
+ callCount := 0
+ responses := map[string]string{
+ "42": `{"errcode":0,"result":{"dept_id":42,"name":"研发部","parent_id":1}}`,
+ "1": `{"errcode":0,"result":{"dept_id":1,"name":"公司","parent_id":0}}`,
+ }
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ callCount++
+ var req struct {
+ DeptID int64 `json:"dept_id"`
+ }
+ _ = json.NewDecoder(r.Body).Decode(&req)
+ w.Header().Set("Content-Type", "application/json")
+ if resp, ok := responses[fmt.Sprintf("%d", req.DeptID)]; ok {
+ _, _ = w.Write([]byte(resp))
+ } else {
+ _, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"not found"}`))
+ }
+ }))
+ defer server.Close()
+
+ cli := &DingTalkClient{
+ cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"},
+ httpClient: server.Client(),
+ }
+ cli.appToken = "tok"
+ cli.appTokenExp = time.Now().Add(time.Hour)
+
+ path, err := handler.resolveDingTalkDeptPath(context.Background(), cli, 42)
+ require.NoError(t, err)
+ require.Equal(t, "研发部", path)
+ require.Equal(t, 2, callCount)
+}
+
+// TestSyncDingTalkIdentity_UsesCfgAttrKeys 验证 syncDingTalkIdentity 使用 cfg 中配置的 attr key
+// 而不是硬编码值。通过 userAttributeService=nil 使同步路径走 warn 跳过,但在此之前先验证
+// syncField 构建逻辑(即 attr key 从 cfg 读取)。
+// 间接验证:通过构造定制 cfg,确认不同 attr key 可以正确传入(编译时保证类型正确,运行时不 panic)。
+func TestSyncDingTalkIdentity_UsesCfgAttrKeys_NoopWithNilService(t *testing.T) {
+ handler := &AuthHandler{
+ userAttributeService: nil, // nil → 触发 warn 跳过,但不 panic
+ }
+
+ cfg := config.DingTalkConnectConfig{
+ CorpRestrictionPolicy: "internal_only",
+ SyncCorpEmail: true,
+ SyncDisplayName: true,
+ SyncDept: true,
+ // 自定义 attr key(非默认值)
+ SyncCorpEmailAttrKey: "custom_email_key",
+ SyncDisplayNameAttrKey: "custom_name_key",
+ SyncDeptAttrKey: "custom_dept_key",
+ }
+
+ staff := &DingTalkStaffInfo{
+ Name: "张三",
+ Email: "zhangsan@example.com",
+ }
+
+ // 调用不应 panic(userAttributeService 为 nil 时走 warn 跳过路径)
+ require.NotPanics(t, func() {
+ handler.syncDingTalkIdentity(context.Background(), cfg, nil, 42, staff, false)
+ })
+}
+
+// TestSyncDingTalkIdentity_DefaultAttrKeys_NoopWithNilService 验证 cfg 默认 attr key 为空时
+// 使用 fallback 默认值(dingtalk_email / dingtalk_name / dingtalk_department)。
+// 此测试主要验证调用路径不 panic;实际 key 赋值默认值的逻辑在 GetDingTalkConnectOAuthConfig 层。
+func TestSyncDingTalkIdentity_DefaultAttrKeys_NoopWithNilService(t *testing.T) {
+ handler := &AuthHandler{
+ userAttributeService: nil,
+ }
+
+ cfg := config.DingTalkConnectConfig{
+ CorpRestrictionPolicy: "internal_only",
+ SyncCorpEmail: true,
+ SyncDisplayName: true,
+ SyncDept: false,
+ // 不设置 attr key(等同于 GetDingTalkConnectOAuthConfig 未设置时 fallback 后的默认值已在调用前填充)
+ SyncCorpEmailAttrKey: "dingtalk_email",
+ SyncDisplayNameAttrKey: "dingtalk_name",
+ SyncDeptAttrKey: "dingtalk_department",
+ }
+
+ staff := &DingTalkStaffInfo{
+ Name: "李四",
+ Email: "lisi@corp.com",
+ }
+
+ require.NotPanics(t, func() {
+ handler.syncDingTalkIdentity(context.Background(), cfg, nil, 99, staff, false)
+ })
+}
+
+// TestResolveDingTalkDeptPath_MultiLevel 验证多层部门路径拼接。
+func TestResolveDingTalkDeptPath_MultiLevel(t *testing.T) {
+ handler := &AuthHandler{}
+ // 模拟:42(AI研发) → parent=10(研发部) → parent=1(根)
+ responses := map[string]string{
+ "42": `{"errcode":0,"result":{"dept_id":42,"name":"AI研发","parent_id":10}}`,
+ "10": `{"errcode":0,"result":{"dept_id":10,"name":"研发部","parent_id":1}}`,
+ "1": `{"errcode":0,"result":{"dept_id":1,"name":"公司","parent_id":0}}`,
+ }
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // 解析请求 body 拿到 dept_id
+ var req struct {
+ DeptID int64 `json:"dept_id"`
+ }
+ _ = json.NewDecoder(r.Body).Decode(&req)
+ key := fmt.Sprintf("%d", req.DeptID)
+ w.Header().Set("Content-Type", "application/json")
+ if resp, ok := responses[key]; ok {
+ _, _ = w.Write([]byte(resp))
+ } else {
+ _, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"not found"}`))
+ }
+ }))
+ defer server.Close()
+
+ cli := &DingTalkClient{
+ cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"},
+ httpClient: server.Client(),
+ }
+ cli.appToken = "tok"
+ cli.appTokenExp = time.Now().Add(time.Hour)
+
+ path, err := handler.resolveDingTalkDeptPath(context.Background(), cli, 42)
+ require.NoError(t, err)
+ require.Equal(t, "研发部/AI研发", path)
+}
diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go
index 1f9a66ff..a9af910d 100644
--- a/backend/internal/handler/auth_handler.go
+++ b/backend/internal/handler/auth_handler.go
@@ -4,6 +4,7 @@ import (
"context"
"log/slog"
"strings"
+ "sync"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -18,25 +19,30 @@ import (
// AuthHandler handles authentication-related requests
type AuthHandler struct {
- cfg *config.Config
- authService *service.AuthService
- userService *service.UserService
- settingSvc *service.SettingService
- promoService *service.PromoService
- redeemService *service.RedeemService
- totpService *service.TotpService
+ cfg *config.Config
+ authService *service.AuthService
+ userService *service.UserService
+ settingSvc *service.SettingService
+ promoService *service.PromoService
+ redeemService *service.RedeemService
+ totpService *service.TotpService
+ userAttributeService *service.UserAttributeService
+
+ dingTalkClientInstance *DingTalkClient
+ dingTalkClientMu sync.Mutex
}
// NewAuthHandler creates a new AuthHandler
-func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService) *AuthHandler {
+func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService, userAttributeService *service.UserAttributeService) *AuthHandler {
return &AuthHandler{
- cfg: cfg,
- authService: authService,
- userService: userService,
- settingSvc: settingService,
- promoService: promoService,
- redeemService: redeemService,
- totpService: totpService,
+ cfg: cfg,
+ authService: authService,
+ userService: userService,
+ settingSvc: settingService,
+ promoService: promoService,
+ redeemService: redeemService,
+ totpService: totpService,
+ userAttributeService: userAttributeService,
}
}
diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go
index 7df4abfd..f0ea5fde 100644
--- a/backend/internal/handler/auth_linuxdo_oauth.go
+++ b/backend/internal/handler/auth_linuxdo_oauth.go
@@ -350,7 +350,8 @@ func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email stri
if email == "" ||
strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
- strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
+ strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.DingTalkConnectSyntheticEmailDomain) {
return nil, nil
}
@@ -519,7 +520,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "linuxdo")
if err != nil {
response.ErrorFrom(c, err)
return
diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go
index 490afd0f..1014a3e8 100644
--- a/backend/internal/handler/auth_oauth_pending_flow.go
+++ b/backend/internal/handler/auth_oauth_pending_flow.go
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
+ "log/slog"
"net/http"
"net/url"
"strings"
@@ -195,6 +196,14 @@ func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPen
},
})
if err != nil {
+ slog.Error("pending auth session create failed",
+ "intent", strings.TrimSpace(payload.Intent),
+ "provider_type", strings.TrimSpace(payload.Identity.ProviderType),
+ "provider_key", strings.TrimSpace(payload.Identity.ProviderKey),
+ "provider_subject_len", len(strings.TrimSpace(payload.Identity.ProviderSubject)),
+ "resolved_email_len", len(strings.TrimSpace(payload.ResolvedEmail)),
+ "has_target_user", payload.TargetUserID != nil,
+ "error", err.Error())
return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err)
}
@@ -266,6 +275,22 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
}
+// pendingSessionRequiresEmailCompletion 判断 callback 写入的 completion payload 是否处于"补邮箱"状态。
+// 钉钉跨组织/staff 邮箱缺失时进入此状态:前端跳到补邮箱页,exchange 不应走 adoption apply。
+func pendingSessionRequiresEmailCompletion(payload map[string]any) bool {
+ if v, ok := payload["requires_email_completion"].(bool); ok && v {
+ return true
+ }
+ return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "email_completion")
+}
+
+// pendingSessionRequiresBindLogin 判断 callback 写入的 completion payload 是否处于"必须绑定已有账户"状态。
+// 钉钉 signupBlocked=true(注册关 + 钉钉企业豁免关)时进入此状态:前端渲染 bind_login 表单,
+// exchange 不应消费 session,否则后续 /pending/bind-login 找不到 session。
+func pendingSessionRequiresBindLogin(payload map[string]any) bool {
+ return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "bind_login_required")
+}
+
func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool {
if session == nil {
return false
@@ -1467,8 +1492,10 @@ func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]
delete(normalized, key)
}
step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step")))
+ // 把多种 choice 别名归一为 oauthPendingChoiceStep;bind_login_required 是独立终态
+ // (前端渲染 needsBindLogin 而非 needsChooser),故不能并入归一化列表。
switch step {
- case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required":
+ case "choice", "choose_account_action", "choose_account", "choose", "email_required":
normalized["step"] = oauthPendingChoiceStep
}
if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(normalized, "step")), oauthPendingChoiceStep) {
@@ -1594,6 +1621,8 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ // bindPendingOAuthLogin = 绑定已有账户登录,不动 users.username(用户已有自己的名字)
+ h.maybeSyncDingTalkAfterLogin(c.Request.Context(), session, user.ID)
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
if err != nil {
response.InternalError(c, "Failed to generate token pair")
@@ -1792,6 +1821,8 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
+ // createPendingOAuthAccount = 注册新账户,需要把钉钉昵称同步到 users.username 作为初始值
+ h.maybeSyncDingTalkAfterRegistration(c.Request.Context(), session, user.ID)
clearCookies()
writeOAuthTokenPairResponse(c, tokenPair)
}
@@ -1893,6 +1924,14 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
response.Success(c, payload)
return
}
+ if pendingSessionRequiresEmailCompletion(payload) {
+ response.Success(c, payload)
+ return
+ }
+ if pendingSessionRequiresBindLogin(payload) {
+ response.Success(c, payload)
+ return
+ }
if !adoptionDecision.hasDecision() {
adoptionRequired, _ := payload["adoption_required"].(bool)
if adoptionRequired {
diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go
index 4264002d..c7c517c8 100644
--- a/backend/internal/handler/auth_oidc_oauth.go
+++ b/backend/internal/handler/auth_oidc_oauth.go
@@ -502,7 +502,8 @@ func (h *AuthHandler) findOIDCCompatEmailUser(ctx context.Context, email string)
if email == "" ||
strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
- strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
+ strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(email, service.DingTalkConnectSyntheticEmailDomain) {
return nil, nil
}
@@ -666,7 +667,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "oidc")
if err != nil {
response.ErrorFrom(c, err)
return
diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go
index 34e70ed0..2199c5bd 100644
--- a/backend/internal/handler/auth_wechat_oauth.go
+++ b/backend/internal/handler/auth_wechat_oauth.go
@@ -548,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
return
}
- tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
+ tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "wechat")
if err != nil {
response.ErrorFrom(c, err)
return
diff --git a/backend/internal/handler/dto/account_mapper_redact_test.go b/backend/internal/handler/dto/account_mapper_redact_test.go
new file mode 100644
index 00000000..bd584e11
--- /dev/null
+++ b/backend/internal/handler/dto/account_mapper_redact_test.go
@@ -0,0 +1,67 @@
+package dto
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+func TestAccountFromServiceShallow_RedactsSensitiveCredentials(t *testing.T) {
+ src := &service.Account{
+ ID: 42,
+ Name: "demo",
+ Platform: "anthropic",
+ Type: "oauth",
+ Credentials: map[string]any{
+ "access_token": "at-secret",
+ "refresh_token": "rt-secret",
+ "id_token": "id-secret",
+ "api_key": "sk-secret",
+ "base_url": "https://api.example.com",
+ "model_mapping": map[string]any{"foo": "bar"},
+ },
+ }
+
+ got := AccountFromServiceShallow(src)
+ require.NotNil(t, got)
+
+ // 敏感键不在 Credentials 里
+ require.NotContains(t, got.Credentials, "access_token")
+ require.NotContains(t, got.Credentials, "refresh_token")
+ require.NotContains(t, got.Credentials, "id_token")
+ require.NotContains(t, got.Credentials, "api_key")
+ // 非敏感键保留
+ require.Equal(t, "https://api.example.com", got.Credentials["base_url"])
+ require.Equal(t, map[string]any{"foo": "bar"}, got.Credentials["model_mapping"])
+
+ // 状态 map 标记敏感键存在
+ require.True(t, got.CredentialsStatus["has_access_token"])
+ require.True(t, got.CredentialsStatus["has_refresh_token"])
+ require.True(t, got.CredentialsStatus["has_id_token"])
+ require.True(t, got.CredentialsStatus["has_api_key"])
+
+ // JSON 序列化校验:响应体里不会出现敏感子串
+ raw, err := json.Marshal(got)
+ require.NoError(t, err)
+ require.NotContains(t, string(raw), "rt-secret")
+ require.NotContains(t, string(raw), "at-secret")
+ require.NotContains(t, string(raw), "sk-secret")
+ require.NotContains(t, string(raw), "id-secret")
+ // 状态标识应序列化进 JSON
+ require.Contains(t, string(raw), "credentials_status")
+ require.Contains(t, string(raw), "has_refresh_token")
+
+ // 原始 service.Account 不应被改动
+ require.Equal(t, "rt-secret", src.Credentials["refresh_token"])
+}
+
+func TestAccountFromServiceShallow_NilCredentialsOmitsStatus(t *testing.T) {
+ src := &service.Account{ID: 1, Name: "n", Platform: "anthropic", Type: "oauth"}
+ got := AccountFromServiceShallow(src)
+ require.NotNil(t, got)
+ require.Nil(t, got.Credentials)
+ require.Nil(t, got.CredentialsStatus)
+}
diff --git a/backend/internal/handler/dto/credentials_redact.go b/backend/internal/handler/dto/credentials_redact.go
new file mode 100644
index 00000000..e65a8007
--- /dev/null
+++ b/backend/internal/handler/dto/credentials_redact.go
@@ -0,0 +1,44 @@
+// Package dto provides data transfer objects for HTTP handlers.
+package dto
+
+import "github.com/Wei-Shaw/sub2api/internal/service"
+
+// RedactCredentials 复制一份 in,剥离 service.SensitiveCredentialKeys 列出的所有敏感子键,
+// 并产出一个 has_ 状态 map 表示哪些敏感键存在且非零值。
+//
+// 输入 nil 时返回 nil, nil(避免响应里出现空对象)。
+// 不修改入参;调用方拿到的 out 可安全序列化进 JSON 返回前端。
+func RedactCredentials(in map[string]any) (out map[string]any, status map[string]bool) {
+ if in == nil {
+ return nil, nil
+ }
+ out = make(map[string]any, len(in))
+ for k, v := range in {
+ if service.IsSensitiveCredentialKey(k) {
+ if isCredentialValuePresent(v) {
+ if status == nil {
+ status = make(map[string]bool, 4)
+ }
+ status["has_"+k] = true
+ }
+ continue
+ }
+ out[k] = v
+ }
+ return out, status
+}
+
+// isCredentialValuePresent 判断值是否"存在且非零"。空字符串、nil、false 均视为未配置;
+// 其余非零类型(数字、对象、字符串等)视为已配置。
+func isCredentialValuePresent(v any) bool {
+ switch x := v.(type) {
+ case nil:
+ return false
+ case string:
+ return x != ""
+ case bool:
+ return x
+ default:
+ return true
+ }
+}
diff --git a/backend/internal/handler/dto/credentials_redact_test.go b/backend/internal/handler/dto/credentials_redact_test.go
new file mode 100644
index 00000000..431078fa
--- /dev/null
+++ b/backend/internal/handler/dto/credentials_redact_test.go
@@ -0,0 +1,97 @@
+package dto
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestRedactCredentials_NilInput(t *testing.T) {
+ out, status := RedactCredentials(nil)
+ require.Nil(t, out)
+ require.Nil(t, status)
+}
+
+func TestRedactCredentials_StripsSensitiveKeysAndReportsStatus(t *testing.T) {
+ in := map[string]any{
+ "refresh_token": "rt-secret",
+ "access_token": "at-secret",
+ "api_key": "sk-secret",
+ "aws_secret_access_key": "aws-secret",
+ "service_account_json": map[string]any{"private_key": "..."},
+ "private_key": "raw-key",
+ // 非敏感
+ "base_url": "https://api.example.com",
+ "model_mapping": map[string]any{"foo": "bar"},
+ "project_id": "proj-1",
+ "expires_at": int64(123456),
+ }
+
+ out, status := RedactCredentials(in)
+
+ require.NotContains(t, out, "refresh_token")
+ require.NotContains(t, out, "access_token")
+ require.NotContains(t, out, "api_key")
+ require.NotContains(t, out, "aws_secret_access_key")
+ require.NotContains(t, out, "service_account_json")
+ require.NotContains(t, out, "private_key")
+
+ require.Equal(t, "https://api.example.com", out["base_url"])
+ require.Equal(t, map[string]any{"foo": "bar"}, out["model_mapping"])
+ require.Equal(t, "proj-1", out["project_id"])
+ require.Equal(t, int64(123456), out["expires_at"])
+
+ require.True(t, status["has_refresh_token"])
+ require.True(t, status["has_access_token"])
+ require.True(t, status["has_api_key"])
+ require.True(t, status["has_aws_secret_access_key"])
+ require.True(t, status["has_service_account_json"])
+ require.True(t, status["has_private_key"])
+
+ // 状态 map 不应携带非敏感键的 has_*
+ require.NotContains(t, status, "has_base_url")
+ require.NotContains(t, status, "has_project_id")
+}
+
+func TestRedactCredentials_EmptyValuesNotMarkedPresent(t *testing.T) {
+ in := map[string]any{
+ "refresh_token": "",
+ "access_token": nil,
+ "api_key": false,
+ "id_token": "actual-id",
+ }
+ out, status := RedactCredentials(in)
+ require.Empty(t, out, "敏感键即使为空也不应出现在 redacted output")
+ require.False(t, status["has_refresh_token"])
+ require.False(t, status["has_access_token"])
+ require.False(t, status["has_api_key"])
+ require.True(t, status["has_id_token"])
+}
+
+func TestRedactCredentials_DoesNotMutateInput(t *testing.T) {
+ in := map[string]any{
+ "refresh_token": "secret",
+ "base_url": "x",
+ }
+ _, _ = RedactCredentials(in)
+ require.Equal(t, "secret", in["refresh_token"], "原始 map 不应被修改")
+ require.Equal(t, "x", in["base_url"])
+}
+
+func TestRedactCredentials_AllKnownSensitiveKeys(t *testing.T) {
+ keys := []string{
+ "access_token", "refresh_token", "id_token",
+ "api_key", "session_key", "cookie",
+ "aws_secret_access_key", "aws_session_token",
+ "service_account_json", "service_account", "private_key",
+ }
+ in := make(map[string]any, len(keys))
+ for _, k := range keys {
+ in[k] = "filled"
+ }
+ out, status := RedactCredentials(in)
+ require.Empty(t, out)
+ for _, k := range keys {
+ require.True(t, status["has_"+k], "key %s 应在 status 中标记为已配置", k)
+ }
+}
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index 2559b112..2c71be9d 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -198,13 +198,15 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if a == nil {
return nil
}
+ redactedCreds, credsStatus := RedactCredentials(a.Credentials)
out := &Account{
ID: a.ID,
Name: a.Name,
Notes: a.Notes,
Platform: a.Platform,
Type: a.Type,
- Credentials: a.Credentials,
+ Credentials: redactedCreds,
+ CredentialsStatus: credsStatus,
Extra: a.Extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
@@ -531,11 +533,15 @@ func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
UsedBy: rc.UsedBy,
UsedAt: rc.UsedAt,
CreatedAt: rc.CreatedAt,
+ ExpiresAt: rc.ExpiresAt,
GroupID: rc.GroupID,
ValidityDays: rc.ValidityDays,
User: UserFromServiceShallow(rc.User),
Group: GroupFromServiceShallow(rc.Group),
}
+ if rc.IsExpired() {
+ out.Status = service.StatusExpired
+ }
// For admin_balance/admin_concurrency types, include notes so users can see
// why they were charged or credited by admin
@@ -600,6 +606,10 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
FirstTokenMs: l.FirstTokenMs,
ImageCount: l.ImageCount,
ImageSize: l.ImageSize,
+ ImageInputSize: l.ImageInputSize,
+ ImageOutputSize: l.ImageOutputSize,
+ ImageSizeSource: l.ImageSizeSource,
+ ImageSizeBreakdown: l.ImageSizeBreakdown,
MediaType: l.MediaType,
UserAgent: l.UserAgent,
CacheTTLOverridden: l.CacheTTLOverridden,
diff --git a/backend/internal/handler/dto/mappers_usage_test.go b/backend/internal/handler/dto/mappers_usage_test.go
index c2635e33..eca838b9 100644
--- a/backend/internal/handler/dto/mappers_usage_test.go
+++ b/backend/internal/handler/dto/mappers_usage_test.go
@@ -148,6 +148,65 @@ func TestUsageLogFromService_FallsBackToLegacyModelWhenRequestedModelMissing(t *
require.Equal(t, "claude-3", adminDTO.Model)
}
+func TestUsageLogFromService_IncludesImageBillingMetadataForUserAndAdmin(t *testing.T) {
+ t.Parallel()
+
+ imageSize := "4K"
+ inputSize := "1024x1024"
+ outputSize := "3840x2160"
+ source := "output"
+ log := &service.UsageLog{
+ RequestID: "req_image_metadata",
+ Model: "gpt-image-2",
+ ImageCount: 2,
+ ImageSize: &imageSize,
+ ImageInputSize: &inputSize,
+ ImageOutputSize: &outputSize,
+ ImageSizeSource: &source,
+ ImageSizeBreakdown: map[string]int{"4K": 2},
+ }
+
+ userDTO := UsageLogFromService(log)
+ adminDTO := UsageLogFromServiceAdmin(log)
+
+ for _, got := range []*UsageLog{userDTO, &adminDTO.UsageLog} {
+ require.Equal(t, 2, got.ImageCount)
+ require.NotNil(t, got.ImageSize)
+ require.Equal(t, imageSize, *got.ImageSize)
+ require.NotNil(t, got.ImageInputSize)
+ require.Equal(t, inputSize, *got.ImageInputSize)
+ require.NotNil(t, got.ImageOutputSize)
+ require.Equal(t, outputSize, *got.ImageOutputSize)
+ require.NotNil(t, got.ImageSizeSource)
+ require.Equal(t, source, *got.ImageSizeSource)
+ require.Equal(t, map[string]int{"4K": 2}, got.ImageSizeBreakdown)
+ }
+}
+
+func TestUsageLogFromService_PreservesHistoricalMissingImageSize(t *testing.T) {
+ t.Parallel()
+
+ log := &service.UsageLog{
+ RequestID: "req_legacy_image_missing_size",
+ Model: "gpt-image-2",
+ ImageCount: 1,
+ ImageSize: nil,
+ }
+
+ dto := UsageLogFromService(log)
+ require.Equal(t, 1, dto.ImageCount)
+ require.Nil(t, dto.ImageSize)
+ require.Nil(t, dto.ImageInputSize)
+ require.Nil(t, dto.ImageOutputSize)
+ require.Nil(t, dto.ImageSizeSource)
+ require.Nil(t, dto.ImageSizeBreakdown)
+
+ body, err := json.Marshal(dto)
+ require.NoError(t, err)
+ require.Contains(t, string(body), `"image_size":null`)
+ require.NotContains(t, string(body), `"image_size":"2K"`)
+}
+
func f64Ptr(value float64) *float64 {
return &value
}
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index 551cf0dc..45ad7a70 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -56,6 +56,23 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
+ DingTalkConnectEnabled bool `json:"dingtalk_connect_enabled"`
+ DingTalkConnectClientID string `json:"dingtalk_connect_client_id"`
+ DingTalkConnectClientSecretConfigured bool `json:"dingtalk_connect_client_secret_configured"`
+ DingTalkConnectRedirectURL string `json:"dingtalk_connect_redirect_url"`
+ DingTalkConnectCorpRestrictionPolicy string `json:"dingtalk_connect_corp_restriction_policy"`
+ DingTalkConnectInternalCorpID string `json:"dingtalk_connect_internal_corp_id"`
+ DingTalkConnectBypassRegistration bool `json:"dingtalk_connect_bypass_registration"`
+ DingTalkConnectSyncCorpEmail bool `json:"dingtalk_connect_sync_corp_email"`
+ DingTalkConnectSyncDisplayName bool `json:"dingtalk_connect_sync_display_name"`
+ DingTalkConnectSyncDept bool `json:"dingtalk_connect_sync_dept"`
+ DingTalkConnectSyncCorpEmailAttrKey string `json:"dingtalk_connect_sync_corp_email_attr_key"`
+ DingTalkConnectSyncDisplayNameAttrKey string `json:"dingtalk_connect_sync_display_name_attr_key"`
+ DingTalkConnectSyncDeptAttrKey string `json:"dingtalk_connect_sync_dept_attr_key"`
+ DingTalkConnectSyncCorpEmailAttrName string `json:"dingtalk_connect_sync_corp_email_attr_name"`
+ DingTalkConnectSyncDisplayNameAttrName string `json:"dingtalk_connect_sync_display_name_attr_name"`
+ DingTalkConnectSyncDeptAttrName string `json:"dingtalk_connect_sync_dept_attr_name"`
+
WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
WeChatConnectAppID string `json:"wechat_connect_app_id"`
WeChatConnectAppSecretConfigured bool `json:"wechat_connect_app_secret_configured"`
@@ -201,6 +218,9 @@ type SystemSettings struct {
PaymentCancelRateLimitUnit string `json:"payment_cancel_rate_limit_unit"`
PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"`
+ // Force Alipay mobile clients to use QR code payment instead of mobile redirect
+ PaymentAlipayForceQRCode bool `json:"payment_alipay_force_qrcode"`
+
// Balance low notification
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
@@ -260,6 +280,7 @@ type PublicSettings struct {
TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
+ DingTalkOAuthEnabled bool `json:"dingtalk_oauth_enabled"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index e15a916e..cc360f78 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -149,25 +149,28 @@ type AdminGroup struct {
}
type Account struct {
- ID int64 `json:"id"`
- Name string `json:"name"`
- Notes *string `json:"notes"`
- Platform string `json:"platform"`
- Type string `json:"type"`
- Credentials map[string]any `json:"credentials"`
- Extra map[string]any `json:"extra"`
- ProxyID *int64 `json:"proxy_id"`
- Concurrency int `json:"concurrency"`
- LoadFactor *int `json:"load_factor,omitempty"`
- Priority int `json:"priority"`
- RateMultiplier float64 `json:"rate_multiplier"`
- Status string `json:"status"`
- ErrorMessage string `json:"error_message"`
- LastUsedAt *time.Time `json:"last_used_at"`
- ExpiresAt *int64 `json:"expires_at"`
- AutoPauseOnExpired bool `json:"auto_pause_on_expired"`
- CreatedAt time.Time `json:"created_at"`
- UpdatedAt time.Time `json:"updated_at"`
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Notes *string `json:"notes"`
+ Platform string `json:"platform"`
+ Type string `json:"type"`
+ // Credentials 经 RedactCredentials 处理后只含非敏感子键;敏感 token / api_key / 私钥
+ // 的存在性通过 CredentialsStatus(has_)暴露,原始值不返回前端。
+ Credentials map[string]any `json:"credentials"`
+ CredentialsStatus map[string]bool `json:"credentials_status,omitempty"`
+ Extra map[string]any `json:"extra"`
+ ProxyID *int64 `json:"proxy_id"`
+ Concurrency int `json:"concurrency"`
+ LoadFactor *int `json:"load_factor,omitempty"`
+ Priority int `json:"priority"`
+ RateMultiplier float64 `json:"rate_multiplier"`
+ Status string `json:"status"`
+ ErrorMessage string `json:"error_message"`
+ LastUsedAt *time.Time `json:"last_used_at"`
+ ExpiresAt *int64 `json:"expires_at"`
+ AutoPauseOnExpired bool `json:"auto_pause_on_expired"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
Schedulable bool `json:"schedulable"`
@@ -335,6 +338,7 @@ type RedeemCode struct {
UsedBy *int64 `json:"used_by"`
UsedAt *time.Time `json:"used_at"`
CreatedAt time.Time `json:"created_at"`
+ ExpiresAt *time.Time `json:"expires_at,omitempty"`
GroupID *int64 `json:"group_id"`
ValidityDays int `json:"validity_days"`
@@ -400,9 +404,13 @@ type UsageLog struct {
FirstTokenMs *int `json:"first_token_ms"`
// 图片生成字段
- ImageCount int `json:"image_count"`
- ImageSize *string `json:"image_size"`
- MediaType *string `json:"media_type"`
+ ImageCount int `json:"image_count"`
+ ImageSize *string `json:"image_size"`
+ ImageInputSize *string `json:"image_input_size"`
+ ImageOutputSize *string `json:"image_output_size"`
+ ImageSizeSource *string `json:"image_size_source"`
+ ImageSizeBreakdown map[string]int `json:"image_size_breakdown"`
+ MediaType *string `json:"media_type"`
// User-Agent
UserAgent *string `json:"user_agent"`
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index 238bc892..800328bb 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -157,7 +158,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
- setOpsRequestContext(c, "", false, body)
+ setOpsRequestContext(c, "", false)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
if err != nil {
@@ -209,7 +210,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
- setOpsRequestContext(c, reqModel, reqStream, body)
+ setOpsRequestContext(c, reqModel, reqStream)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 验证 model 必填
@@ -351,6 +352,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
+ markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
reqLog.Warn("gateway.select_account_no_available",
zap.String("model", reqModel),
zap.Int64p("group_id", apiKey.GroupID),
@@ -400,6 +402,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
+ markOpsRoutingCapacityLimited(c)
reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
zap.Int64("account_id", account.ID),
zap.String("model", reqModel),
@@ -621,6 +624,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
+ markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
reqLog.Warn("gateway.select_account_no_available",
zap.String("model", reqModel),
zap.Int64p("group_id", currentAPIKey.GroupID),
@@ -681,6 +685,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
+ markOpsRoutingCapacityLimited(c)
reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
zap.Int64("account_id", account.ID),
zap.String("model", reqModel),
@@ -1041,8 +1046,8 @@ func (h *GatewayHandler) Models(c *gin.Context) {
platform = forcedPlatform
}
- // Get available models from account configurations (without platform filter)
- availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
+ // Get available models from account configurations for the selected group platform.
+ availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, platform)
if len(availableModels) > 0 {
// Build model list from whitelist
@@ -1063,7 +1068,7 @@ func (h *GatewayHandler) Models(c *gin.Context) {
}
// Fallback to default models
- if platform == "openai" {
+ if platform == service.PlatformOpenAI {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": openai.DefaultModels,
@@ -1071,6 +1076,14 @@ func (h *GatewayHandler) Models(c *gin.Context) {
return
}
+ if platform == service.PlatformGemini {
+ c.JSON(http.StatusOK, gin.H{
+ "object": "list",
+ "data": geminicli.DefaultModels,
+ })
+ return
+ }
+
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": claude.DefaultModels,
@@ -1407,6 +1420,11 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
+ if service.IsOpenAISilentRefusalErrorBody(responseBody) {
+ service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
+ h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage(), streamStarted)
+ return
+ }
// 先检查透传规则
if h.errorPassthroughService != nil && len(responseBody) > 0 {
@@ -1611,7 +1629,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
- setOpsRequestContext(c, "", false, body)
+ setOpsRequestContext(c, "", false)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
if err != nil {
@@ -1630,7 +1648,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
- setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body)
+ setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsedReq.Stream, false)))
// 获取订阅信息(可能为nil)
@@ -1659,6 +1677,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil {
reqLog.Warn("gateway.count_tokens_select_account_failed", zap.Error(err))
+ markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
return
}
diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go
index c6b73190..9a091fcd 100644
--- a/backend/internal/handler/gateway_handler_chat_completions.go
+++ b/backend/internal/handler/gateway_handler_chat_completions.go
@@ -60,7 +60,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
return
}
- setOpsRequestContext(c, "", false, body)
+ setOpsRequestContext(c, "", false)
// Validate JSON
if !gjson.ValidBytes(body) {
@@ -78,7 +78,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
- setOpsRequestContext(c, reqModel, reqStream, body)
+ setOpsRequestContext(c, reqModel, reqStream)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
@@ -161,14 +161,26 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
+ groupPlatform := ""
+ if apiKey.Group != nil {
+ groupPlatform = apiKey.Group.Platform
+ }
+ selectionSessionHash := sessionHash
+ if groupPlatform == service.PlatformGemini && selectionSessionHash != "" {
+ selectionSessionHash = "gemini:" + selectionSessionHash
+ }
// 3. Account selection + failover loop
fs := NewFailoverState(h.maxAccountSwitches, false)
+ if groupPlatform == service.PlatformGemini {
+ fs = NewFailoverState(h.maxAccountSwitchesGemini, false)
+ }
for {
- selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
+ selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, selectionSessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
+ markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
@@ -194,6 +206,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
+ markOpsRoutingCapacityLimited(c)
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
@@ -213,13 +226,33 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
+ if groupPlatform == service.PlatformGemini && account.Platform != service.PlatformGemini {
+ if accountReleaseFunc != nil {
+ accountReleaseFunc()
+ }
+ fs.FailedAccountIDs[account.ID] = struct{}{}
+ continue
+ }
+
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
- result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
+ var result *service.ForwardResult
+ if account.Platform == service.PlatformGemini {
+ if h.geminiCompatService == nil {
+ h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "upstream_error", "Gemini compatibility service is not configured")
+ if accountReleaseFunc != nil {
+ accountReleaseFunc()
+ }
+ return
+ }
+ result, err = h.geminiCompatService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody)
+ } else {
+ result, err = h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
+ }
if accountReleaseFunc != nil {
accountReleaseFunc()
@@ -302,5 +335,10 @@ func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *serv
if lastErr != nil && lastErr.StatusCode > 0 {
statusCode = lastErr.StatusCode
}
+ if lastErr != nil && service.IsOpenAISilentRefusalErrorBody(lastErr.ResponseBody) {
+ service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
+ h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage())
+ return
+ }
h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
}
diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go
index a97f572d..e1a5b723 100644
--- a/backend/internal/handler/gateway_handler_responses.go
+++ b/backend/internal/handler/gateway_handler_responses.go
@@ -60,7 +60,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
return
}
- setOpsRequestContext(c, "", false, body)
+ setOpsRequestContext(c, "", false)
// Validate JSON
if !gjson.ValidBytes(body) {
@@ -78,7 +78,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
- setOpsRequestContext(c, reqModel, reqStream, body)
+ setOpsRequestContext(c, reqModel, reqStream)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
@@ -174,6 +174,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
+ markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
@@ -199,6 +200,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
+ markOpsRoutingCapacityLimited(c)
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
@@ -308,5 +310,10 @@ func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastEr
if lastErr != nil && lastErr.StatusCode > 0 {
statusCode = lastErr.StatusCode
}
+ if lastErr != nil && service.IsOpenAISilentRefusalErrorBody(lastErr.ResponseBody) {
+ service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
+ h.responsesErrorResponse(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage())
+ return
+ }
h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
}
diff --git a/backend/internal/handler/gateway_models_test.go b/backend/internal/handler/gateway_models_test.go
new file mode 100644
index 00000000..78b07a1a
--- /dev/null
+++ b/backend/internal/handler/gateway_models_test.go
@@ -0,0 +1,136 @@
+package handler
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type gatewayModelsAccountRepoStub struct {
+ service.AccountRepository
+
+ byGroup map[int64][]service.Account
+}
+
+type gatewayModelsResponseForTest struct {
+ Object string `json:"object"`
+ Data []gatewayModelItemForTest `json:"data"`
+}
+
+type gatewayModelItemForTest struct {
+ ID string `json:"id"`
+}
+
+func (s *gatewayModelsAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
+ accounts, ok := s.byGroup[groupID]
+ if !ok {
+ return nil, nil
+ }
+ out := make([]service.Account, len(accounts))
+ copy(out, accounts)
+ return out, nil
+}
+
+func newGatewayModelsHandlerForTest(repo service.AccountRepository) *GatewayHandler {
+ return &GatewayHandler{
+ gatewayService: service.NewGatewayService(
+ repo,
+ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
+ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
+ ),
+ }
+}
+
+func TestGatewayModels_GeminiGroupFallsBackToGeminiModels(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ groupID := int64(20)
+ h := newGatewayModelsHandlerForTest(
+ &gatewayModelsAccountRepoStub{
+ byGroup: map[int64][]service.Account{
+ groupID: {
+ {ID: 1, Platform: service.PlatformGemini},
+ },
+ },
+ },
+ )
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
+ c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
+ Group: &service.Group{ID: groupID, Platform: service.PlatformGemini},
+ })
+
+ h.Models(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ var got gatewayModelsResponseForTest
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
+ require.Equal(t, "list", got.Object)
+ require.Contains(t, modelIDsForTest(got.Data), "gemini-2.5-flash")
+ require.NotContains(t, modelIDsForTest(got.Data), "claude-sonnet-4-6")
+}
+
+func TestGatewayModels_GeminiGroupFiltersMappedModelsByPlatform(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ groupID := int64(21)
+ h := newGatewayModelsHandlerForTest(
+ &gatewayModelsAccountRepoStub{
+ byGroup: map[int64][]service.Account{
+ groupID: {
+ {
+ ID: 1,
+ Platform: service.PlatformAnthropic,
+ Credentials: map[string]any{
+ "model_mapping": map[string]any{
+ "claude-sonnet-4-6": "claude-sonnet-4-6",
+ },
+ },
+ },
+ {
+ ID: 2,
+ Platform: service.PlatformGemini,
+ Credentials: map[string]any{
+ "model_mapping": map[string]any{
+ "gemini-2.5-flash": "gemini-2.5-flash",
+ },
+ },
+ },
+ },
+ },
+ },
+ )
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
+ c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
+ Group: &service.Group{ID: groupID, Platform: service.PlatformGemini},
+ })
+
+ h.Models(c)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ var got gatewayModelsResponseForTest
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
+ require.Equal(t, []string{"gemini-2.5-flash"}, modelIDsForTest(got.Data))
+}
+
+func modelIDsForTest(models []gatewayModelItemForTest) []string {
+ ids := make([]string, 0, len(models))
+ for _, model := range models {
+ ids = append(ids, model.ID)
+ }
+ return ids
+}
diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go
index 90ebe9ec..665c0677 100644
--- a/backend/internal/handler/gemini_v1beta_handler.go
+++ b/backend/internal/handler/gemini_v1beta_handler.go
@@ -61,6 +61,7 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
c.JSON(http.StatusOK, gemini.FallbackModelsList())
return
}
+ markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
@@ -113,6 +114,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return
}
+ markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
@@ -182,7 +184,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
return
}
- setOpsRequestContext(c, modelName, stream, body)
+ setOpsRequestContext(c, modelName, stream)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
if decision := h.checkContentModeration(c, reqLog, apiKey, authSubject, service.ContentModerationProtocolGemini, modelName, body); decision != nil && decision.Blocked {
@@ -372,6 +374,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
+ markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
@@ -419,6 +422,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
+ markOpsRoutingCapacityLimited(c)
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
return
}
diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go
index de384710..f7269214 100644
--- a/backend/internal/handler/openai_chat_completions.go
+++ b/backend/internal/handler/openai_chat_completions.go
@@ -78,7 +78,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
- setOpsRequestContext(c, reqModel, reqStream, body)
+ setOpsRequestContext(c, reqModel, reqStream)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked {
@@ -143,6 +143,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
+ markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
} else {
@@ -155,6 +156,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
}
}
if selection == nil || selection.Account == nil {
+ markOpsRoutingCapacityLimited(c)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
@@ -176,6 +178,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
+ writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
forwardDurationMs := time.Since(forwardStart).Milliseconds()
@@ -201,6 +204,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
} else {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
+ if c.Writer.Size() != writerSizeBeforeForward {
+ h.handleFailoverExhausted(c, failoverErr, true)
+ return
+ }
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// Pool mode: retry on the same account
if failoverErr.RetryableOnSameAccount {
@@ -292,7 +299,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
// resolveRawCCUpstreamEndpoint returns the actual upstream endpoint for
// OpenAI Chat Completions requests. For APIKey accounts whose upstream
-// has been probed to not support the Responses API, the request is
+// is forced or probed to not support the Responses API, the request is
// forwarded directly to /v1/chat/completions — not through the default
// CC→Responses conversion path.
func resolveRawCCUpstreamEndpoint(c *gin.Context, account *service.Account) string {
diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go
index 6b07b7ba..e7ba699d 100644
--- a/backend/internal/handler/openai_gateway_handler.go
+++ b/backend/internal/handler/openai_gateway_handler.go
@@ -130,7 +130,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
- setOpsRequestContext(c, "", false, body)
+ setOpsRequestContext(c, "", false)
sessionHashBody := body
if service.IsOpenAIResponsesCompactPathForTest(c) {
if compactSeed := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()); compactSeed != "" {
@@ -189,7 +189,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
- setOpsRequestContext(c, reqModel, reqStream, body)
+ setOpsRequestContext(c, reqModel, reqStream)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked {
@@ -282,6 +282,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
+ markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
if errors.Is(err, service.ErrNoAvailableCompactAccounts) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "compact_not_supported", "No available OpenAI accounts support /responses/compact", streamStarted)
return
@@ -297,6 +298,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
if selection == nil || selection.Account == nil {
+ markOpsRoutingCapacityLimited(c)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
@@ -330,6 +332,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
+ writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
@@ -354,6 +357,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
} else {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
+ if c.Writer.Size() != writerSizeBeforeForward {
+ h.handleFailoverExhausted(c, failoverErr, true)
+ return
+ }
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// 池模式:同账号重试
if failoverErr.RetryableOnSameAccount {
@@ -604,7 +611,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
- setOpsRequestContext(c, reqModel, reqStream, body)
+ setOpsRequestContext(c, reqModel, reqStream)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked {
@@ -677,6 +684,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
)
if len(failedAccountIDs) == 0 {
if err != nil {
+ markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
@@ -690,6 +698,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
}
}
if selection == nil || selection.Account == nil {
+ markOpsRoutingCapacityLimited(c)
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
@@ -992,6 +1001,7 @@ func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
reqLog *zap.Logger,
) (func(), bool) {
if selection == nil || selection.Account == nil {
+ markOpsRoutingCapacityLimited(c)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
return nil, false
}
@@ -1002,6 +1012,7 @@ func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true
}
if selection.WaitPlan == nil {
+ markOpsRoutingCapacityLimited(c)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
return nil, false
}
@@ -1163,7 +1174,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
zap.Bool("has_previous_response_id", previousResponseID != ""),
zap.String("previous_response_id_kind", previousResponseIDKind),
)
- setOpsRequestContext(c, reqModel, true, firstMessage)
+ setOpsRequestContext(c, reqModel, true)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, firstMessage); decision != nil && decision.Blocked {
@@ -1598,6 +1609,11 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error,
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
+ if service.IsOpenAISilentRefusalErrorBody(responseBody) {
+ service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
+ h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage(), streamStarted)
+ return
+ }
// 先检查透传规则
if h.errorPassthroughService != nil && len(responseBody) > 0 {
diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go
index 08a6b6e8..1a81a59e 100644
--- a/backend/internal/handler/openai_images.go
+++ b/backend/internal/handler/openai_images.go
@@ -63,9 +63,9 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
}
if isMultipartImagesContentType(c.GetHeader("Content-Type")) {
- setOpsRequestContext(c, "", false, nil)
+ setOpsRequestContext(c, "", false)
} else {
- setOpsRequestContext(c, "", false, body)
+ setOpsRequestContext(c, "", false)
}
parsed, err := h.gatewayService.ParseOpenAIImagesRequest(c, body)
@@ -98,9 +98,9 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
}
if parsed.Multipart {
- setOpsRequestContext(c, parsed.Model, parsed.Stream, nil)
+ setOpsRequestContext(c, parsed.Model, parsed.Stream)
} else {
- setOpsRequestContext(c, parsed.Model, parsed.Stream, body)
+ setOpsRequestContext(c, parsed.Model, parsed.Stream)
}
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsed.Stream, false)))
@@ -157,6 +157,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
+ markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
return
}
@@ -168,6 +169,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
return
}
if selection == nil || selection.Account == nil {
+ markOpsRoutingCapacityLimited(c)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
return
}
diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go
index c676637e..166f67d7 100644
--- a/backend/internal/handler/ops_error_logger.go
+++ b/backend/internal/handler/ops_error_logger.go
@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
+ "errors"
"log"
"runtime"
"runtime/debug"
@@ -22,10 +23,10 @@ import (
)
const (
- opsModelKey = "ops_model"
- opsStreamKey = "ops_stream"
- opsRequestBodyKey = "ops_request_body"
- opsAccountIDKey = "ops_account_id"
+ opsModelKey = "ops_model"
+ opsStreamKey = "ops_stream"
+ opsAccountIDKey = "ops_account_id"
+ opsRoutingCapacityLimitedKey = "ops_routing_capacity_limited"
opsUpstreamModelKey = "ops_upstream_model"
opsRequestTypeKey = "ops_request_type"
@@ -45,6 +46,8 @@ const (
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
opsCodeUserInactive = "USER_INACTIVE"
+ opsCodeInvalidAPIKey = "INVALID_API_KEY"
+ opsCodeAPIKeyRequired = "API_KEY_REQUIRED"
)
const (
@@ -332,16 +335,13 @@ func opsErrorLogConfig() (workerCount int, queueSize int) {
return workerCount, queueSize
}
-func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody []byte) {
+func setOpsRequestContext(c *gin.Context, model string, stream bool) {
if c == nil {
return
}
model = strings.TrimSpace(model)
c.Set(opsModelKey, model)
c.Set(opsStreamKey, stream)
- if len(requestBody) > 0 {
- c.Set(opsRequestBodyKey, requestBody)
- }
if c.Request != nil && model != "" {
ctx := context.WithValue(c.Request.Context(), ctxkey.Model, model)
c.Request = c.Request.WithContext(ctx)
@@ -360,22 +360,6 @@ func setOpsEndpointContext(c *gin.Context, upstreamModel string, requestType int
c.Set(opsRequestTypeKey, requestType)
}
-func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
- if c == nil || entry == nil {
- return
- }
- v, ok := c.Get(opsRequestBodyKey)
- if !ok {
- return
- }
- raw, ok := v.([]byte)
- if !ok || len(raw) == 0 {
- return
- }
- entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = service.PrepareOpsRequestBodyForQueue(raw)
- opsErrorLogSanitized.Add(1)
-}
-
func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) {
if c == nil || accountID <= 0 {
return
@@ -393,6 +377,42 @@ func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string)
}
}
+func markOpsRoutingCapacityLimited(c *gin.Context) {
+ if c == nil {
+ return
+ }
+ c.Set(opsRoutingCapacityLimitedKey, true)
+}
+
+func markOpsRoutingCapacityLimitedIfNoAvailable(c *gin.Context, err error) {
+ if !isOpsNoAvailableAccountError(err) {
+ return
+ }
+ markOpsRoutingCapacityLimited(c)
+}
+
+func isOpsRoutingCapacityLimited(c *gin.Context) bool {
+ if c == nil {
+ return false
+ }
+ v, ok := c.Get(opsRoutingCapacityLimitedKey)
+ if !ok {
+ return false
+ }
+ marked, _ := v.(bool)
+ return marked
+}
+
+func isOpsNoAvailableAccountError(err error) bool {
+ if err == nil {
+ return false
+ }
+ if errors.Is(err, service.ErrNoAvailableAccounts) || errors.Is(err, service.ErrNoAvailableCompactAccounts) {
+ return true
+ }
+ return isOpsNoAvailableAccountMessage(err.Error())
+}
+
type opsCaptureWriter struct {
gin.ResponseWriter
limit int
@@ -671,7 +691,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
ErrorPhase: "upstream",
ErrorType: "upstream_error",
- // Severity/retryability should reflect the upstream failure, not the final client status (200).
+ // Severity should reflect the upstream failure, not the final client status (200).
Severity: classifyOpsSeverity("upstream_error", effectiveUpstreamStatus),
StatusCode: status,
IsBusinessLimited: false,
@@ -688,9 +708,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
UpstreamErrorDetail: upstreamErrorDetail,
UpstreamErrors: events,
- IsRetryable: classifyOpsIsRetryable("upstream_error", effectiveUpstreamStatus),
- RetryCount: 0,
- CreatedAt: time.Now(),
+ CreatedAt: time.Now(),
}
applyOpsLatencyFieldsFromContext(c, entry)
@@ -714,10 +732,6 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
entry.ClientIP = &clientIP
}
- // Store request headers/body only when an upstream error occurred to keep overhead minimal.
- entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
- attachOpsRequestBodyToEntry(c, entry)
-
// Skip logging if a passthrough rule with skip_monitoring=true matched.
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
if skip, _ := v.(bool); skip {
@@ -775,11 +789,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
normalizedType := normalizeOpsErrorType(parsed.ErrorType, parsed.Code)
- phase := classifyOpsPhase(normalizedType, parsed.Message, parsed.Code)
- isBusinessLimited := classifyOpsIsBusinessLimited(normalizedType, phase, parsed.Code, status, parsed.Message)
-
- errorOwner := classifyOpsErrorOwner(phase, parsed.Message)
- errorSource := classifyOpsErrorSource(phase, parsed.Message)
+ phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, normalizedType, parsed.Message, parsed.Code, status)
entry := &service.OpsInsertErrorLogInput{
RequestID: requestID,
@@ -834,9 +844,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
ErrorSource: errorSource,
ErrorOwner: errorOwner,
- IsRetryable: classifyOpsIsRetryable(normalizedType, status),
- RetryCount: 0,
- CreatedAt: time.Now(),
+ CreatedAt: time.Now(),
}
applyOpsLatencyFieldsFromContext(c, entry)
@@ -914,20 +922,10 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
entry.ClientIP = &clientIP
}
- // Persist only a minimal, whitelisted set of request headers to improve retry fidelity.
- // Do NOT store Authorization/Cookie/etc.
- entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
- attachOpsRequestBodyToEntry(c, entry)
-
enqueueOpsErrorLog(ops, entry)
}
}
-var opsRetryRequestHeaderAllowlist = []string{
- "anthropic-beta",
- "anthropic-version",
-}
-
// isCountTokensRequest checks if the request is a count_tokens request
func isCountTokensRequest(c *gin.Context) bool {
if c == nil || c.Request == nil || c.Request.URL == nil {
@@ -936,32 +934,6 @@ func isCountTokensRequest(c *gin.Context) bool {
return strings.Contains(c.Request.URL.Path, "/count_tokens")
}
-func extractOpsRetryRequestHeaders(c *gin.Context) *string {
- if c == nil || c.Request == nil {
- return nil
- }
-
- headers := make(map[string]string, 4)
- for _, key := range opsRetryRequestHeaderAllowlist {
- v := strings.TrimSpace(c.GetHeader(key))
- if v == "" {
- continue
- }
- // Keep headers small even if a client sends something unexpected.
- headers[key] = truncateString(v, 512)
- }
- if len(headers) == 0 {
- return nil
- }
-
- raw, err := json.Marshal(headers)
- if err != nil {
- return nil
- }
- s := string(raw)
- return &s
-}
-
func applyOpsLatencyFieldsFromContext(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
if c == nil || entry == nil {
return
@@ -1116,6 +1088,9 @@ func classifyOpsPhase(errType, message, code string) string {
msg := strings.ToLower(message)
// Standardized phases: request|auth|routing|upstream|network|internal
// Map billing/concurrency/response => request; scheduling => routing.
+ if isOpsClientAuthError(code, msg) {
+ return "auth"
+ }
switch strings.TrimSpace(code) {
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
return "request"
@@ -1136,7 +1111,7 @@ func classifyOpsPhase(errType, message, code string) string {
case "upstream_error", "overloaded_error":
return "upstream"
case "api_error":
- if strings.Contains(msg, opsErrNoAvailableAccounts) {
+ if isOpsNoAvailableAccountMessage(msg) {
return "routing"
}
return "internal"
@@ -1162,25 +1137,31 @@ func classifyOpsSeverity(errType string, status int) string {
return "P3"
}
-func classifyOpsIsRetryable(errType string, statusCode int) bool {
- switch errType {
- case "authentication_error", "invalid_request_error":
- return false
- case "timeout_error":
- return true
- case "rate_limit_error":
- // May be transient (upstream or queue); retry can help.
- return true
- case "billing_error", "subscription_error":
- return false
- case "upstream_error", "overloaded_error":
- return statusCode >= 500 || statusCode == 429 || statusCode == 529
- default:
- return statusCode >= 500
+func classifyOpsErrorLog(c *gin.Context, errType, message, code string, status int) (phase string, isBusinessLimited bool, errorOwner string, errorSource string) {
+ phase = classifyOpsPhase(errType, message, code)
+ routingCapacityLimited := isOpsRoutingCapacityLimited(c)
+ clientBusinessLimited := service.HasOpsClientBusinessLimited(c)
+ upstreamError := hasOpsUpstreamErrorContext(c)
+ if upstreamError && !routingCapacityLimited {
+ phase = "upstream"
}
+ if clientBusinessLimited && !upstreamError && !routingCapacityLimited {
+ phase = "auth"
+ }
+ if routingCapacityLimited {
+ phase = "routing"
+ }
+ localClientAuthError := !upstreamError && phase == "auth" && isOpsClientAuthError(code, strings.ToLower(message))
+ isBusinessLimited = routingCapacityLimited || clientBusinessLimited || classifyOpsIsBusinessLimited(errType, phase, code, status, message, localClientAuthError)
+ errorOwner = classifyOpsErrorOwner(phase, message)
+ errorSource = classifyOpsErrorSource(phase, message)
+ return phase, isBusinessLimited, errorOwner, errorSource
}
-func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
+func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string, localClientAuthError ...bool) bool {
+ if len(localClientAuthError) > 0 && localClientAuthError[0] {
+ return true
+ }
switch strings.TrimSpace(code) {
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
return true
@@ -1197,6 +1178,47 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa
return false
}
+func isOpsClientAuthError(code string, msg string) bool {
+ switch strings.TrimSpace(code) {
+ case opsCodeInvalidAPIKey, opsCodeAPIKeyRequired:
+ return true
+ }
+ return strings.Contains(msg, "invalid api key") || strings.Contains(msg, "api key is required")
+}
+
+func hasOpsUpstreamErrorContext(c *gin.Context) bool {
+ if c == nil {
+ return false
+ }
+ if v, ok := c.Get(service.OpsUpstreamStatusCodeKey); ok {
+ switch code := v.(type) {
+ case int:
+ if code > 0 {
+ return true
+ }
+ case int64:
+ if code > 0 {
+ return true
+ }
+ }
+ }
+ if v, ok := c.Get(service.OpsUpstreamErrorsKey); ok {
+ if events, ok := v.([]*service.OpsUpstreamErrorEvent); ok && len(events) > 0 {
+ return true
+ }
+ }
+ return false
+}
+
+func isOpsNoAvailableAccountMessage(message string) bool {
+ msg := strings.ToLower(message)
+ return strings.Contains(msg, opsErrNoAvailableAccounts) ||
+ strings.Contains(msg, "no available account") ||
+ strings.Contains(msg, "no available gemini accounts") ||
+ strings.Contains(msg, "no available openai accounts") ||
+ strings.Contains(msg, "no available compatible accounts")
+}
+
func classifyOpsErrorOwner(phase string, message string) string {
// Standardized owners: client|provider|platform
switch phase {
diff --git a/backend/internal/handler/ops_error_logger_test.go b/backend/internal/handler/ops_error_logger_test.go
index 6ae45110..99a9af2f 100644
--- a/backend/internal/handler/ops_error_logger_test.go
+++ b/backend/internal/handler/ops_error_logger_test.go
@@ -44,49 +44,6 @@ func resetOpsErrorLoggerStateForTest(t *testing.T) {
opsErrorLogDrained.Store(false)
}
-func TestAttachOpsRequestBodyToEntry_SanitizeAndTrim(t *testing.T) {
- resetOpsErrorLoggerStateForTest(t)
- gin.SetMode(gin.TestMode)
-
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
-
- raw := []byte(`{"access_token":"secret-token","messages":[{"role":"user","content":"hello"}]}`)
- setOpsRequestContext(c, "claude-3", false, raw)
-
- entry := &service.OpsInsertErrorLogInput{}
- attachOpsRequestBodyToEntry(c, entry)
-
- require.NotNil(t, entry.RequestBodyBytes)
- require.Equal(t, len(raw), *entry.RequestBodyBytes)
- require.NotNil(t, entry.RequestBodyJSON)
- require.NotContains(t, *entry.RequestBodyJSON, "secret-token")
- require.Contains(t, *entry.RequestBodyJSON, "[REDACTED]")
- require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
-}
-
-func TestAttachOpsRequestBodyToEntry_InvalidJSONKeepsSize(t *testing.T) {
- resetOpsErrorLoggerStateForTest(t)
- gin.SetMode(gin.TestMode)
-
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
-
- raw := []byte("not-json")
- setOpsRequestContext(c, "claude-3", false, raw)
-
- entry := &service.OpsInsertErrorLogInput{}
- attachOpsRequestBodyToEntry(c, entry)
-
- require.Nil(t, entry.RequestBodyJSON)
- require.NotNil(t, entry.RequestBodyBytes)
- require.Equal(t, len(raw), *entry.RequestBodyBytes)
- require.False(t, entry.RequestBodyTruncated)
- require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
-}
-
func TestEnqueueOpsErrorLog_QueueFullDrop(t *testing.T) {
resetOpsErrorLoggerStateForTest(t)
@@ -108,39 +65,6 @@ func TestEnqueueOpsErrorLog_QueueFullDrop(t *testing.T) {
require.Equal(t, int64(1), OpsErrorLogQueueLength())
}
-func TestAttachOpsRequestBodyToEntry_EarlyReturnBranches(t *testing.T) {
- resetOpsErrorLoggerStateForTest(t)
- gin.SetMode(gin.TestMode)
-
- entry := &service.OpsInsertErrorLogInput{}
- attachOpsRequestBodyToEntry(nil, entry)
- attachOpsRequestBodyToEntry(&gin.Context{}, nil)
-
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
-
- // 无请求体 key
- attachOpsRequestBodyToEntry(c, entry)
- require.Nil(t, entry.RequestBodyJSON)
- require.Nil(t, entry.RequestBodyBytes)
- require.False(t, entry.RequestBodyTruncated)
-
- // 错误类型
- c.Set(opsRequestBodyKey, "not-bytes")
- attachOpsRequestBodyToEntry(c, entry)
- require.Nil(t, entry.RequestBodyJSON)
- require.Nil(t, entry.RequestBodyBytes)
-
- // 空 bytes
- c.Set(opsRequestBodyKey, []byte{})
- attachOpsRequestBodyToEntry(c, entry)
- require.Nil(t, entry.RequestBodyJSON)
- require.Nil(t, entry.RequestBodyBytes)
-
- require.Equal(t, int64(0), OpsErrorLogSanitizedTotal())
-}
-
func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) {
resetOpsErrorLoggerStateForTest(t)
@@ -275,6 +199,218 @@ func TestNormalizeOpsErrorType(t *testing.T) {
}
}
+func TestClassifyOpsNoAvailableAccountsExcludedFromSLA(t *testing.T) {
+ const message = "No available accounts"
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ markOpsRoutingCapacityLimited(c)
+
+ errType := normalizeOpsErrorType("api_error", "")
+ phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, message, "", http.StatusServiceUnavailable)
+
+ require.Equal(t, "api_error", errType)
+ require.Equal(t, "routing", phase)
+ require.True(t, isBusinessLimited)
+ require.Equal(t, "platform", errorOwner)
+ require.Equal(t, "gateway", errorSource)
+}
+
+func TestClassifyOpsRoutingCapacityMarkerExcludesMaskedSelectionFailureFromSLA(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ markOpsRoutingCapacityLimited(c)
+
+ phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
+ c,
+ "api_error",
+ "Service temporarily unavailable",
+ "",
+ http.StatusServiceUnavailable,
+ )
+
+ require.Equal(t, "routing", phase)
+ require.True(t, isBusinessLimited)
+ require.Equal(t, "platform", errorOwner)
+ require.Equal(t, "gateway", errorSource)
+}
+
+func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) {
+ tests := []struct {
+ name string
+ errType string
+ message string
+ code string
+ status int
+ }{
+ {
+ name: "standard invalid API key",
+ errType: "api_error",
+ message: "Invalid API key",
+ code: "INVALID_API_KEY",
+ status: http.StatusUnauthorized,
+ },
+ {
+ name: "standard missing API key",
+ errType: "api_error",
+ message: "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header",
+ code: "API_KEY_REQUIRED",
+ status: http.StatusUnauthorized,
+ },
+ {
+ name: "google invalid API key",
+ errType: "api_error",
+ message: "Invalid API key",
+ code: "401",
+ status: http.StatusUnauthorized,
+ },
+ {
+ name: "google missing API key",
+ errType: "api_error",
+ message: "API key is required",
+ code: "401",
+ status: http.StatusUnauthorized,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ errType := normalizeOpsErrorType(tt.errType, tt.code)
+ phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, tt.message, tt.code, tt.status)
+
+ require.Equal(t, "api_error", errType)
+ require.Equal(t, "auth", phase)
+ require.True(t, isBusinessLimited)
+ require.Equal(t, "client", errorOwner)
+ require.Equal(t, "client_request", errorSource)
+ })
+ }
+}
+
+func TestClassifyOpsIPRestrictionAccessDeniedExcludedFromSLA(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonIPRestriction)
+
+ errType := normalizeOpsErrorType("api_error", "ACCESS_DENIED")
+ phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, "Access denied", "ACCESS_DENIED", http.StatusForbidden)
+
+ require.Equal(t, "api_error", errType)
+ require.Equal(t, "auth", phase)
+ require.True(t, isBusinessLimited)
+ require.Equal(t, "client", errorOwner)
+ require.Equal(t, "client_request", errorSource)
+}
+
+func TestClassifyOpsOtherErrorsStillCountForSLA(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ errType := normalizeOpsErrorType("api_error", "INTERNAL_ERROR")
+ phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, "Failed to validate API key", "INTERNAL_ERROR", http.StatusInternalServerError)
+
+ require.Equal(t, "api_error", errType)
+ require.Equal(t, "internal", phase)
+ require.False(t, isBusinessLimited)
+ require.Equal(t, "platform", errorOwner)
+ require.Equal(t, "gateway", errorSource)
+}
+
+func TestClassifyOpsUnsupportedModelExcludedFromSLA(t *testing.T) {
+ tests := []string{
+ "No available accounts: no available accounts supporting model: made-up-model",
+ "No available accounts: no available OpenAI accounts supporting model: made-up-model",
+ "No available Gemini accounts: no available Gemini accounts supporting model: made-up-model",
+ "No available accounts: no available accounts supporting model: made-up-model (channel pricing restriction)",
+ }
+
+ for _, message := range tests {
+ t.Run(message, func(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ markOpsRoutingCapacityLimited(c)
+
+ errType := normalizeOpsErrorType("api_error", "")
+ phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, message, "", http.StatusServiceUnavailable)
+
+ require.Equal(t, "api_error", errType)
+ require.Equal(t, "routing", phase)
+ require.True(t, isBusinessLimited)
+ require.Equal(t, "platform", errorOwner)
+ require.Equal(t, "gateway", errorSource)
+ })
+ }
+}
+
+func TestClassifyOpsUnmarkedNoAvailableTextStillCountsForSLA(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
+ c,
+ "api_error",
+ "No available accounts",
+ "",
+ http.StatusServiceUnavailable,
+ )
+
+ require.Equal(t, "routing", phase)
+ require.False(t, isBusinessLimited)
+ require.Equal(t, "platform", errorOwner)
+ require.Equal(t, "gateway", errorSource)
+}
+
+func TestClassifyOpsUpstreamAuthTextStillCountsForSLA(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ service.SetOpsUpstreamError(c, http.StatusUnauthorized, "Invalid API key", "")
+
+ phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
+ c,
+ "api_error",
+ "Invalid API key",
+ "401",
+ http.StatusUnauthorized,
+ )
+
+ require.Equal(t, "upstream", phase)
+ require.False(t, isBusinessLimited)
+ require.Equal(t, "provider", errorOwner)
+ require.Equal(t, "upstream_http", errorSource)
+}
+
+func TestClassifyOpsUpstreamNoAvailableTextStillCountsForSLA(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ service.SetOpsUpstreamError(c, http.StatusServiceUnavailable, "No available accounts", "")
+
+ phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
+ c,
+ "api_error",
+ "No available accounts",
+ "",
+ http.StatusServiceUnavailable,
+ )
+
+ require.Equal(t, "upstream", phase)
+ require.False(t, isBusinessLimited)
+ require.Equal(t, "provider", errorOwner)
+ require.Equal(t, "upstream_http", errorSource)
+}
+
func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
diff --git a/backend/internal/handler/page_handler_test.go b/backend/internal/handler/page_handler_test.go
index 0a9f0d96..a6813cdf 100644
--- a/backend/internal/handler/page_handler_test.go
+++ b/backend/internal/handler/page_handler_test.go
@@ -58,7 +58,7 @@ func TestResolvePageImagePath(t *testing.T) {
if !ok {
t.Fatal("expected direct image path to be accepted")
}
- want := filepath.Join(base, "logo.png")
+ want := mustEvalSymlinks(t, filepath.Join(base, "logo.png"))
if got != want {
t.Fatalf("path = %q, want %q", got, want)
}
@@ -67,7 +67,7 @@ func TestResolvePageImagePath(t *testing.T) {
if !ok {
t.Fatal("expected nested image path to be accepted")
}
- want = filepath.Join(base, "images", "logo.png")
+ want = mustEvalSymlinks(t, filepath.Join(base, "images", "logo.png"))
if got != want {
t.Fatalf("path = %q, want %q", got, want)
}
@@ -100,3 +100,13 @@ func TestResolvePageImagePathRejectsSymlinkEscape(t *testing.T) {
t.Fatalf("expected symlink escape to be rejected, got %q", got)
}
}
+
+func mustEvalSymlinks(t *testing.T, path string) string {
+ t.Helper()
+
+ realPath, err := filepath.EvalSymlinks(path)
+ if err != nil {
+ t.Fatalf("eval symlinks for %q: %v", path, err)
+ }
+ return realPath
+}
diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go
index f293c2f2..1bb81190 100644
--- a/backend/internal/handler/payment_handler.go
+++ b/backend/internal/handler/payment_handler.go
@@ -141,6 +141,7 @@ func (h *PaymentHandler) GetCheckoutInfo(c *gin.Context) {
HelpText: cfg.HelpText,
HelpImageURL: cfg.HelpImageURL,
StripePublishableKey: cfg.StripePublishableKey,
+ AlipayForceQRCode: cfg.AlipayForceQRCode,
})
}
@@ -155,6 +156,7 @@ type checkoutInfoResponse struct {
HelpText string `json:"help_text"`
HelpImageURL string `json:"help_image_url"`
StripePublishableKey string `json:"stripe_publishable_key"`
+ AlipayForceQRCode bool `json:"alipay_force_qrcode"`
}
type checkoutPlan struct {
diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go
index 6c389e3d..c4ba43e4 100644
--- a/backend/internal/handler/setting_handler.go
+++ b/backend/internal/handler/setting_handler.go
@@ -61,6 +61,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
TablePageSizeOptions: settings.TablePageSizeOptions,
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
+ DingTalkOAuthEnabled: settings.DingTalkOAuthEnabled,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled,
diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go
index 3f6ed8c2..f1dbf4e1 100644
--- a/backend/internal/handler/user_handler.go
+++ b/backend/internal/handler/user_handler.go
@@ -67,6 +67,7 @@ type userProfileResponse struct {
LinuxDoBound bool `json:"linuxdo_bound"`
OIDCBound bool `json:"oidc_bound"`
WeChatBound bool `json:"wechat_bound"`
+ DingTalkBound bool `json:"dingtalk_bound"`
}
type userProfileSourceContext struct {
@@ -528,15 +529,17 @@ func userProfileResponseFromService(user *service.User, identities service.UserI
LinuxDoBound: identities.LinuxDo.Bound,
OIDCBound: identities.OIDC.Bound,
WeChatBound: identities.WeChat.Bound,
+ DingTalkBound: identities.DingTalk.Bound,
}
}
func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string]service.UserIdentitySummary {
return map[string]service.UserIdentitySummary{
- "email": identities.Email,
- "linuxdo": identities.LinuxDo,
- "oidc": identities.OIDC,
- "wechat": identities.WeChat,
+ "email": identities.Email,
+ "linuxdo": identities.LinuxDo,
+ "oidc": identities.OIDC,
+ "wechat": identities.WeChat,
+ "dingtalk": identities.DingTalk,
}
}
@@ -585,7 +588,7 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity
func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary {
out := make([]service.UserIdentitySummary, 0, 3)
- for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} {
+ for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat, identities.DingTalk} {
if summary.Bound {
out = append(out, summary)
}
diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go
index 1234b568..c4c6e634 100644
--- a/backend/internal/payment/provider/alipay.go
+++ b/backend/internal/payment/provider/alipay.go
@@ -105,10 +105,16 @@ func (a *Alipay) MerchantIdentityMetadata() map[string]string {
// CreatePayment creates an Alipay payment using the following routing:
// - Mobile (H5): alipay.trade.wap.pay — browser redirect into Alipay.
-// - Desktop: prefer alipay.trade.precreate to get a scan payload directly.
-// - Desktop fallback: if precreate is unavailable for the merchant, fall back
-// to alipay.trade.page.pay and expose both pay_url and qr_code so the
-// frontend can render a QR while still allowing direct page open.
+// - Desktop, default: prefer alipay.trade.precreate (FACE_TO_FACE_PAYMENT) to
+// get a scannable QR payload. If precreate is unavailable for the merchant,
+// fall back to alipay.trade.page.pay and expose pay_url only — the frontend
+// opens the Alipay checkout in a new tab.
+// - Desktop, paymentMode == "redirect": skip precreate and go straight to
+// alipay.trade.page.pay so the frontend always opens the Alipay checkout
+// in a new tab. Use this when the merchant has not enabled FACE_TO_FACE_PAYMENT.
+//
+// Note: alipay.trade.page.pay returns a checkout page URL, not a scannable
+// payment QR. Never expose it via the QRCode field.
func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
client, err := a.getClient()
if err != nil {
@@ -150,6 +156,13 @@ func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePayment
}
func (a *Alipay) createDesktopTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
+ // Explicit redirect mode: merchant opted into "always open the Alipay
+ // checkout page in a new tab" via the provider instance's payment_mode.
+ // Skip precreate to avoid a wasted API call.
+ if strings.EqualFold(strings.TrimSpace(a.config["paymentMode"]), "redirect") {
+ return a.createPagePayTrade(client, req, notifyURL, returnURL)
+ }
+
resp, precreateErr := a.createPrecreateTrade(ctx, client, req, notifyURL)
if precreateErr == nil {
return resp, nil
@@ -204,10 +217,12 @@ func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePay
if err != nil {
return nil, fmt.Errorf("alipay TradePagePay: %w", err)
}
+ // Only PayURL is exposed: alipay.trade.page.pay returns a checkout page URL
+ // that must be opened in a browser, not a scannable payment QR. Setting it
+ // as QRCode would let the frontend render an unscannable image.
return &payment.CreatePaymentResponse{
TradeNo: req.OrderID,
PayURL: payURL.String(),
- QRCode: payURL.String(),
}, nil
}
diff --git a/backend/internal/payment/provider/alipay_test.go b/backend/internal/payment/provider/alipay_test.go
index fdc8eec1..9f8aec53 100644
--- a/backend/internal/payment/provider/alipay_test.go
+++ b/backend/internal/payment/provider/alipay_test.go
@@ -189,8 +189,63 @@ func TestCreateTradeUsesPagePayForDesktop(t *testing.T) {
if resp.PayURL == "" {
t.Fatal("expected pay_url for desktop page pay")
}
- if resp.QRCode != resp.PayURL {
- t.Fatalf("qr_code = %q, want same as pay_url %q", resp.QRCode, resp.PayURL)
+ // page.pay returns a checkout page URL, not a scannable QR payload —
+ // it must never be exposed via QRCode (the frontend would render an
+ // unscannable image from it).
+ if resp.QRCode != "" {
+ t.Fatalf("qr_code = %q, want empty for page pay", resp.QRCode)
+ }
+}
+
+// When the provider instance is configured with paymentMode == "redirect",
+// the desktop flow must skip precreate and go straight to page.pay.
+func TestCreateTradeRedirectModeSkipsPrecreate(t *testing.T) {
+ origPreCreate := alipayTradePreCreate
+ origPagePay := alipayTradePagePay
+ t.Cleanup(func() {
+ alipayTradePreCreate = origPreCreate
+ alipayTradePagePay = origPagePay
+ })
+
+ preCreateCalls := 0
+ pagePayCalls := 0
+ alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
+ preCreateCalls++
+ return &alipay.TradePreCreateRsp{
+ Error: alipay.Error{Code: alipay.CodeSuccess},
+ QRCode: "https://qr.alipay.example.com/precreate-token",
+ }, nil
+ }
+ alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
+ pagePayCalls++
+ if param.ProductCode != alipayProductCodePagePay {
+ t.Fatalf("product_code = %q, want %q", param.ProductCode, alipayProductCodePagePay)
+ }
+ return url.Parse("https://openapi.alipay.com/gateway.do?page-pay")
+ }
+
+ provider := &Alipay{
+ config: map[string]string{"paymentMode": "redirect"},
+ }
+ resp, err := provider.createDesktopTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{
+ OrderID: "sub2_103",
+ Amount: "12.00",
+ Subject: "Balance recharge",
+ }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if preCreateCalls != 0 {
+ t.Fatalf("precreate calls = %d, want 0 (redirect mode must skip precreate)", preCreateCalls)
+ }
+ if pagePayCalls != 1 {
+ t.Fatalf("page pay calls = %d, want 1", pagePayCalls)
+ }
+ if resp.PayURL == "" {
+ t.Fatal("expected pay_url for redirect mode")
+ }
+ if resp.QRCode != "" {
+ t.Fatalf("qr_code = %q, want empty for redirect mode", resp.QRCode)
}
}
diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go
index 16aff9f8..e318d1cd 100644
--- a/backend/internal/pkg/antigravity/client.go
+++ b/backend/internal/pkg/antigravity/client.go
@@ -254,6 +254,8 @@ const (
proxyTLSHandshakeTimeout = 5 * time.Second
// clientTimeout 整体请求超时(含连接、发送、等待响应、读取 body)
clientTimeout = 10 * time.Second
+ // fetchAvailableModelsBodyLimit limits model-list responses to avoid unbounded memory use.
+ fetchAvailableModelsBodyLimit int64 = 8 << 20
)
func NewClient(proxyURL string) (*Client, error) {
@@ -655,6 +657,10 @@ type FetchAvailableModelsResponse struct {
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
// 支持 URL fallback:sandbox → daily → prod
func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) {
+ if c == nil || c.httpClient == nil {
+ return nil, nil, errors.New("antigravity client is not configured")
+ }
+
reqBody := FetchAvailableModelsRequest{Project: projectID}
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
@@ -664,6 +670,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
// 固定顺序:prod -> daily
availableURLs := BaseURLs
+ fetchClient := c.fetchAvailableModelsHTTPClient()
var lastErr error
for urlIdx, baseURL := range availableURLs {
apiURL := baseURL + "/v1internal:fetchAvailableModels"
@@ -676,7 +683,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", GetUserAgentForContext(ctx))
- resp, err := c.httpClient.Do(req)
+ resp, err := fetchClient.Do(req)
if err != nil {
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
@@ -686,11 +693,14 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
return nil, nil, lastErr
}
- respBodyBytes, err := io.ReadAll(resp.Body)
+ respBodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, fetchAvailableModelsBodyLimit+1))
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
if err != nil {
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
}
+ if int64(len(respBodyBytes)) > fetchAvailableModelsBodyLimit {
+ return nil, nil, fmt.Errorf("响应超过 %d 字节", fetchAvailableModelsBodyLimit)
+ }
// 检查是否需要 URL 降级
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
@@ -726,6 +736,42 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
return nil, nil, lastErr
}
+func (c *Client) fetchAvailableModelsHTTPClient() *http.Client {
+ fetchClient := *c.httpClient
+ fetchClient.CheckRedirect = checkFetchAvailableModelsRedirect
+ return &fetchClient
+}
+
+func checkFetchAvailableModelsRedirect(req *http.Request, via []*http.Request) error {
+ if len(via) >= 10 {
+ return errors.New("stopped after 10 redirects")
+ }
+ if req == nil || req.URL == nil {
+ return errors.New("redirect url is nil")
+ }
+ if !isAllowedFetchAvailableModelsRedirectHost(req.URL.Hostname()) {
+ return fmt.Errorf("redirect to unsupported host: %s", req.URL.Hostname())
+ }
+ return nil
+}
+
+func isAllowedFetchAvailableModelsRedirectHost(host string) bool {
+ host = strings.ToLower(strings.TrimSpace(host))
+ if host == "" {
+ return false
+ }
+ for _, baseURL := range BaseURLs {
+ parsed, err := url.Parse(baseURL)
+ if err != nil {
+ continue
+ }
+ if strings.EqualFold(host, parsed.Hostname()) {
+ return true
+ }
+ }
+ return false
+}
+
// ── Privacy API ──────────────────────────────────────────────────────
// privacyBaseURL 隐私设置 API 仅使用 daily 端点(与 Antigravity 客户端行为一致)
diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go
index aa36ef0b..7490654d 100644
--- a/backend/internal/pkg/apicompat/anthropic_responses_test.go
+++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go
@@ -744,6 +744,10 @@ func TestStreamingReasoning(t *testing.T) {
assert.Equal(t, "content_block_start", events[0].Type)
assert.Equal(t, "thinking", events[0].ContentBlock.Type)
+ sse, err := ResponsesAnthropicEventToSSE(events[0])
+ require.NoError(t, err)
+ assert.Contains(t, sse, `"content_block":{"thinking":"","type":"thinking"}`)
+
// reasoning text delta
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.reasoning_summary_text.delta",
@@ -1520,3 +1524,49 @@ func TestAnthropicToResponses_ToolWithNilSchema(t *testing.T) {
assert.JSONEq(t, `"object"`, string(params["type"]))
assert.JSONEq(t, `{}`, string(params["properties"]))
}
+
+// ---------------------------------------------------------------------------
+// isReasoningModel / temperature-stripping tests
+// ---------------------------------------------------------------------------
+
+func TestAnthropicToResponses_TemperatureStrippedForReasoningModel(t *testing.T) {
+ temp := 0.7
+ req := &AnthropicRequest{
+ Model: "gpt-5.2",
+ MaxTokens: 1024,
+ Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
+ Temperature: &temp,
+ TopP: &temp,
+ }
+
+ resp, err := AnthropicToResponses(req)
+ require.NoError(t, err)
+ assert.Nil(t, resp.Temperature, "reasoning model: temperature must be stripped")
+ assert.Nil(t, resp.TopP, "reasoning model: top_p must be stripped")
+
+ // Verify the fields are absent from the serialised JSON.
+ b, err := json.Marshal(resp)
+ require.NoError(t, err)
+ assert.NotContains(t, string(b), `"temperature"`)
+ assert.NotContains(t, string(b), `"top_p"`)
+}
+
+func TestAnthropicToResponses_TemperatureStrippedForAllGpt5Variants(t *testing.T) {
+ temp := 1.0
+ models := []string{"gpt-5.2", "gpt-5.4", "gpt-5.4-mini", "gpt-5.3-codex", "gpt-5.5"}
+ for _, model := range models {
+ t.Run(model, func(t *testing.T) {
+ req := &AnthropicRequest{
+ Model: model,
+ MaxTokens: 1024,
+ Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
+ Temperature: &temp,
+ TopP: &temp,
+ }
+ resp, err := AnthropicToResponses(req)
+ require.NoError(t, err)
+ assert.Nil(t, resp.Temperature, "model %s: temperature must be stripped", model)
+ assert.Nil(t, resp.TopP, "model %s: top_p must be stripped", model)
+ })
+ }
+}
diff --git a/backend/internal/pkg/apicompat/anthropic_to_responses.go b/backend/internal/pkg/apicompat/anthropic_to_responses.go
index 5f04004d..e2011bee 100644
--- a/backend/internal/pkg/apicompat/anthropic_to_responses.go
+++ b/backend/internal/pkg/apicompat/anthropic_to_responses.go
@@ -22,12 +22,19 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
}
out := &ResponsesRequest{
- Model: req.Model,
- Input: inputJSON,
- Temperature: req.Temperature,
- TopP: req.TopP,
- Stream: req.Stream,
- Include: []string{"reasoning.encrypted_content"},
+ Model: req.Model,
+ Input: inputJSON,
+ Stream: req.Stream,
+ Include: []string{"reasoning.encrypted_content"},
+ }
+
+ // Reasoning models (gpt-5.x) served via the Responses API do not accept
+ // sampling parameters. Sending temperature or top_p causes a 400
+ // "Unsupported parameter" error, so we only forward them for non-reasoning
+ // models.
+ if !isReasoningModel(req.Model) {
+ out.Temperature = req.Temperature
+ out.TopP = req.TopP
}
storeFalse := false
@@ -437,6 +444,14 @@ func boolPtr(v bool) *bool {
return &v
}
+// isReasoningModel reports whether model is a reasoning model that does not
+// support sampling parameters (temperature, top_p) via the Responses API.
+// All gpt-5.x models are reasoning-only; the Responses API returns
+// "Unsupported parameter: temperature" if these fields are present.
+func isReasoningModel(model string) bool {
+ return strings.HasPrefix(model, "gpt-5")
+}
+
// normalizeToolParameters ensures the tool parameter schema is valid for
// OpenAI's Responses API, which requires "properties" on object schemas.
//
diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go
index bf5c23d5..ad26f273 100644
--- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go
+++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go
@@ -225,6 +225,41 @@ func TestChatCompletionsToResponses_WhitespaceOnlyBase64ImageURLSkipped(t *testi
assert.Equal(t, "Describe this", parts[0].Text)
}
+func TestChatCompletionsToResponses_EmptyContentNeverNull(t *testing.T) {
+ // Regression for #2515: the upstream Responses API rejects an input item
+ // whose content field is JSON null. Any chat-completions message that
+ // yields no usable content parts must serialize content as a string.
+ cases := []struct {
+ name string
+ content json.RawMessage
+ }{
+ {"null content", json.RawMessage(`null`)},
+ {"empty array content", json.RawMessage(`[]`)},
+ {"only empty text part", json.RawMessage(`[{"type":"text","text":""}]`)},
+ {"only empty base64 image part", json.RawMessage(`[{"type":"image_url","image_url":{"url":"data:image/png;base64,"}}]`)},
+ }
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := &ChatCompletionsRequest{
+ Model: "gpt-5.5",
+ Messages: []ChatMessage{
+ {Role: "user", Content: tc.content},
+ },
+ }
+ resp, err := ChatCompletionsToResponses(req)
+ require.NoError(t, err)
+ assert.NotContains(t, string(resp.Input), `"content":null`,
+ "converted input must not contain a null content field")
+
+ var items []ResponsesInputItem
+ require.NoError(t, json.Unmarshal(resp.Input, &items))
+ require.Len(t, items, 1)
+ assert.Equal(t, `""`, string(items[0].Content),
+ "content must be an empty string, not null")
+ })
+ }
+}
+
func TestChatCompletionsToResponses_SystemArrayContent(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
@@ -296,6 +331,48 @@ func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
assert.Equal(t, "flex", resp.ServiceTier)
}
+// ---------------------------------------------------------------------------
+// temperature / top_p stripping for reasoning models
+// ---------------------------------------------------------------------------
+
+func TestChatCompletionsToResponses_TemperatureStrippedForReasoningModel(t *testing.T) {
+ temp := 0.7
+ req := &ChatCompletionsRequest{
+ Model: "gpt-5.2",
+ Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
+ Temperature: &temp,
+ TopP: &temp,
+ }
+
+ resp, err := ChatCompletionsToResponses(req)
+ require.NoError(t, err)
+ assert.Nil(t, resp.Temperature, "reasoning model: temperature must be stripped")
+ assert.Nil(t, resp.TopP, "reasoning model: top_p must be stripped")
+
+ // Must not appear in the serialised request body sent to the upstream.
+ b, err := json.Marshal(resp)
+ require.NoError(t, err)
+ assert.NotContains(t, string(b), `"temperature"`)
+ assert.NotContains(t, string(b), `"top_p"`)
+}
+
+func TestChatCompletionsToResponses_TemperaturePreservedForNonReasoningModel(t *testing.T) {
+ temp := 0.7
+ req := &ChatCompletionsRequest{
+ Model: "gpt-4o",
+ Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
+ Temperature: &temp,
+ TopP: &temp,
+ }
+
+ resp, err := ChatCompletionsToResponses(req)
+ require.NoError(t, err)
+ require.NotNil(t, resp.Temperature, "non-reasoning model: temperature must be preserved")
+ assert.InDelta(t, 0.7, *resp.Temperature, 1e-9)
+ require.NotNil(t, resp.TopP, "non-reasoning model: top_p must be preserved")
+ assert.InDelta(t, 0.7, *resp.TopP, 1e-9)
+}
+
func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
@@ -379,6 +456,34 @@ func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T)
assert.Contains(t, parts[0].Text, "final answer")
}
+func TestChatCompletionsToResponses_AssistantReasoningContentPreserved(t *testing.T) {
+ req := &ChatCompletionsRequest{
+ Model: "gpt-4o",
+ Messages: []ChatMessage{
+ {Role: "user", Content: json.RawMessage(`"Hi"`)},
+ {
+ Role: "assistant",
+ ReasoningContent: "internal plan",
+ Content: json.RawMessage(`"final answer"`),
+ },
+ },
+ }
+
+ resp, err := ChatCompletionsToResponses(req)
+ require.NoError(t, err)
+
+ var items []ResponsesInputItem
+ require.NoError(t, json.Unmarshal(resp.Input, &items))
+ require.Len(t, items, 2)
+
+ var parts []ResponsesContentPart
+ require.NoError(t, json.Unmarshal(items[1].Content, &parts))
+ require.Len(t, parts, 1)
+ assert.Equal(t, "output_text", parts[0].Type)
+ assert.Contains(t, parts[0].Text, "internal plan ")
+ assert.Contains(t, parts[0].Text, "final answer")
+}
+
// ---------------------------------------------------------------------------
// ResponsesToChatCompletions tests
// ---------------------------------------------------------------------------
diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go
index 64ef5781..463bdd0d 100644
--- a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go
+++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go
@@ -30,13 +30,18 @@ func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest,
Model: req.Model,
Instructions: req.Instructions,
Input: inputJSON,
- Temperature: req.Temperature,
- TopP: req.TopP,
Stream: true, // upstream always streams
Include: []string{"reasoning.encrypted_content"},
ServiceTier: req.ServiceTier,
}
+ // Reasoning models (gpt-5.x) do not accept sampling parameters.
+ // See isReasoningModel in anthropic_to_responses.go.
+ if !isReasoningModel(req.Model) {
+ out.Temperature = req.Temperature
+ out.TopP = req.TopP
+ }
+
storeFalse := false
out.Store = &storeFalse
@@ -150,6 +155,11 @@ func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
// empty/nil and there are tool_calls, only function_call items are emitted.
func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
var items []ResponsesInputItem
+ content := ""
+
+ if m.ReasoningContent != "" {
+ content = "" + m.ReasoningContent + " "
+ }
// Emit assistant message with output_text if content is non-empty.
if len(m.Content) > 0 {
@@ -158,15 +168,22 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
return nil, err
}
if s != "" {
- parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
- partsJSON, err := json.Marshal(parts)
- if err != nil {
- return nil, err
+ if content != "" {
+ content += "\n"
}
- items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
+ content += s
}
}
+ if content != "" {
+ parts := []ResponsesContentPart{{Type: "output_text", Text: content}}
+ partsJSON, err := json.Marshal(parts)
+ if err != nil {
+ return nil, err
+ }
+ items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
+ }
+
// Emit one function_call item per tool_call.
for _, tc := range m.ToolCalls {
args := tc.Function.Arguments
@@ -325,7 +342,14 @@ func marshalChatInputContent(content chatMessageContent) (json.RawMessage, error
if content.Text != nil {
return json.Marshal(*content.Text)
}
- return json.Marshal(convertChatContentPartsToResponses(content.Parts))
+ parts := convertChatContentPartsToResponses(content.Parts)
+ if len(parts) == 0 {
+ // A nil slice marshals to JSON null, which the upstream Responses API
+ // rejects ("expected an array of objects or string, but got null").
+ // Fall back to an empty string when no usable parts remain.
+ return json.Marshal("")
+ }
+ return json.Marshal(parts)
}
func convertChatContentPartsToResponses(parts []ChatContentPart) []ResponsesContentPart {
diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go
index f9cd5a1c..7c46ccaf 100644
--- a/backend/internal/pkg/apicompat/types.go
+++ b/backend/internal/pkg/apicompat/types.go
@@ -75,6 +75,28 @@ type AnthropicContentBlock struct {
IsError bool `json:"is_error,omitempty"`
}
+func (b AnthropicContentBlock) MarshalJSON() ([]byte, error) {
+ type anthropicContentBlock AnthropicContentBlock
+ base := struct {
+ anthropicContentBlock
+ }{anthropicContentBlock: anthropicContentBlock(b)}
+
+ switch b.Type {
+ case "text":
+ return json.Marshal(struct {
+ Text string `json:"text"`
+ anthropicContentBlock
+ }{Text: b.Text, anthropicContentBlock: anthropicContentBlock(b)})
+ case "thinking":
+ return json.Marshal(struct {
+ Thinking string `json:"thinking"`
+ anthropicContentBlock
+ }{Thinking: b.Thinking, anthropicContentBlock: anthropicContentBlock(b)})
+ default:
+ return json.Marshal(base)
+ }
+}
+
// AnthropicImageSource describes the source data for an image content block.
type AnthropicImageSource struct {
Type string `json:"type"` // "base64"
@@ -306,6 +328,37 @@ type ResponsesUsage struct {
OutputTokensDetails *ResponsesOutputTokensDetails `json:"output_tokens_details,omitempty"`
}
+func (u *ResponsesUsage) UnmarshalJSON(data []byte) error {
+ type responsesUsageAlias ResponsesUsage
+ var aux struct {
+ responsesUsageAlias
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ PromptTokensDetails *ResponsesInputTokensDetails `json:"prompt_tokens_details,omitempty"`
+ CompletionTokensDetails *ResponsesOutputTokensDetails `json:"completion_tokens_details,omitempty"`
+ }
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+ *u = ResponsesUsage(aux.responsesUsageAlias)
+ if u.InputTokens == 0 && aux.PromptTokens != 0 {
+ u.InputTokens = aux.PromptTokens
+ }
+ if u.OutputTokens == 0 && aux.CompletionTokens != 0 {
+ u.OutputTokens = aux.CompletionTokens
+ }
+ if u.InputTokensDetails == nil && aux.PromptTokensDetails != nil {
+ u.InputTokensDetails = aux.PromptTokensDetails
+ }
+ if u.OutputTokensDetails == nil && aux.CompletionTokensDetails != nil {
+ u.OutputTokensDetails = aux.CompletionTokensDetails
+ }
+ if u.TotalTokens == 0 && (u.InputTokens != 0 || u.OutputTokens != 0) {
+ u.TotalTokens = u.InputTokens + u.OutputTokens
+ }
+ return nil
+}
+
// ResponsesInputTokensDetails breaks down input token usage.
type ResponsesInputTokensDetails struct {
CachedTokens int `json:"cached_tokens,omitempty"`
diff --git a/backend/internal/pkg/gemini/models.go b/backend/internal/pkg/gemini/models.go
index fac79d18..d5d4ed64 100644
--- a/backend/internal/pkg/gemini/models.go
+++ b/backend/internal/pkg/gemini/models.go
@@ -22,6 +22,7 @@ func DefaultModels() []Model {
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-flash-image", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
+ {Name: "models/gemini-3.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
diff --git a/backend/internal/pkg/geminicli/models.go b/backend/internal/pkg/geminicli/models.go
index 195fb06f..bbd9a6c4 100644
--- a/backend/internal/pkg/geminicli/models.go
+++ b/backend/internal/pkg/geminicli/models.go
@@ -15,6 +15,7 @@ var DefaultModels = []Model{
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
{ID: "gemini-2.5-flash-image", Type: "model", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: ""},
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
+ {ID: "gemini-3.5-flash", Type: "model", DisplayName: "Gemini 3.5 Flash", CreatedAt: ""},
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
diff --git a/backend/internal/pkg/openai_compat/upstream_capability.go b/backend/internal/pkg/openai_compat/upstream_capability.go
index ff05afe5..154a01fb 100644
--- a/backend/internal/pkg/openai_compat/upstream_capability.go
+++ b/backend/internal/pkg/openai_compat/upstream_capability.go
@@ -17,7 +17,7 @@
// pensieve/short-term/maxims/preserve-existing-runtime-behavior-when-replacing-logic-in-stateful-systems)
package openai_compat
-// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的支持状态。
+// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的有效支持状态。
//
// 仅用于 platform=openai + type=apikey 的账号;其他账号类型不应调用本包判定。
type AccountResponsesSupport int
@@ -35,11 +35,43 @@ const (
ResponsesSupportNo
)
-// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储探测结果的键名。
+// ResponsesSupportMode 描述账号级 Responses API 路由覆盖模式。
+type ResponsesSupportMode string
+
+const (
+ // ResponsesSupportModeAuto 表示跟随自动探测结果。
+ ResponsesSupportModeAuto ResponsesSupportMode = "auto"
+
+ // ResponsesSupportModeForceResponses 强制使用 /v1/responses。
+ ResponsesSupportModeForceResponses ResponsesSupportMode = "force_responses"
+
+ // ResponsesSupportModeForceChatCompletions 强制使用 /v1/chat/completions。
+ ResponsesSupportModeForceChatCompletions ResponsesSupportMode = "force_chat_completions"
+)
+
+// ExtraKeyResponsesMode 是 accounts.extra JSON 中存储手动覆盖模式的键名。
+// 值类型为 string:auto=跟随探测,force_responses=强制 Responses,
+// force_chat_completions=强制 Chat Completions。
+const ExtraKeyResponsesMode = "openai_responses_mode"
+
+// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储自动探测结果的键名。
// 值类型为 bool:true=支持、false=不支持、键缺失=未探测。
const ExtraKeyResponsesSupported = "openai_responses_supported"
-// ResolveResponsesSupport 从账号的 extra map 中读取探测标记。
+// NormalizeResponsesSupportMode 归一化账号级 Responses API 路由覆盖模式。
+// 缺失或非法值按 auto 处理,以保持存量行为。
+func NormalizeResponsesSupportMode(mode string) ResponsesSupportMode {
+ switch ResponsesSupportMode(mode) {
+ case ResponsesSupportModeForceResponses:
+ return ResponsesSupportModeForceResponses
+ case ResponsesSupportModeForceChatCompletions:
+ return ResponsesSupportModeForceChatCompletions
+ default:
+ return ResponsesSupportModeAuto
+ }
+}
+
+// ResolveResponsesSupport 从账号的 extra map 中读取手动覆盖模式与探测标记。
//
// 标记缺失或类型不匹配时返回 ResponsesSupportUnknown——调用方应按
// "未探测=保留旧行为=走 Responses" 处理(参见 ShouldUseResponsesAPI)。
@@ -47,6 +79,14 @@ func ResolveResponsesSupport(extra map[string]any) AccountResponsesSupport {
if extra == nil {
return ResponsesSupportUnknown
}
+ if mode, ok := extra[ExtraKeyResponsesMode].(string); ok {
+ switch NormalizeResponsesSupportMode(mode) {
+ case ResponsesSupportModeForceResponses:
+ return ResponsesSupportYes
+ case ResponsesSupportModeForceChatCompletions:
+ return ResponsesSupportNo
+ }
+ }
v, ok := extra[ExtraKeyResponsesSupported]
if !ok {
return ResponsesSupportUnknown
diff --git a/backend/internal/pkg/openai_compat/upstream_capability_test.go b/backend/internal/pkg/openai_compat/upstream_capability_test.go
index d650daa4..008579a7 100644
--- a/backend/internal/pkg/openai_compat/upstream_capability_test.go
+++ b/backend/internal/pkg/openai_compat/upstream_capability_test.go
@@ -16,6 +16,12 @@ func TestResolveResponsesSupport(t *testing.T) {
{"value wrong type string", map[string]any{ExtraKeyResponsesSupported: "true"}, ResponsesSupportUnknown},
{"value wrong type number", map[string]any{ExtraKeyResponsesSupported: 1}, ResponsesSupportUnknown},
{"value nil", map[string]any{ExtraKeyResponsesSupported: nil}, ResponsesSupportUnknown},
+ {"force responses", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses)}, ResponsesSupportYes},
+ {"force chat completions", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions)}, ResponsesSupportNo},
+ {"auto follows probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeAuto), ExtraKeyResponsesSupported: false}, ResponsesSupportNo},
+ {"invalid mode follows probe", map[string]any{ExtraKeyResponsesMode: "bogus", ExtraKeyResponsesSupported: true}, ResponsesSupportYes},
+ {"force responses overrides probe false", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses), ExtraKeyResponsesSupported: false}, ResponsesSupportYes},
+ {"force chat completions overrides probe true", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions), ExtraKeyResponsesSupported: true}, ResponsesSupportNo},
}
for _, tc := range tests {
@@ -42,6 +48,10 @@ func TestShouldUseResponsesAPI(t *testing.T) {
// 已探测:标记决定
{"explicitly supported", map[string]any{ExtraKeyResponsesSupported: true}, true},
{"explicitly unsupported", map[string]any{ExtraKeyResponsesSupported: false}, false},
+
+ // 手动覆盖:覆盖自动探测结果
+ {"force responses overrides unsupported probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses), ExtraKeyResponsesSupported: false}, true},
+ {"force chat completions overrides supported probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions), ExtraKeyResponsesSupported: true}, false},
}
for _, tc := range tests {
@@ -53,3 +63,26 @@ func TestShouldUseResponsesAPI(t *testing.T) {
})
}
}
+
+func TestNormalizeResponsesSupportMode(t *testing.T) {
+ tests := []struct {
+ name string
+ mode string
+ want ResponsesSupportMode
+ }{
+ {"empty", "", ResponsesSupportModeAuto},
+ {"auto", "auto", ResponsesSupportModeAuto},
+ {"force responses", "force_responses", ResponsesSupportModeForceResponses},
+ {"force chat completions", "force_chat_completions", ResponsesSupportModeForceChatCompletions},
+ {"invalid", "enabled", ResponsesSupportModeAuto},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ got := NormalizeResponsesSupportMode(tc.mode)
+ if got != tc.want {
+ t.Errorf("NormalizeResponsesSupportMode(%q) = %q, want %q", tc.mode, got, tc.want)
+ }
+ })
+ }
+}
diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go
index fe5f98d6..39283d22 100644
--- a/backend/internal/pkg/usagestats/usage_log_types.go
+++ b/backend/internal/pkg/usagestats/usage_log_types.go
@@ -230,6 +230,20 @@ type UserDashboardStats struct {
// 性能指标
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
+
+ // 按"有效平台"维度拆分(与 ops 路径口径一致:group.platform 优先,否则 account.platform)
+ ByPlatform []PlatformDashboardStats `json:"by_platform,omitempty"`
+}
+
+// PlatformDashboardStats 单个平台的用量明细。
+type PlatformDashboardStats struct {
+ Platform string `json:"platform"`
+ TotalRequests int64 `json:"total_requests"`
+ TotalTokens int64 `json:"total_tokens"`
+ TotalActualCost float64 `json:"total_actual_cost"`
+ TodayRequests int64 `json:"today_requests"`
+ TodayTokens int64 `json:"today_tokens"`
+ TodayActualCost float64 `json:"today_actual_cost"`
}
// UsageLogFilters represents filters for usage log queries
@@ -265,13 +279,22 @@ type UsageStats struct {
EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"`
}
-// BatchUserUsageStats represents usage stats for a single user
-type BatchUserUsageStats struct {
- UserID int64 `json:"user_id"`
+// PlatformUsage 表示某用户/某 API key 在单个"有效平台"维度的用量明细。
+// Platform 取值与 ops 路径口径一致:优先 groups.platform,否则 accounts.platform。
+type PlatformUsage struct {
+ Platform string `json:"platform"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
+// BatchUserUsageStats represents usage stats for a single user
+type BatchUserUsageStats struct {
+ UserID int64 `json:"user_id"`
+ TodayActualCost float64 `json:"today_actual_cost"`
+ TotalActualCost float64 `json:"total_actual_cost"`
+ ByPlatform []PlatformUsage `json:"by_platform,omitempty"`
+}
+
// BatchAPIKeyUsageStats represents usage stats for a single API key
type BatchAPIKeyUsageStats struct {
APIKeyID int64 `json:"api_key_id"`
diff --git a/backend/internal/repository/account_repo_compact_extra_test.go b/backend/internal/repository/account_repo_compact_extra_test.go
index 604f392e..e2ce6602 100644
--- a/backend/internal/repository/account_repo_compact_extra_test.go
+++ b/backend/internal/repository/account_repo_compact_extra_test.go
@@ -12,3 +12,14 @@ func TestShouldEnqueueSchedulerOutboxForExtraUpdates_CompactCapabilityKeysAreRel
t.Fatalf("expected compact capability updates to enqueue scheduler outbox")
}
}
+
+func TestShouldEnqueueSchedulerOutboxForExtraUpdates_OpenAIResponsesCapabilityKeysAreRelevant(t *testing.T) {
+ updates := map[string]any{
+ "openai_responses_mode": "force_chat_completions",
+ "openai_responses_supported": false,
+ }
+
+ if !shouldEnqueueSchedulerOutboxForExtraUpdates(updates) {
+ t.Fatalf("expected responses capability updates to enqueue scheduler outbox")
+ }
+}
diff --git a/backend/internal/repository/announcement_repo.go b/backend/internal/repository/announcement_repo.go
index afe1fb25..f19c24f1 100644
--- a/backend/internal/repository/announcement_repo.go
+++ b/backend/internal/repository/announcement_repo.go
@@ -204,7 +204,8 @@ func (r *announcementRepository) ListActive(ctx context.Context, now time.Time)
announcement.Or(announcement.StartsAtIsNil(), announcement.StartsAtLTE(now)),
announcement.Or(announcement.EndsAtIsNil(), announcement.EndsAtGT(now)),
).
- Order(dbent.Desc(announcement.FieldID))
+ Order(dbent.Desc(announcement.FieldID)).
+ Limit(200)
items, err := q.All(ctx)
if err != nil {
diff --git a/backend/internal/repository/channel_monitor_repo.go b/backend/internal/repository/channel_monitor_repo.go
index 800ee43b..6666a130 100644
--- a/backend/internal/repository/channel_monitor_repo.go
+++ b/backend/internal/repository/channel_monitor_repo.go
@@ -37,6 +37,7 @@ func (r *channelMonitorRepository) Create(ctx context.Context, m *service.Channe
builder := client.ChannelMonitor.Create().
SetName(m.Name).
SetProvider(channelmonitor.Provider(m.Provider)).
+ SetAPIMode(defaultAPIModeRepo(m.APIMode)).
SetEndpoint(m.Endpoint).
SetAPIKeyEncrypted(m.APIKey). // 调用方传入的已是密文
SetPrimaryModel(m.PrimaryModel).
@@ -79,6 +80,7 @@ func (r *channelMonitorRepository) Update(ctx context.Context, m *service.Channe
updater := client.ChannelMonitor.UpdateOneID(m.ID).
SetName(m.Name).
SetProvider(channelmonitor.Provider(m.Provider)).
+ SetAPIMode(defaultAPIModeRepo(m.APIMode)).
SetEndpoint(m.Endpoint).
SetAPIKeyEncrypted(m.APIKey).
SetPrimaryModel(m.PrimaryModel).
@@ -708,6 +710,7 @@ func entToServiceMonitor(row *dbent.ChannelMonitor) *service.ChannelMonitor {
ID: row.ID,
Name: row.Name,
Provider: string(row.Provider),
+ APIMode: defaultAPIModeRepo(row.APIMode),
Endpoint: row.Endpoint,
APIKey: row.APIKeyEncrypted, // 仍为密文,service 层负责解密
PrimaryModel: row.PrimaryModel,
@@ -747,6 +750,13 @@ func defaultBodyModeRepo(mode string) string {
return mode
}
+func defaultAPIModeRepo(apiMode string) string {
+ if apiMode == "" {
+ return "chat_completions"
+ }
+ return apiMode
+}
+
func emptySliceIfNil(in []string) []string {
if in == nil {
return []string{}
diff --git a/backend/internal/repository/channel_monitor_template_repo.go b/backend/internal/repository/channel_monitor_template_repo.go
index 845d186b..3a972360 100644
--- a/backend/internal/repository/channel_monitor_template_repo.go
+++ b/backend/internal/repository/channel_monitor_template_repo.go
@@ -30,6 +30,7 @@ func (r *channelMonitorRequestTemplateRepository) Create(ctx context.Context, t
builder := client.ChannelMonitorRequestTemplate.Create().
SetName(t.Name).
SetProvider(channelmonitorrequesttemplate.Provider(t.Provider)).
+ SetAPIMode(defaultAPIModeRepo(t.APIMode)).
SetDescription(t.Description).
SetExtraHeaders(emptyHeadersIfNilRepo(t.ExtraHeaders)).
SetBodyOverrideMode(defaultBodyModeRepo(t.BodyOverrideMode))
@@ -61,6 +62,7 @@ func (r *channelMonitorRequestTemplateRepository) Update(ctx context.Context, t
client := clientFromContext(ctx, r.client)
updater := client.ChannelMonitorRequestTemplate.UpdateOneID(t.ID).
SetName(t.Name).
+ SetAPIMode(defaultAPIModeRepo(t.APIMode)).
SetDescription(t.Description).
SetExtraHeaders(emptyHeadersIfNilRepo(t.ExtraHeaders)).
SetBodyOverrideMode(defaultBodyModeRepo(t.BodyOverrideMode))
@@ -90,8 +92,11 @@ func (r *channelMonitorRequestTemplateRepository) List(ctx context.Context, para
if params.Provider != "" {
q = q.Where(channelmonitorrequesttemplate.ProviderEQ(channelmonitorrequesttemplate.Provider(params.Provider)))
}
+ if params.APIMode != "" {
+ q = q.Where(channelmonitorrequesttemplate.APIModeEQ(defaultAPIModeRepo(params.APIMode)))
+ }
rows, err := q.
- Order(dbent.Asc(channelmonitorrequesttemplate.FieldProvider), dbent.Asc(channelmonitorrequesttemplate.FieldName)).
+ Order(dbent.Asc(channelmonitorrequesttemplate.FieldProvider), dbent.Asc(channelmonitorrequesttemplate.FieldAPIMode), dbent.Asc(channelmonitorrequesttemplate.FieldName)).
All(ctx)
if err != nil {
return nil, fmt.Errorf("list monitor templates: %w", err)
@@ -122,7 +127,10 @@ func (r *channelMonitorRequestTemplateRepository) ApplyToMonitors(ctx context.Co
Where(
channelmonitor.TemplateIDEQ(id),
channelmonitor.IDIn(monitorIDs...),
+ channelmonitor.ProviderEQ(channelmonitor.Provider(tpl.Provider)),
+ channelmonitor.APIModeEQ(defaultAPIModeRepo(tpl.APIMode)),
).
+ SetAPIMode(defaultAPIModeRepo(tpl.APIMode)).
SetExtraHeaders(emptyHeadersIfNilRepo(tpl.ExtraHeaders)).
SetBodyOverrideMode(defaultBodyModeRepo(tpl.BodyOverrideMode))
if tpl.BodyOverride != nil {
@@ -165,6 +173,7 @@ func (r *channelMonitorRequestTemplateRepository) ListAssociatedMonitors(ctx con
ID: row.ID,
Name: row.Name,
Provider: string(row.Provider),
+ APIMode: defaultAPIModeRepo(row.APIMode),
Enabled: row.Enabled,
})
}
@@ -185,6 +194,7 @@ func entToServiceTemplate(row *dbent.ChannelMonitorRequestTemplate) *service.Cha
ID: row.ID,
Name: row.Name,
Provider: string(row.Provider),
+ APIMode: defaultAPIModeRepo(row.APIMode),
Description: row.Description,
ExtraHeaders: headers,
BodyOverrideMode: row.BodyOverrideMode,
diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go
index 112575f4..9b6377bc 100644
--- a/backend/internal/repository/group_repo.go
+++ b/backend/internal/repository/group_repo.go
@@ -283,47 +283,90 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
}
func (r *groupRepository) listWithAccountCountSort(ctx context.Context, q *dbent.GroupQuery, params pagination.PaginationParams, total int) ([]service.Group, *pagination.PaginationResult, error) {
- groups, err := q.
+ // 第一步:只查 ID + sort_order(轻量,不做分页 — 需要全量排序 account_count)。
+ rows, err := q.Clone().
+ Select(group.FieldID, group.FieldSortOrder).
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
- groupIDs := make([]int64, 0, len(groups))
- outGroups := make([]service.Group, 0, len(groups))
- for i := range groups {
- g := groupEntityToService(groups[i])
- outGroups = append(outGroups, *g)
- groupIDs = append(groupIDs, g.ID)
+ type sortEntry struct {
+ id int64
+ sortOrder int
+ accountCount int64
+ }
+ entries := make([]sortEntry, 0, len(rows))
+ groupIDs := make([]int64, len(rows))
+ for i, r := range rows {
+ groupIDs[i] = r.ID
+ entries = append(entries, sortEntry{id: r.ID, sortOrder: r.SortOrder})
}
+ // 第二步:批量加载 account counts(一次 SQL)。
counts, err := r.loadAccountCounts(ctx, groupIDs)
if err != nil {
return nil, nil, err
}
- for i := range outGroups {
- c := counts[outGroups[i].ID]
- outGroups[i].AccountCount = c.Total
- outGroups[i].ActiveAccountCount = c.Active
- outGroups[i].RateLimitedAccountCount = c.RateLimited
+ for i := range entries {
+ c := counts[entries[i].id]
+ if c.Total > 0 {
+ entries[i].accountCount = c.Total
+ }
}
+ // 第三步:Go 侧排序(数据量 = Group 总数,通常 < 200,安全)。
sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc)
- sort.SliceStable(outGroups, func(i, j int) bool {
- if outGroups[i].AccountCount == outGroups[j].AccountCount {
- if outGroups[i].SortOrder == outGroups[j].SortOrder {
- return outGroups[i].ID < outGroups[j].ID
- }
- return outGroups[i].SortOrder < outGroups[j].SortOrder
+ tieCmp := func(a, b sortEntry) bool {
+ if a.sortOrder == b.sortOrder {
+ return a.id < b.id
+ }
+ return a.sortOrder < b.sortOrder
+ }
+ sort.SliceStable(entries, func(i, j int) bool {
+ if entries[i].accountCount == entries[j].accountCount {
+ return tieCmp(entries[i], entries[j])
}
if sortOrder == pagination.SortOrderAsc {
- return outGroups[i].AccountCount < outGroups[j].AccountCount
+ return entries[i].accountCount < entries[j].accountCount
}
- return outGroups[i].AccountCount > outGroups[j].AccountCount
+ return entries[i].accountCount > entries[j].accountCount
})
- return paginateSlice(outGroups, params), paginationResultFromTotal(int64(total), params), nil
+ // 第四步:分页,只加载当前页需要的完整 Group。
+ page := paginateSlice(entries, params)
+ if len(page) == 0 {
+ return nil, paginationResultFromTotal(int64(total), params), nil
+ }
+
+ pageIDs := make([]int64, len(page))
+ pageIdx := make(map[int64]int, len(page))
+ for i, e := range page {
+ pageIDs[i] = e.id
+ pageIdx[e.id] = i
+ }
+
+ groups, err := r.client.Group.Query().
+ Where(group.IDIn(pageIDs...)).
+ All(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ outGroups := make([]service.Group, len(page))
+ for i := range groups {
+ g := groupEntityToService(groups[i])
+ c := counts[g.ID]
+ g.AccountCount = c.Total
+ g.ActiveAccountCount = c.Active
+ g.RateLimitedAccountCount = c.RateLimited
+ if idx, ok := pageIdx[g.ID]; ok {
+ outGroups[idx] = *g
+ }
+ }
+
+ return outGroups, paginationResultFromTotal(int64(total), params), nil
}
func groupListOrder(params pagination.PaginationParams) []func(*entsql.Selector) {
diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go
index eeee5c23..7ef82f0c 100644
--- a/backend/internal/repository/migrations_schema_integration_test.go
+++ b/backend/internal/repository/migrations_schema_integration_test.go
@@ -44,6 +44,33 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
+ requireColumn(t, tx, "usage_logs", "image_input_size", "character varying", 32, true)
+ requireColumn(t, tx, "usage_logs", "image_output_size", "character varying", 32, true)
+ requireColumn(t, tx, "usage_logs", "image_size_source", "character varying", 16, true)
+ requireColumn(t, tx, "usage_logs", "image_size_breakdown", "jsonb", 0, true)
+ requireConstraintDefinitionContains(
+ t,
+ tx,
+ "usage_logs",
+ "usage_logs_image_size_source_check",
+ "image_size_source",
+ "'output'",
+ "'input'",
+ "'default'",
+ "'legacy'",
+ )
+ requireConstraintDefinitionContains(
+ t,
+ tx,
+ "usage_logs",
+ "usage_logs_image_billing_size_check",
+ "image_count",
+ "image_size IS NOT NULL",
+ "'1K'",
+ "'2K'",
+ "'4K'",
+ "'mixed'",
+ )
// usage_billing_dedup: billing idempotency narrow table
var usageBillingDedupRegclass sql.NullString
diff --git a/backend/internal/repository/ops_repo.go b/backend/internal/repository/ops_repo.go
index 5154b269..4371b8a2 100644
--- a/backend/internal/repository/ops_repo.go
+++ b/backend/internal/repository/ops_repo.go
@@ -54,15 +54,9 @@ INSERT INTO ops_error_logs (
upstream_latency_ms,
response_latency_ms,
time_to_first_token_ms,
- request_body,
- request_body_truncated,
- request_body_bytes,
- request_headers,
- is_retryable,
- retry_count,
created_at
) VALUES (
- $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43
+ $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37
)`
func NewOpsRepository(db *sql.DB) service.OpsRepository {
@@ -170,12 +164,6 @@ func opsInsertErrorLogArgs(input *service.OpsInsertErrorLogInput) []any {
opsNullInt64(input.UpstreamLatencyMs),
opsNullInt64(input.ResponseLatencyMs),
opsNullInt64(input.TimeToFirstTokenMs),
- opsNullString(input.RequestBodyJSON),
- input.RequestBodyTruncated,
- opsNullInt(input.RequestBodyBytes),
- opsNullString(input.RequestHeadersJSON),
- input.IsRetryable,
- input.RetryCount,
input.CreatedAt,
}
}
@@ -222,13 +210,10 @@ SELECT
COALESCE(e.upstream_status_code, e.status_code, 0),
COALESCE(e.platform, ''),
COALESCE(e.model, ''),
- COALESCE(e.is_retryable, false),
- COALESCE(e.retry_count, 0),
COALESCE(e.resolved, false),
e.resolved_at,
e.resolved_by_user_id,
COALESCE(u2.email, ''),
- e.resolved_retry_id,
COALESCE(e.client_request_id, ''),
COALESCE(e.request_id, ''),
COALESCE(e.error_message, ''),
@@ -277,7 +262,6 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
var resolvedAt sql.NullTime
var resolvedBy sql.NullInt64
var resolvedByName string
- var resolvedRetryID sql.NullInt64
var requestType sql.NullInt64
if err := rows.Scan(
&item.ID,
@@ -290,13 +274,10 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
&statusCode,
&item.Platform,
&item.Model,
- &item.IsRetryable,
- &item.RetryCount,
&item.Resolved,
&resolvedAt,
&resolvedBy,
&resolvedByName,
- &resolvedRetryID,
&item.ClientRequestID,
&item.RequestID,
&item.Message,
@@ -327,10 +308,6 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
item.ResolvedByUserID = &v
}
item.ResolvedByUserName = resolvedByName
- if resolvedRetryID.Valid {
- v := resolvedRetryID.Int64
- item.ResolvedRetryID = &v
- }
item.StatusCode = int(statusCode.Int64)
if clientIP.Valid {
s := clientIP.String
@@ -393,12 +370,9 @@ SELECT
COALESCE(e.upstream_status_code, e.status_code, 0),
COALESCE(e.platform, ''),
COALESCE(e.model, ''),
- COALESCE(e.is_retryable, false),
- COALESCE(e.retry_count, 0),
COALESCE(e.resolved, false),
e.resolved_at,
e.resolved_by_user_id,
- e.resolved_retry_id,
COALESCE(e.client_request_id, ''),
COALESCE(e.request_id, ''),
COALESCE(e.error_message, ''),
@@ -428,11 +402,7 @@ SELECT
e.routing_latency_ms,
e.upstream_latency_ms,
e.response_latency_ms,
- e.time_to_first_token_ms,
- COALESCE(e.request_body::text, ''),
- e.request_body_truncated,
- e.request_body_bytes,
- COALESCE(e.request_headers::text, '')
+ e.time_to_first_token_ms
FROM ops_error_logs e
LEFT JOIN users u ON e.user_id = u.id
LEFT JOIN accounts a ON e.account_id = a.id
@@ -445,7 +415,6 @@ LIMIT 1`
var upstreamStatusCode sql.NullInt64
var resolvedAt sql.NullTime
var resolvedBy sql.NullInt64
- var resolvedRetryID sql.NullInt64
var clientIP sql.NullString
var userID sql.NullInt64
var apiKeyID sql.NullInt64
@@ -456,7 +425,6 @@ LIMIT 1`
var upstreamLatency sql.NullInt64
var responseLatency sql.NullInt64
var ttft sql.NullInt64
- var requestBodyBytes sql.NullInt64
var requestType sql.NullInt64
err := r.db.QueryRowContext(ctx, q, id).Scan(
@@ -470,12 +438,9 @@ LIMIT 1`
&statusCode,
&out.Platform,
&out.Model,
- &out.IsRetryable,
- &out.RetryCount,
&out.Resolved,
&resolvedAt,
&resolvedBy,
- &resolvedRetryID,
&out.ClientRequestID,
&out.RequestID,
&out.Message,
@@ -506,10 +471,6 @@ LIMIT 1`
&upstreamLatency,
&responseLatency,
&ttft,
- &out.RequestBody,
- &out.RequestBodyTruncated,
- &requestBodyBytes,
- &out.RequestHeaders,
)
if err != nil {
return nil, err
@@ -524,10 +485,6 @@ LIMIT 1`
v := resolvedBy.Int64
out.ResolvedByUserID = &v
}
- if resolvedRetryID.Valid {
- v := resolvedRetryID.Int64
- out.ResolvedRetryID = &v
- }
if clientIP.Valid {
s := clientIP.String
out.ClientIP = &s
@@ -572,25 +529,11 @@ LIMIT 1`
v := ttft.Int64
out.TimeToFirstTokenMs = &v
}
- if requestBodyBytes.Valid {
- v := int(requestBodyBytes.Int64)
- out.RequestBodyBytes = &v
- }
if requestType.Valid {
v := int16(requestType.Int64)
out.RequestType = &v
}
- // Normalize request_body to empty string when stored as JSON null.
- out.RequestBody = strings.TrimSpace(out.RequestBody)
- if out.RequestBody == "null" {
- out.RequestBody = ""
- }
- // Normalize request_headers to empty string when stored as JSON null.
- out.RequestHeaders = strings.TrimSpace(out.RequestHeaders)
- if out.RequestHeaders == "null" {
- out.RequestHeaders = ""
- }
// Normalize upstream_errors to empty string when stored as JSON null.
out.UpstreamErrors = strings.TrimSpace(out.UpstreamErrors)
if out.UpstreamErrors == "null" {
@@ -600,398 +543,7 @@ LIMIT 1`
return &out, nil
}
-func (r *opsRepository) InsertRetryAttempt(ctx context.Context, input *service.OpsInsertRetryAttemptInput) (int64, error) {
- if r == nil || r.db == nil {
- return 0, fmt.Errorf("nil ops repository")
- }
- if input == nil {
- return 0, fmt.Errorf("nil input")
- }
- if input.SourceErrorID <= 0 {
- return 0, fmt.Errorf("invalid source_error_id")
- }
- if strings.TrimSpace(input.Mode) == "" {
- return 0, fmt.Errorf("invalid mode")
- }
-
- q := `
-INSERT INTO ops_retry_attempts (
- requested_by_user_id,
- source_error_id,
- mode,
- pinned_account_id,
- status,
- started_at
-) VALUES (
- $1,$2,$3,$4,$5,$6
-) RETURNING id`
-
- var id int64
- err := r.db.QueryRowContext(
- ctx,
- q,
- opsNullInt64(&input.RequestedByUserID),
- input.SourceErrorID,
- strings.TrimSpace(input.Mode),
- opsNullInt64(input.PinnedAccountID),
- strings.TrimSpace(input.Status),
- input.StartedAt,
- ).Scan(&id)
- if err != nil {
- return 0, err
- }
- return id, nil
-}
-
-func (r *opsRepository) UpdateRetryAttempt(ctx context.Context, input *service.OpsUpdateRetryAttemptInput) error {
- if r == nil || r.db == nil {
- return fmt.Errorf("nil ops repository")
- }
- if input == nil {
- return fmt.Errorf("nil input")
- }
- if input.ID <= 0 {
- return fmt.Errorf("invalid id")
- }
-
- q := `
-UPDATE ops_retry_attempts
-SET
- status = $2,
- finished_at = $3,
- duration_ms = $4,
- success = $5,
- http_status_code = $6,
- upstream_request_id = $7,
- used_account_id = $8,
- response_preview = $9,
- response_truncated = $10,
- result_request_id = $11,
- result_error_id = $12,
- error_message = $13
-WHERE id = $1`
-
- _, err := r.db.ExecContext(
- ctx,
- q,
- input.ID,
- strings.TrimSpace(input.Status),
- nullTime(input.FinishedAt),
- input.DurationMs,
- nullBool(input.Success),
- nullInt(input.HTTPStatusCode),
- opsNullString(input.UpstreamRequestID),
- nullInt64(input.UsedAccountID),
- opsNullString(input.ResponsePreview),
- nullBool(input.ResponseTruncated),
- opsNullString(input.ResultRequestID),
- nullInt64(input.ResultErrorID),
- opsNullString(input.ErrorMessage),
- )
- return err
-}
-
-func (r *opsRepository) GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*service.OpsRetryAttempt, error) {
- if r == nil || r.db == nil {
- return nil, fmt.Errorf("nil ops repository")
- }
- if sourceErrorID <= 0 {
- return nil, fmt.Errorf("invalid source_error_id")
- }
-
- q := `
-SELECT
- id,
- created_at,
- COALESCE(requested_by_user_id, 0),
- source_error_id,
- COALESCE(mode, ''),
- pinned_account_id,
- COALESCE(status, ''),
- started_at,
- finished_at,
- duration_ms,
- success,
- http_status_code,
- upstream_request_id,
- used_account_id,
- response_preview,
- response_truncated,
- result_request_id,
- result_error_id,
- error_message
-FROM ops_retry_attempts
-WHERE source_error_id = $1
-ORDER BY created_at DESC
-LIMIT 1`
-
- var out service.OpsRetryAttempt
- var pinnedAccountID sql.NullInt64
- var requestedBy sql.NullInt64
- var startedAt sql.NullTime
- var finishedAt sql.NullTime
- var durationMs sql.NullInt64
- var success sql.NullBool
- var httpStatusCode sql.NullInt64
- var upstreamRequestID sql.NullString
- var usedAccountID sql.NullInt64
- var responsePreview sql.NullString
- var responseTruncated sql.NullBool
- var resultRequestID sql.NullString
- var resultErrorID sql.NullInt64
- var errorMessage sql.NullString
-
- err := r.db.QueryRowContext(ctx, q, sourceErrorID).Scan(
- &out.ID,
- &out.CreatedAt,
- &requestedBy,
- &out.SourceErrorID,
- &out.Mode,
- &pinnedAccountID,
- &out.Status,
- &startedAt,
- &finishedAt,
- &durationMs,
- &success,
- &httpStatusCode,
- &upstreamRequestID,
- &usedAccountID,
- &responsePreview,
- &responseTruncated,
- &resultRequestID,
- &resultErrorID,
- &errorMessage,
- )
- if err != nil {
- return nil, err
- }
- out.RequestedByUserID = requestedBy.Int64
- if pinnedAccountID.Valid {
- v := pinnedAccountID.Int64
- out.PinnedAccountID = &v
- }
- if startedAt.Valid {
- t := startedAt.Time
- out.StartedAt = &t
- }
- if finishedAt.Valid {
- t := finishedAt.Time
- out.FinishedAt = &t
- }
- if durationMs.Valid {
- v := durationMs.Int64
- out.DurationMs = &v
- }
- if success.Valid {
- v := success.Bool
- out.Success = &v
- }
- if httpStatusCode.Valid {
- v := int(httpStatusCode.Int64)
- out.HTTPStatusCode = &v
- }
- if upstreamRequestID.Valid {
- s := upstreamRequestID.String
- out.UpstreamRequestID = &s
- }
- if usedAccountID.Valid {
- v := usedAccountID.Int64
- out.UsedAccountID = &v
- }
- if responsePreview.Valid {
- s := responsePreview.String
- out.ResponsePreview = &s
- }
- if responseTruncated.Valid {
- v := responseTruncated.Bool
- out.ResponseTruncated = &v
- }
- if resultRequestID.Valid {
- s := resultRequestID.String
- out.ResultRequestID = &s
- }
- if resultErrorID.Valid {
- v := resultErrorID.Int64
- out.ResultErrorID = &v
- }
- if errorMessage.Valid {
- s := errorMessage.String
- out.ErrorMessage = &s
- }
-
- return &out, nil
-}
-
-func nullTime(t time.Time) sql.NullTime {
- if t.IsZero() {
- return sql.NullTime{}
- }
- return sql.NullTime{Time: t, Valid: true}
-}
-
-func nullBool(v *bool) sql.NullBool {
- if v == nil {
- return sql.NullBool{}
- }
- return sql.NullBool{Bool: *v, Valid: true}
-}
-
-func (r *opsRepository) ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*service.OpsRetryAttempt, error) {
- if r == nil || r.db == nil {
- return nil, fmt.Errorf("nil ops repository")
- }
- if sourceErrorID <= 0 {
- return nil, fmt.Errorf("invalid source_error_id")
- }
- if limit <= 0 {
- limit = 50
- }
- if limit > 200 {
- limit = 200
- }
-
- q := `
-SELECT
- r.id,
- r.created_at,
- COALESCE(r.requested_by_user_id, 0),
- r.source_error_id,
- COALESCE(r.mode, ''),
- r.pinned_account_id,
- COALESCE(pa.name, ''),
- COALESCE(r.status, ''),
- r.started_at,
- r.finished_at,
- r.duration_ms,
- r.success,
- r.http_status_code,
- r.upstream_request_id,
- r.used_account_id,
- COALESCE(ua.name, ''),
- r.response_preview,
- r.response_truncated,
- r.result_request_id,
- r.result_error_id,
- r.error_message
-FROM ops_retry_attempts r
-LEFT JOIN accounts pa ON r.pinned_account_id = pa.id
-LEFT JOIN accounts ua ON r.used_account_id = ua.id
-WHERE r.source_error_id = $1
-ORDER BY r.created_at DESC
-LIMIT $2`
-
- rows, err := r.db.QueryContext(ctx, q, sourceErrorID, limit)
- if err != nil {
- return nil, err
- }
- defer func() { _ = rows.Close() }()
-
- out := make([]*service.OpsRetryAttempt, 0, 16)
- for rows.Next() {
- var item service.OpsRetryAttempt
- var pinnedAccountID sql.NullInt64
- var pinnedAccountName string
- var requestedBy sql.NullInt64
- var startedAt sql.NullTime
- var finishedAt sql.NullTime
- var durationMs sql.NullInt64
- var success sql.NullBool
- var httpStatusCode sql.NullInt64
- var upstreamRequestID sql.NullString
- var usedAccountID sql.NullInt64
- var usedAccountName string
- var responsePreview sql.NullString
- var responseTruncated sql.NullBool
- var resultRequestID sql.NullString
- var resultErrorID sql.NullInt64
- var errorMessage sql.NullString
-
- if err := rows.Scan(
- &item.ID,
- &item.CreatedAt,
- &requestedBy,
- &item.SourceErrorID,
- &item.Mode,
- &pinnedAccountID,
- &pinnedAccountName,
- &item.Status,
- &startedAt,
- &finishedAt,
- &durationMs,
- &success,
- &httpStatusCode,
- &upstreamRequestID,
- &usedAccountID,
- &usedAccountName,
- &responsePreview,
- &responseTruncated,
- &resultRequestID,
- &resultErrorID,
- &errorMessage,
- ); err != nil {
- return nil, err
- }
-
- item.RequestedByUserID = requestedBy.Int64
- if pinnedAccountID.Valid {
- v := pinnedAccountID.Int64
- item.PinnedAccountID = &v
- }
- item.PinnedAccountName = pinnedAccountName
- if startedAt.Valid {
- t := startedAt.Time
- item.StartedAt = &t
- }
- if finishedAt.Valid {
- t := finishedAt.Time
- item.FinishedAt = &t
- }
- if durationMs.Valid {
- v := durationMs.Int64
- item.DurationMs = &v
- }
- if success.Valid {
- v := success.Bool
- item.Success = &v
- }
- if httpStatusCode.Valid {
- v := int(httpStatusCode.Int64)
- item.HTTPStatusCode = &v
- }
- if upstreamRequestID.Valid {
- item.UpstreamRequestID = &upstreamRequestID.String
- }
- if usedAccountID.Valid {
- v := usedAccountID.Int64
- item.UsedAccountID = &v
- }
- item.UsedAccountName = usedAccountName
- if responsePreview.Valid {
- item.ResponsePreview = &responsePreview.String
- }
- if responseTruncated.Valid {
- v := responseTruncated.Bool
- item.ResponseTruncated = &v
- }
- if resultRequestID.Valid {
- item.ResultRequestID = &resultRequestID.String
- }
- if resultErrorID.Valid {
- v := resultErrorID.Int64
- item.ResultErrorID = &v
- }
- if errorMessage.Valid {
- item.ErrorMessage = &errorMessage.String
- }
- out = append(out, &item)
- }
- if err := rows.Err(); err != nil {
- return nil, err
- }
- return out, nil
-}
-
-func (r *opsRepository) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error {
+func (r *opsRepository) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedAt *time.Time) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
@@ -1004,8 +556,7 @@ UPDATE ops_error_logs
SET
resolved = $2,
resolved_at = $3,
- resolved_by_user_id = $4,
- resolved_retry_id = $5
+ resolved_by_user_id = $4
WHERE id = $1`
at := sql.NullTime{}
@@ -1023,7 +574,6 @@ WHERE id = $1`
resolved,
at,
nullInt64(resolvedByUserID),
- nullInt64(resolvedRetryID),
)
return err
}
diff --git a/backend/internal/repository/ops_repo_replay_cleanup_test.go b/backend/internal/repository/ops_repo_replay_cleanup_test.go
new file mode 100644
index 00000000..a6a15e9a
--- /dev/null
+++ b/backend/internal/repository/ops_repo_replay_cleanup_test.go
@@ -0,0 +1,44 @@
+package repository
+
+import (
+ "reflect"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+func TestOpsErrorLogInsertDoesNotPersistRequestReplayFields(t *testing.T) {
+ disallowedColumns := []string{
+ "request_body",
+ "request_headers",
+ "request_body_truncated",
+ "request_body_bytes",
+ "is_retryable",
+ "retry_count",
+ "resolved_retry_id",
+ }
+
+ insertSQL := strings.ToLower(insertOpsErrorLogSQL)
+ for _, column := range disallowedColumns {
+ if strings.Contains(insertSQL, column) {
+ t.Fatalf("ops error log insert still references dropped replay column %q", column)
+ }
+ }
+
+ inputType := reflect.TypeOf(service.OpsInsertErrorLogInput{})
+ disallowedFields := []string{
+ "RequestBodyJSON",
+ "RequestBodyTruncated",
+ "RequestBodyBytes",
+ "RequestHeadersJSON",
+ "IsRetryable",
+ "RetryCount",
+ "ResolvedRetryID",
+ }
+ for _, field := range disallowedFields {
+ if _, ok := inputType.FieldByName(field); ok {
+ t.Fatalf("OpsInsertErrorLogInput still carries replay field %q", field)
+ }
+ }
+}
diff --git a/backend/internal/repository/redeem_code_repo.go b/backend/internal/repository/redeem_code_repo.go
index 07975970..47c38d3e 100644
--- a/backend/internal/repository/redeem_code_repo.go
+++ b/backend/internal/repository/redeem_code_repo.go
@@ -30,6 +30,7 @@ func (r *redeemCodeRepository) Create(ctx context.Context, code *service.RedeemC
SetStatus(code.Status).
SetNotes(code.Notes).
SetValidityDays(code.ValidityDays).
+ SetNillableExpiresAt(code.ExpiresAt).
SetNillableUsedBy(code.UsedBy).
SetNillableUsedAt(code.UsedAt).
SetNillableGroupID(code.GroupID).
@@ -56,6 +57,7 @@ func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []service.
SetStatus(c.Status).
SetNotes(c.Notes).
SetValidityDays(c.ValidityDays).
+ SetNillableExpiresAt(c.ExpiresAt).
SetNillableUsedBy(c.UsedBy).
SetNillableUsedAt(c.UsedAt).
SetNillableGroupID(c.GroupID)
@@ -107,7 +109,28 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin
q = q.Where(redeemcode.TypeEQ(codeType))
}
if status != "" {
- q = q.Where(redeemcode.StatusEQ(status))
+ now := time.Now()
+ switch status {
+ case service.StatusExpired:
+ q = q.Where(redeemcode.Or(
+ redeemcode.StatusEQ(service.StatusExpired),
+ redeemcode.And(
+ redeemcode.StatusEQ(service.StatusUnused),
+ redeemcode.ExpiresAtNotNil(),
+ redeemcode.ExpiresAtLTE(now),
+ ),
+ ))
+ case service.StatusUnused:
+ q = q.Where(
+ redeemcode.StatusEQ(service.StatusUnused),
+ redeemcode.Or(
+ redeemcode.ExpiresAtIsNil(),
+ redeemcode.ExpiresAtGT(now),
+ ),
+ )
+ default:
+ q = q.Where(redeemcode.StatusEQ(status))
+ }
}
if search != "" {
q = q.Where(
@@ -158,6 +181,8 @@ func redeemCodeListOrder(params pagination.PaginationParams) []func(*entsql.Sele
field = redeemcode.FieldUsedAt
case "created_at":
field = redeemcode.FieldCreatedAt
+ case "expires_at":
+ field = redeemcode.FieldExpiresAt
case "code":
field = redeemcode.FieldCode
default:
@@ -194,6 +219,11 @@ func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemC
} else {
up.ClearGroupID()
}
+ if code.ExpiresAt != nil {
+ up.SetExpiresAt(*code.ExpiresAt)
+ } else {
+ up.ClearExpiresAt()
+ }
updated, err := up.Save(ctx)
if err != nil {
@@ -307,6 +337,7 @@ func redeemCodeEntityToService(m *dbent.RedeemCode) *service.RedeemCode {
UsedAt: m.UsedAt,
Notes: derefString(m.Notes),
CreatedAt: m.CreatedAt,
+ ExpiresAt: m.ExpiresAt,
GroupID: m.GroupID,
ValidityDays: m.ValidityDays,
}
diff --git a/backend/internal/repository/redeem_code_repo_integration_test.go b/backend/internal/repository/redeem_code_repo_integration_test.go
index 39674b52..24e5910e 100644
--- a/backend/internal/repository/redeem_code_repo_integration_test.go
+++ b/backend/internal/repository/redeem_code_repo_integration_test.go
@@ -51,11 +51,13 @@ func (s *RedeemCodeRepoSuite) createGroup(name string) *dbent.Group {
// --- Create / CreateBatch / GetByID / GetByCode ---
func (s *RedeemCodeRepoSuite) TestCreate() {
+ expiresAt := time.Now().UTC().Add(2 * time.Hour)
code := &service.RedeemCode{
- Code: "TEST-CREATE",
- Type: service.RedeemTypeBalance,
- Value: 100,
- Status: service.StatusUnused,
+ Code: "TEST-CREATE",
+ Type: service.RedeemTypeBalance,
+ Value: 100,
+ Status: service.StatusUnused,
+ ExpiresAt: &expiresAt,
}
err := s.repo.Create(s.ctx, code)
@@ -65,6 +67,8 @@ func (s *RedeemCodeRepoSuite) TestCreate() {
got, err := s.repo.GetByID(s.ctx, code.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("TEST-CREATE", got.Code)
+ s.Require().NotNil(got.ExpiresAt)
+ s.Require().WithinDuration(expiresAt, *got.ExpiresAt, time.Second)
}
func (s *RedeemCodeRepoSuite) TestCreateBatch() {
@@ -166,6 +170,23 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
s.Require().Equal(service.StatusUsed, codes[0].Status)
}
+func (s *RedeemCodeRepoSuite) TestListWithFilters_StatusExpiredByExpiresAt() {
+ past := time.Now().UTC().Add(-time.Hour)
+ future := time.Now().UTC().Add(time.Hour)
+ s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-EXPIRED-BY-TIME", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused, ExpiresAt: &past}))
+ s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-UNUSED-FUTURE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused, ExpiresAt: &future}))
+
+ expired, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusExpired, "")
+ s.Require().NoError(err)
+ s.Require().Len(expired, 1)
+ s.Require().Equal("STAT-EXPIRED-BY-TIME", expired[0].Code)
+
+ unused, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusUnused, "")
+ s.Require().NoError(err)
+ s.Require().Len(unused, 1)
+ s.Require().Equal("STAT-UNUSED-FUTURE", unused[0].Code)
+}
+
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "BETA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
diff --git a/backend/internal/repository/scheduler_cache.go b/backend/internal/repository/scheduler_cache.go
index f1a42ef7..36956fb2 100644
--- a/backend/internal/repository/scheduler_cache.go
+++ b/backend/internal/repository/scheduler_cache.go
@@ -546,6 +546,8 @@ func filterSchedulerExtra(extra map[string]any) map[string]any {
"responses_websockets_v2_enabled",
"openai_ws_enabled",
"openai_ws_force_http",
+ "openai_responses_mode",
+ "openai_responses_supported",
// model_rate_limits 必须进入调度快照:SetModelRateLimit 写入的模型级冷却
// 时间戳(accounts.extra.model_rate_limits..rate_limit_reset_at)
// 是 isAccountSchedulableForModelSelection/IsSchedulableForModelWithContext
diff --git a/backend/internal/repository/scheduler_cache_unit_test.go b/backend/internal/repository/scheduler_cache_unit_test.go
index 32dda0a8..5dc5242a 100644
--- a/backend/internal/repository/scheduler_cache_unit_test.go
+++ b/backend/internal/repository/scheduler_cache_unit_test.go
@@ -18,6 +18,8 @@ func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) {
"openai_oauth_responses_websockets_v2_enabled": true,
"openai_oauth_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough,
"openai_ws_force_http": true,
+ "openai_responses_mode": "force_chat_completions",
+ "openai_responses_supported": false,
"mixed_scheduling": true,
"unused_large_field": "drop-me",
},
@@ -28,6 +30,8 @@ func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) {
require.Equal(t, true, got.Extra["openai_oauth_responses_websockets_v2_enabled"])
require.Equal(t, service.OpenAIWSIngressModePassthrough, got.Extra["openai_oauth_responses_websockets_v2_mode"])
require.Equal(t, true, got.Extra["openai_ws_force_http"])
+ require.Equal(t, "force_chat_completions", got.Extra["openai_responses_mode"])
+ require.Equal(t, false, got.Extra["openai_responses_supported"])
require.Equal(t, true, got.Extra["mixed_scheduling"])
require.Nil(t, got.Extra["unused_large_field"])
}
diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go
index f2fb87da..f11910a0 100644
--- a/backend/internal/repository/usage_log_repo.go
+++ b/backend/internal/repository/usage_log_repo.go
@@ -28,7 +28,7 @@ import (
gocache "github.com/patrickmn/go-cache"
)
-const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at"
+const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, image_input_size, image_output_size, image_size_source, image_size_breakdown, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at"
// usageLogInsertArgTypes must stay in the same order as:
// 1. prepareUsageLogInsert().args
@@ -73,6 +73,10 @@ var usageLogInsertArgTypes = [...]string{
"text", // ip_address
"integer", // image_count
"text", // image_size
+ "text", // image_input_size
+ "text", // image_output_size
+ "text", // image_size_source
+ "jsonb", // image_size_breakdown
"text", // service_tier
"text", // reasoning_effort
"text", // inbound_endpoint
@@ -92,6 +96,22 @@ const rawUsageLogModelColumn = "model"
// Historical rows may contain upstream/billing model values, while newer rows store requested_model.
// Requested/upstream/mapping analytics must use resolveModelDimensionExpression instead.
+// usageLogSuccessFilterUL 用于把"失败请求 usage log"(tokens=0、cost=0、不计费的占位记录)
+// 从统计性聚合中排除,避免污染 Dashboard / 用量拆分等指标。
+//
+// schema 中没有 success bool 列;新增列要做迁移,风险大;这里用 actual_cost > 0 作为代理:
+// 任何成功落账的请求都会产生 actual_cost(包括 token 计费、纯图片 token 计费、按次/按图计费),
+// 反之 failed-request usage log 的 actual_cost 为 0。
+// 早期版本用 4 项 token 和 > 0 判定会把"按次/按图计费"与"image_output_tokens 独立计费"的纯图片
+// 请求误判为失败,导致这部分请求从用量统计里消失,故改用 actual_cost。
+// 配合 `FROM usage_logs ul` JOIN 查询使用。
+const usageLogSuccessFilterUL = "ul.actual_cost > 0"
+
+// usageLogEffectivePlatformExpr 用于按"有效平台"维度聚合 usage_logs:
+// 优先取请求实际走的分组 platform,若分组未设置 platform 再 fallback 到 account.platform。
+// 配套要求查询里 LEFT JOIN groups g ON g.id = ul.group_id 与 LEFT JOIN accounts a ON a.id = ul.account_id。
+const usageLogEffectivePlatformExpr = "COALESCE(NULLIF(g.platform,''), a.platform)"
+
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var dateFormatWhitelist = map[string]string{
"hour": "YYYY-MM-DD HH24:00",
@@ -120,6 +140,24 @@ func appendRawUsageLogModelWhereCondition(conditions []string, args []any, model
return conditions, args
}
+func appendUsageLogBillingModeWhereCondition(conditions []string, args []any, billingMode string) ([]string, []any) {
+ mode := strings.TrimSpace(billingMode)
+ if mode == "" {
+ return conditions, args
+ }
+ placeholder := fmt.Sprintf("$%d", len(args)+1)
+ switch service.BillingMode(mode) {
+ case service.BillingModeImage:
+ conditions = append(conditions, fmt.Sprintf("(billing_mode = %s OR COALESCE(image_count, 0) > 0)", placeholder))
+ case service.BillingModeToken:
+ conditions = append(conditions, fmt.Sprintf("(billing_mode = %s OR ((billing_mode IS NULL OR billing_mode = '') AND COALESCE(image_count, 0) <= 0))", placeholder))
+ default:
+ conditions = append(conditions, fmt.Sprintf("billing_mode = %s", placeholder))
+ }
+ args = append(args, mode)
+ return conditions, args
+}
+
// appendRawUsageLogModelQueryFilter keeps direct model filters on the raw model column for backward
// compatibility with historical rows. Requested/upstream analytics must use
// resolveModelDimensionExpression instead.
@@ -352,6 +390,10 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
ip_address,
image_count,
image_size,
+ image_input_size,
+ image_output_size,
+ image_size_source,
+ image_size_breakdown,
service_tier,
reasoning_effort,
inbound_endpoint,
@@ -369,7 +411,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor,
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
- $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
+ $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -790,6 +832,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address,
image_count,
image_size,
+ image_input_size,
+ image_output_size,
+ image_size_source,
+ image_size_breakdown,
service_tier,
reasoning_effort,
inbound_endpoint,
@@ -803,7 +849,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
created_at
) AS (VALUES `)
- args := make([]any, 0, len(keys)*46)
+ args := make([]any, 0, len(keys)*50)
argPos := 1
for idx, key := range keys {
if idx > 0 {
@@ -867,6 +913,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address,
image_count,
image_size,
+ image_input_size,
+ image_output_size,
+ image_size_source,
+ image_size_breakdown,
service_tier,
reasoning_effort,
inbound_endpoint,
@@ -915,6 +965,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage
ip_address,
image_count,
image_size,
+ image_input_size,
+ image_output_size,
+ image_size_source,
+ image_size_breakdown,
service_tier,
reasoning_effort,
inbound_endpoint,
@@ -1003,6 +1057,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address,
image_count,
image_size,
+ image_input_size,
+ image_output_size,
+ image_size_source,
+ image_size_breakdown,
service_tier,
reasoning_effort,
inbound_endpoint,
@@ -1016,7 +1074,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
created_at
) AS (VALUES `)
- args := make([]any, 0, len(preparedList)*46)
+ args := make([]any, 0, len(preparedList)*50)
argPos := 1
for idx, prepared := range preparedList {
if idx > 0 {
@@ -1077,6 +1135,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address,
image_count,
image_size,
+ image_input_size,
+ image_output_size,
+ image_size_source,
+ image_size_breakdown,
service_tier,
reasoning_effort,
inbound_endpoint,
@@ -1125,6 +1187,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (
ip_address,
image_count,
image_size,
+ image_input_size,
+ image_output_size,
+ image_size_source,
+ image_size_breakdown,
service_tier,
reasoning_effort,
inbound_endpoint,
@@ -1181,6 +1247,10 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
ip_address,
image_count,
image_size,
+ image_input_size,
+ image_output_size,
+ image_size_source,
+ image_size_breakdown,
service_tier,
reasoning_effort,
inbound_endpoint,
@@ -1198,7 +1268,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared
$10, $11, $12, $13,
$14, $15, $16, $17,
$18, $19, $20, $21, $22, $23,
- $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46
+ $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
`, prepared.args...)
@@ -1225,6 +1295,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize)
+ imageInputSize := nullString(log.ImageInputSize)
+ imageOutputSize := nullString(log.ImageOutputSize)
+ imageSizeSource := nullString(log.ImageSizeSource)
+ imageSizeBreakdown := nullStringIntMapJSON(log.ImageSizeBreakdown)
serviceTier := nullString(log.ServiceTier)
reasoningEffort := nullString(log.ReasoningEffort)
inboundEndpoint := nullString(log.InboundEndpoint)
@@ -1285,6 +1359,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared {
ipAddress,
log.ImageCount,
imageSize,
+ imageInputSize,
+ imageOutputSize,
+ imageSizeSource,
+ imageSizeBreakdown,
serviceTier,
reasoningEffort,
inboundEndpoint,
@@ -2352,6 +2430,9 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi
// UserDashboardStats 用户仪表盘统计
type UserDashboardStats = usagestats.UserDashboardStats
+// PlatformDashboardStats 单平台用量明细
+type PlatformDashboardStats = usagestats.PlatformDashboardStats
+
// GetUserDashboardStats 获取用户专属的仪表盘统计
func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) {
stats := &UserDashboardStats{}
@@ -2447,6 +2528,57 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
stats.Rpm = rpm
stats.Tpm = tpm
+ // 按"有效平台"维度拆分(group.platform 优先,否则 account.platform)。
+ // 与 ops 路径口径一致;HAVING 过滤掉无法确定平台的行(避免出现空字符串平台)。
+ // 与上面 totalStatsQuery/todayStatsQuery 的总值可能略微差异,原因有二:
+ // 1) 无平台归属的极少数行(group/account 都没 platform)会被 HAVING 排除;
+ // 2) usageLogSuccessFilterUL 会把 actual_cost = 0 的失败 placeholder 行排除,
+ // 而 totalStatsQuery/todayStatsQuery 没有这层过滤、会把这些行的 request 计数算进去。
+ platformQuery := `
+ SELECT
+ ` + usageLogEffectivePlatformExpr + ` as platform,
+ COUNT(*) as total_requests,
+ COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
+ COALESCE(SUM(ul.actual_cost), 0) as total_actual_cost,
+ COUNT(*) FILTER (WHERE ul.created_at >= $2) as today_requests,
+ COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens) FILTER (WHERE ul.created_at >= $2), 0) as today_tokens,
+ COALESCE(SUM(ul.actual_cost) FILTER (WHERE ul.created_at >= $2), 0) as today_actual_cost
+ FROM usage_logs ul
+ LEFT JOIN groups g ON g.id = ul.group_id
+ LEFT JOIN accounts a ON a.id = ul.account_id
+ WHERE ul.user_id = $1
+ AND ` + usageLogSuccessFilterUL + `
+ GROUP BY ` + usageLogEffectivePlatformExpr + `
+ HAVING ` + usageLogEffectivePlatformExpr + ` IS NOT NULL AND ` + usageLogEffectivePlatformExpr + ` <> ''
+ ORDER BY total_actual_cost DESC
+ `
+ rows, err := r.sql.QueryContext(ctx, platformQuery, userID, today)
+ if err != nil {
+ return nil, err
+ }
+ for rows.Next() {
+ var p PlatformDashboardStats
+ if err := rows.Scan(
+ &p.Platform,
+ &p.TotalRequests,
+ &p.TotalTokens,
+ &p.TotalActualCost,
+ &p.TodayRequests,
+ &p.TodayTokens,
+ &p.TodayActualCost,
+ ); err != nil {
+ _ = rows.Close()
+ return nil, err
+ }
+ stats.ByPlatform = append(stats.ByPlatform, p)
+ }
+ if err := rows.Close(); err != nil {
+ return nil, err
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+
return stats, nil
}
@@ -2662,10 +2794,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
- if filters.BillingMode != "" {
- conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
- args = append(args, filters.BillingMode)
- }
+ conditions, args = appendUsageLogBillingModeWhereCondition(conditions, args, filters.BillingMode)
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
@@ -2710,6 +2839,9 @@ type UsageStats = usagestats.UsageStats
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats = usagestats.BatchUserUsageStats
+// PlatformUsage represents per-platform usage breakdown
+type PlatformUsage = usagestats.PlatformUsage
+
func normalizePositiveInt64IDs(ids []int64) []int64 {
if len(ids) == 0 {
return nil
@@ -2750,15 +2882,21 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
result[id] = &BatchUserUsageStats{UserID: id}
}
+ // GROUP BY (user_id, effective_platform) 一次查询同时得到总值与按平台拆分。
+ // 应用层把同一 user_id 的多行累加为总值,并把非空 platform 行收集到 ByPlatform。
query := `
SELECT
- user_id,
- COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost,
- COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost
- FROM usage_logs
- WHERE user_id = ANY($1)
- AND created_at >= LEAST($2, $4)
- GROUP BY user_id
+ ul.user_id,
+ ` + usageLogEffectivePlatformExpr + ` as platform,
+ COALESCE(SUM(ul.actual_cost) FILTER (WHERE ul.created_at >= $2 AND ul.created_at < $3), 0) as total_cost,
+ COALESCE(SUM(ul.actual_cost) FILTER (WHERE ul.created_at >= $4), 0) as today_cost
+ FROM usage_logs ul
+ LEFT JOIN groups g ON g.id = ul.group_id
+ LEFT JOIN accounts a ON a.id = ul.account_id
+ WHERE ul.user_id = ANY($1)
+ AND ul.created_at >= LEAST($2, $4)
+ AND ` + usageLogSuccessFilterUL + `
+ GROUP BY ul.user_id, ` + usageLogEffectivePlatformExpr + `
`
today := timezone.Today()
rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedUserIDs), startTime, endTime, today)
@@ -2767,15 +2905,25 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
}
for rows.Next() {
var userID int64
+ var platform sql.NullString
var total float64
var todayTotal float64
- if err := rows.Scan(&userID, &total, &todayTotal); err != nil {
+ if err := rows.Scan(&userID, &platform, &total, &todayTotal); err != nil {
_ = rows.Close()
return nil, err
}
- if stats, ok := result[userID]; ok {
- stats.TotalActualCost = total
- stats.TodayActualCost = todayTotal
+ stats, ok := result[userID]
+ if !ok {
+ continue
+ }
+ stats.TotalActualCost += total
+ stats.TodayActualCost += todayTotal
+ if platform.Valid && platform.String != "" {
+ stats.ByPlatform = append(stats.ByPlatform, PlatformUsage{
+ Platform: platform.String,
+ TotalActualCost: total,
+ TodayActualCost: todayTotal,
+ })
}
}
if err := rows.Close(); err != nil {
@@ -3363,10 +3511,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
}
- if filters.BillingMode != "" {
- conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1))
- args = append(args, filters.BillingMode)
- }
+ conditions, args = appendUsageLogBillingModeWhereCondition(conditions, args, filters.BillingMode)
if filters.StartTime != nil {
conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1))
args = append(args, *filters.StartTime)
@@ -4084,6 +4229,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
ipAddress sql.NullString
imageCount int
imageSize sql.NullString
+ imageInputSize sql.NullString
+ imageOutputSize sql.NullString
+ imageSizeSource sql.NullString
+ imageSizeBreakdown sql.NullString
serviceTier sql.NullString
reasoningEffort sql.NullString
inboundEndpoint sql.NullString
@@ -4134,6 +4283,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&ipAddress,
&imageCount,
&imageSize,
+ &imageInputSize,
+ &imageOutputSize,
+ &imageSizeSource,
+ &imageSizeBreakdown,
&serviceTier,
&reasoningEffort,
&inboundEndpoint,
@@ -4212,6 +4365,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if imageSize.Valid {
log.ImageSize = &imageSize.String
}
+ if imageInputSize.Valid {
+ log.ImageInputSize = &imageInputSize.String
+ }
+ if imageOutputSize.Valid {
+ log.ImageOutputSize = &imageOutputSize.String
+ }
+ if imageSizeSource.Valid {
+ log.ImageSizeSource = &imageSizeSource.String
+ }
+ log.ImageSizeBreakdown = stringIntMapFromNullJSON(imageSizeBreakdown)
if serviceTier.Valid {
log.ServiceTier = &serviceTier.String
}
@@ -4378,6 +4541,31 @@ func nullString(v *string) sql.NullString {
return sql.NullString{String: *v, Valid: true}
}
+func nullStringIntMapJSON(v map[string]int) any {
+ if len(v) == 0 {
+ return nil
+ }
+ payload, err := json.Marshal(v)
+ if err != nil {
+ return nil
+ }
+ return string(payload)
+}
+
+func stringIntMapFromNullJSON(v sql.NullString) map[string]int {
+ if !v.Valid || strings.TrimSpace(v.String) == "" {
+ return nil
+ }
+ var out map[string]int
+ if err := json.Unmarshal([]byte(v.String), &out); err != nil {
+ return nil
+ }
+ if len(out) == 0 {
+ return nil
+ }
+ return out
+}
+
func coalesceTrimmedString(v sql.NullString, fallback string) string {
if v.Valid && strings.TrimSpace(v.String) != "" {
return v.String
diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go
index a5ff4bc1..597c9597 100644
--- a/backend/internal/repository/usage_log_repo_request_type_test.go
+++ b/backend/internal/repository/usage_log_repo_request_type_test.go
@@ -76,6 +76,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
sqlmock.AnyArg(), // ip_address
log.ImageCount,
sqlmock.AnyArg(), // image_size
+ sqlmock.AnyArg(), // image_input_size
+ sqlmock.AnyArg(), // image_output_size
+ sqlmock.AnyArg(), // image_size_source
+ sqlmock.AnyArg(), // image_size_breakdown
sqlmock.AnyArg(), // service_tier
sqlmock.AnyArg(), // reasoning_effort
sqlmock.AnyArg(), // inbound_endpoint
@@ -155,6 +159,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) {
sqlmock.AnyArg(),
log.ImageCount,
sqlmock.AnyArg(),
+ sqlmock.AnyArg(), // image_input_size
+ sqlmock.AnyArg(), // image_output_size
+ sqlmock.AnyArg(), // image_size_source
+ sqlmock.AnyArg(), // image_size_breakdown
serviceTier,
sqlmock.AnyArg(),
sqlmock.AnyArg(),
@@ -230,12 +238,74 @@ func TestPrepareUsageLogInsert_ArgCountMatchesTypes(t *testing.T) {
require.Len(t, prepared.args, len(usageLogInsertArgTypes))
}
+func TestPrepareUsageLogInsert_PersistsImageSizeMetadata(t *testing.T) {
+ imageSize := "4K"
+ inputSize := "1024x1024"
+ outputSize := "3840x2160"
+ source := "output"
+ prepared := prepareUsageLogInsert(&service.UsageLog{
+ UserID: 1,
+ APIKeyID: 2,
+ AccountID: 3,
+ RequestID: "req-image-metadata",
+ Model: "gpt-image-2",
+ RequestedModel: "gpt-image-2",
+ ImageCount: 2,
+ ImageSize: &imageSize,
+ ImageInputSize: &inputSize,
+ ImageOutputSize: &outputSize,
+ ImageSizeSource: &source,
+ ImageSizeBreakdown: map[string]int{"1K": 1, "4K": 1},
+ CreatedAt: time.Date(2025, 1, 6, 12, 0, 0, 0, time.UTC),
+ })
+
+ require.Equal(t, sql.NullString{String: imageSize, Valid: true}, prepared.args[34])
+ require.Equal(t, sql.NullString{String: inputSize, Valid: true}, prepared.args[35])
+ require.Equal(t, sql.NullString{String: outputSize, Valid: true}, prepared.args[36])
+ require.Equal(t, sql.NullString{String: source, Valid: true}, prepared.args[37])
+ breakdownJSON, ok := prepared.args[38].(string)
+ require.True(t, ok)
+ require.JSONEq(t, `{"1K":1,"4K":1}`, breakdownJSON)
+}
+
func TestCoalesceTrimmedString(t *testing.T) {
require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{}, "fallback"))
require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{Valid: true, String: " "}, "fallback"))
require.Equal(t, "value", coalesceTrimmedString(sql.NullString{Valid: true, String: "value"}, "fallback"))
}
+func TestAppendUsageLogBillingModeWhereCondition(t *testing.T) {
+ tests := []struct {
+ name string
+ billingMode string
+ wantCondition string
+ }{
+ {
+ name: "image includes legacy image rows",
+ billingMode: string(service.BillingModeImage),
+ wantCondition: "(billing_mode = $1 OR COALESCE(image_count, 0) > 0)",
+ },
+ {
+ name: "token includes legacy non-image rows",
+ billingMode: string(service.BillingModeToken),
+ wantCondition: "(billing_mode = $1 OR ((billing_mode IS NULL OR billing_mode = '') AND COALESCE(image_count, 0) <= 0))",
+ },
+ {
+ name: "per request remains exact",
+ billingMode: string(service.BillingModePerRequest),
+ wantCondition: "billing_mode = $1",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ conditions, args := appendUsageLogBillingModeWhereCondition(nil, nil, tt.billingMode)
+ require.Equal(t, []string{tt.wantCondition}, conditions)
+ require.Equal(t, []any{tt.billingMode}, args)
+ })
+ }
+}
+
func anySliceToDriverValues(values []any) []driver.Value {
out := make([]driver.Value, 0, len(values))
for _, value := range values {
@@ -528,6 +598,63 @@ func (s usageLogScannerStub) Scan(dest ...any) error {
}
func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
+ t.Run("image_size_metadata_is_scanned", func(t *testing.T) {
+ now := time.Now().UTC()
+ log, err := scanUsageLog(usageLogScannerStub{values: []any{
+ int64(4),
+ int64(13),
+ int64(23),
+ int64(33),
+ sql.NullString{Valid: true, String: "req-image-metadata"},
+ "gpt-image-2",
+ sql.NullString{Valid: true, String: "gpt-image-2"},
+ sql.NullString{},
+ sql.NullInt64{},
+ sql.NullInt64{},
+ 0, 0, 0, 0, 0, 0,
+ 0, 0.0, // image_output_tokens, image_output_cost
+ 0.0, 0.0, 0.0, 0.0, 0.8, 0.8,
+ 1.0,
+ sql.NullFloat64{},
+ int16(service.BillingTypeBalance),
+ int16(service.RequestTypeSync),
+ false,
+ false,
+ sql.NullInt64{},
+ sql.NullInt64{},
+ sql.NullString{},
+ sql.NullString{},
+ 2,
+ sql.NullString{Valid: true, String: "4K"},
+ sql.NullString{Valid: true, String: "1024x1024"},
+ sql.NullString{Valid: true, String: "3840x2160"},
+ sql.NullString{Valid: true, String: "output"},
+ sql.NullString{Valid: true, String: `{"4K":2}`},
+ sql.NullString{},
+ sql.NullString{},
+ sql.NullString{},
+ sql.NullString{},
+ false,
+ sql.NullInt64{},
+ sql.NullString{},
+ sql.NullString{},
+ sql.NullString{},
+ sql.NullFloat64{},
+ now,
+ }})
+ require.NoError(t, err)
+ require.Equal(t, 2, log.ImageCount)
+ require.NotNil(t, log.ImageSize)
+ require.Equal(t, "4K", *log.ImageSize)
+ require.NotNil(t, log.ImageInputSize)
+ require.Equal(t, "1024x1024", *log.ImageInputSize)
+ require.NotNil(t, log.ImageOutputSize)
+ require.Equal(t, "3840x2160", *log.ImageOutputSize)
+ require.NotNil(t, log.ImageSizeSource)
+ require.Equal(t, "output", *log.ImageSizeSource)
+ require.Equal(t, map[string]int{"4K": 2}, log.ImageSizeBreakdown)
+ })
+
t.Run("request_type_ws_v2_overrides_legacy", func(t *testing.T) {
now := time.Now().UTC()
log, err := scanUsageLog(usageLogScannerStub{values: []any{
@@ -567,6 +694,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
0,
sql.NullString{},
+ sql.NullString{}, // image_input_size
+ sql.NullString{}, // image_output_size
+ sql.NullString{}, // image_size_source
+ sql.NullString{}, // image_size_breakdown
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
sql.NullString{},
@@ -615,6 +746,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
0,
sql.NullString{},
+ sql.NullString{}, // image_input_size
+ sql.NullString{}, // image_output_size
+ sql.NullString{}, // image_size_source
+ sql.NullString{}, // image_size_breakdown
sql.NullString{Valid: true, String: "flex"},
sql.NullString{},
sql.NullString{},
@@ -663,6 +798,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
sql.NullString{},
0,
sql.NullString{},
+ sql.NullString{}, // image_input_size
+ sql.NullString{}, // image_output_size
+ sql.NullString{}, // image_size_source
+ sql.NullString{}, // image_size_breakdown
sql.NullString{Valid: true, String: "priority"},
sql.NullString{},
sql.NullString{},
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index 1566756d..610d9a7b 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -334,7 +334,8 @@ func normalizeEmailAuthIdentitySubject(email string) string {
}
if strings.HasSuffix(normalized, service.LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, service.OIDCConnectSyntheticEmailDomain) ||
- strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) {
+ strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, service.DingTalkConnectSyntheticEmailDomain) {
return ""
}
return normalized
@@ -956,7 +957,7 @@ func userSignupSourceOrDefault(signupSource string) string {
switch strings.TrimSpace(strings.ToLower(signupSource)) {
case "", "email":
return "email"
- case "linuxdo", "wechat", "oidc":
+ case "linuxdo", "wechat", "oidc", "dingtalk":
return strings.TrimSpace(strings.ToLower(signupSource))
default:
return "email"
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index 0cc5a211..3770d585 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -68,6 +68,7 @@ func TestAPIContracts(t *testing.T) {
"linuxdo_bound": false,
"oidc_bound": false,
"wechat_bound": false,
+ "dingtalk_bound": false,
"identities": {
"email": {
"provider": "email",
@@ -104,6 +105,14 @@ func TestAPIContracts(t *testing.T) {
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "dingtalk": {
+ "provider": "dingtalk",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/dingtalk/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"identity_bindings": {
@@ -142,6 +151,14 @@ func TestAPIContracts(t *testing.T) {
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "dingtalk": {
+ "provider": "dingtalk",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/dingtalk/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"auth_bindings": {
@@ -180,6 +197,14 @@ func TestAPIContracts(t *testing.T) {
"can_bind": true,
"can_unbind": false,
"bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
+ },
+ "dingtalk": {
+ "provider": "dingtalk",
+ "bound": false,
+ "bound_count": 0,
+ "can_bind": true,
+ "can_unbind": false,
+ "bind_start_path": "/api/v1/auth/oauth/dingtalk/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile"
}
},
"run_mode": "standard"
@@ -554,6 +579,10 @@ func TestAPIContracts(t *testing.T) {
"first_token_ms": 50,
"image_count": 0,
"image_size": null,
+ "image_input_size": null,
+ "image_output_size": null,
+ "image_size_source": null,
+ "image_size_breakdown": null,
"media_type": null,
"cache_ttl_overridden": false,
"created_at": "2025-01-02T03:04:05Z",
@@ -672,6 +701,22 @@ func TestAPIContracts(t *testing.T) {
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
+ "dingtalk_connect_enabled": false,
+ "dingtalk_connect_bypass_registration": false,
+ "dingtalk_connect_client_id": "",
+ "dingtalk_connect_client_secret_configured": false,
+ "dingtalk_connect_redirect_url": "",
+ "dingtalk_connect_internal_corp_id": "",
+ "dingtalk_connect_corp_restriction_policy": "",
+ "dingtalk_connect_sync_corp_email": false,
+ "dingtalk_connect_sync_corp_email_attr_key": "dingtalk_email",
+ "dingtalk_connect_sync_corp_email_attr_name": "钉钉企业邮箱",
+ "dingtalk_connect_sync_dept": false,
+ "dingtalk_connect_sync_dept_attr_key": "dingtalk_department",
+ "dingtalk_connect_sync_dept_attr_name": "钉钉部门",
+ "dingtalk_connect_sync_display_name": false,
+ "dingtalk_connect_sync_display_name_attr_key": "dingtalk_name",
+ "dingtalk_connect_sync_display_name_attr_name": "钉钉姓名",
"oidc_connect_enabled": false,
"oidc_connect_provider_name": "OIDC",
"oidc_connect_client_id": "",
@@ -744,6 +789,11 @@ func TestAPIContracts(t *testing.T) {
"auth_source_default_wechat_subscriptions": [],
"auth_source_default_wechat_grant_on_signup": false,
"auth_source_default_wechat_grant_on_first_bind": false,
+ "auth_source_default_dingtalk_balance": 0,
+ "auth_source_default_dingtalk_concurrency": 5,
+ "auth_source_default_dingtalk_subscriptions": [],
+ "auth_source_default_dingtalk_grant_on_signup": false,
+ "auth_source_default_dingtalk_grant_on_first_bind": false,
"force_email_on_third_party_signup": false,
"default_concurrency": 5,
"default_balance": 1.25,
@@ -784,14 +834,7 @@ func TestAPIContracts(t *testing.T) {
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": true,
"openai_fast_policy_settings": {
- "rules": [
- {
- "service_tier": "priority",
- "action": "filter",
- "scope": "all",
- "fallback_action": "pass"
- }
- ]
+ "rules": []
},
"custom_menu_items": [],
"custom_endpoints": [],
@@ -815,6 +858,7 @@ func TestAPIContracts(t *testing.T) {
"payment_cancel_rate_limit_window": 0,
"payment_cancel_rate_limit_unit": "",
"payment_cancel_rate_limit_window_mode": "",
+ "payment_alipay_force_qrcode": false,
"balance_low_notify_enabled": false,
"account_quota_notify_enabled": false,
"balance_low_notify_threshold": 0,
@@ -917,6 +961,22 @@ func TestAPIContracts(t *testing.T) {
"linuxdo_connect_client_id": "",
"linuxdo_connect_client_secret_configured": false,
"linuxdo_connect_redirect_url": "",
+ "dingtalk_connect_enabled": false,
+ "dingtalk_connect_bypass_registration": false,
+ "dingtalk_connect_client_id": "",
+ "dingtalk_connect_client_secret_configured": false,
+ "dingtalk_connect_redirect_url": "",
+ "dingtalk_connect_internal_corp_id": "",
+ "dingtalk_connect_corp_restriction_policy": "",
+ "dingtalk_connect_sync_corp_email": false,
+ "dingtalk_connect_sync_corp_email_attr_key": "dingtalk_email",
+ "dingtalk_connect_sync_corp_email_attr_name": "钉钉企业邮箱",
+ "dingtalk_connect_sync_dept": false,
+ "dingtalk_connect_sync_dept_attr_key": "dingtalk_department",
+ "dingtalk_connect_sync_dept_attr_name": "钉钉部门",
+ "dingtalk_connect_sync_display_name": false,
+ "dingtalk_connect_sync_display_name_attr_key": "dingtalk_name",
+ "dingtalk_connect_sync_display_name_attr_name": "钉钉姓名",
"oidc_connect_enabled": true,
"oidc_connect_provider_name": "ConfigOIDC",
"oidc_connect_client_id": "oidc-config-client",
@@ -999,14 +1059,7 @@ func TestAPIContracts(t *testing.T) {
"payment_visible_method_wxpay_enabled": false,
"openai_advanced_scheduler_enabled": false,
"openai_fast_policy_settings": {
- "rules": [
- {
- "service_tier": "priority",
- "action": "filter",
- "scope": "all",
- "fallback_action": "pass"
- }
- ]
+ "rules": []
},
"payment_enabled": false,
"payment_min_amount": 0,
@@ -1028,6 +1081,7 @@ func TestAPIContracts(t *testing.T) {
"payment_cancel_rate_limit_window": 0,
"payment_cancel_rate_limit_unit": "",
"payment_cancel_rate_limit_window_mode": "",
+ "payment_alipay_force_qrcode": false,
"balance_low_notify_enabled": false,
"account_quota_notify_enabled": false,
"balance_low_notify_threshold": 0,
@@ -1084,6 +1138,11 @@ func TestAPIContracts(t *testing.T) {
"auth_source_default_wechat_subscriptions": [],
"auth_source_default_wechat_grant_on_signup": false,
"auth_source_default_wechat_grant_on_first_bind": false,
+ "auth_source_default_dingtalk_balance": 0,
+ "auth_source_default_dingtalk_concurrency": 5,
+ "auth_source_default_dingtalk_subscriptions": [],
+ "auth_source_default_dingtalk_grant_on_signup": false,
+ "auth_source_default_dingtalk_grant_on_first_bind": false,
"force_email_on_third_party_signup": false
}
}`,
@@ -1194,10 +1253,10 @@ func newContractDeps(t *testing.T) *contractDeps {
settingService := service.NewSettingService(settingRepo, cfg)
adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
- authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
+ authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
- adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil, nil)
+ adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil, nil, nil)
adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) {
diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go
index 972c1eaf..c15f534e 100644
--- a/backend/internal/server/middleware/api_key_auth.go
+++ b/backend/internal/server/middleware/api_key_auth.go
@@ -92,6 +92,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
clientIP := ip.GetTrustedClientIP(c)
allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist)
if !allowed {
+ service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonIPRestriction)
AbortWithError(c, 403, "ACCESS_DENIED", "Access denied")
return
}
diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go
index 4a4ab0f9..d6760d8d 100644
--- a/backend/internal/server/middleware/api_key_auth_test.go
+++ b/backend/internal/server/middleware/api_key_auth_test.go
@@ -333,6 +333,15 @@ func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T)
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg)
router := gin.New()
require.NoError(t, router.SetTrustedProxies(nil))
+ var markedBusinessLimited bool
+ var businessLimitedReason string
+ router.Use(func(c *gin.Context) {
+ c.Next()
+ markedBusinessLimited = service.HasOpsClientBusinessLimited(c)
+ if v, ok := c.Get(service.OpsClientBusinessLimitedReasonKey); ok {
+ businessLimitedReason, _ = v.(string)
+ }
+ })
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
router.GET("/t", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
@@ -349,6 +358,8 @@ func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T)
require.Equal(t, http.StatusForbidden, w.Code)
require.Contains(t, w.Body.String(), "ACCESS_DENIED")
+ require.True(t, markedBusinessLimited)
+ require.Equal(t, service.OpsClientBusinessLimitedReasonIPRestriction, businessLimitedReason)
}
func TestAPIKeyAuthTouchesLastUsedOnSuccess(t *testing.T) {
diff --git a/backend/internal/server/middleware/backend_mode_guard.go b/backend/internal/server/middleware/backend_mode_guard.go
index 157f06b0..050e3bc6 100644
--- a/backend/internal/server/middleware/backend_mode_guard.go
+++ b/backend/internal/server/middleware/backend_mode_guard.go
@@ -42,15 +42,19 @@ func backendModeAllowsAuthPath(path string) bool {
"/auth/oauth/oidc/callback",
"/auth/oauth/github/callback",
"/auth/oauth/google/callback",
+ "/auth/oauth/dingtalk/callback",
"/auth/oauth/linuxdo/complete-registration",
"/auth/oauth/wechat/complete-registration",
"/auth/oauth/oidc/complete-registration",
+ "/auth/oauth/dingtalk/complete-registration",
"/auth/oauth/linuxdo/create-account",
"/auth/oauth/wechat/create-account",
"/auth/oauth/oidc/create-account",
+ "/auth/oauth/dingtalk/create-account",
"/auth/oauth/linuxdo/bind-login",
"/auth/oauth/wechat/bind-login",
"/auth/oauth/oidc/bind-login",
+ "/auth/oauth/dingtalk/bind-login",
} {
if strings.HasSuffix(path, suffix) {
return true
diff --git a/backend/internal/server/middleware/backend_mode_guard_test.go b/backend/internal/server/middleware/backend_mode_guard_test.go
index de9c9ec9..df2edde6 100644
--- a/backend/internal/server/middleware/backend_mode_guard_test.go
+++ b/backend/internal/server/middleware/backend_mode_guard_test.go
@@ -270,6 +270,36 @@ func TestBackendModeAuthGuard(t *testing.T) {
path: "/api/v1/auth/oauth/google/callback",
wantStatus: http.StatusOK,
},
+ {
+ name: "enabled_blocks_dingtalk_oauth_start",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/dingtalk/start",
+ wantStatus: http.StatusForbidden,
+ },
+ {
+ name: "enabled_allows_dingtalk_oauth_callback",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/dingtalk/callback",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_dingtalk_complete_registration",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/dingtalk/complete-registration",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_dingtalk_create_account",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/dingtalk/create-account",
+ wantStatus: http.StatusOK,
+ },
+ {
+ name: "enabled_allows_dingtalk_bind_login",
+ enabled: "true",
+ path: "/api/v1/auth/oauth/dingtalk/bind-login",
+ wantStatus: http.StatusOK,
+ },
{
name: "enabled_allows_oauth_pending_exchange",
enabled: "true",
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index de1d06ad..b8af9cc5 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -180,22 +180,17 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// Error logs (legacy)
ops.GET("/errors", h.Admin.Ops.GetErrorLogs)
ops.GET("/errors/:id", h.Admin.Ops.GetErrorLogByID)
- ops.GET("/errors/:id/retries", h.Admin.Ops.ListRetryAttempts)
- ops.POST("/errors/:id/retry", h.Admin.Ops.RetryErrorRequest)
ops.PUT("/errors/:id/resolve", h.Admin.Ops.UpdateErrorResolution)
// Request errors (client-visible failures)
ops.GET("/request-errors", h.Admin.Ops.ListRequestErrors)
ops.GET("/request-errors/:id", h.Admin.Ops.GetRequestError)
ops.GET("/request-errors/:id/upstream-errors", h.Admin.Ops.ListRequestErrorUpstreamErrors)
- ops.POST("/request-errors/:id/retry-client", h.Admin.Ops.RetryRequestErrorClient)
- ops.POST("/request-errors/:id/upstream-errors/:idx/retry", h.Admin.Ops.RetryRequestErrorUpstreamEvent)
ops.PUT("/request-errors/:id/resolve", h.Admin.Ops.ResolveRequestError)
// Upstream errors (independent upstream failures)
ops.GET("/upstream-errors", h.Admin.Ops.ListUpstreamErrors)
ops.GET("/upstream-errors/:id", h.Admin.Ops.GetUpstreamError)
- ops.POST("/upstream-errors/:id/retry", h.Admin.Ops.RetryUpstreamError)
ops.PUT("/upstream-errors/:id/resolve", h.Admin.Ops.ResolveUpstreamError)
// Request drilldown (success + error)
@@ -309,6 +304,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable)
accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable)
accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels)
+ accounts.POST("/:id/models/sync-upstream", h.Admin.Account.SyncUpstreamModels)
accounts.POST("/batch", h.Admin.Account.BatchCreate)
accounts.GET("/data", h.Admin.Account.ExportData)
accounts.POST("/data", h.Admin.Account.ImportData)
@@ -595,6 +591,7 @@ func registerChannelRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
{
channels.GET("", h.Admin.Channel.List)
channels.GET("/model-pricing", h.Admin.Channel.GetModelDefaultPricing)
+ channels.GET("/pricing/sync-models", h.Admin.Channel.SyncPricingModels)
channels.GET("/:id", h.Admin.Channel.GetByID)
channels.POST("", h.Admin.Channel.Create)
channels.PUT("/:id", h.Admin.Channel.Update)
diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go
index 54d40e92..19d0fd2a 100644
--- a/backend/internal/server/routes/auth.go
+++ b/backend/internal/server/routes/auth.go
@@ -182,6 +182,32 @@ func RegisterAuthRoutes(
}),
h.Auth.CreateOIDCOAuthAccount,
)
+ auth.GET("/oauth/dingtalk/start", h.Auth.DingTalkOAuthStart)
+ auth.GET("/oauth/dingtalk/bind/start", func(c *gin.Context) {
+ query := c.Request.URL.Query()
+ query.Set("intent", "bind_current_user")
+ c.Request.URL.RawQuery = query.Encode()
+ h.Auth.DingTalkOAuthStart(c)
+ })
+ auth.GET("/oauth/dingtalk/callback", h.Auth.DingTalkOAuthCallback)
+ auth.POST("/oauth/dingtalk/complete-registration",
+ rateLimiter.LimitWithOptions("oauth-dingtalk-complete", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CompleteDingTalkOAuthRegistration,
+ )
+ auth.POST("/oauth/dingtalk/bind-login",
+ rateLimiter.LimitWithOptions("oauth-dingtalk-bind-login", 20, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.BindDingTalkOAuthLogin,
+ )
+ auth.POST("/oauth/dingtalk/create-account",
+ rateLimiter.LimitWithOptions("oauth-dingtalk-create-account", 10, time.Minute, middleware.RateLimitOptions{
+ FailureMode: middleware.RateLimitFailClose,
+ }),
+ h.Auth.CreateDingTalkOAuthAccount,
+ )
}
// 公开设置(无需认证)
diff --git a/backend/internal/service/account_credentials_redact.go b/backend/internal/service/account_credentials_redact.go
new file mode 100644
index 00000000..76c2d1de
--- /dev/null
+++ b/backend/internal/service/account_credentials_redact.go
@@ -0,0 +1,50 @@
+package service
+
+// SensitiveCredentialKeys 列出 Account.Credentials JSON map 中绝不允许返回到前端的子键。
+// dto 层做响应脱敏、service 层做更新合并都引用此清单——新增凭证类型时务必同步。
+var SensitiveCredentialKeys = []string{
+ // OAuth
+ "access_token", "refresh_token", "id_token",
+ // API Key 类
+ "api_key", "session_key", "cookie",
+ // 云服务凭据
+ "aws_secret_access_key", "aws_session_token",
+ "service_account_json", "service_account", "private_key",
+}
+
+var sensitiveCredentialKeySet = func() map[string]struct{} {
+ m := make(map[string]struct{}, len(SensitiveCredentialKeys))
+ for _, k := range SensitiveCredentialKeys {
+ m[k] = struct{}{}
+ }
+ return m
+}()
+
+// IsSensitiveCredentialKey 判断指定键是否为敏感凭证子键。
+func IsSensitiveCredentialKey(key string) bool {
+ _, ok := sensitiveCredentialKeySet[key]
+ return ok
+}
+
+// MergePreservingSensitiveCreds 把 incoming 写入 existing 之上,但敏感子键采用"incoming 没提供就保留 existing"
+// 的语义。返回新的 map,不修改入参。
+//
+// 用途:前端编辑账号通常采用"全对象 PUT"模式;脱敏后前端 spread 旧 credentials 时不会带上敏感键,
+// 直接覆盖会清空已有 token。此函数保证:
+// - 非敏感键:完全由 incoming 决定(用户可以编辑、删除非敏感字段)。
+// - 敏感键:incoming 显式提供则覆盖(用户主动旋转 token),否则保留 existing。
+func MergePreservingSensitiveCreds(existing, incoming map[string]any) map[string]any {
+ out := make(map[string]any, len(incoming)+len(SensitiveCredentialKeys))
+ for k, v := range incoming {
+ out[k] = v
+ }
+ for _, key := range SensitiveCredentialKeys {
+ if _, hasIncoming := incoming[key]; hasIncoming {
+ continue
+ }
+ if existingVal, ok := existing[key]; ok {
+ out[key] = existingVal
+ }
+ }
+ return out
+}
diff --git a/backend/internal/service/account_credentials_redact_test.go b/backend/internal/service/account_credentials_redact_test.go
new file mode 100644
index 00000000..05f37da9
--- /dev/null
+++ b/backend/internal/service/account_credentials_redact_test.go
@@ -0,0 +1,90 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestMergePreservingSensitiveCreds_PreservesSensitiveWhenIncomingMissing(t *testing.T) {
+ existing := map[string]any{
+ "refresh_token": "rt-old",
+ "access_token": "at-old",
+ "api_key": "sk-old",
+ "base_url": "https://old.example.com",
+ }
+ incoming := map[string]any{
+ "base_url": "https://new.example.com",
+ "model_mapping": map[string]any{"foo": "bar"},
+ }
+
+ out := MergePreservingSensitiveCreds(existing, incoming)
+
+ require.Equal(t, "rt-old", out["refresh_token"], "incoming 没传 refresh_token,应保留 existing")
+ require.Equal(t, "at-old", out["access_token"])
+ require.Equal(t, "sk-old", out["api_key"])
+ require.Equal(t, "https://new.example.com", out["base_url"], "非敏感键由 incoming 决定")
+ require.Equal(t, map[string]any{"foo": "bar"}, out["model_mapping"])
+}
+
+func TestMergePreservingSensitiveCreds_OverwritesWhenIncomingProvidesSensitive(t *testing.T) {
+ existing := map[string]any{
+ "refresh_token": "rt-old",
+ "api_key": "sk-old",
+ }
+ incoming := map[string]any{
+ "refresh_token": "rt-new",
+ // 显式没传 api_key —— 应保留
+ }
+ out := MergePreservingSensitiveCreds(existing, incoming)
+ require.Equal(t, "rt-new", out["refresh_token"], "incoming 显式传入应覆盖")
+ require.Equal(t, "sk-old", out["api_key"], "incoming 没传应保留")
+}
+
+func TestMergePreservingSensitiveCreds_DoesNotMutateInputs(t *testing.T) {
+ existing := map[string]any{"refresh_token": "rt"}
+ incoming := map[string]any{"base_url": "x"}
+
+ _ = MergePreservingSensitiveCreds(existing, incoming)
+
+ require.Equal(t, "rt", existing["refresh_token"])
+ require.NotContains(t, existing, "base_url")
+ require.Equal(t, "x", incoming["base_url"])
+ require.NotContains(t, incoming, "refresh_token")
+}
+
+func TestMergePreservingSensitiveCreds_NilInputs(t *testing.T) {
+ out := MergePreservingSensitiveCreds(nil, map[string]any{"base_url": "x"})
+ require.Equal(t, "x", out["base_url"])
+ require.NotContains(t, out, "refresh_token")
+
+ out2 := MergePreservingSensitiveCreds(map[string]any{"refresh_token": "rt"}, nil)
+ require.Equal(t, "rt", out2["refresh_token"])
+}
+
+func TestMergePreservingSensitiveCreds_NonSensitiveDeletionAllowed(t *testing.T) {
+ existing := map[string]any{
+ "refresh_token": "rt",
+ "base_url": "https://old",
+ "project_id": "p1",
+ }
+ incoming := map[string]any{
+ "base_url": "https://new",
+ // 不带 project_id —— 等同删除(非敏感键由 incoming 决定)
+ }
+ out := MergePreservingSensitiveCreds(existing, incoming)
+ require.Equal(t, "rt", out["refresh_token"], "敏感键保留")
+ require.Equal(t, "https://new", out["base_url"])
+ require.NotContains(t, out, "project_id", "非敏感键 incoming 不传 = 删除")
+}
+
+func TestIsSensitiveCredentialKey(t *testing.T) {
+ require.True(t, IsSensitiveCredentialKey("refresh_token"))
+ require.True(t, IsSensitiveCredentialKey("api_key"))
+ require.True(t, IsSensitiveCredentialKey("private_key"))
+ require.False(t, IsSensitiveCredentialKey("base_url"))
+ require.False(t, IsSensitiveCredentialKey(""))
+ require.False(t, IsSensitiveCredentialKey("model_mapping"))
+}
diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go
index b8fc1d4c..5d4a1e8a 100644
--- a/backend/internal/service/account_usage_service.go
+++ b/backend/internal/service/account_usage_service.go
@@ -295,14 +295,16 @@ func NewAccountUsageService(
// OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),API响应缓存10分钟,窗口统计缓存1分钟
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
// API Key账号: 不支持usage查询
-func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
+func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64, force ...bool) (*UsageInfo, error) {
+ forceProbe := len(force) > 0 && force[0]
+
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get account failed: %w", err)
}
if account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth {
- usage, err := s.getOpenAIUsage(ctx, account)
+ usage, err := s.getOpenAIUsage(ctx, account, forceProbe)
if err == nil {
s.tryClearRecoverableAccountError(ctx, account)
}
@@ -492,7 +494,7 @@ func (s *AccountUsageService) syncActiveToPassive(ctx context.Context, accountID
}
}
-func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
+func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Account, force bool) (*UsageInfo, error) {
now := time.Now()
usage := &UsageInfo{UpdatedAt: &now}
@@ -507,7 +509,7 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou
usage.SevenDay = progress
}
- if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) {
+ if (force || shouldRefreshOpenAICodexSnapshot(account, usage, now)) && s.shouldProbeOpenAICodexSnapshot(account.ID, now, force) {
if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 {
mergeAccountExtra(account, updates)
if usage.UpdatedAt == nil {
@@ -577,13 +579,16 @@ func isOpenAICodexSnapshotStale(account *Account, now time.Time) bool {
return now.Sub(ts) >= openAIProbeCacheTTL
}
-func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, now time.Time) bool {
+func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, now time.Time, force ...bool) bool {
if s == nil || s.cache == nil || accountID <= 0 {
return true
}
- if cached, ok := s.cache.openAIProbeCache.Load(accountID); ok {
- if ts, ok := cached.(time.Time); ok && now.Sub(ts) < openAIProbeCacheTTL {
- return false
+ forceProbe := len(force) > 0 && force[0]
+ if !forceProbe {
+ if cached, ok := s.cache.openAIProbeCache.Load(accountID); ok {
+ if ts, ok := cached.(time.Time); ok && now.Sub(ts) < openAIProbeCacheTTL {
+ return false
+ }
}
}
s.cache.openAIProbeCache.Store(accountID, now)
diff --git a/backend/internal/service/account_usage_service_test.go b/backend/internal/service/account_usage_service_test.go
index 28b49838..e0390c4c 100644
--- a/backend/internal/service/account_usage_service_test.go
+++ b/backend/internal/service/account_usage_service_test.go
@@ -140,7 +140,7 @@ func TestAccountUsageService_GetOpenAIUsage_DoesNotPromoteCodexExtraToRateLimit(
},
}
- usage, err := svc.getOpenAIUsage(context.Background(), account)
+ usage, err := svc.getOpenAIUsage(context.Background(), account, false)
if err != nil {
t.Fatalf("getOpenAIUsage() error = %v", err)
}
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 24afbd68..9b5c1afc 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -397,6 +397,7 @@ type GenerateRedeemCodesInput struct {
Value float64
GroupID *int64 // 订阅类型专用:关联的分组ID
ValidityDays int // 订阅类型专用:有效天数
+ ExpiresAt *time.Time
}
type ProxyBatchDeleteResult struct {
@@ -1238,7 +1239,7 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6
providerKey := strings.TrimSpace(input.ProviderKey)
providerSubject := strings.TrimSpace(input.ProviderSubject)
if providerType == "" {
- return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat")
+ return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, wechat, or dingtalk")
}
if providerKey == "" || providerSubject == "" {
return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required")
@@ -1493,6 +1494,8 @@ func normalizeAdminAuthIdentityProviderType(input string) string {
return "oidc"
case "wechat":
return "wechat"
+ case "dingtalk":
+ return "dingtalk"
default:
return ""
}
@@ -2470,7 +2473,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.Notes = normalizeAccountNotes(input.Notes)
}
if len(input.Credentials) > 0 {
- account.Credentials = input.Credentials
+ // 敏感子键采用"incoming 没提供就保留"的合并语义:前端响应已脱敏,
+ // 全对象 PUT 编辑时不会再带回 token,避免覆盖时清空已有凭证。
+ account.Credentials = MergePreservingSensitiveCreds(account.Credentials, input.Credentials)
}
// Extra 使用 map:需要区分“未提供(nil)”与“显式清空({})”。
// 关闭配额限制时前端会删除 quota_* 键并提交 extra:{},此时也必须落库。
@@ -2966,6 +2971,10 @@ func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*Redeem
}
func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) {
+ if input.ExpiresAt != nil && !input.ExpiresAt.After(time.Now()) {
+ return nil, ErrRedeemCodeExpired
+ }
+
// 如果是订阅类型,验证必须有 GroupID
if input.Type == RedeemTypeSubscription {
if input.GroupID == nil {
@@ -2988,10 +2997,11 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener
return nil, err
}
code := RedeemCode{
- Code: codeValue,
- Type: input.Type,
- Value: input.Value,
- Status: StatusUnused,
+ Code: codeValue,
+ Type: input.Type,
+ Value: input.Value,
+ Status: StatusUnused,
+ ExpiresAt: input.ExpiresAt,
}
// 订阅类型专用字段
if input.Type == RedeemTypeSubscription {
diff --git a/backend/internal/service/admin_service_credentials_merge_test.go b/backend/internal/service/admin_service_credentials_merge_test.go
new file mode 100644
index 00000000..8250db28
--- /dev/null
+++ b/backend/internal/service/admin_service_credentials_merge_test.go
@@ -0,0 +1,117 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+type updateAccountCredsRepoStub struct {
+ mockAccountRepoForGemini
+ account *Account
+ updateCalls int
+}
+
+func (r *updateAccountCredsRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
+ return r.account, nil
+}
+
+func (r *updateAccountCredsRepoStub) Update(ctx context.Context, account *Account) error {
+ r.updateCalls++
+ r.account = account
+ return nil
+}
+
+func TestUpdateAccount_PreservesSensitiveCredsWhenIncomingOmits(t *testing.T) {
+ accountID := int64(202)
+ repo := &updateAccountCredsRepoStub{
+ account: &Account{
+ ID: accountID,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Credentials: map[string]any{
+ "refresh_token": "rt-existing",
+ "access_token": "at-existing",
+ "id_token": "id-existing",
+ "base_url": "https://old.example.com",
+ },
+ },
+ }
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ // 模拟前端编辑:仅修改 base_url,没有传 token(脱敏后前端 spread 拿不到敏感键)
+ updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
+ Credentials: map[string]any{
+ "base_url": "https://new.example.com",
+ },
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, updated)
+ require.Equal(t, 1, repo.updateCalls)
+
+ // 敏感键应保留
+ require.Equal(t, "rt-existing", repo.account.Credentials["refresh_token"])
+ require.Equal(t, "at-existing", repo.account.Credentials["access_token"])
+ require.Equal(t, "id-existing", repo.account.Credentials["id_token"])
+ // 非敏感键被替换
+ require.Equal(t, "https://new.example.com", repo.account.Credentials["base_url"])
+}
+
+func TestUpdateAccount_ExplicitNewTokenOverwrites(t *testing.T) {
+ accountID := int64(203)
+ repo := &updateAccountCredsRepoStub{
+ account: &Account{
+ ID: accountID,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Credentials: map[string]any{
+ "refresh_token": "rt-old",
+ "api_key": "sk-old",
+ },
+ },
+ }
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
+ Credentials: map[string]any{
+ "refresh_token": "rt-new",
+ // api_key 没传 → 应保留旧值
+ },
+ })
+ require.NoError(t, err)
+ require.NotNil(t, updated)
+
+ require.Equal(t, "rt-new", repo.account.Credentials["refresh_token"])
+ require.Equal(t, "sk-old", repo.account.Credentials["api_key"])
+}
+
+func TestUpdateAccount_EmptyCredentialsSkipsUpdate(t *testing.T) {
+ accountID := int64(204)
+ repo := &updateAccountCredsRepoStub{
+ account: &Account{
+ ID: accountID,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ Credentials: map[string]any{
+ "refresh_token": "rt-existing",
+ },
+ },
+ }
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ _, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
+ Credentials: map[string]any{}, // len == 0 → 闸门跳过
+ Name: "renamed",
+ })
+ require.NoError(t, err)
+
+ require.Equal(t, "rt-existing", repo.account.Credentials["refresh_token"], "空 credentials 不应触碰已有 token")
+ require.Equal(t, "renamed", repo.account.Name)
+}
diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go
index a76e59fb..951f324c 100644
--- a/backend/internal/service/antigravity_gateway_service.go
+++ b/backend/internal/service/antigravity_gateway_service.go
@@ -628,11 +628,6 @@ urlFallbackLoop:
return nil, err
}
- // Capture upstream request body for ops retry of this attempt.
- if p.c != nil && len(p.body) > 0 {
- p.c.Set(OpsUpstreamRequestBodyKey, string(p.body))
- }
-
resp, err = p.httpUpstream.Do(upstreamReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if err == nil && resp == nil {
err = errors.New("upstream returned nil response")
@@ -2094,7 +2089,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
}
// 解析请求以获取 image_size(用于图片计费)
- imageSize := s.extractImageSize(body)
+ imageInputSize := s.extractImageInputSize(body)
+ imageSize := normalizeOpenAIImageSizeTier(imageInputSize)
switch action {
case "generateContent", "streamGenerateContent":
@@ -2465,6 +2461,7 @@ handleSuccess:
ClientDisconnect: clientDisconnect,
ImageCount: imageCount,
ImageSize: imageSize,
+ ImageInputSize: imageInputSize,
}, nil
}
@@ -4063,21 +4060,17 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
}
}
-// extractImageSize 从 Gemini 请求中提取 image_size 参数
-func (s *AntigravityGatewayService) extractImageSize(body []byte) string {
+func (s *AntigravityGatewayService) extractImageInputSize(body []byte) string {
var req antigravity.GeminiRequest
if err := json.Unmarshal(body, &req); err != nil {
- return "2K" // 默认 2K
+ return ""
}
if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil {
- size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize))
- if size == "1K" || size == "2K" || size == "4K" {
- return size
- }
+ return strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize)
}
- return "2K" // 默认 2K
+ return ""
}
// isImageGenerationModel 判断模型是否为图片生成模型
diff --git a/backend/internal/service/antigravity_image_test.go b/backend/internal/service/antigravity_image_test.go
index 7fd2f843..76269dd3 100644
--- a/backend/internal/service/antigravity_image_test.go
+++ b/backend/internal/service/antigravity_image_test.go
@@ -46,15 +46,15 @@ func TestExtractImageSize_ValidSizes(t *testing.T) {
// 1K
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1K"}}}`)
- require.Equal(t, "1K", svc.extractImageSize(body))
+ require.Equal(t, "1K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
// 2K
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"2K"}}}`)
- require.Equal(t, "2K", svc.extractImageSize(body))
+ require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
// 4K
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4K"}}}`)
- require.Equal(t, "4K", svc.extractImageSize(body))
+ require.Equal(t, "4K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
}
// TestExtractImageSize_CaseInsensitive 测试大小写不敏感
@@ -62,10 +62,10 @@ func TestExtractImageSize_CaseInsensitive(t *testing.T) {
svc := &AntigravityGatewayService{}
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1k"}}}`)
- require.Equal(t, "1K", svc.extractImageSize(body))
+ require.Equal(t, "1K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4k"}}}`)
- require.Equal(t, "4K", svc.extractImageSize(body))
+ require.Equal(t, "4K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
}
// TestExtractImageSize_Default 测试无 imageConfig 返回默认 2K
@@ -74,15 +74,15 @@ func TestExtractImageSize_Default(t *testing.T) {
// 无 generationConfig
body := []byte(`{"contents":[]}`)
- require.Equal(t, "2K", svc.extractImageSize(body))
+ require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
// 有 generationConfig 但无 imageConfig
body = []byte(`{"generationConfig":{"temperature":0.7}}`)
- require.Equal(t, "2K", svc.extractImageSize(body))
+ require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
// 有 imageConfig 但无 imageSize
body = []byte(`{"generationConfig":{"imageConfig":{}}}`)
- require.Equal(t, "2K", svc.extractImageSize(body))
+ require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
}
// TestExtractImageSize_InvalidJSON 测试非法 JSON 返回默认 2K
@@ -90,10 +90,10 @@ func TestExtractImageSize_InvalidJSON(t *testing.T) {
svc := &AntigravityGatewayService{}
body := []byte(`not valid json`)
- require.Equal(t, "2K", svc.extractImageSize(body))
+ require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
body = []byte(`{"broken":`)
- require.Equal(t, "2K", svc.extractImageSize(body))
+ require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
}
// TestExtractImageSize_EmptySize 测试空 imageSize 返回默认 2K
@@ -101,11 +101,11 @@ func TestExtractImageSize_EmptySize(t *testing.T) {
svc := &AntigravityGatewayService{}
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":""}}}`)
- require.Equal(t, "2K", svc.extractImageSize(body))
+ require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
// 空格
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":" "}}}`)
- require.Equal(t, "2K", svc.extractImageSize(body))
+ require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
}
// TestExtractImageSize_InvalidSize 测试无效尺寸返回默认 2K
@@ -113,11 +113,11 @@ func TestExtractImageSize_InvalidSize(t *testing.T) {
svc := &AntigravityGatewayService{}
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"3K"}}}`)
- require.Equal(t, "2K", svc.extractImageSize(body))
+ require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"8K"}}}`)
- require.Equal(t, "2K", svc.extractImageSize(body))
+ require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"invalid"}}}`)
- require.Equal(t, "2K", svc.extractImageSize(body))
+ require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body)))
}
diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go
index e3c8298c..3478fda5 100644
--- a/backend/internal/service/auth_oauth_email_flow.go
+++ b/backend/internal/service/auth_oauth_email_flow.go
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
+ "log/slog"
"net/mail"
"strings"
"time"
@@ -18,7 +19,7 @@ func normalizeOAuthSignupSource(signupSource string) string {
switch signupSource {
case "", "email":
return "email"
- case "linuxdo", "wechat", "oidc", "github", "google":
+ case "linuxdo", "wechat", "oidc", "github", "google", "dingtalk":
return signupSource
default:
return "email"
@@ -71,7 +72,7 @@ func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, i
if err != nil {
return nil, ErrInvitationCodeInvalid
}
- if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
+ if redeemCode.Type != RedeemTypeInvitation || !redeemCode.CanUse() {
return nil, ErrInvitationCodeInvalid
}
return redeemCode, nil
@@ -109,7 +110,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
if s == nil {
return nil, nil, ErrServiceUnavailable
}
- if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
+ if s.settingService == nil || (!s.settingService.IsRegistrationEnabled(ctx) && !s.canBypassRegistrationDisabledForOAuth(ctx, signupSource)) {
return nil, nil, ErrRegDisabled
}
@@ -118,18 +119,22 @@ func (s *AuthService) RegisterOAuthEmailAccount(
return nil, nil, ErrEmailReserved
}
if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil {
+ slog.Error("oauth email register: policy rejected", "email", email, "error", err.Error())
return nil, nil, err
}
if err := s.VerifyOAuthEmailCode(ctx, email, verifyCode); err != nil {
+ slog.Error("oauth email register: verify code failed", "email", email, "error", err.Error())
return nil, nil, err
}
if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil {
+ slog.Error("oauth email register: invitation failed", "email", email, "error", err.Error())
return nil, nil, err
}
existsEmail, err := s.userRepo.ExistsByEmail(ctx, email)
if err != nil {
+ slog.Error("oauth email register: ExistsByEmail failed", "email", email, "error", err.Error())
return nil, nil, ErrServiceUnavailable
}
if existsEmail {
@@ -158,6 +163,7 @@ func (s *AuthService) RegisterOAuthEmailAccount(
if errors.Is(err, ErrEmailExists) {
return nil, nil, ErrEmailExists
}
+ slog.Error("oauth email register: userRepo.Create failed", "email", email, "signup_source", signupSource, "error", err.Error())
return nil, nil, ErrServiceUnavailable
}
@@ -181,7 +187,7 @@ func (s *AuthService) RegisterVerifiedOAuthEmailAccount(
if s == nil {
return nil, nil, ErrServiceUnavailable
}
- if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
+ if s.settingService == nil || (!s.settingService.IsRegistrationEnabled(ctx) && !s.canBypassRegistrationDisabledForOAuth(ctx, signupSource)) {
return nil, nil, ErrRegDisabled
}
@@ -358,6 +364,7 @@ func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invit
UsedAt: entity.UsedAt,
Notes: oauthEmailFlowStringValue(entity.Notes),
CreatedAt: entity.CreatedAt,
+ ExpiresAt: entity.ExpiresAt,
GroupID: entity.GroupID,
ValidityDays: entity.ValidityDays,
}, nil
@@ -368,7 +375,11 @@ func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invit
func (s *AuthService) useOAuthRegistrationInvitation(ctx context.Context, invitationID, userID int64) error {
if client := s.oauthEmailFlowClient(ctx); client != nil {
affected, err := client.RedeemCode.Update().
- Where(redeemcode.IDEQ(invitationID), redeemcode.StatusEQ(StatusUnused)).
+ Where(
+ redeemcode.IDEQ(invitationID),
+ redeemcode.StatusEQ(StatusUnused),
+ redeemcode.Or(redeemcode.ExpiresAtIsNil(), redeemcode.ExpiresAtGT(time.Now().UTC())),
+ ).
SetStatus(StatusUsed).
SetUsedBy(userID).
SetUsedAt(time.Now().UTC()).
@@ -396,6 +407,11 @@ func (s *AuthService) updateOAuthRegistrationInvitation(ctx context.Context, cod
SetStatus(code.Status).
SetNotes(code.Notes).
SetValidityDays(code.ValidityDays)
+ if code.ExpiresAt != nil {
+ update = update.SetExpiresAt(*code.ExpiresAt)
+ } else {
+ update = update.ClearExpiresAt()
+ }
if code.UsedBy != nil {
update = update.SetUsedBy(*code.UsedBy)
} else {
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index e01e8217..ce2b3fa3 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -157,7 +157,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
return "", nil, ErrInvitationCodeInvalid
}
// 检查类型和状态
- if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
+ if redeemCode.Type != RedeemTypeInvitation || !redeemCode.CanUse() {
logger.LegacyPrintf("service.auth", "[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status)
return "", nil, ErrInvitationCodeInvalid
}
@@ -560,11 +560,25 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
return token, user, nil
}
+// canBypassRegistrationDisabledForOAuth 在钉钉企业模式(internal_only)且
+// dingtalk_connect_bypass_registration=true 时,允许跳过全局 registration_enabled 检查。
+func (s *AuthService) canBypassRegistrationDisabledForOAuth(ctx context.Context, signupSource string) bool {
+ if signupSource != "dingtalk" {
+ return false
+ }
+ cfg, err := s.settingService.GetDingTalkConnectOAuthConfig(ctx)
+ if err != nil || !cfg.Enabled || !cfg.BypassRegistration {
+ return false
+ }
+ return cfg.CorpRestrictionPolicy == "internal_only"
+}
+
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。
// invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。
// affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。
-func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) {
+// signupSource 标识来源渠道("dingtalk"/"linuxdo"/"wechat"/"oidc" 等),仅用于豁免检查。
+func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode, signupSource string) (*TokenPair, *User, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, nil, errors.New("refresh token cache not configured")
@@ -587,7 +601,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// OAuth 首次登录视为注册
- if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
+ if s.settingService == nil || (!s.settingService.IsRegistrationEnabled(ctx) && !s.canBypassRegistrationDisabledForOAuth(ctx, signupSource)) {
return nil, nil, ErrRegDisabled
}
@@ -601,7 +615,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
if err != nil {
return nil, nil, ErrInvitationCodeInvalid
}
- if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused {
+ if redeemCode.Type != RedeemTypeInvitation || !redeemCode.CanUse() {
return nil, nil, ErrInvitationCodeInvalid
}
invitationRedeemCode = redeemCode
@@ -617,7 +631,11 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return nil, nil, fmt.Errorf("hash password: %w", err)
}
- signupSource := inferLegacySignupSource(email)
+ // 优先用 caller 显式传入的 signupSource(如 "dingtalk" / "linuxdo" / "oidc" / "wechat"),
+ // 否则才按邮箱后缀推断——避免有真实邮箱的 OAuth 用户被推断为 "email" 渠道,导致渠道授权错读。
+ if strings.TrimSpace(signupSource) == "" {
+ signupSource = inferLegacySignupSource(email)
+ }
grantPlan := s.resolveSignupGrantPlan(ctx, signupSource)
var defaultRPMLimit int
if s.settingService != nil {
@@ -779,6 +797,8 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource
return defaults.GitHub, true
case "google":
return defaults.Google, true
+ case "dingtalk":
+ return defaults.DingTalk, true
default:
return ProviderDefaultGrantSettings{}, false
}
@@ -992,6 +1012,8 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User, s
func inferLegacySignupSource(email string) string {
normalized := strings.ToLower(strings.TrimSpace(email))
switch {
+ case strings.HasSuffix(normalized, DingTalkConnectSyntheticEmailDomain):
+ return "dingtalk"
case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain):
return "linuxdo"
case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain):
@@ -1086,7 +1108,8 @@ func isReservedEmail(email string) bool {
normalized := strings.ToLower(strings.TrimSpace(email))
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) ||
- strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain)
+ strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain) ||
+ strings.HasSuffix(normalized, DingTalkConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT access token
diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go
index acc44a38..ece02474 100644
--- a/backend/internal/service/auth_service_register_test.go
+++ b/backend/internal/service/auth_service_register_test.go
@@ -602,7 +602,7 @@ func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaul
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, 9.5, user.Balance)
- require.Equal(t, 2, user.Concurrency)
+ require.Equal(t, 5, user.Concurrency)
require.Len(t, assigner.calls, 1)
require.Equal(t, int64(31), assigner.calls[0].GroupID)
require.Equal(t, 5, assigner.calls[0].ValidityDays)
@@ -622,7 +622,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
- tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "")
+ tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "", "linuxdo")
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.NotNil(t, user)
@@ -658,7 +658,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA
service.defaultSubAssigner = assigner
service.refreshTokenCache = &refreshTokenCacheStub{}
- tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "")
+ tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "", "linuxdo")
require.NoError(t, err)
require.NotNil(t, tokenPair)
require.Equal(t, existing.ID, user.ID)
@@ -667,3 +667,99 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA
require.Empty(t, repo.created)
require.Empty(t, assigner.calls)
}
+
+// newAuthServiceWithDingTalkCfg 构建一个含完整 DingTalk config 的 AuthService,
+// 用于测试 canBypassRegistrationDisabledForOAuth。
+func newAuthServiceWithDingTalkCfg(settings map[string]string, dtCfg config.DingTalkConnectConfig) *AuthService {
+ cfg := &config.Config{
+ JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1},
+ Default: config.DefaultConfig{UserBalance: 3.5, UserConcurrency: 2},
+ DingTalk: dtCfg,
+ }
+ settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
+ return NewAuthService(nil, nil, nil, nil, cfg, settingService, nil, nil, nil, nil, nil, nil)
+}
+
+// minDingTalkURLs 返回一个包含必填字段的基础 DingTalkConnectConfig(不设 Enabled/BypassRegistration/Policy)。
+func minDingTalkURLs() config.DingTalkConnectConfig {
+ return config.DingTalkConnectConfig{
+ ClientID: "test-client",
+ ClientSecret: "test-secret",
+ AuthorizeURL: "https://example.com/oauth2/auth",
+ TokenURL: "https://example.com/oauth2/token",
+ UserInfoURL: "https://example.com/oauth2/userinfo",
+ RedirectURL: "https://example.com/callback",
+ FrontendRedirectURL: "https://example.com/auth/callback",
+ DingTalkAppKind: "internal_app",
+ AppType: "internal",
+ }
+}
+
+func TestCanBypassRegistrationDisabledForOAuth(t *testing.T) {
+ cases := []struct {
+ name string
+ signupSource string
+ settings map[string]string
+ dtCfg config.DingTalkConnectConfig
+ want bool
+ }{
+ {
+ name: "non-dingtalk source → false",
+ signupSource: "linuxdo",
+ settings: map[string]string{},
+ dtCfg: minDingTalkURLs(),
+ want: false,
+ },
+ {
+ name: "dingtalk but cfg.Enabled=false → false",
+ signupSource: "dingtalk",
+ settings: map[string]string{
+ SettingKeyDingTalkConnectEnabled: "false",
+ SettingKeyDingTalkConnectBypassRegistration: "true",
+ SettingKeyDingTalkConnectCorpRestrictionPolicy: "internal_only",
+ },
+ dtCfg: minDingTalkURLs(),
+ want: false,
+ },
+ {
+ name: "dingtalk enabled but BypassRegistration=false → false",
+ signupSource: "dingtalk",
+ settings: map[string]string{
+ SettingKeyDingTalkConnectEnabled: "true",
+ SettingKeyDingTalkConnectBypassRegistration: "false",
+ SettingKeyDingTalkConnectCorpRestrictionPolicy: "internal_only",
+ },
+ dtCfg: minDingTalkURLs(),
+ want: false,
+ },
+ {
+ name: "dingtalk enabled + bypass=true but policy=none → false",
+ signupSource: "dingtalk",
+ settings: map[string]string{
+ SettingKeyDingTalkConnectEnabled: "true",
+ SettingKeyDingTalkConnectBypassRegistration: "true",
+ SettingKeyDingTalkConnectCorpRestrictionPolicy: "none",
+ },
+ dtCfg: minDingTalkURLs(),
+ want: false,
+ },
+ {
+ name: "dingtalk enabled + bypass=true + policy=internal_only → true",
+ signupSource: "dingtalk",
+ settings: map[string]string{
+ SettingKeyDingTalkConnectEnabled: "true",
+ SettingKeyDingTalkConnectBypassRegistration: "true",
+ SettingKeyDingTalkConnectCorpRestrictionPolicy: "internal_only",
+ },
+ dtCfg: minDingTalkURLs(),
+ want: true,
+ },
+ }
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ svc := newAuthServiceWithDingTalkCfg(tc.settings, tc.dtCfg)
+ got := svc.canBypassRegistrationDisabledForOAuth(context.Background(), tc.signupSource)
+ require.Equal(t, tc.want, got)
+ })
+ }
+}
diff --git a/backend/internal/service/auth_service_test.go b/backend/internal/service/auth_service_test.go
new file mode 100644
index 00000000..2aeb6205
--- /dev/null
+++ b/backend/internal/service/auth_service_test.go
@@ -0,0 +1,13 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestIsReservedEmail_DingTalkDomain(t *testing.T) {
+ require.True(t, isReservedEmail("dingtalk-123@dingtalk-connect.invalid"))
+ require.True(t, isReservedEmail("DINGTALK-456@DINGTALK-CONNECT.INVALID")) // case-insensitive
+ require.False(t, isReservedEmail("real@dingtalk.com"))
+}
diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go
index 45025fe6..47975c8c 100644
--- a/backend/internal/service/billing_service.go
+++ b/backend/internal/service/billing_service.go
@@ -809,6 +809,7 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag
if imageCount <= 0 {
return &CostBreakdown{}
}
+ imageSize = NormalizeImageBillingTierOrDefault(imageSize)
// 获取单价
unitPrice := s.getImageUnitPrice(model, imageSize, groupConfig)
diff --git a/backend/internal/service/billing_service_image_test.go b/backend/internal/service/billing_service_image_test.go
index 8d3ca987..0232a258 100644
--- a/backend/internal/service/billing_service_image_test.go
+++ b/backend/internal/service/billing_service_image_test.go
@@ -48,6 +48,21 @@ func TestCalculateImageCost_GroupCustomPricing(t *testing.T) {
require.InDelta(t, 0.30, cost.TotalCost, 0.0001)
}
+func TestCalculateImageCost_NormalizesInvalidSizeTo2K(t *testing.T) {
+ svc := &BillingService{}
+
+ price2K := 0.25
+ groupConfig := &ImagePriceConfig{Price2K: &price2K}
+
+ for _, imageSize := range []string{"", "auto", "not-a-size"} {
+ t.Run(imageSize, func(t *testing.T) {
+ cost := svc.CalculateImageCost("gemini-3-pro-image", imageSize, 2, groupConfig, 1.0)
+ require.InDelta(t, 0.50, cost.TotalCost, 0.0001)
+ require.InDelta(t, 0.50, cost.ActualCost, 0.0001)
+ })
+ }
+}
+
// TestCalculateImageCost_4KDoublePrice 测试 4K 默认价格翻倍
func TestCalculateImageCost_4KDoublePrice(t *testing.T) {
svc := &BillingService{}
diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go
index 158bf8a3..760f688d 100644
--- a/backend/internal/service/channel.go
+++ b/backend/internal/service/channel.go
@@ -262,10 +262,17 @@ func deepCopyFeaturesConfig(src map[string]any) map[string]any {
}
// ValidateIntervals 校验区间列表的合法性。
-// 规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens;
-// 所有价格字段 >= 0;区间按 MinTokens 排序后无重叠((min, max] 语义);
-// 无界区间(MaxTokens=nil)必须是最后一个。间隙允许(回退默认价格)。
-func ValidateIntervals(intervals []PricingInterval) error {
+//
+// mode 决定区间语义:
+// - BillingModeToken(含空值):区间是上下文 token 数分段 (min, max],
+// 按 MinTokens 排序后无重叠,无界区间(MaxTokens=nil)必须是最后一个。
+// - BillingModePerRequest / BillingModeImage:区间是按 tier_label
+// (1K/2K/4K 等) 分层,匹配走 label 不依赖 min/max,因此跳过区间重叠
+// 与 last-unlimited 校验,仅做单条字段自洽(min/max/价格非负)检查。
+//
+// 通用规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens;
+// 所有价格字段 >= 0。
+func ValidateIntervals(intervals []PricingInterval, mode BillingMode) error {
if len(intervals) == 0 {
return nil
}
@@ -280,6 +287,11 @@ func ValidateIntervals(intervals []PricingInterval) error {
return err
}
}
+
+ // per_request / image 模式按 tier_label 匹配,不做 token 区间重叠校验
+ if mode == BillingModePerRequest || mode == BillingModeImage {
+ return nil
+ }
return validateIntervalOverlap(sorted)
}
diff --git a/backend/internal/service/channel_available.go b/backend/internal/service/channel_available.go
index 815730e3..d2d24659 100644
--- a/backend/internal/service/channel_available.go
+++ b/backend/internal/service/channel_available.go
@@ -103,7 +103,11 @@ func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel,
}
// fillGlobalPricingFallback 对未命中渠道定价的支持模型,从全局 LiteLLM 数据合成一份
-// 展示用定价(按 token 计费)。仅用于「可用渠道」展示,不影响真实计费链路。
+// 展示用定价。仅用于「可用渠道」展示,不影响真实计费链路。
+//
+// 触发条件:
+// 1. Pricing == nil(渠道完全没声明该模型的定价条目)
+// 2. Pricing 非 nil 但所有价格字段为空(admin UI 建了条目但没填价格)
//
// 当 s.pricingService 为 nil(测试场景),跳过回落。
func (s *ChannelService) fillGlobalPricingFallback(models []SupportedModel) {
@@ -111,28 +115,72 @@ func (s *ChannelService) fillGlobalPricingFallback(models []SupportedModel) {
return
}
for i := range models {
- if models[i].Pricing != nil {
+ if !pricingNeedsFallback(models[i].Pricing) {
continue
}
lp := s.pricingService.GetModelPricing(models[i].Name)
if lp == nil {
continue
}
- models[i].Pricing = synthesizePricingFromLiteLLM(lp)
+ models[i].Pricing = synthesizePricingFromLiteLLM(lp, models[i].Pricing)
}
}
+// pricingNeedsFallback 判定一个 ChannelModelPricing 是否需要走全局回落。
+// 价格全部缺失(无 flat 字段且无任何带价 interval)即视为未配置。
+func pricingNeedsFallback(p *ChannelModelPricing) bool {
+ if p == nil {
+ return true
+ }
+ if p.InputPrice != nil || p.OutputPrice != nil ||
+ p.CacheWritePrice != nil || p.CacheReadPrice != nil ||
+ p.ImageOutputPrice != nil || p.PerRequestPrice != nil {
+ return false
+ }
+ for _, iv := range p.Intervals {
+ if iv.InputPrice != nil || iv.OutputPrice != nil ||
+ iv.CacheWritePrice != nil || iv.CacheReadPrice != nil ||
+ iv.PerRequestPrice != nil {
+ return false
+ }
+ }
+ return true
+}
+
// synthesizePricingFromLiteLLM 把 LiteLLM 的定价数据转成 ChannelModelPricing 形态,
-// 仅用于展示。BillingMode 固定为 token;图片场景的 OutputCostPerImageToken 也归到
-// ImageOutputPrice 字段(与渠道侧"图片输出按 token 计价"语义一致)。
+// 仅用于展示。
+//
+// 计费模式优先级:
+// 1. 渠道已选 BillingMode(admin 在 UI 里选了 image / per_request 但没填价的场景,
+// 按选定模式合成对应字段)
+// 2. LiteLLM mode="image_generation" → image
+// 3. 默认 token
//
// LiteLLM 中字段 0 视为未配置,不带入展示。
-func synthesizePricingFromLiteLLM(lp *LiteLLMModelPricing) *ChannelModelPricing {
+func synthesizePricingFromLiteLLM(lp *LiteLLMModelPricing, existing *ChannelModelPricing) *ChannelModelPricing {
if lp == nil {
- return nil
+ return existing
+ }
+
+ mode := BillingModeToken
+ switch {
+ case existing != nil && existing.BillingMode != "":
+ mode = existing.BillingMode
+ case lp.Mode == "image_generation":
+ mode = BillingModeImage
+ }
+
+ if mode == BillingModeImage || mode == BillingModePerRequest {
+ return &ChannelModelPricing{
+ BillingMode: mode,
+ PerRequestPrice: nonZeroPtr(lp.OutputCostPerImage),
+ ImageOutputPrice: nonZeroPtr(lp.OutputCostPerImageToken),
+ InputPrice: nonZeroPtr(lp.InputCostPerToken),
+ OutputPrice: nonZeroPtr(lp.OutputCostPerToken),
+ }
}
return &ChannelModelPricing{
- BillingMode: BillingModeToken,
+ BillingMode: mode,
InputPrice: nonZeroPtr(lp.InputCostPerToken),
OutputPrice: nonZeroPtr(lp.OutputCostPerToken),
CacheWritePrice: nonZeroPtr(lp.CacheCreationInputTokenCost),
diff --git a/backend/internal/service/channel_available_test.go b/backend/internal/service/channel_available_test.go
index 8be70ceb..d59e587e 100644
--- a/backend/internal/service/channel_available_test.go
+++ b/backend/internal/service/channel_available_test.go
@@ -175,3 +175,137 @@ func TestListAvailable_DefaultsEmptyBillingModelSource(t *testing.T) {
require.Equal(t, BillingModelSourceChannelMapped, byName["empty"])
require.Equal(t, BillingModelSourceUpstream, byName["explicit"])
}
+
+func TestPricingNeedsFallback(t *testing.T) {
+ tests := []struct {
+ name string
+ in *ChannelModelPricing
+ want bool
+ }{
+ {"nil", nil, true},
+ {"empty struct", &ChannelModelPricing{BillingMode: BillingModeToken}, true},
+ {"all-empty intervals", &ChannelModelPricing{
+ BillingMode: BillingModeImage,
+ Intervals: []PricingInterval{{TierLabel: "1K"}, {TierLabel: "2K"}},
+ }, true},
+ {"flat input set", &ChannelModelPricing{InputPrice: testPtrFloat64(3e-6)}, false},
+ {"flat per_request set", &ChannelModelPricing{PerRequestPrice: testPtrFloat64(0.04)}, false},
+ {"interval with price", &ChannelModelPricing{
+ Intervals: []PricingInterval{{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}},
+ }, false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ require.Equal(t, tt.want, pricingNeedsFallback(tt.in))
+ })
+ }
+}
+
+func TestSynthesizePricingFromLiteLLM_TokenMode(t *testing.T) {
+ lp := &LiteLLMModelPricing{
+ Mode: "chat",
+ InputCostPerToken: 3e-6,
+ OutputCostPerToken: 1.5e-5,
+ CacheCreationInputTokenCost: 3.75e-6,
+ CacheReadInputTokenCost: 3e-7,
+ }
+ got := synthesizePricingFromLiteLLM(lp, nil)
+ require.NotNil(t, got)
+ require.Equal(t, BillingModeToken, got.BillingMode)
+ require.NotNil(t, got.InputPrice)
+ require.InDelta(t, 3e-6, *got.InputPrice, 1e-12)
+ require.NotNil(t, got.CacheReadPrice)
+}
+
+func TestSynthesizePricingFromLiteLLM_ImageGenerationMode(t *testing.T) {
+ // LiteLLM mode=image_generation 且渠道未声明模式时,按 image 合成。
+ lp := &LiteLLMModelPricing{
+ Mode: "image_generation",
+ OutputCostPerImageToken: 4e-5,
+ }
+ got := synthesizePricingFromLiteLLM(lp, nil)
+ require.NotNil(t, got)
+ require.Equal(t, BillingModeImage, got.BillingMode)
+ require.Nil(t, got.PerRequestPrice)
+ require.NotNil(t, got.ImageOutputPrice)
+}
+
+func TestSynthesizePricingFromLiteLLM_RespectsExistingChannelMode(t *testing.T) {
+ // admin UI 选了 per_request 但没填价:LiteLLM 数据按 per_request 合成,
+ // 即便 LiteLLM 标的是 chat 模式也尊重渠道选择。
+ lp := &LiteLLMModelPricing{
+ Mode: "chat",
+ InputCostPerToken: 5e-6,
+ OutputCostPerImage: 0.04,
+ }
+ existing := &ChannelModelPricing{BillingMode: BillingModePerRequest}
+ got := synthesizePricingFromLiteLLM(lp, existing)
+ require.NotNil(t, got)
+ require.Equal(t, BillingModePerRequest, got.BillingMode)
+ require.NotNil(t, got.PerRequestPrice)
+ require.InDelta(t, 0.04, *got.PerRequestPrice, 1e-12)
+}
+
+func TestFillGlobalPricingFallback_NilPricing(t *testing.T) {
+ pricingSvc := newStubPricingServiceFromMap(map[string]*LiteLLMModelPricing{
+ "claude-opus-4-5": {Mode: "chat", InputCostPerToken: 5e-6},
+ })
+ svc := &ChannelService{pricingService: pricingSvc}
+
+ models := []SupportedModel{
+ {Name: "claude-opus-4-5", Platform: "anthropic"},
+ }
+ svc.fillGlobalPricingFallback(models)
+ require.NotNil(t, models[0].Pricing)
+ require.NotNil(t, models[0].Pricing.InputPrice)
+ require.InDelta(t, 5e-6, *models[0].Pricing.InputPrice, 1e-12)
+}
+
+func TestFillGlobalPricingFallback_EmptyPricingFillsFromLiteLLM(t *testing.T) {
+ // 核心场景:admin UI 建了 pricing 条目(image 模式)但没填价,应走 LiteLLM 兜底。
+ pricingSvc := newStubPricingServiceFromMap(map[string]*LiteLLMModelPricing{
+ "gpt-image-1": {
+ Mode: "image_generation",
+ OutputCostPerImageToken: 4e-5,
+ },
+ })
+ svc := &ChannelService{pricingService: pricingSvc}
+
+ models := []SupportedModel{
+ {
+ Name: "gpt-image-1",
+ Platform: "openai",
+ Pricing: &ChannelModelPricing{
+ BillingMode: BillingModeImage,
+ Intervals: []PricingInterval{{TierLabel: "1K"}, {TierLabel: "2K"}},
+ },
+ },
+ }
+ svc.fillGlobalPricingFallback(models)
+ require.NotNil(t, models[0].Pricing)
+ require.Equal(t, BillingModeImage, models[0].Pricing.BillingMode)
+ require.NotNil(t, models[0].Pricing.ImageOutputPrice)
+ require.InDelta(t, 4e-5, *models[0].Pricing.ImageOutputPrice, 1e-12)
+}
+
+func TestFillGlobalPricingFallback_KeepsExistingPrice(t *testing.T) {
+ // 渠道已经填了价格的条目不应被回落覆盖。
+ pricingSvc := newStubPricingServiceFromMap(map[string]*LiteLLMModelPricing{
+ "served-model": {Mode: "chat", InputCostPerToken: 1e-6},
+ })
+ svc := &ChannelService{pricingService: pricingSvc}
+
+ existing := &ChannelModelPricing{
+ BillingMode: BillingModeToken,
+ InputPrice: testPtrFloat64(9e-9),
+ }
+ models := []SupportedModel{
+ {Name: "served-model", Platform: "anthropic", Pricing: existing},
+ }
+ svc.fillGlobalPricingFallback(models)
+ require.Same(t, existing, models[0].Pricing)
+}
+
+func newStubPricingServiceFromMap(data map[string]*LiteLLMModelPricing) *PricingService {
+ return &PricingService{pricingData: data}
+}
diff --git a/backend/internal/service/channel_monitor_checker.go b/backend/internal/service/channel_monitor_checker.go
index 33570629..25737e45 100644
--- a/backend/internal/service/channel_monitor_checker.go
+++ b/backend/internal/service/channel_monitor_checker.go
@@ -40,6 +40,8 @@ func newSSRFSafeHTTPClient(timeout time.Duration) *http.Client {
// CheckOptions 承载一次检测的自定义入参。
// 所有字段都是可选(零值即等价于"用默认行为")。
type CheckOptions struct {
+ // APIMode 仅对 OpenAI provider 生效;空串等同 chat_completions。
+ APIMode string
// ExtraHeaders 用户自定义 HTTP 头(merge 到 adapter 默认 headers,用户优先)。
ExtraHeaders map[string]string
// BodyOverrideMode: off | merge | replace
@@ -164,21 +166,7 @@ type providerAdapter struct {
//
//nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。
var providerAdapters = map[string]providerAdapter{
- MonitorProviderOpenAI: {
- buildPath: func(string) string { return providerOpenAIPath },
- buildBody: func(model, prompt string) ([]byte, error) {
- return json.Marshal(map[string]any{
- "model": model,
- "messages": []map[string]string{{"role": "user", "content": prompt}},
- "max_tokens": monitorChallengeMaxTokens,
- "stream": false,
- })
- },
- buildHeaders: func(apiKey string) map[string]string {
- return map[string]string{"Authorization": "Bearer " + apiKey}
- },
- textPath: "choices.0.message.content",
- },
+ MonitorProviderOpenAI: providerOpenAIChatAdapter,
MonitorProviderAnthropic: {
buildPath: func(string) string { return providerAnthropicPath },
buildBody: func(model, prompt string) ([]byte, error) {
@@ -215,6 +203,50 @@ var providerAdapters = map[string]providerAdapter{
},
}
+//nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。
+var providerOpenAIChatAdapter = providerAdapter{
+ buildPath: func(string) string { return providerOpenAIPath },
+ buildBody: func(model, prompt string) ([]byte, error) {
+ return json.Marshal(map[string]any{
+ "model": model,
+ "messages": []map[string]string{{"role": "user", "content": prompt}},
+ "max_tokens": monitorChallengeMaxTokens,
+ "stream": false,
+ })
+ },
+ buildHeaders: func(apiKey string) map[string]string {
+ return map[string]string{"Authorization": "Bearer " + apiKey}
+ },
+ textPath: "choices.0.message.content",
+}
+
+//nolint:gochecknoglobals // 适配器表是只读静态数据,初始化后不变更。
+var providerOpenAIResponsesAdapter = providerAdapter{
+ buildPath: func(string) string { return providerOpenAIResponsesPath },
+ buildBody: func(model, prompt string) ([]byte, error) {
+ return json.Marshal(map[string]any{
+ "model": model,
+ "instructions": "You are a channel health-check endpoint. Answer the arithmetic challenge exactly and briefly.",
+ "input": prompt,
+ "max_output_tokens": monitorChallengeMaxTokens,
+ "stream": false,
+ })
+ },
+ buildHeaders: func(apiKey string) map[string]string {
+ return map[string]string{"Authorization": "Bearer " + apiKey}
+ },
+ textPath: "output.0.content.0.text",
+}
+
+// providerAdapterFor 按 provider + api_mode 选择具体 adapter。
+func providerAdapterFor(provider, apiMode string) (providerAdapter, string, bool) {
+ if provider == MonitorProviderOpenAI && defaultAPIMode(apiMode) == MonitorAPIModeResponses {
+ return providerOpenAIResponsesAdapter, MonitorAPIModeResponses, true
+ }
+ adapter, ok := providerAdapters[provider]
+ return adapter, MonitorAPIModeChatCompletions, ok
+}
+
// isSupportedProvider 校验 provider 字符串是否在 adapter 表中。
// 供 validate.go 的 validateProvider 复用,避免两份 switch 漂移。
func isSupportedProvider(p string) bool {
@@ -231,11 +263,15 @@ func isSupportedProvider(p string) bool {
// - status: HTTP 状态码
// - err: 网络 / 序列化错误
func callProvider(ctx context.Context, provider, endpoint, apiKey, model, prompt string, opts *CheckOptions) (extractedText, rawBody string, status int, err error) {
- adapter, ok := providerAdapters[provider]
+ requestedAPIMode := checkAPIMode(opts)
+ if err := validateAPIMode(provider, requestedAPIMode); err != nil {
+ return "", "", 0, err
+ }
+ adapter, apiMode, ok := providerAdapterFor(provider, requestedAPIMode)
if !ok {
return "", "", 0, fmt.Errorf("unsupported provider %q", provider)
}
- body, err := buildRequestBody(adapter, provider, model, prompt, opts)
+ body, err := buildRequestBody(adapter, provider, apiMode, model, prompt, opts)
if err != nil {
return "", "", 0, err
}
@@ -275,13 +311,16 @@ func mergeHeaders(base map[string]string, opts *CheckOptions) map[string]string
// - replace: 直接 marshal BodyOverride 作为完整 body
//
// 任何 mode 返回的 []byte 都已经是合法 JSON,可直接送入 postRawJSON。
-func buildRequestBody(adapter providerAdapter, provider, model, prompt string, opts *CheckOptions) ([]byte, error) {
+func buildRequestBody(adapter providerAdapter, provider, apiMode, model, prompt string, opts *CheckOptions) ([]byte, error) {
mode := bodyOverrideMode(opts)
if mode == MonitorBodyOverrideModeReplace {
if opts == nil || len(opts.BodyOverride) == 0 {
return nil, fmt.Errorf("replace mode: body_override is empty")
}
+ if err := validateReplaceRequestBody(provider, apiMode, opts.BodyOverride); err != nil {
+ return nil, err
+ }
body, err := json.Marshal(opts.BodyOverride)
if err != nil {
return nil, fmt.Errorf("marshal body_override (replace): %w", err)
@@ -301,7 +340,7 @@ func buildRequestBody(adapter providerAdapter, provider, model, prompt string, o
if err := json.Unmarshal(defaultBody, &defaultMap); err != nil {
return nil, fmt.Errorf("unmarshal default body for merge: %w", err)
}
- deny := bodyMergeKeyDenyList[provider]
+ deny := bodyMergeKeyDenyList[bodyMergeDenyKey(provider, apiMode)]
for k, v := range opts.BodyOverride {
if deny[k] {
continue
@@ -321,9 +360,63 @@ func buildRequestBody(adapter providerAdapter, provider, model, prompt string, o
//
//nolint:gochecknoglobals // 静态查表,初始化后不变。
var bodyMergeKeyDenyList = map[string]map[string]bool{
- MonitorProviderOpenAI: {"model": true, "messages": true, "stream": true},
- MonitorProviderAnthropic: {"model": true, "messages": true},
- MonitorProviderGemini: {"contents": true},
+ MonitorProviderOpenAI + ":" + MonitorAPIModeChatCompletions: {"model": true, "messages": true, "stream": true},
+ MonitorProviderOpenAI + ":" + MonitorAPIModeResponses: {"model": true, "instructions": true, "input": true, "stream": true},
+ MonitorProviderAnthropic: {"model": true, "messages": true},
+ MonitorProviderGemini: {"contents": true},
+}
+
+func checkAPIMode(opts *CheckOptions) string {
+ if opts == nil {
+ return MonitorAPIModeChatCompletions
+ }
+ return defaultAPIMode(opts.APIMode)
+}
+
+func bodyMergeDenyKey(provider, apiMode string) string {
+ if provider == MonitorProviderOpenAI {
+ return provider + ":" + defaultAPIMode(apiMode)
+ }
+ return provider
+}
+
+func validateReplaceRequestBody(provider, apiMode string, body map[string]any) error {
+ if provider != MonitorProviderOpenAI {
+ return nil
+ }
+ switch defaultAPIMode(apiMode) {
+ case MonitorAPIModeResponses:
+ if strings.TrimSpace(stringFromAny(body["instructions"])) == "" || !hasNonEmptyBodyValue(body["input"]) {
+ return fmt.Errorf("replace mode responses body: instructions and input are required")
+ }
+ case MonitorAPIModeChatCompletions:
+ if !hasNonEmptyBodyValue(body["messages"]) {
+ return fmt.Errorf("replace mode chat_completions body: messages are required")
+ }
+ }
+ return nil
+}
+
+func stringFromAny(v any) string {
+ s, _ := v.(string)
+ return s
+}
+
+func hasNonEmptyBodyValue(v any) bool {
+ switch val := v.(type) {
+ case nil:
+ return false
+ case string:
+ return strings.TrimSpace(val) != ""
+ case []any:
+ return len(val) > 0
+ case []map[string]any:
+ return len(val) > 0
+ case []map[string]string:
+ return len(val) > 0
+ default:
+ return true
+ }
}
// postRawJSON 发送 POST + 已序列化好的 JSON 字节,限制响应体大小,返回响应字节、HTTP status、错误。
diff --git a/backend/internal/service/channel_monitor_checker_body_test.go b/backend/internal/service/channel_monitor_checker_body_test.go
index 323cf8b7..620cf565 100644
--- a/backend/internal/service/channel_monitor_checker_body_test.go
+++ b/backend/internal/service/channel_monitor_checker_body_test.go
@@ -7,6 +7,8 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
+ "regexp"
+ "strconv"
"strings"
"testing"
"time"
@@ -57,6 +59,76 @@ func setupFakeAnthropic(t *testing.T, handler *captureHandler) string {
return srv.URL
}
+type openAICaptureHandler struct {
+ lastBody map[string]any
+ lastHeaders http.Header
+ lastPath string
+ status int
+}
+
+func (h *openAICaptureHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ h.lastHeaders = r.Header.Clone()
+ h.lastPath = r.URL.Path
+ defer func() { _ = r.Body.Close() }()
+ var parsed map[string]any
+ _ = json.NewDecoder(r.Body).Decode(&parsed)
+ h.lastBody = parsed
+
+ if h.status == 0 {
+ h.status = http.StatusOK
+ }
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(h.status)
+
+ answer := answerFromOpenAIRequest(parsed)
+ if h.lastPath == providerOpenAIResponsesPath {
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "output": []map[string]any{{
+ "content": []map[string]any{{"type": "output_text", "text": answer}},
+ }},
+ })
+ return
+ }
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "choices": []map[string]any{{"message": map[string]any{"content": answer}}},
+ })
+}
+
+func setupFakeOpenAI(t *testing.T, handler *openAICaptureHandler) string {
+ t.Helper()
+ swapMonitorHTTPClient(t)
+ srv := httptest.NewServer(handler)
+ t.Cleanup(srv.Close)
+ return srv.URL
+}
+
+func answerFromOpenAIRequest(body map[string]any) string {
+ prompt, _ := body["input"].(string)
+ if prompt == "" {
+ if messages, ok := body["messages"].([]any); ok && len(messages) > 0 {
+ if msg, ok := messages[0].(map[string]any); ok {
+ prompt, _ = msg["content"].(string)
+ }
+ }
+ }
+ return answerFromChallengePrompt(prompt)
+}
+
+var challengeQuestionRegex = regexp.MustCompile(`Q: (\d+) ([+-]) (\d+) = \?\nA:$`)
+
+func answerFromChallengePrompt(prompt string) string {
+ m := challengeQuestionRegex.FindStringSubmatch(prompt)
+ if len(m) != 4 {
+ return "0"
+ }
+ left, _ := strconv.Atoi(m[1])
+ right, _ := strconv.Atoi(m[3])
+ if m[2] == "+" {
+ return strconv.Itoa(left + right)
+ }
+ return strconv.Itoa(left - right)
+}
+
func TestRunCheckForModel_OffMode_PreservesDefaultBody(t *testing.T) {
h := &captureHandler{respondText: "the answer is 42"}
endpoint := setupFakeAnthropic(t, h)
@@ -75,6 +147,95 @@ func TestRunCheckForModel_OffMode_PreservesDefaultBody(t *testing.T) {
}
}
+func TestRunCheckForModel_OpenAI_DefaultChatRequest(t *testing.T) {
+ h := &openAICaptureHandler{}
+ endpoint := setupFakeOpenAI(t, h)
+
+ res := runCheckForModel(context.Background(), MonitorProviderOpenAI, endpoint, "sk-openai", "gpt-test", nil)
+
+ if res.Status != MonitorStatusOperational {
+ t.Fatalf("default chat request should pass challenge, got status=%s message=%q", res.Status, res.Message)
+ }
+ if h.lastPath != providerOpenAIPath {
+ t.Fatalf("expected chat completions path %q, got %q", providerOpenAIPath, h.lastPath)
+ }
+ if h.lastBody["model"] != "gpt-test" {
+ t.Errorf("chat body should contain model=gpt-test, got %v", h.lastBody["model"])
+ }
+ if _, ok := h.lastBody["messages"]; !ok {
+ t.Error("chat body should contain messages")
+ }
+ if _, ok := h.lastBody["instructions"]; ok {
+ t.Error("chat body must not contain top-level instructions")
+ }
+ if h.lastBody["stream"] != false {
+ t.Errorf("chat body should set stream=false, got %v", h.lastBody["stream"])
+ }
+ if h.lastHeaders.Get("Authorization") != "Bearer sk-openai" {
+ t.Errorf("expected bearer auth header, got %q", h.lastHeaders.Get("Authorization"))
+ }
+}
+
+func TestRunCheckForModel_OpenAIResponses_DefaultRequest(t *testing.T) {
+ h := &openAICaptureHandler{}
+ endpoint := setupFakeOpenAI(t, h)
+
+ res := runCheckForModel(context.Background(), MonitorProviderOpenAI, endpoint, "sk-openai", "gpt-test", &CheckOptions{
+ APIMode: MonitorAPIModeResponses,
+ })
+
+ if res.Status != MonitorStatusOperational {
+ t.Fatalf("default responses request should pass challenge, got status=%s message=%q", res.Status, res.Message)
+ }
+ if h.lastPath != providerOpenAIResponsesPath {
+ t.Fatalf("expected responses path %q, got %q", providerOpenAIResponsesPath, h.lastPath)
+ }
+ if h.lastBody["model"] != "gpt-test" {
+ t.Errorf("responses body should contain model=gpt-test, got %v", h.lastBody["model"])
+ }
+ instructions, _ := h.lastBody["instructions"].(string)
+ if strings.TrimSpace(instructions) == "" {
+ t.Error("responses body should contain non-empty instructions")
+ }
+ input, _ := h.lastBody["input"].(string)
+ if strings.TrimSpace(input) == "" {
+ t.Error("responses body should contain non-empty input")
+ }
+ if _, ok := h.lastBody["messages"]; ok {
+ t.Error("responses body must not contain chat messages")
+ }
+ if h.lastBody["stream"] != false {
+ t.Errorf("responses body should set stream=false, got %v", h.lastBody["stream"])
+ }
+ if h.lastHeaders.Get("Authorization") != "Bearer sk-openai" {
+ t.Errorf("expected bearer auth header, got %q", h.lastHeaders.Get("Authorization"))
+ }
+}
+
+func TestRunCheckForModel_OpenAIResponsesReplaceMissingInstructionsFailsLocally(t *testing.T) {
+ h := &openAICaptureHandler{}
+ endpoint := setupFakeOpenAI(t, h)
+
+ res := runCheckForModel(context.Background(), MonitorProviderOpenAI, endpoint, "sk-openai", "gpt-test", &CheckOptions{
+ APIMode: MonitorAPIModeResponses,
+ BodyOverrideMode: MonitorBodyOverrideModeReplace,
+ BodyOverride: map[string]any{
+ "model": "gpt-test",
+ "input": "hello",
+ },
+ })
+
+ if res.Status != MonitorStatusError {
+ t.Fatalf("invalid responses replace body should fail locally as error, got status=%s", res.Status)
+ }
+ if !strings.Contains(res.Message, "instructions and input are required") {
+ t.Errorf("expected local validation message about instructions/input, got %q", res.Message)
+ }
+ if h.lastPath != "" {
+ t.Errorf("invalid replace body should fail before HTTP request, got path %q", h.lastPath)
+ }
+}
+
func TestRunCheckForModel_MergeMode_UserFieldsWinButDenyListProtects(t *testing.T) {
h := &captureHandler{respondText: "the answer is 42"}
endpoint := setupFakeAnthropic(t, h)
diff --git a/backend/internal/service/channel_monitor_const.go b/backend/internal/service/channel_monitor_const.go
index 2e1614f7..d1dffac5 100644
--- a/backend/internal/service/channel_monitor_const.go
+++ b/backend/internal/service/channel_monitor_const.go
@@ -47,6 +47,8 @@ const (
// providerOpenAIPath OpenAI Chat Completions 路径。
providerOpenAIPath = "/v1/chat/completions"
+ // providerOpenAIResponsesPath OpenAI Responses API 路径。
+ providerOpenAIResponsesPath = "/v1/responses"
// providerAnthropicPath Anthropic Messages 路径。
providerAnthropicPath = "/v1/messages"
// providerGeminiPathTemplate Gemini generateContent 路径模板(含 model 占位)。
@@ -112,6 +114,12 @@ var (
ErrChannelMonitorInvalidProvider = infraerrors.BadRequest(
"CHANNEL_MONITOR_INVALID_PROVIDER", "provider must be one of openai/anthropic/gemini",
)
+ ErrChannelMonitorInvalidAPIMode = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_INVALID_API_MODE", "api_mode must be chat_completions or responses; responses is only supported for openai",
+ )
+ ErrChannelMonitorInvalidRequestBody = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_INVALID_REQUEST_BODY", "openai replace-mode body_override must include non-empty messages for chat_completions or non-empty instructions and input for responses",
+ )
ErrChannelMonitorInvalidInterval = infraerrors.BadRequest(
"CHANNEL_MONITOR_INVALID_INTERVAL", "interval_seconds must be in [15, 3600]",
)
diff --git a/backend/internal/service/channel_monitor_service.go b/backend/internal/service/channel_monitor_service.go
index 7050e141..6eec0ae0 100644
--- a/backend/internal/service/channel_monitor_service.go
+++ b/backend/internal/service/channel_monitor_service.go
@@ -107,7 +107,7 @@ func (s *ChannelMonitorService) Create(ctx context.Context, p ChannelMonitorCrea
if err := validateCreateParams(p); err != nil {
return nil, err
}
- if err := validateBodyModeParams(p.BodyOverrideMode, p.BodyOverride); err != nil {
+ if err := validateBodyModeForProtocol(p.Provider, p.APIMode, p.BodyOverrideMode, p.BodyOverride); err != nil {
return nil, err
}
if err := validateExtraHeaders(p.ExtraHeaders); err != nil {
@@ -120,6 +120,7 @@ func (s *ChannelMonitorService) Create(ctx context.Context, p ChannelMonitorCrea
m := &ChannelMonitor{
Name: strings.TrimSpace(p.Name),
Provider: p.Provider,
+ APIMode: defaultAPIMode(p.APIMode),
Endpoint: normalizeEndpoint(p.Endpoint),
APIKey: encrypted, // 注意:传入 repository 时该字段为密文
PrimaryModel: strings.TrimSpace(p.PrimaryModel),
@@ -150,6 +151,9 @@ func validateCreateParams(p ChannelMonitorCreateParams) error {
if err := validateProvider(p.Provider); err != nil {
return err
}
+ if err := validateAPIMode(p.Provider, p.APIMode); err != nil {
+ return err
+ }
if err := validateInterval(p.IntervalSeconds); err != nil {
return err
}
@@ -298,6 +302,7 @@ func (s *ChannelMonitorService) runChecksConcurrent(ctx context.Context, m *Chan
// 所有模型共用同一份 CheckOptions(来自监控的快照字段)。
opts := &CheckOptions{
+ APIMode: m.APIMode,
ExtraHeaders: m.ExtraHeaders,
BodyOverrideMode: m.BodyOverrideMode,
BodyOverride: m.BodyOverride,
@@ -469,6 +474,7 @@ func (s *ChannelMonitorService) decryptInPlace(m *ChannelMonitor) {
// 行数稍超过 30:这是逐字段平铺的 dispatcher,每个 if 都是 1-3 行的"非 nil 则覆盖"模式,
// 拆分反而会增加跳转噪音、影响可读性,故保留为单函数。
func applyMonitorUpdate(existing *ChannelMonitor, p ChannelMonitorUpdateParams) error {
+ providerChanged := false
if p.Name != nil {
existing.Name = strings.TrimSpace(*p.Name)
}
@@ -477,6 +483,7 @@ func applyMonitorUpdate(existing *ChannelMonitor, p ChannelMonitorUpdateParams)
return err
}
existing.Provider = *p.Provider
+ providerChanged = true
}
if p.Endpoint != nil {
if err := validateEndpoint(*p.Endpoint); err != nil {
@@ -502,11 +509,11 @@ func applyMonitorUpdate(existing *ChannelMonitor, p ChannelMonitorUpdateParams)
}
existing.IntervalSeconds = *p.IntervalSeconds
}
- return applyMonitorAdvancedUpdate(existing, p)
+ return applyMonitorAdvancedUpdate(existing, p, providerChanged)
}
// applyMonitorAdvancedUpdate 处理自定义请求快照相关字段,从 applyMonitorUpdate 拆出避免过长。
-func applyMonitorAdvancedUpdate(existing *ChannelMonitor, p ChannelMonitorUpdateParams) error {
+func applyMonitorAdvancedUpdate(existing *ChannelMonitor, p ChannelMonitorUpdateParams, providerChanged bool) error {
if p.ClearTemplate {
existing.TemplateID = nil
} else if p.TemplateID != nil {
@@ -519,6 +526,15 @@ func applyMonitorAdvancedUpdate(existing *ChannelMonitor, p ChannelMonitorUpdate
}
existing.ExtraHeaders = emptyHeadersIfNil(*p.ExtraHeaders)
}
+ newAPIMode := defaultAPIMode(existing.APIMode)
+ if p.APIMode != nil {
+ newAPIMode = defaultAPIMode(*p.APIMode)
+ } else if existing.Provider != MonitorProviderOpenAI {
+ newAPIMode = MonitorAPIModeChatCompletions
+ }
+ if err := validateAPIMode(existing.Provider, newAPIMode); err != nil {
+ return err
+ }
// BodyOverrideMode / BodyOverride 联合校验,和模板一致。
newMode := existing.BodyOverrideMode
newBody := existing.BodyOverride
@@ -528,12 +544,13 @@ func applyMonitorAdvancedUpdate(existing *ChannelMonitor, p ChannelMonitorUpdate
if p.BodyOverride != nil {
newBody = *p.BodyOverride
}
- if p.BodyOverrideMode != nil || p.BodyOverride != nil {
- if err := validateBodyModeParams(newMode, newBody); err != nil {
+ if providerChanged || p.APIMode != nil || p.BodyOverrideMode != nil || p.BodyOverride != nil {
+ if err := validateBodyModeForProtocol(existing.Provider, newAPIMode, newMode, newBody); err != nil {
return err
}
existing.BodyOverrideMode = defaultBodyMode(newMode)
existing.BodyOverride = newBody
}
+ existing.APIMode = newAPIMode
return nil
}
diff --git a/backend/internal/service/channel_monitor_template_service.go b/backend/internal/service/channel_monitor_template_service.go
index 8d2e8173..4d50952d 100644
--- a/backend/internal/service/channel_monitor_template_service.go
+++ b/backend/internal/service/channel_monitor_template_service.go
@@ -14,14 +14,14 @@ type ChannelMonitorRequestTemplateRepository interface {
Update(ctx context.Context, t *ChannelMonitorRequestTemplate) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params ChannelMonitorRequestTemplateListParams) ([]*ChannelMonitorRequestTemplate, error)
- // ApplyToMonitors 把模板当前的 extra_headers / body_override_mode / body_override
+ // ApplyToMonitors 把模板当前的 api_mode / extra_headers / body_override_mode / body_override
// 批量覆盖到指定 monitorIDs 的监控上(同时还要求这些监控当前 template_id = id,
// 防止误覆盖未关联的监控)。monitorIDs 必须非空;空列表直接返回 0 不写库。
// 返回被覆盖的监控数量。
ApplyToMonitors(ctx context.Context, id int64, monitorIDs []int64) (int64, error)
// CountAssociatedMonitors 统计 template_id = id 的监控数(用于 UI 展示「应用到 N 个配置」)。
CountAssociatedMonitors(ctx context.Context, id int64) (int64, error)
- // ListAssociatedMonitors 列出所有 template_id = id 的监控简略信息(id/name/provider/enabled)
+ // ListAssociatedMonitors 列出所有 template_id = id 的监控简略信息(id/name/provider/api_mode/enabled)
// 给 apply picker UI 用,避免前端再做一次 list+filter。
ListAssociatedMonitors(ctx context.Context, id int64) ([]*AssociatedMonitorBrief, error)
}
@@ -31,6 +31,7 @@ type AssociatedMonitorBrief struct {
ID int64
Name string
Provider string
+ APIMode string
Enabled bool
}
@@ -53,6 +54,15 @@ func (s *ChannelMonitorRequestTemplateService) List(ctx context.Context, params
return nil, err
}
}
+ if params.APIMode != "" {
+ if params.Provider == "" {
+ if err := validateAPIMode(MonitorProviderOpenAI, params.APIMode); err != nil {
+ return nil, err
+ }
+ } else if err := validateAPIMode(params.Provider, params.APIMode); err != nil {
+ return nil, err
+ }
+ }
return s.repo.List(ctx, params)
}
@@ -69,6 +79,7 @@ func (s *ChannelMonitorRequestTemplateService) Create(ctx context.Context, p Cha
t := &ChannelMonitorRequestTemplate{
Name: strings.TrimSpace(p.Name),
Provider: p.Provider,
+ APIMode: defaultAPIMode(p.APIMode),
Description: strings.TrimSpace(p.Description),
ExtraHeaders: emptyHeadersIfNil(p.ExtraHeaders),
BodyOverrideMode: defaultBodyMode(p.BodyOverrideMode),
@@ -144,7 +155,10 @@ func validateTemplateCreateParams(p ChannelMonitorRequestTemplateCreateParams) e
if err := validateProvider(p.Provider); err != nil {
return ErrChannelMonitorTemplateInvalidProvider
}
- if err := validateBodyModeParams(p.BodyOverrideMode, p.BodyOverride); err != nil {
+ if err := validateAPIMode(p.Provider, p.APIMode); err != nil {
+ return ErrChannelMonitorTemplateInvalidAPIMode
+ }
+ if err := validateBodyModeForProtocol(p.Provider, p.APIMode, p.BodyOverrideMode, p.BodyOverride); err != nil {
return err
}
if err := validateExtraHeaders(p.ExtraHeaders); err != nil {
@@ -165,6 +179,13 @@ func applyTemplateUpdate(existing *ChannelMonitorRequestTemplate, p ChannelMonit
if p.Description != nil {
existing.Description = strings.TrimSpace(*p.Description)
}
+ newAPIMode := defaultAPIMode(existing.APIMode)
+ if p.APIMode != nil {
+ newAPIMode = defaultAPIMode(*p.APIMode)
+ }
+ if err := validateAPIMode(existing.Provider, newAPIMode); err != nil {
+ return ErrChannelMonitorTemplateInvalidAPIMode
+ }
if p.ExtraHeaders != nil {
if err := validateExtraHeaders(*p.ExtraHeaders); err != nil {
return err
@@ -180,14 +201,29 @@ func applyTemplateUpdate(existing *ChannelMonitorRequestTemplate, p ChannelMonit
if p.BodyOverride != nil {
newBody = *p.BodyOverride
}
- if err := validateBodyModeParams(newMode, newBody); err != nil {
+ if err := validateBodyModeForProtocol(existing.Provider, newAPIMode, newMode, newBody); err != nil {
return err
}
+ existing.APIMode = newAPIMode
existing.BodyOverrideMode = defaultBodyMode(newMode)
existing.BodyOverride = newBody
return nil
}
+// validateBodyModeForProtocol 校验 body_override_mode 与 provider/api_mode 的协议特定要求。
+func validateBodyModeForProtocol(provider, apiMode, mode string, body map[string]any) error {
+ if err := validateBodyModeParams(mode, body); err != nil {
+ return err
+ }
+ if defaultBodyMode(mode) != MonitorBodyOverrideModeReplace {
+ return nil
+ }
+ if err := validateReplaceRequestBody(provider, defaultAPIMode(apiMode), body); err != nil {
+ return ErrChannelMonitorInvalidRequestBody
+ }
+ return nil
+}
+
// validateBodyModeParams 校验 body_override_mode 合法,且 merge/replace 模式下 body_override 非空。
func validateBodyModeParams(mode string, body map[string]any) error {
switch mode {
diff --git a/backend/internal/service/channel_monitor_template_types.go b/backend/internal/service/channel_monitor_template_types.go
index 06b4f3ab..f4795aac 100644
--- a/backend/internal/service/channel_monitor_template_types.go
+++ b/backend/internal/service/channel_monitor_template_types.go
@@ -13,6 +13,7 @@ type ChannelMonitorRequestTemplate struct {
ID int64
Name string
Provider string
+ APIMode string
Description string
ExtraHeaders map[string]string
BodyOverrideMode string
@@ -24,12 +25,14 @@ type ChannelMonitorRequestTemplate struct {
// ChannelMonitorRequestTemplateListParams 列表过滤。
type ChannelMonitorRequestTemplateListParams struct {
Provider string // 空 = 全部;非空则按 provider 过滤
+ APIMode string // 空 = 全部;非空则按 api_mode 过滤
}
// ChannelMonitorRequestTemplateCreateParams 创建参数。
type ChannelMonitorRequestTemplateCreateParams struct {
Name string
Provider string
+ APIMode string
Description string
ExtraHeaders map[string]string
BodyOverrideMode string
@@ -40,6 +43,7 @@ type ChannelMonitorRequestTemplateCreateParams struct {
// 注意 Provider 不可修改:改 provider 会让已关联监控的 body 黑名单语义错乱。
type ChannelMonitorRequestTemplateUpdateParams struct {
Name *string
+ APIMode *string
Description *string
ExtraHeaders *map[string]string
BodyOverrideMode *string
@@ -54,6 +58,9 @@ var (
ErrChannelMonitorTemplateInvalidProvider = infraerrors.BadRequest(
"CHANNEL_MONITOR_TEMPLATE_INVALID_PROVIDER", "template provider must be one of openai/anthropic/gemini",
)
+ ErrChannelMonitorTemplateInvalidAPIMode = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_INVALID_API_MODE", "template api_mode must be chat_completions or responses; responses is only supported for openai",
+ )
ErrChannelMonitorTemplateMissingName = infraerrors.BadRequest(
"CHANNEL_MONITOR_TEMPLATE_MISSING_NAME", "template name is required",
)
@@ -72,6 +79,9 @@ var (
ErrChannelMonitorTemplateProviderMismatch = infraerrors.BadRequest(
"CHANNEL_MONITOR_TEMPLATE_PROVIDER_MISMATCH", "monitor provider does not match template provider",
)
+ ErrChannelMonitorTemplateAPIModeMismatch = infraerrors.BadRequest(
+ "CHANNEL_MONITOR_TEMPLATE_API_MODE_MISMATCH", "monitor api_mode does not match template api_mode",
+ )
ErrChannelMonitorTemplateApplyEmpty = infraerrors.BadRequest(
"CHANNEL_MONITOR_TEMPLATE_APPLY_EMPTY", "monitor_ids must be a non-empty array",
)
diff --git a/backend/internal/service/channel_monitor_types.go b/backend/internal/service/channel_monitor_types.go
index b797a89b..ef86eeb8 100644
--- a/backend/internal/service/channel_monitor_types.go
+++ b/backend/internal/service/channel_monitor_types.go
@@ -15,11 +15,23 @@ const (
MonitorBodyOverrideModeReplace = "replace"
)
+// MonitorAPIMode 描述 OpenAI provider 的请求协议。
+//
+// - chat_completions OpenAI-compatible Chat Completions: /v1/chat/completions + messages
+// - responses OpenAI Responses API: /v1/responses + instructions/input
+//
+// 非 OpenAI provider 固定使用 chat_completions 作为占位默认值,避免为每个 provider 单独扩表。
+const (
+ MonitorAPIModeChatCompletions = "chat_completions"
+ MonitorAPIModeResponses = "responses"
+)
+
// ChannelMonitor 渠道监控配置(service 层模型,不直接暴露 ent 类型)。
type ChannelMonitor struct {
ID int64
Name string
Provider string
+ APIMode string
Endpoint string
APIKey string // 解密后的明文 API Key(仅在 service 内部使用,handler 层不应直接序列化返回)
PrimaryModel string
@@ -56,6 +68,7 @@ type ChannelMonitorListParams struct {
type ChannelMonitorCreateParams struct {
Name string
Provider string
+ APIMode string
Endpoint string
APIKey string
PrimaryModel string
@@ -74,6 +87,7 @@ type ChannelMonitorCreateParams struct {
type ChannelMonitorUpdateParams struct {
Name *string
Provider *string
+ APIMode *string
Endpoint *string
APIKey *string // 空字符串表示不修改;非空字符串覆盖
PrimaryModel *string
diff --git a/backend/internal/service/channel_monitor_validate.go b/backend/internal/service/channel_monitor_validate.go
index 16bbec71..6ff22b9f 100644
--- a/backend/internal/service/channel_monitor_validate.go
+++ b/backend/internal/service/channel_monitor_validate.go
@@ -18,6 +18,23 @@ func validateProvider(p string) error {
return nil
}
+// validateAPIMode 校验 provider 与 api_mode 的组合。
+// responses 只对 OpenAI 有意义;其它 provider 使用 chat_completions 作为默认占位。
+func validateAPIMode(provider, apiMode string) error {
+ apiMode = defaultAPIMode(apiMode)
+ switch apiMode {
+ case MonitorAPIModeChatCompletions:
+ return nil
+ case MonitorAPIModeResponses:
+ if provider == "" || provider == MonitorProviderOpenAI {
+ return nil
+ }
+ return ErrChannelMonitorInvalidAPIMode
+ default:
+ return ErrChannelMonitorInvalidAPIMode
+ }
+}
+
// validateInterval 校验 interval_seconds 范围。
func validateInterval(sec int) error {
if sec < monitorMinIntervalSeconds || sec > monitorMaxIntervalSeconds {
@@ -97,3 +114,11 @@ func normalizeModels(in []string) []string {
}
return out
}
+
+// defaultAPIMode 空串归一为 chat_completions,保证历史数据与旧客户端兼容。
+func defaultAPIMode(apiMode string) string {
+ if strings.TrimSpace(apiMode) == "" {
+ return MonitorAPIModeChatCompletions
+ }
+ return strings.TrimSpace(apiMode)
+}
diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go
index 4e08df4a..4bf0147f 100644
--- a/backend/internal/service/channel_service.go
+++ b/backend/internal/service/channel_service.go
@@ -951,7 +951,7 @@ func validateNoConflictingMappings(mapping map[string]map[string]string) error {
func validatePricingIntervals(pricingList []ChannelModelPricing) error {
for _, pricing := range pricingList {
- if err := ValidateIntervals(pricing.Intervals); err != nil {
+ if err := ValidateIntervals(pricing.Intervals, pricing.BillingMode); err != nil {
return infraerrors.BadRequest(
"INVALID_PRICING_INTERVALS",
fmt.Sprintf("invalid pricing intervals for platform '%s' models %v: %v",
diff --git a/backend/internal/service/channel_test.go b/backend/internal/service/channel_test.go
index 26db59a7..6b4bbef8 100644
--- a/backend/internal/service/channel_test.go
+++ b/backend/internal/service/channel_test.go
@@ -311,8 +311,8 @@ func TestChannelClone_EdgeCases(t *testing.T) {
// --- ValidateIntervals ---
func TestValidateIntervals_Empty(t *testing.T) {
- require.NoError(t, ValidateIntervals(nil))
- require.NoError(t, ValidateIntervals([]PricingInterval{}))
+ require.NoError(t, ValidateIntervals(nil, BillingModeToken))
+ require.NoError(t, ValidateIntervals([]PricingInterval{}, BillingModeToken))
}
func TestValidateIntervals_ValidIntervals(t *testing.T) {
@@ -357,7 +357,7 @@ func TestValidateIntervals_ValidIntervals(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- require.NoError(t, ValidateIntervals(tt.intervals))
+ require.NoError(t, ValidateIntervals(tt.intervals, BillingModeToken))
})
}
}
@@ -366,7 +366,7 @@ func TestValidateIntervals_NegativeMinTokens(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: -1, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)},
}
- err := ValidateIntervals(intervals)
+ err := ValidateIntervals(intervals, BillingModeToken)
require.Error(t, err)
require.Contains(t, err.Error(), "min_tokens")
require.Contains(t, err.Error(), ">= 0")
@@ -376,7 +376,7 @@ func TestValidateIntervals_MaxTokensZero(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(0), InputPrice: testPtrFloat64(1e-6)},
}
- err := ValidateIntervals(intervals)
+ err := ValidateIntervals(intervals, BillingModeToken)
require.Error(t, err)
require.Contains(t, err.Error(), "max_tokens")
require.Contains(t, err.Error(), "> 0")
@@ -386,7 +386,7 @@ func TestValidateIntervals_MaxLessThanMin(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 100, MaxTokens: testPtrInt(50), InputPrice: testPtrFloat64(1e-6)},
}
- err := ValidateIntervals(intervals)
+ err := ValidateIntervals(intervals, BillingModeToken)
require.Error(t, err)
require.Contains(t, err.Error(), "max_tokens")
require.Contains(t, err.Error(), "> min_tokens")
@@ -396,7 +396,7 @@ func TestValidateIntervals_MaxEqualsMin(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 100, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)},
}
- err := ValidateIntervals(intervals)
+ err := ValidateIntervals(intervals, BillingModeToken)
require.Error(t, err)
require.Contains(t, err.Error(), "max_tokens")
require.Contains(t, err.Error(), "> min_tokens")
@@ -407,7 +407,7 @@ func TestValidateIntervals_NegativePrice(t *testing.T) {
intervals := []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(100), InputPrice: &negPrice},
}
- err := ValidateIntervals(intervals)
+ err := ValidateIntervals(intervals, BillingModeToken)
require.Error(t, err)
require.Contains(t, err.Error(), "input_price")
require.Contains(t, err.Error(), ">= 0")
@@ -418,7 +418,7 @@ func TestValidateIntervals_OverlappingIntervals(t *testing.T) {
{MinTokens: 0, MaxTokens: testPtrInt(200), InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 100, MaxTokens: testPtrInt(300), InputPrice: testPtrFloat64(2e-6)},
}
- err := ValidateIntervals(intervals)
+ err := ValidateIntervals(intervals, BillingModeToken)
require.Error(t, err)
require.Contains(t, err.Error(), "overlap")
}
@@ -428,12 +428,43 @@ func TestValidateIntervals_UnboundedNotLast(t *testing.T) {
{MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)},
{MinTokens: 128000, MaxTokens: testPtrInt(256000), InputPrice: testPtrFloat64(2e-6)},
}
- err := ValidateIntervals(intervals)
+ err := ValidateIntervals(intervals, BillingModeToken)
require.Error(t, err)
require.Contains(t, err.Error(), "unbounded")
require.Contains(t, err.Error(), "last")
}
+func TestValidateIntervals_ImageModeAllowsMultipleUnboundedTiers(t *testing.T) {
+ // image / per_request 按 tier_label 匹配,多条 min=0/max=nil 是合法形态。
+ intervals := []PricingInterval{
+ {MinTokens: 0, MaxTokens: nil, TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
+ {MinTokens: 0, MaxTokens: nil, TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.06)},
+ {MinTokens: 0, MaxTokens: nil, TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.08)},
+ }
+ require.NoError(t, ValidateIntervals(intervals, BillingModeImage))
+ require.NoError(t, ValidateIntervals(intervals, BillingModePerRequest))
+}
+
+func TestValidateIntervals_ImageModeStillRejectsNegativePrice(t *testing.T) {
+ // image 模式只跳过区间重叠校验,单条字段自洽(价格非负)仍要校验。
+ intervals := []PricingInterval{
+ {MinTokens: 0, MaxTokens: nil, TierLabel: "1K", PerRequestPrice: testPtrFloat64(-1)},
+ }
+ err := ValidateIntervals(intervals, BillingModeImage)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "must be >= 0")
+}
+
+func TestValidateIntervals_ImageModeStillRejectsBadMaxTokens(t *testing.T) {
+ // image 模式仍校验 max <= min 这种单条不合法。
+ intervals := []PricingInterval{
+ {MinTokens: 100, MaxTokens: testPtrInt(50), TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)},
+ }
+ err := ValidateIntervals(intervals, BillingModeImage)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "must be > min_tokens")
+}
+
func TestSupportedModels_ExactKeysAndPricing(t *testing.T) {
ch := &Channel{
ModelPricing: []ChannelModelPricing{
diff --git a/backend/internal/service/content_moderation.go b/backend/internal/service/content_moderation.go
index 144222c2..6a7c9904 100644
--- a/backend/internal/service/content_moderation.go
+++ b/backend/internal/service/content_moderation.go
@@ -32,10 +32,17 @@ const (
contentModerationAPIKeysModeAppend = "append"
contentModerationAPIKeysModeReplace = "replace"
- ContentModerationActionAllow = "allow"
- ContentModerationActionBlock = "block"
- ContentModerationActionHashBlock = "hash_block"
- ContentModerationActionError = "error"
+ ContentModerationActionAllow = "allow"
+ ContentModerationActionBlock = "block"
+ ContentModerationActionHashBlock = "hash_block"
+ ContentModerationActionKeywordBlock = "keyword_block"
+ ContentModerationActionError = "error"
+
+ contentModerationKeywordCategory = "keyword"
+
+ ContentModerationKeywordModeKeywordOnly = "keyword_only"
+ ContentModerationKeywordModeKeywordAndAPI = "keyword_and_api"
+ ContentModerationKeywordModeAPIOnly = "api_only"
ContentModerationProtocolAnthropicMessages = "anthropic_messages"
ContentModerationProtocolOpenAIResponses = "openai_responses"
@@ -71,6 +78,8 @@ const (
maxContentModerationTestImages = maxContentModerationInputImages
maxContentModerationTestImageBytes = 8 * 1024 * 1024
maxContentModerationTestImageDataURLBytes = 12 * 1024 * 1024
+ maxContentModerationBlockedKeywords = 10000
+ maxContentModerationBlockedKeywordRunes = 200
contentModerationCleanupInterval = 24 * time.Hour
contentModerationCleanupTimeout = 30 * time.Minute
@@ -142,6 +151,8 @@ type ContentModerationConfig struct {
HitRetentionDays int `json:"hit_retention_days"`
NonHitRetentionDays int `json:"non_hit_retention_days"`
PreHashCheckEnabled bool `json:"pre_hash_check_enabled"`
+ BlockedKeywords []string `json:"blocked_keywords"`
+ KeywordBlockingMode string `json:"keyword_blocking_mode"`
}
type ContentModerationConfigView struct {
@@ -171,6 +182,8 @@ type ContentModerationConfigView struct {
HitRetentionDays int `json:"hit_retention_days"`
NonHitRetentionDays int `json:"non_hit_retention_days"`
PreHashCheckEnabled bool `json:"pre_hash_check_enabled"`
+ BlockedKeywords []string `json:"blocked_keywords"`
+ KeywordBlockingMode string `json:"keyword_blocking_mode"`
}
type ContentModerationAPIKeyStatus struct {
@@ -240,6 +253,8 @@ type UpdateContentModerationConfigInput struct {
HitRetentionDays *int `json:"hit_retention_days"`
NonHitRetentionDays *int `json:"non_hit_retention_days"`
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
+ BlockedKeywords *[]string `json:"blocked_keywords"`
+ KeywordBlockingMode *string `json:"keyword_blocking_mode"`
}
type ContentModerationCheckInput struct {
@@ -560,6 +575,12 @@ func (s *ContentModerationService) UpdateConfig(ctx context.Context, input Updat
if input.PreHashCheckEnabled != nil {
cfg.PreHashCheckEnabled = *input.PreHashCheckEnabled
}
+ if input.BlockedKeywords != nil {
+ cfg.BlockedKeywords = normalizeBlockedKeywords(*input.BlockedKeywords)
+ }
+ if input.KeywordBlockingMode != nil {
+ cfg.KeywordBlockingMode = strings.TrimSpace(*input.KeywordBlockingMode)
+ }
if input.AllGroups != nil {
cfg.AllGroups = *input.AllGroups
}
@@ -767,6 +788,44 @@ func (s *ContentModerationService) Check(ctx context.Context, input ContentModer
"protocol", input.Protocol,
"text_runes", len([]rune(content.Text)),
"image_count", len(content.Images))
+ if cfg.Mode == ContentModerationModePreBlock {
+ if cfg.KeywordBlockingMode != ContentModerationKeywordModeAPIOnly && len(cfg.BlockedKeywords) > 0 {
+ if keyword, hit := matchBlockedKeyword(content.Text, cfg.BlockedKeywords); hit {
+ slog.Info("content_moderation.keyword_block",
+ "user_id", input.UserID,
+ "api_key_id", input.APIKeyID,
+ "group_id", contentModerationLogGroupID(input.GroupID),
+ "endpoint", input.Endpoint,
+ "protocol", input.Protocol,
+ "keyword_blocking_mode", cfg.KeywordBlockingMode,
+ "keyword", keyword)
+ scores := map[string]float64{contentModerationKeywordCategory: 1.0}
+ log := s.buildLog(input, cfg, ContentModerationActionKeywordBlock, true, contentModerationKeywordCategory, 1.0, scores, content.ExcerptText(), nil, nil, "")
+ s.applyFlaggedSideEffects(ctx, cfg, log)
+ _ = s.repo.CreateLog(ctx, log)
+ return &ContentModerationDecision{
+ Allowed: false,
+ Blocked: true,
+ Flagged: true,
+ Message: cfg.BlockMessage,
+ StatusCode: cfg.BlockStatus,
+ HighestCategory: contentModerationKeywordCategory,
+ HighestScore: 1.0,
+ CategoryScores: scores,
+ Action: ContentModerationActionKeywordBlock,
+ }, nil
+ }
+ }
+ if cfg.KeywordBlockingMode == ContentModerationKeywordModeKeywordOnly {
+ slog.Info("content_moderation.skip_api_keyword_only",
+ "user_id", input.UserID,
+ "api_key_id", input.APIKeyID,
+ "group_id", contentModerationLogGroupID(input.GroupID),
+ "endpoint", input.Endpoint,
+ "protocol", input.Protocol)
+ return allow, nil
+ }
+ }
hashText := content.Hash()
if cfg.PreHashCheckEnabled && s.hashCache != nil {
matched, err := s.hashCache.HasFlaggedInputHash(ctx, hashText)
@@ -1451,6 +1510,8 @@ func defaultContentModerationConfig() *ContentModerationConfig {
HitRetentionDays: defaultContentModerationHitRetentionDays,
NonHitRetentionDays: defaultContentModerationNonHitRetentionDays,
PreHashCheckEnabled: false,
+ BlockedKeywords: []string{},
+ KeywordBlockingMode: ContentModerationKeywordModeKeywordAndAPI,
}
}
@@ -1529,6 +1590,8 @@ func (cfg *ContentModerationConfig) normalize() {
}
cfg.GroupIDs = normalizeInt64IDs(cfg.GroupIDs)
cfg.Thresholds = mergeContentModerationThresholds(ContentModerationDefaultThresholds(), cfg.Thresholds)
+ cfg.BlockedKeywords = normalizeBlockedKeywords(cfg.BlockedKeywords)
+ cfg.KeywordBlockingMode = normalizeKeywordBlockingMode(cfg.KeywordBlockingMode)
}
func (cfg *ContentModerationConfig) includesGroup(groupID *int64) bool {
@@ -1705,6 +1768,8 @@ func (s *ContentModerationService) configView(cfg *ContentModerationConfig) *Con
HitRetentionDays: cfg.HitRetentionDays,
NonHitRetentionDays: cfg.NonHitRetentionDays,
PreHashCheckEnabled: cfg.PreHashCheckEnabled,
+ BlockedKeywords: append([]string(nil), cfg.BlockedKeywords...),
+ KeywordBlockingMode: cfg.KeywordBlockingMode,
}
}
@@ -1944,6 +2009,60 @@ func normalizeInt64IDs(ids []int64) []int64 {
return out
}
+func normalizeBlockedKeywords(in []string) []string {
+ if len(in) == 0 {
+ return []string{}
+ }
+ out := make([]string, 0, len(in))
+ seen := make(map[string]struct{}, len(in))
+ for _, raw := range in {
+ kw := strings.TrimSpace(raw)
+ if kw == "" {
+ continue
+ }
+ kw = trimRunes(kw, maxContentModerationBlockedKeywordRunes)
+ key := strings.ToLower(kw)
+ if _, ok := seen[key]; ok {
+ continue
+ }
+ seen[key] = struct{}{}
+ out = append(out, kw)
+ if len(out) >= maxContentModerationBlockedKeywords {
+ break
+ }
+ }
+ return out
+}
+
+func normalizeKeywordBlockingMode(mode string) string {
+ switch strings.TrimSpace(mode) {
+ case ContentModerationKeywordModeKeywordOnly:
+ return ContentModerationKeywordModeKeywordOnly
+ case ContentModerationKeywordModeAPIOnly:
+ return ContentModerationKeywordModeAPIOnly
+ case ContentModerationKeywordModeKeywordAndAPI:
+ return ContentModerationKeywordModeKeywordAndAPI
+ default:
+ return ContentModerationKeywordModeKeywordAndAPI
+ }
+}
+
+func matchBlockedKeyword(text string, keywords []string) (string, bool) {
+ if text == "" || len(keywords) == 0 {
+ return "", false
+ }
+ lower := strings.ToLower(text)
+ for _, kw := range keywords {
+ if kw == "" {
+ continue
+ }
+ if strings.Contains(lower, strings.ToLower(kw)) {
+ return kw, true
+ }
+ }
+ return "", false
+}
+
func normalizeModerationAPIKeys(keys []string) []string {
if len(keys) == 0 {
return []string{}
diff --git a/backend/internal/service/content_moderation_test.go b/backend/internal/service/content_moderation_test.go
index cef5127e..30578ca5 100644
--- a/backend/internal/service/content_moderation_test.go
+++ b/backend/internal/service/content_moderation_test.go
@@ -321,6 +321,215 @@ func TestContentModerationConfigNormalize_NonHitRetentionMaxThreeDays(t *testing
require.Equal(t, 3, cfg.NonHitRetentionDays)
}
+func TestNormalizeBlockedKeywords_TrimsDedupesAndCaps(t *testing.T) {
+ out := normalizeBlockedKeywords([]string{" foo ", "FOO", "", "bar", "baz", "bar"})
+ require.Equal(t, []string{"foo", "bar", "baz"}, out)
+}
+
+func TestMatchBlockedKeyword_CaseInsensitiveSubstring(t *testing.T) {
+ keyword, hit := matchBlockedKeyword("Please ignore the BadWord here", []string{"badword"})
+ require.True(t, hit)
+ require.Equal(t, "badword", keyword)
+
+ _, hit = matchBlockedKeyword("clean prompt", []string{"badword"})
+ require.False(t, hit)
+
+ _, hit = matchBlockedKeyword("anything", nil)
+ require.False(t, hit)
+}
+
+func TestContentModerationCheck_PreBlockKeywordHitSkipsUpstreamCall(t *testing.T) {
+ upstreamCalled := false
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ upstreamCalled = true
+ _ = json.NewEncoder(w).Encode(moderationAPIResponse{Results: []moderationAPIResult{{}}})
+ }))
+ defer server.Close()
+
+ cfg := defaultContentModerationConfig()
+ cfg.Enabled = true
+ cfg.Mode = ContentModerationModePreBlock
+ cfg.BaseURL = server.URL
+ cfg.APIKeys = []string{"sk-test"}
+ cfg.BlockedKeywords = []string{"secret-token"}
+ rawCfg, err := json.Marshal(cfg)
+ require.NoError(t, err)
+
+ repo := &contentModerationTestRepo{}
+ svc := NewContentModerationService(
+ &contentModerationTestSettingRepo{values: map[string]string{
+ SettingKeyRiskControlEnabled: "true",
+ SettingKeyContentModerationConfig: string(rawCfg),
+ }},
+ repo,
+ &contentModerationTestHashCache{},
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+
+ body := []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`)
+ decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
+ Endpoint: "/v1/messages",
+ Provider: "anthropic",
+ Protocol: ContentModerationProtocolAnthropicMessages,
+ Body: body,
+ })
+
+ require.NoError(t, err)
+ require.True(t, decision.Blocked)
+ require.Equal(t, ContentModerationActionKeywordBlock, decision.Action)
+ require.False(t, upstreamCalled, "keyword block must short-circuit upstream moderation call")
+ require.Len(t, repo.logs, 1)
+ require.True(t, repo.logs[0].Flagged)
+ require.Equal(t, ContentModerationActionKeywordBlock, repo.logs[0].Action)
+ require.Equal(t, contentModerationKeywordCategory, repo.logs[0].HighestCategory)
+}
+
+func TestContentModerationCheck_KeywordsIgnoredInObserveMode(t *testing.T) {
+ upstreamHits := 0
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ upstreamHits++
+ _ = json.NewEncoder(w).Encode(moderationAPIResponse{Results: []moderationAPIResult{{CategoryScores: map[string]float64{"sexual": 0.1}}}})
+ }))
+ defer server.Close()
+
+ cfg := defaultContentModerationConfig()
+ cfg.Enabled = true
+ cfg.Mode = ContentModerationModeObserve
+ cfg.BaseURL = server.URL
+ cfg.APIKeys = []string{"sk-test"}
+ cfg.BlockedKeywords = []string{"secret-token"}
+ rawCfg, err := json.Marshal(cfg)
+ require.NoError(t, err)
+
+ repo := &contentModerationTestRepo{}
+ svc := NewContentModerationService(
+ &contentModerationTestSettingRepo{values: map[string]string{
+ SettingKeyRiskControlEnabled: "true",
+ SettingKeyContentModerationConfig: string(rawCfg),
+ }},
+ repo,
+ &contentModerationTestHashCache{},
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+
+ body := []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`)
+ decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
+ Endpoint: "/v1/messages",
+ Provider: "anthropic",
+ Protocol: ContentModerationProtocolAnthropicMessages,
+ Body: body,
+ })
+
+ require.NoError(t, err)
+ require.True(t, decision.Allowed, "observe mode must let the request through even on keyword hit")
+ require.Equal(t, ContentModerationActionAllow, decision.Action)
+}
+
+func TestContentModerationCheck_KeywordOnlyStrategySkipsAPIOnMiss(t *testing.T) {
+ upstreamCalled := false
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ upstreamCalled = true
+ _ = json.NewEncoder(w).Encode(moderationAPIResponse{Results: []moderationAPIResult{{CategoryScores: map[string]float64{"sexual": 0.99}}}})
+ }))
+ defer server.Close()
+
+ cfg := defaultContentModerationConfig()
+ cfg.Enabled = true
+ cfg.Mode = ContentModerationModePreBlock
+ cfg.BaseURL = server.URL
+ cfg.APIKeys = []string{"sk-test"}
+ cfg.BlockedKeywords = []string{"never-matches"}
+ cfg.KeywordBlockingMode = ContentModerationKeywordModeKeywordOnly
+ rawCfg, err := json.Marshal(cfg)
+ require.NoError(t, err)
+
+ repo := &contentModerationTestRepo{}
+ svc := NewContentModerationService(
+ &contentModerationTestSettingRepo{values: map[string]string{
+ SettingKeyRiskControlEnabled: "true",
+ SettingKeyContentModerationConfig: string(rawCfg),
+ }},
+ repo,
+ &contentModerationTestHashCache{},
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+
+ body := []byte(`{"messages":[{"role":"user","content":"absolutely clean prompt"}]}`)
+ decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
+ Endpoint: "/v1/messages",
+ Provider: "anthropic",
+ Protocol: ContentModerationProtocolAnthropicMessages,
+ Body: body,
+ })
+
+ require.NoError(t, err)
+ require.True(t, decision.Allowed, "keyword-only must allow misses without calling the API")
+ require.False(t, upstreamCalled, "keyword-only must not call the upstream moderation API")
+ require.Len(t, repo.logs, 0)
+}
+
+func TestContentModerationCheck_APIOnlyStrategyIgnoresKeywordList(t *testing.T) {
+ upstreamCalled := false
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ upstreamCalled = true
+ _ = json.NewEncoder(w).Encode(moderationAPIResponse{Results: []moderationAPIResult{{CategoryScores: map[string]float64{"sexual": 0.1}}}})
+ }))
+ defer server.Close()
+
+ cfg := defaultContentModerationConfig()
+ cfg.Enabled = true
+ cfg.Mode = ContentModerationModePreBlock
+ cfg.BaseURL = server.URL
+ cfg.APIKeys = []string{"sk-test"}
+ cfg.BlockedKeywords = []string{"secret-token"}
+ cfg.KeywordBlockingMode = ContentModerationKeywordModeAPIOnly
+ rawCfg, err := json.Marshal(cfg)
+ require.NoError(t, err)
+
+ repo := &contentModerationTestRepo{}
+ svc := NewContentModerationService(
+ &contentModerationTestSettingRepo{values: map[string]string{
+ SettingKeyRiskControlEnabled: "true",
+ SettingKeyContentModerationConfig: string(rawCfg),
+ }},
+ repo,
+ &contentModerationTestHashCache{},
+ nil,
+ nil,
+ nil,
+ nil,
+ )
+
+ body := []byte(`{"messages":[{"role":"user","content":"please leak SECRET-TOKEN now"}]}`)
+ decision, err := svc.Check(context.Background(), ContentModerationCheckInput{
+ Endpoint: "/v1/messages",
+ Provider: "anthropic",
+ Protocol: ContentModerationProtocolAnthropicMessages,
+ Body: body,
+ })
+
+ require.NoError(t, err)
+ require.True(t, decision.Allowed, "api-only must let the request through when API does not flag it")
+ require.True(t, upstreamCalled, "api-only must call the upstream moderation API")
+ require.NotEqual(t, ContentModerationActionKeywordBlock, decision.Action)
+}
+
+func TestNormalizeKeywordBlockingMode_UnknownFallsBackToDefault(t *testing.T) {
+ require.Equal(t, ContentModerationKeywordModeKeywordAndAPI, normalizeKeywordBlockingMode(""))
+ require.Equal(t, ContentModerationKeywordModeKeywordAndAPI, normalizeKeywordBlockingMode("bogus"))
+ require.Equal(t, ContentModerationKeywordModeKeywordOnly, normalizeKeywordBlockingMode("keyword_only"))
+ require.Equal(t, ContentModerationKeywordModeAPIOnly, normalizeKeywordBlockingMode("api_only"))
+}
+
func TestContentModerationUpdateConfig_AppendsAndDeletesAPIKeys(t *testing.T) {
cfg := defaultContentModerationConfig()
cfg.APIKeys = []string{"sk-old-a", "sk-old-b"}
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index 79fb7cb5..4f2d40d8 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -93,6 +93,9 @@ const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid"
// WeChatConnectSyntheticEmailDomain 是 WeChat Connect 用户的合成邮箱后缀(RFC 保留域名)。
const WeChatConnectSyntheticEmailDomain = "@wechat-connect.invalid"
+// DingTalkConnectSyntheticEmailDomain 是 DingTalk Connect 用户的合成邮箱后缀(RFC 保留域名)。
+const DingTalkConnectSyntheticEmailDomain = "@dingtalk-connect.invalid"
+
// Setting keys
const (
// 注册设置
@@ -138,6 +141,24 @@ const (
SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret"
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
+ // DingTalk Connect OAuth 登录设置
+ SettingKeyDingTalkConnectEnabled = "dingtalk_connect_enabled"
+ SettingKeyDingTalkConnectClientID = "dingtalk_connect_client_id"
+ SettingKeyDingTalkConnectClientSecret = "dingtalk_connect_client_secret"
+ SettingKeyDingTalkConnectRedirectURL = "dingtalk_connect_redirect_url"
+ SettingKeyDingTalkConnectCorpRestrictionPolicy = "dingtalk_connect_corp_restriction_policy"
+ SettingKeyDingTalkConnectInternalCorpID = "dingtalk_connect_internal_corp_id"
+ SettingKeyDingTalkConnectBypassRegistration = "dingtalk_connect_bypass_registration"
+ SettingKeyDingTalkConnectSyncCorpEmail = "dingtalk_connect_sync_corp_email"
+ SettingKeyDingTalkConnectSyncDisplayName = "dingtalk_connect_sync_display_name"
+ SettingKeyDingTalkConnectSyncDept = "dingtalk_connect_sync_dept"
+ SettingKeyDingTalkConnectSyncCorpEmailAttrKey = "dingtalk_connect_sync_corp_email_attr_key"
+ SettingKeyDingTalkConnectSyncDisplayNameAttrKey = "dingtalk_connect_sync_display_name_attr_key"
+ SettingKeyDingTalkConnectSyncDeptAttrKey = "dingtalk_connect_sync_dept_attr_key"
+ SettingKeyDingTalkConnectSyncCorpEmailAttrName = "dingtalk_connect_sync_corp_email_attr_name"
+ SettingKeyDingTalkConnectSyncDisplayNameAttrName = "dingtalk_connect_sync_display_name_attr_name"
+ SettingKeyDingTalkConnectSyncDeptAttrName = "dingtalk_connect_sync_dept_attr_name"
+
// WeChat Connect OAuth 登录设置
SettingKeyWeChatConnectEnabled = "wechat_connect_enabled"
SettingKeyWeChatConnectAppID = "wechat_connect_app_id"
@@ -215,37 +236,42 @@ const (
SettingKeyDefaultUserRPMLimit = "default_user_rpm_limit" // 新用户默认 RPM 限制(0 = 不限制)
// 第三方认证来源默认授予配置
- SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
- SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency"
- SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions"
- SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup"
- SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind"
- SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance"
- SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency"
- SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions"
- SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup"
- SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind"
- SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance"
- SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency"
- SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions"
- SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup"
- SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind"
- SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance"
- SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency"
- SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions"
- SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup"
- SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind"
- SettingKeyAuthSourceDefaultGitHubBalance = "auth_source_default_github_balance"
- SettingKeyAuthSourceDefaultGitHubConcurrency = "auth_source_default_github_concurrency"
- SettingKeyAuthSourceDefaultGitHubSubscriptions = "auth_source_default_github_subscriptions"
- SettingKeyAuthSourceDefaultGitHubGrantOnSignup = "auth_source_default_github_grant_on_signup"
- SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind = "auth_source_default_github_grant_on_first_bind"
- SettingKeyAuthSourceDefaultGoogleBalance = "auth_source_default_google_balance"
- SettingKeyAuthSourceDefaultGoogleConcurrency = "auth_source_default_google_concurrency"
- SettingKeyAuthSourceDefaultGoogleSubscriptions = "auth_source_default_google_subscriptions"
- SettingKeyAuthSourceDefaultGoogleGrantOnSignup = "auth_source_default_google_grant_on_signup"
- SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind = "auth_source_default_google_grant_on_first_bind"
- SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup"
+ SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance"
+ SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency"
+ SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions"
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup"
+ SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance"
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency"
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions"
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup"
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance"
+ SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency"
+ SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions"
+ SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup"
+ SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance"
+ SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency"
+ SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions"
+ SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup"
+ SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultGitHubBalance = "auth_source_default_github_balance"
+ SettingKeyAuthSourceDefaultGitHubConcurrency = "auth_source_default_github_concurrency"
+ SettingKeyAuthSourceDefaultGitHubSubscriptions = "auth_source_default_github_subscriptions"
+ SettingKeyAuthSourceDefaultGitHubGrantOnSignup = "auth_source_default_github_grant_on_signup"
+ SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind = "auth_source_default_github_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultGoogleBalance = "auth_source_default_google_balance"
+ SettingKeyAuthSourceDefaultGoogleConcurrency = "auth_source_default_google_concurrency"
+ SettingKeyAuthSourceDefaultGoogleSubscriptions = "auth_source_default_google_subscriptions"
+ SettingKeyAuthSourceDefaultGoogleGrantOnSignup = "auth_source_default_google_grant_on_signup"
+ SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind = "auth_source_default_google_grant_on_first_bind"
+ SettingKeyAuthSourceDefaultDingTalkBalance = "auth_source_default_dingtalk_balance"
+ SettingKeyAuthSourceDefaultDingTalkConcurrency = "auth_source_default_dingtalk_concurrency"
+ SettingKeyAuthSourceDefaultDingTalkSubscriptions = "auth_source_default_dingtalk_subscriptions"
+ SettingKeyAuthSourceDefaultDingTalkGrantOnSignup = "auth_source_default_dingtalk_grant_on_signup"
+ SettingKeyAuthSourceDefaultDingTalkGrantOnFirstBind = "auth_source_default_dingtalk_grant_on_first_bind"
+ SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup"
// 管理员 API Key
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go
index 930f9522..1071e485 100644
--- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go
+++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go
@@ -188,11 +188,6 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAnd
require.NotContains(t, rec.Body.String(), `"cache_read_input_tokens":7`, "透传输出不应被网关改写")
require.Equal(t, 7, result.Usage.CacheReadInputTokens, "计费 usage 解析应保留 cached_tokens 兼容")
require.Empty(t, rec.Header().Get("Set-Cookie"), "响应头应经过安全过滤")
- rawBody, ok := c.Get(OpsUpstreamRequestBodyKey)
- require.True(t, ok)
- bodyBytes, ok := rawBody.([]byte)
- require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝")
- require.Equal(t, "claude-3-haiku-20240307", gjson.GetBytes(bodyBytes, "model").String(), "缓存的上游请求体应包含映射后的模型")
}
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) {
@@ -973,10 +968,6 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequest
require.Error(t, err)
require.Contains(t, err.Error(), "upstream request failed")
require.Equal(t, http.StatusBadGateway, rec.Code)
- rawBody, ok := c.Get(OpsUpstreamRequestBodyKey)
- require.True(t, ok)
- _, ok = rawBody.([]byte)
- require.True(t, ok)
}
func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBody(t *testing.T) {
@@ -1173,6 +1164,99 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingDataIntervalTimeout(
require.False(t, result.clientDisconnect)
}
+func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingSendsKeepaliveDuringIdle(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
+
+ svc := &GatewayService{
+ cfg: &config.Config{
+ Gateway: config.GatewayConfig{
+ StreamKeepaliveInterval: 1,
+ MaxLineSize: defaultMaxLineSize,
+ },
+ },
+ rateLimitService: &RateLimitService{},
+ }
+
+ pr, pw := io.Pipe()
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}},
+ Body: pr,
+ }
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ time.Sleep(1200 * time.Millisecond)
+ _, _ = pw.Write([]byte(strings.Join([]string{
+ `data: {"type":"message_start","message":{"usage":{"input_tokens":3}}}`,
+ "",
+ `data: {"type":"message_delta","usage":{"output_tokens":2}}`,
+ "",
+ "data: [DONE]",
+ "",
+ }, "\n")))
+ _ = pw.Close()
+ }()
+
+ result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 8}, time.Now(), "claude-3-7-sonnet-20250219")
+ _ = pr.Close()
+ <-done
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Contains(t, rec.Body.String(), "event: ping\ndata: {\"type\": \"ping\"}\n\n")
+ require.Contains(t, rec.Body.String(), "data: [DONE]")
+}
+
+func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingKeepaliveDoesNotInterleavePartialEvent(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
+
+ svc := &GatewayService{
+ cfg: &config.Config{
+ Gateway: config.GatewayConfig{
+ StreamKeepaliveInterval: 1,
+ MaxLineSize: defaultMaxLineSize,
+ },
+ },
+ rateLimitService: &RateLimitService{},
+ }
+
+ pr, pw := io.Pipe()
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}},
+ Body: pr,
+ }
+
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ _, _ = pw.Write([]byte(`data: {"type":"message_start","message":{"usage":{"input_tokens":4}}}` + "\n"))
+ time.Sleep(1200 * time.Millisecond)
+ _, _ = pw.Write([]byte("\n"))
+ _, _ = pw.Write([]byte("data: [DONE]\n\n"))
+ _ = pw.Close()
+ }()
+
+ result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 9}, time.Now(), "claude-3-7-sonnet-20250219")
+ _ = pr.Close()
+ <-done
+
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ body := rec.Body.String()
+ require.NotContains(t, body, `data: {"type":"message_start","message":{"usage":{"input_tokens":4}}}`+"\n"+"event: ping")
+ require.NotContains(t, body, "event: ping")
+ require.Contains(t, body, "data: [DONE]")
+}
+
func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingReadError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go
index 5e5caeb7..1c3ace93 100644
--- a/backend/internal/service/gateway_record_usage_test.go
+++ b/backend/internal/service/gateway_record_usage_test.go
@@ -193,6 +193,46 @@ func TestGatewayServiceRecordUsage_PreservesRequestedAndUpstreamModels(t *testin
require.Equal(t, mappedModel, *usageRepo.lastLog.UpstreamModel)
}
+func TestGatewayServiceRecordUsage_EmptyImageSizeDefaultsBeforeBillingAndPersistence(t *testing.T) {
+ imagePrice2K := 0.19
+ groupID := int64(901)
+ usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
+ svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{})
+
+ err := svc.RecordUsage(context.Background(), &RecordUsageInput{
+ Result: &ForwardResult{
+ RequestID: "gateway_image_default_size",
+ Model: "gemini-image",
+ ImageCount: 1,
+ ImageInputSize: "auto",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 801,
+ GroupID: i64p(groupID),
+ Group: &Group{
+ ID: groupID,
+ RateMultiplier: 1.0,
+ ImagePrice2K: &imagePrice2K,
+ },
+ },
+ User: &User{ID: 601},
+ Account: &Account{ID: 701},
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, usageRepo.lastLog)
+ require.Equal(t, 1, usageRepo.lastLog.ImageCount)
+ require.NotNil(t, usageRepo.lastLog.ImageSize)
+ require.Equal(t, ImageBillingSize2K, *usageRepo.lastLog.ImageSize)
+ require.NotNil(t, usageRepo.lastLog.ImageInputSize)
+ require.Equal(t, "auto", *usageRepo.lastLog.ImageInputSize)
+ require.NotNil(t, usageRepo.lastLog.ImageSizeSource)
+ require.Equal(t, ImageSizeSourceDefault, *usageRepo.lastLog.ImageSizeSource)
+ require.InDelta(t, 0.19, usageRepo.lastLog.TotalCost, 1e-12)
+ require.InDelta(t, 0.19, usageRepo.lastLog.ActualCost, 1e-12)
+}
+
func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) {
usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)}
userRepo := &openAIRecordUsageUserRepoStub{}
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index a08fedbf..e25e6d82 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -522,8 +522,13 @@ type ForwardResult struct {
ReasoningEffort *string
// 图片生成计费字段(图片生成模型使用)
- ImageCount int // 生成的图片数量
- ImageSize string // 图片尺寸 "1K", "2K", "4K"
+ ImageCount int // 生成的图片数量
+ ImageSize string // 最终计费尺寸 "1K", "2K", "4K"
+ ImageInputSize string // 请求中的原始图片尺寸
+ ImageOutputSize string // 上游响应中的图片尺寸
+ ImageOutputSizes []string
+ ImageSizeSource string
+ ImageSizeBreakdown map[string]int
}
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
@@ -1424,7 +1429,6 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
-// 调度流程文档见 docs/ACCOUNT_SCHEDULING_FLOW.md 。
// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID
// sub2apiUserID: 系统用户 ID,用于二维亲和调度
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) {
@@ -4611,9 +4615,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
body = compressMessagesInBody(body, maxTok)
}
- // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
- setOpsUpstreamRequestBody(c, body)
-
// 重试循环
var resp *http.Response
retryStart := time.Now()
@@ -5116,9 +5117,6 @@ func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput(
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
input.Body = StripEmptyTextBlocks(input.Body)
- // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
- setOpsUpstreamRequestBody(c, input.Body)
-
var resp *http.Response
retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
@@ -5459,6 +5457,22 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
intervalCh = intervalTicker.C
}
+ keepaliveInterval := time.Duration(0)
+ if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 {
+ keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second
+ }
+ var keepaliveTicker *time.Ticker
+ if keepaliveInterval > 0 {
+ keepaliveTicker = time.NewTicker(keepaliveInterval)
+ defer keepaliveTicker.Stop()
+ }
+ var keepaliveCh <-chan time.Time
+ if keepaliveTicker != nil {
+ keepaliveCh = keepaliveTicker.C
+ }
+ lastDataAt := time.Now()
+ inPartialEvent := false
+
for {
select {
case ev, ok := <-events:
@@ -5524,6 +5538,10 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
} else if line == "" {
// 按 SSE 事件边界刷出,减少每行 flush 带来的 syscall 开销。
flusher.Flush()
+ lastDataAt = time.Now()
+ inPartialEvent = false
+ } else {
+ inPartialEvent = true
}
}
@@ -5540,6 +5558,21 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough(
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
+
+ case <-keepaliveCh:
+ if clientDisconnected || inPartialEvent {
+ continue
+ }
+ if time.Since(lastDataAt) < keepaliveInterval {
+ continue
+ }
+ if _, err := fmt.Fprint(w, "event: ping\ndata: {\"type\": \"ping\"}\n\n"); err != nil {
+ clientDisconnected = true
+ logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during keepalive ping, continue draining upstream for usage: account=%d", account.ID)
+ continue
+ }
+ flusher.Flush()
+ lastDataAt = time.Now()
}
}
}
@@ -6288,7 +6321,6 @@ func (s *GatewayService) buildUpstreamRequestAnthropicVertex(
if err != nil {
return nil, err
}
- setOpsUpstreamRequestBody(c, vertexBody)
fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, reqStream)
if err != nil {
return nil, err
@@ -8492,6 +8524,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
user := input.User
account := input.Account
subscription := input.Subscription
+ ApplyForwardImageBillingResolution(result)
// 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens
// 用于粘性会话切换时的特殊计费处理
@@ -8637,6 +8670,7 @@ func (s *GatewayService) calculateImageCost(
billingModel string,
multiplier float64,
) *CostBreakdown {
+ sizeTier := NormalizeImageBillingTierOrDefault(result.ImageSize)
if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil {
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
@@ -8650,7 +8684,7 @@ func (s *GatewayService) calculateImageCost(
GroupID: &gid,
Tokens: tokens,
RequestCount: result.ImageCount,
- SizeTier: result.ImageSize,
+ SizeTier: sizeTier,
RateMultiplier: multiplier,
Resolver: s.resolver,
Resolved: resolved,
@@ -8670,7 +8704,7 @@ func (s *GatewayService) calculateImageCost(
Price4K: apiKey.Group.ImagePrice4K,
}
}
- return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
+ return s.billingService.CalculateImageCost(billingModel, sizeTier, result.ImageCount, groupConfig, multiplier)
}
// calculateTokenCost 计算 Token 计费:根据 opts 决定走普通/长上下文/渠道统一计费。
@@ -8771,6 +8805,10 @@ func (s *GatewayService) buildRecordUsageLog(
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
+ ImageInputSize: optionalTrimmedStringPtr(result.ImageInputSize),
+ ImageOutputSize: optionalTrimmedStringPtr(result.ImageOutputSize),
+ ImageSizeSource: optionalTrimmedStringPtr(result.ImageSizeSource),
+ ImageSizeBreakdown: result.ImageSizeBreakdown,
CacheTTLOverridden: cacheTTLOverridden,
ChannelID: optionalInt64Ptr(input.ChannelID),
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
diff --git a/backend/internal/service/gemini_chat_completions_compat_service.go b/backend/internal/service/gemini_chat_completions_compat_service.go
new file mode 100644
index 00000000..dcc3213b
--- /dev/null
+++ b/backend/internal/service/gemini_chat_completions_compat_service.go
@@ -0,0 +1,886 @@
+package service
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ "github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
+ "github.com/gin-gonic/gin"
+)
+
+// ForwardAsChatCompletions serves OpenAI Chat Completions clients through
+// Gemini accounts. It keeps the client-facing response in Chat Completions
+// format while routing the upstream call through Gemini native endpoints.
+func (s *GeminiMessagesCompatService) ForwardAsChatCompletions(
+ ctx context.Context,
+ c *gin.Context,
+ account *Account,
+ body []byte,
+) (*ForwardResult, error) {
+ startTime := time.Now()
+
+ var ccReq apicompat.ChatCompletionsRequest
+ if err := json.Unmarshal(body, &ccReq); err != nil {
+ return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
+ }
+ if strings.TrimSpace(ccReq.Model) == "" {
+ return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required")
+ }
+
+ originalModel := ccReq.Model
+ clientStream := ccReq.Stream
+ includeUsage := ccReq.StreamOptions != nil && ccReq.StreamOptions.IncludeUsage
+
+ responsesReq, err := apicompat.ChatCompletionsToResponses(&ccReq)
+ if err != nil {
+ return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
+ }
+
+ anthropicReq, err := apicompat.ResponsesToAnthropicRequest(responsesReq)
+ if err != nil {
+ return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
+ }
+ anthropicReq.Stream = clientStream
+
+ claudeBody, err := json.Marshal(anthropicReq)
+ if err != nil {
+ return nil, fmt.Errorf("marshal chat completions compat request: %w", err)
+ }
+
+ return s.forwardClaudeBodyAsChatCompletions(ctx, c, account, claudeBody, originalModel, clientStream, includeUsage, startTime, body)
+}
+
+func (s *GeminiMessagesCompatService) forwardClaudeBodyAsChatCompletions(
+ ctx context.Context,
+ c *gin.Context,
+ account *Account,
+ claudeBody []byte,
+ originalModel string,
+ clientStream bool,
+ includeUsage bool,
+ startTime time.Time,
+ originalChatBody []byte,
+) (*ForwardResult, error) {
+ var req struct {
+ Model string `json:"model"`
+ Stream bool `json:"stream"`
+ }
+ if err := json.Unmarshal(claudeBody, &req); err != nil {
+ return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
+ }
+ if strings.TrimSpace(req.Model) == "" {
+ return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required")
+ }
+
+ mappedModel := req.Model
+ if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount {
+ mappedModel = account.GetMappedModel(req.Model)
+ }
+
+ geminiReq, err := convertClaudeMessagesToGeminiGenerateContent(claudeBody)
+ if err != nil {
+ return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error())
+ }
+ geminiReq = ensureGeminiFunctionCallThoughtSignatures(geminiReq)
+
+ proxyURL := ""
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL = account.Proxy.URL()
+ }
+
+ useUpstreamStream := clientStream
+ if account.Type == AccountTypeOAuth && !clientStream && strings.TrimSpace(account.GetCredential("project_id")) != "" {
+ useUpstreamStream = true
+ }
+
+ buildReq, requestIDHeader := s.buildGeminiChatCompletionsUpstreamRequestFunc(
+ account,
+ mappedModel,
+ geminiReq,
+ clientStream,
+ useUpstreamStream,
+ )
+
+ var resp *http.Response
+ for attempt := 1; attempt <= geminiMaxRetries; attempt++ {
+ upstreamReq, idHeader, err := buildReq(ctx)
+ if err != nil {
+ if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
+ return nil, err
+ }
+ return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", err.Error())
+ }
+ requestIDHeader = idHeader
+
+ resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
+ if err != nil {
+ safeErr := sanitizeUpstreamErrorMessage(err.Error())
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: 0,
+ Kind: "request_error",
+ Message: safeErr,
+ })
+ if attempt < geminiMaxRetries {
+ logger.LegacyPrintf("service.gemini_chat_completions", "Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err)
+ sleepGeminiBackoff(attempt)
+ continue
+ }
+ setOpsUpstreamError(c, 0, safeErr, "")
+ return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+safeErr)
+ }
+
+ if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched {
+ resp = rebuilt
+ break
+ } else {
+ resp = rebuilt
+ }
+
+ if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) {
+ respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+ _ = resp.Body.Close()
+ if resp.StatusCode == http.StatusForbidden && isGeminiInsufficientScope(resp.Header, respBody) {
+ resp = &http.Response{
+ StatusCode: resp.StatusCode,
+ Header: resp.Header.Clone(),
+ Body: io.NopCloser(bytes.NewReader(respBody)),
+ }
+ break
+ }
+ if resp.StatusCode == http.StatusTooManyRequests {
+ s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
+ }
+ if attempt < geminiMaxRetries {
+ upstreamReqID := resp.Header.Get(requestIDHeader)
+ if upstreamReqID == "" {
+ upstreamReqID = resp.Header.Get("x-goog-request-id")
+ }
+ upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
+ upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: resp.StatusCode,
+ UpstreamRequestID: upstreamReqID,
+ Kind: "retry",
+ Message: upstreamMsg,
+ })
+ logger.LegacyPrintf("service.gemini_chat_completions", "Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries)
+ sleepGeminiBackoff(attempt)
+ continue
+ }
+ resp = &http.Response{
+ StatusCode: resp.StatusCode,
+ Header: resp.Header.Clone(),
+ Body: io.NopCloser(bytes.NewReader(respBody)),
+ }
+ break
+ }
+
+ break
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ requestID := resp.Header.Get(requestIDHeader)
+ if requestID == "" {
+ requestID = resp.Header.Get("x-goog-request-id")
+ }
+ if requestID != "" {
+ c.Header("x-request-id", requestID)
+ }
+
+ reasoningEffort := extractCCReasoningEffortFromBody(originalChatBody)
+
+ if resp.StatusCode >= 400 {
+ respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+ s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
+ evBody := unwrapIfNeeded(account.Type == AccountTypeOAuth, respBody)
+
+ if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
+ upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(evBody)))
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: resp.StatusCode,
+ UpstreamRequestID: requestID,
+ Kind: "failover",
+ Message: upstreamMsg,
+ })
+ return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody}
+ }
+
+ return nil, s.writeGeminiChatCompletionsMappedError(c, account, resp.StatusCode, requestID, evBody)
+ }
+
+ var usage *ClaudeUsage
+ var firstTokenMs *int
+ if clientStream {
+ streamRes, err := s.handleChatCompletionsStreamingResponseFromGemini(c, resp, startTime, originalModel, account.Type == AccountTypeOAuth, includeUsage)
+ if err != nil {
+ return nil, err
+ }
+ usage = streamRes.usage
+ firstTokenMs = streamRes.firstTokenMs
+ } else if useUpstreamStream {
+ collected, usageObj, err := collectGeminiSSE(resp.Body, account.Type == AccountTypeOAuth)
+ if err != nil {
+ return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream stream")
+ }
+ collectedBytes, _ := json.Marshal(collected)
+ chatResp, usageObj2, err := geminiResponseToChatCompletions(collected, originalModel, collectedBytes, usageObj)
+ if err != nil {
+ return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
+ }
+ c.JSON(http.StatusOK, chatResp)
+ usage = usageObj2
+ } else {
+ usageResp, err := s.handleChatCompletionsNonStreamingResponseFromGemini(c, resp, originalModel, account.Type == AccountTypeOAuth)
+ if err != nil {
+ return nil, err
+ }
+ usage = usageResp
+ }
+
+ if usage == nil {
+ usage = &ClaudeUsage{}
+ }
+
+ imageCount := 0
+ imageInputSize := s.extractImageInputSize(claudeBody)
+ imageSize := normalizeOpenAIImageSizeTier(imageInputSize)
+ if isImageGenerationModel(originalModel) {
+ imageCount = 1
+ }
+
+ return &ForwardResult{
+ RequestID: requestID,
+ Usage: *usage,
+ Model: originalModel,
+ UpstreamModel: mappedModel,
+ Stream: clientStream,
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ ReasoningEffort: reasoningEffort,
+ ImageCount: imageCount,
+ ImageSize: imageSize,
+ ImageInputSize: imageInputSize,
+ ClientDisconnect: false,
+ }, nil
+}
+
+func (s *GeminiMessagesCompatService) buildGeminiChatCompletionsUpstreamRequestFunc(
+ account *Account,
+ mappedModel string,
+ geminiReq []byte,
+ clientStream bool,
+ useUpstreamStream bool,
+) (func(context.Context) (*http.Request, string, error), string) {
+ switch account.Type {
+ case AccountTypeAPIKey:
+ return func(ctx context.Context) (*http.Request, string, error) {
+ apiKey := account.GetCredential("api_key")
+ if strings.TrimSpace(apiKey) == "" {
+ return nil, "", errors.New("gemini api_key not configured")
+ }
+
+ baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, "", err
+ }
+
+ action := "generateContent"
+ if clientStream {
+ action = "streamGenerateContent"
+ }
+ fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
+ if clientStream {
+ fullURL += "?alt=sse"
+ }
+
+ restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
+ upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
+ if err != nil {
+ return nil, "", err
+ }
+ upstreamReq.Header.Set("Content-Type", "application/json")
+ upstreamReq.Header.Set("x-goog-api-key", apiKey)
+ return upstreamReq, "x-request-id", nil
+ }, "x-request-id"
+
+ case AccountTypeOAuth:
+ return func(ctx context.Context) (*http.Request, string, error) {
+ if s.tokenProvider == nil {
+ return nil, "", errors.New("gemini token provider not configured")
+ }
+ accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
+ if err != nil {
+ return nil, "", err
+ }
+
+ projectID := strings.TrimSpace(account.GetCredential("project_id"))
+ action := "generateContent"
+ if useUpstreamStream {
+ action = "streamGenerateContent"
+ }
+
+ if projectID != "" {
+ baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL)
+ if err != nil {
+ return nil, "", err
+ }
+ fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action)
+ if useUpstreamStream {
+ fullURL += "?alt=sse"
+ }
+
+ var inner any
+ if err := json.Unmarshal(geminiReq, &inner); err != nil {
+ return nil, "", fmt.Errorf("failed to parse gemini request: %w", err)
+ }
+ wrappedBytes, _ := json.Marshal(map[string]any{
+ "model": mappedModel,
+ "project": projectID,
+ "request": inner,
+ })
+
+ upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes))
+ if err != nil {
+ return nil, "", err
+ }
+ upstreamReq.Header.Set("Content-Type", "application/json")
+ upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
+ upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent)
+ return upstreamReq, "x-request-id", nil
+ }
+
+ baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, "", err
+ }
+
+ fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action)
+ if useUpstreamStream {
+ fullURL += "?alt=sse"
+ }
+
+ restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
+ upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
+ if err != nil {
+ return nil, "", err
+ }
+ upstreamReq.Header.Set("Content-Type", "application/json")
+ upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
+ return upstreamReq, "x-request-id", nil
+ }, "x-request-id"
+
+ case AccountTypeServiceAccount:
+ return func(ctx context.Context) (*http.Request, string, error) {
+ if s.tokenProvider == nil {
+ return nil, "", errors.New("gemini token provider not configured")
+ }
+ accessToken, err := s.tokenProvider.GetAccessToken(ctx, account)
+ if err != nil {
+ return nil, "", err
+ }
+
+ action := "generateContent"
+ if clientStream {
+ action = "streamGenerateContent"
+ }
+ fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, action, clientStream)
+ if err != nil {
+ return nil, "", err
+ }
+
+ restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
+ upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
+ if err != nil {
+ return nil, "", err
+ }
+ upstreamReq.Header.Set("Content-Type", "application/json")
+ upstreamReq.Header.Set("Authorization", "Bearer "+accessToken)
+ return upstreamReq, "x-request-id", nil
+ }, "x-request-id"
+
+ default:
+ return func(context.Context) (*http.Request, string, error) {
+ return nil, "", fmt.Errorf("unsupported account type: %s", account.Type)
+ }, "x-request-id"
+ }
+}
+
+func (s *GeminiMessagesCompatService) handleChatCompletionsNonStreamingResponseFromGemini(
+ c *gin.Context,
+ resp *http.Response,
+ originalModel string,
+ isOAuth bool,
+) (*ClaudeUsage, error) {
+ respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
+ if err != nil {
+ return nil, err
+ }
+ if isOAuth {
+ if unwrappedBody, uwErr := unwrapGeminiResponse(respBody); uwErr == nil {
+ respBody = unwrappedBody
+ }
+ }
+
+ var geminiResp map[string]any
+ if err := json.Unmarshal(respBody, &geminiResp); err != nil {
+ return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
+ }
+
+ chatResp, usage, err := geminiResponseToChatCompletions(geminiResp, originalModel, respBody, nil)
+ if err != nil {
+ return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
+ }
+
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
+ c.JSON(http.StatusOK, chatResp)
+ return usage, nil
+}
+
+func geminiResponseToChatCompletions(
+ geminiResp map[string]any,
+ originalModel string,
+ rawData []byte,
+ usageOverride *ClaudeUsage,
+) (*apicompat.ChatCompletionsResponse, *ClaudeUsage, error) {
+ claudeRespMap, usage := convertGeminiToClaudeMessage(geminiResp, originalModel, rawData)
+ if usageOverride != nil && (usageOverride.InputTokens > 0 || usageOverride.OutputTokens > 0 || usageOverride.CacheReadInputTokens > 0) {
+ usage = usageOverride
+ if usageMap, ok := claudeRespMap["usage"].(map[string]any); ok {
+ usageMap["input_tokens"] = usage.InputTokens
+ usageMap["output_tokens"] = usage.OutputTokens
+ usageMap["cache_read_input_tokens"] = usage.CacheReadInputTokens
+ }
+ }
+
+ claudeBytes, err := json.Marshal(claudeRespMap)
+ if err != nil {
+ return nil, nil, err
+ }
+ var anthropicResp apicompat.AnthropicResponse
+ if err := json.Unmarshal(claudeBytes, &anthropicResp); err != nil {
+ return nil, nil, err
+ }
+ responsesResp := apicompat.AnthropicToResponsesResponse(&anthropicResp)
+ return apicompat.ResponsesToChatCompletions(responsesResp, originalModel), usage, nil
+}
+
+func (s *GeminiMessagesCompatService) handleChatCompletionsStreamingResponseFromGemini(
+ c *gin.Context,
+ resp *http.Response,
+ startTime time.Time,
+ originalModel string,
+ isOAuth bool,
+ includeUsage bool,
+) (*geminiStreamResult, error) {
+ if s.responseHeaderFilter != nil {
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
+ }
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Writer.WriteHeader(http.StatusOK)
+
+ flusher, ok := c.Writer.(http.Flusher)
+ if !ok {
+ return nil, errors.New("streaming not supported")
+ }
+
+ anthState := apicompat.NewAnthropicEventToResponsesState()
+ anthState.Model = originalModel
+ ccState := apicompat.NewResponsesEventToChatState()
+ ccState.Model = originalModel
+ ccState.IncludeUsage = includeUsage
+
+ var usage ClaudeUsage
+ var firstTokenMs *int
+ firstChunk := true
+
+ writeChatChunk := func(chunk apicompat.ChatCompletionsChunk) bool {
+ sse, err := apicompat.ChatChunkToSSE(chunk)
+ if err != nil {
+ return false
+ }
+ if _, err := io.WriteString(c.Writer, sse); err != nil {
+ return true
+ }
+ return false
+ }
+
+ emitAnthropicEvent := func(evt *apicompat.AnthropicStreamEvent) bool {
+ responsesEvents := apicompat.AnthropicEventToResponsesEvents(evt, anthState)
+ for _, resEvt := range responsesEvents {
+ chunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState)
+ for _, chunk := range chunks {
+ if disconnected := writeChatChunk(chunk); disconnected {
+ return true
+ }
+ }
+ }
+ flusher.Flush()
+ return false
+ }
+
+ messageID := "msg_" + randomHex(12)
+ if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{
+ Type: "message_start",
+ Message: &apicompat.AnthropicResponse{
+ ID: messageID,
+ Type: "message",
+ Role: "assistant",
+ Model: originalModel,
+ Content: []apicompat.AnthropicContentBlock{},
+ Usage: apicompat.AnthropicUsage{},
+ },
+ }) {
+ return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
+ }
+
+ finishReason := ""
+ sawToolUse := false
+ nextBlockIndex := 0
+ openBlockIndex := -1
+ openBlockType := ""
+ seenText := ""
+ openToolIndex := -1
+ openToolName := ""
+ seenToolJSON := ""
+
+ closeOpenBlock := func() bool {
+ if openBlockIndex < 0 {
+ return false
+ }
+ disconnected := emitAnthropicEvent(&apicompat.AnthropicStreamEvent{Type: "content_block_stop"})
+ openBlockIndex = -1
+ openBlockType = ""
+ return disconnected
+ }
+ closeOpenTool := func() bool {
+ if openToolIndex < 0 {
+ return false
+ }
+ disconnected := emitAnthropicEvent(&apicompat.AnthropicStreamEvent{Type: "content_block_stop"})
+ openToolIndex = -1
+ openToolName = ""
+ seenToolJSON = ""
+ return disconnected
+ }
+
+ reader := bufio.NewReader(resp.Body)
+ for {
+ line, err := reader.ReadString('\n')
+ if len(line) > 0 {
+ trimmed := strings.TrimRight(line, "\r\n")
+ if strings.HasPrefix(trimmed, "data:") {
+ payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:"))
+ if payload != "" && payload != "[DONE]" {
+ rawBytes := []byte(payload)
+ if isOAuth {
+ if innerBytes, uwErr := unwrapGeminiResponse(rawBytes); uwErr == nil {
+ rawBytes = innerBytes
+ }
+ }
+
+ var geminiResp map[string]any
+ if err := json.Unmarshal(rawBytes, &geminiResp); err == nil {
+ if firstChunk {
+ firstChunk = false
+ ms := int(time.Since(startTime).Milliseconds())
+ firstTokenMs = &ms
+ }
+ if fr := extractGeminiFinishReason(geminiResp); fr != "" {
+ finishReason = fr
+ }
+ if u := extractGeminiUsage(rawBytes); u != nil {
+ usage = *u
+ }
+
+ for _, part := range extractGeminiParts(geminiResp) {
+ if text, ok := part["text"].(string); ok && text != "" {
+ if openToolIndex >= 0 {
+ if closeOpenTool() {
+ return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
+ }
+ }
+ delta, newSeen := computeGeminiTextDelta(seenText, text)
+ seenText = newSeen
+ if delta == "" {
+ continue
+ }
+ if openBlockType != "text" {
+ if closeOpenBlock() {
+ return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
+ }
+ idx := nextBlockIndex
+ nextBlockIndex++
+ openBlockIndex = idx
+ openBlockType = "text"
+ if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{
+ Type: "content_block_start",
+ Index: &idx,
+ ContentBlock: &apicompat.AnthropicContentBlock{
+ Type: "text",
+ Text: "",
+ },
+ }) {
+ return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
+ }
+ }
+ if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{
+ Type: "content_block_delta",
+ Delta: &apicompat.AnthropicDelta{
+ Type: "text_delta",
+ Text: delta,
+ },
+ }) {
+ return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
+ }
+ continue
+ }
+
+ if fc, ok := part["functionCall"].(map[string]any); ok && fc != nil {
+ name, _ := fc["name"].(string)
+ if strings.TrimSpace(name) == "" {
+ name = "tool"
+ }
+ if closeOpenBlock() {
+ return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
+ }
+ if openToolIndex >= 0 && openToolName != name {
+ if closeOpenTool() {
+ return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
+ }
+ }
+ if openToolIndex < 0 {
+ idx := nextBlockIndex
+ nextBlockIndex++
+ openToolIndex = idx
+ openToolName = name
+ sawToolUse = true
+ if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{
+ Type: "content_block_start",
+ Index: &idx,
+ ContentBlock: &apicompat.AnthropicContentBlock{
+ Type: "tool_use",
+ ID: "toolu_" + randomHex(8),
+ Name: name,
+ Input: json.RawMessage(`{}`),
+ },
+ }) {
+ return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
+ }
+ }
+
+ argsJSONText := "{}"
+ switch v := fc["args"].(type) {
+ case nil:
+ case string:
+ if strings.TrimSpace(v) != "" {
+ argsJSONText = v
+ }
+ default:
+ if b, err := json.Marshal(v); err == nil && len(b) > 0 {
+ argsJSONText = string(b)
+ }
+ }
+ delta, newSeen := computeGeminiTextDelta(seenToolJSON, argsJSONText)
+ seenToolJSON = newSeen
+ if delta != "" {
+ if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{
+ Type: "content_block_delta",
+ Delta: &apicompat.AnthropicDelta{
+ Type: "input_json_delta",
+ PartialJSON: delta,
+ },
+ }) {
+ return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ if err != nil {
+ return nil, fmt.Errorf("stream read error: %w", err)
+ }
+ }
+
+ if closeOpenBlock() {
+ return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
+ }
+ if closeOpenTool() {
+ return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
+ }
+
+ stopReason := mapGeminiFinishReasonToClaudeStopReason(finishReason)
+ if sawToolUse {
+ stopReason = "tool_use"
+ }
+ anthState.InputTokens = usage.InputTokens
+ anthState.CacheReadInputTokens = usage.CacheReadInputTokens
+ if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{
+ Type: "message_delta",
+ Delta: &apicompat.AnthropicDelta{
+ Type: "message_delta",
+ StopReason: stopReason,
+ },
+ Usage: &apicompat.AnthropicUsage{
+ InputTokens: usage.InputTokens,
+ OutputTokens: usage.OutputTokens,
+ CacheReadInputTokens: usage.CacheReadInputTokens,
+ },
+ }) {
+ return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
+ }
+ if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{Type: "message_stop"}) {
+ return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
+ }
+
+ for _, resEvt := range apicompat.FinalizeAnthropicResponsesStream(anthState) {
+ chunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState)
+ for _, chunk := range chunks {
+ if disconnected := writeChatChunk(chunk); disconnected {
+ return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
+ }
+ }
+ }
+ for _, chunk := range apicompat.FinalizeResponsesChatStream(ccState) {
+ if disconnected := writeChatChunk(chunk); disconnected {
+ return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
+ }
+ }
+
+ _, _ = io.WriteString(c.Writer, "data: [DONE]\n\n")
+ flusher.Flush()
+
+ return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil
+}
+
+func (s *GeminiMessagesCompatService) writeGeminiChatCompletionsMappedError(
+ c *gin.Context,
+ account *Account,
+ upstreamStatus int,
+ upstreamRequestID string,
+ body []byte,
+) error {
+ upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(body)))
+ setOpsUpstreamError(c, upstreamStatus, upstreamMsg, "")
+ if account != nil {
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: upstreamStatus,
+ UpstreamRequestID: upstreamRequestID,
+ Kind: "http_error",
+ Message: upstreamMsg,
+ })
+ }
+
+ if status, errType, errMsg, matched := applyErrorPassthroughRule(
+ c,
+ PlatformGemini,
+ upstreamStatus,
+ body,
+ http.StatusBadGateway,
+ "upstream_error",
+ "Upstream request failed",
+ ); matched {
+ return s.writeChatCompletionsError(c, status, errType, errMsg)
+ }
+
+ statusCode := http.StatusBadGateway
+ errType := "upstream_error"
+ errMsg := "Upstream request failed"
+ if mapped := mapGeminiErrorBodyToClaudeError(body); mapped != nil {
+ if mapped.Type != "" {
+ errType = mapped.Type
+ }
+ if mapped.Message != "" {
+ errMsg = mapped.Message
+ }
+ if mapped.StatusCode > 0 {
+ statusCode = mapped.StatusCode
+ }
+ }
+
+ switch upstreamStatus {
+ case http.StatusBadRequest:
+ if statusCode == http.StatusBadGateway {
+ statusCode = http.StatusBadRequest
+ }
+ if errType == "upstream_error" {
+ errType = "invalid_request_error"
+ }
+ if errMsg == "Upstream request failed" {
+ errMsg = "Invalid request"
+ }
+ case http.StatusNotFound:
+ statusCode = http.StatusNotFound
+ if errType == "upstream_error" {
+ errType = "not_found_error"
+ }
+ if errMsg == "Upstream request failed" {
+ errMsg = "Resource not found"
+ }
+ case http.StatusTooManyRequests:
+ statusCode = http.StatusTooManyRequests
+ if errType == "upstream_error" {
+ errType = "rate_limit_error"
+ }
+ if errMsg == "Upstream request failed" {
+ errMsg = "Upstream rate limit exceeded, please retry later"
+ }
+ case 529:
+ statusCode = http.StatusServiceUnavailable
+ if errType == "upstream_error" {
+ errType = "overloaded_error"
+ }
+ if errMsg == "Upstream request failed" {
+ errMsg = "Upstream service overloaded, please retry later"
+ }
+ }
+
+ if upstreamMsg != "" && errMsg == "Upstream request failed" {
+ errMsg = upstreamMsg
+ }
+ return s.writeChatCompletionsError(c, statusCode, errType, errMsg)
+}
+
+func (s *GeminiMessagesCompatService) writeChatCompletionsError(c *gin.Context, status int, errType, message string) error {
+ c.JSON(status, gin.H{
+ "error": gin.H{
+ "type": errType,
+ "message": message,
+ },
+ })
+ return fmt.Errorf("%s", message)
+}
diff --git a/backend/internal/service/gemini_error_policy_test.go b/backend/internal/service/gemini_error_policy_test.go
index 4bd1ced7..84f9a706 100644
--- a/backend/internal/service/gemini_error_policy_test.go
+++ b/backend/internal/service/gemini_error_policy_test.go
@@ -383,6 +383,37 @@ func TestGeminiErrorPolicy_NilRateLimitService(t *testing.T) {
// policy tests. Embeds mockAccountRepoForGemini and adds tracking.
// ---------------------------------------------------------------------------
+func TestHandleGeminiUpstreamError_GoogleOneCapacityExhaustedUsesTierCooldown(t *testing.T) {
+ repo := &rateLimit429AccountRepoStub{}
+ quotaSvc := NewGeminiQuotaService(&config.Config{}, nil)
+ rlSvc := NewRateLimitService(repo, nil, &config.Config{}, quotaSvc, nil)
+ svc := &GeminiMessagesCompatService{
+ accountRepo: repo,
+ rateLimitService: rlSvc,
+ }
+
+ account := &Account{
+ ID: 511,
+ Platform: PlatformGemini,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "oauth_type": "google_one",
+ "tier_id": "google_ai_pro",
+ },
+ }
+ body := []byte(`{"error":{"code":429,"details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","domain":"cloudcode-pa.googleapis.com","metadata":{"model":"gemini-3.1-pro-preview"},"reason":"MODEL_CAPACITY_EXHAUSTED"}],"message":"No capacity available for model gemini-3.1-pro-preview on the server","status":"RESOURCE_EXHAUSTED"}}`)
+
+ before := time.Now()
+ svc.handleGeminiUpstreamError(context.Background(), account, http.StatusTooManyRequests, http.Header{}, body)
+ after := time.Now()
+
+ require.Equal(t, 1, repo.rateLimitCalls)
+ require.Equal(t, int64(511), repo.lastRateLimitID)
+ require.WithinDuration(t, before.Add(5*time.Minute), repo.lastRateLimitReset, 2*time.Second)
+ require.True(t, repo.lastRateLimitReset.After(before))
+ require.True(t, repo.lastRateLimitReset.Before(after.Add(5*time.Minute).Add(2*time.Second)))
+}
+
type geminiErrorPolicyRepo struct {
mockAccountRepoForGemini
setErrorCalls int
diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go
index ea0c0d7d..516556ca 100644
--- a/backend/internal/service/gemini_messages_compat_service.go
+++ b/backend/internal/service/gemini_messages_compat_service.go
@@ -766,12 +766,6 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
requestIDHeader = idHeader
- // Capture upstream request body for ops retry of this attempt.
- if c != nil {
- // In this code path `body` is already the JSON sent to upstream.
- c.Set(OpsUpstreamRequestBodyKey, string(body))
- }
-
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
@@ -1072,21 +1066,23 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
// 图片生成计费
imageCount := 0
- imageSize := s.extractImageSize(body)
+ imageInputSize := s.extractImageInputSize(body)
+ imageSize := normalizeOpenAIImageSizeTier(imageInputSize)
if isImageGenerationModel(originalModel) {
imageCount = 1
}
return &ForwardResult{
- RequestID: requestID,
- Usage: *usage,
- Model: originalModel,
- UpstreamModel: mappedModel,
- Stream: req.Stream,
- Duration: time.Since(startTime),
- FirstTokenMs: firstTokenMs,
- ImageCount: imageCount,
- ImageSize: imageSize,
+ RequestID: requestID,
+ Usage: *usage,
+ Model: originalModel,
+ UpstreamModel: mappedModel,
+ Stream: req.Stream,
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ ImageCount: imageCount,
+ ImageSize: imageSize,
+ ImageInputSize: imageInputSize,
}, nil
}
@@ -1291,12 +1287,6 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
requestIDHeader = idHeader
- // Capture upstream request body for ops retry of this attempt.
- if c != nil {
- // In this code path `body` is already the JSON sent to upstream.
- c.Set(OpsUpstreamRequestBodyKey, string(body))
- }
-
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
@@ -1600,21 +1590,23 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
// 图片生成计费
imageCount := 0
- imageSize := s.extractImageSize(body)
+ imageInputSize := s.extractImageInputSize(body)
+ imageSize := normalizeOpenAIImageSizeTier(imageInputSize)
if isImageGenerationModel(originalModel) {
imageCount = 1
}
return &ForwardResult{
- RequestID: requestID,
- Usage: *usage,
- Model: originalModel,
- UpstreamModel: mappedModel,
- Stream: stream,
- Duration: time.Since(startTime),
- FirstTokenMs: firstTokenMs,
- ImageCount: imageCount,
- ImageSize: imageSize,
+ RequestID: requestID,
+ Usage: *usage,
+ Model: originalModel,
+ UpstreamModel: mappedModel,
+ Stream: stream,
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ ImageCount: imageCount,
+ ImageSize: imageSize,
+ ImageInputSize: imageInputSize,
}, nil
}
@@ -2822,14 +2814,18 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
if resetAt == nil {
// 根据账号类型使用不同的默认重置时间
var ra time.Time
- if isCodeAssist {
- // Code Assist: fallback cooldown by tier
+ if isCodeAssist || oauthType == "google_one" {
+ // Gemini CLI / Google One: fallback cooldown by tier
cooldown := geminiCooldownForTier(tierID)
if s.rateLimitService != nil {
cooldown = s.rateLimitService.GeminiCooldown(ctx, account)
}
ra = time.Now().Add(cooldown)
- logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
+ if isCodeAssist {
+ logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
+ } else {
+ logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (Google One OAuth, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
+ }
} else {
// API Key / AI Studio OAuth: PST 午夜
if ts := nextGeminiDailyResetUnix(); ts != nil {
@@ -3430,8 +3426,7 @@ func convertClaudeGenerationConfig(req map[string]any) map[string]any {
return out
}
-// extractImageSize 从 Gemini 请求中提取 image_size 参数
-func (s *GeminiMessagesCompatService) extractImageSize(body []byte) string {
+func (s *GeminiMessagesCompatService) extractImageInputSize(body []byte) string {
var req struct {
GenerationConfig *struct {
ImageConfig *struct {
@@ -3440,15 +3435,12 @@ func (s *GeminiMessagesCompatService) extractImageSize(body []byte) string {
} `json:"generationConfig"`
}
if err := json.Unmarshal(body, &req); err != nil {
- return "2K"
+ return ""
}
if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil {
- size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize))
- if size == "1K" || size == "2K" || size == "4K" {
- return size
- }
+ return strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize)
}
- return "2K"
+ return ""
}
diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go
index c2adf45d..d0560344 100644
--- a/backend/internal/service/gemini_messages_compat_service_test.go
+++ b/backend/internal/service/gemini_messages_compat_service_test.go
@@ -1,6 +1,7 @@
package service
import (
+ "bytes"
"context"
"encoding/json"
"fmt"
@@ -41,6 +42,134 @@ func (s *geminiCompatHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL str
return s.Do(req, proxyURL, accountID, accountConcurrency)
}
+func TestGeminiForwardAsChatCompletions_OAuthRoutesToGeminiAndReturnsChatFormat(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ upstreamBody := `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello from gemini"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":7,"candidatesTokenCount":3}}}` + "\n\n" +
+ "data: [DONE]\n\n"
+ httpStub := &geminiCompatHTTPUpstreamStub{
+ response: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}},
+ Body: io.NopCloser(strings.NewReader(upstreamBody)),
+ },
+ }
+ svc := &GeminiMessagesCompatService{
+ tokenProvider: &GeminiTokenProvider{},
+ httpUpstream: httpStub,
+ cfg: &config.Config{},
+ }
+ account := &Account{
+ ID: 101,
+ Platform: PlatformGemini,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "ya29.test-token",
+ "project_id": "project-1",
+ },
+ Concurrency: 1,
+ }
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ body := []byte(`{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"hi"}]}`)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
+
+ result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Equal(t, "gemini-2.5-flash", result.Model)
+ require.Equal(t, 7, result.Usage.InputTokens)
+ require.Equal(t, 3, result.Usage.OutputTokens)
+
+ require.NotNil(t, httpStub.lastReq)
+ require.Contains(t, httpStub.lastReq.URL.String(), "/v1internal:streamGenerateContent?alt=sse")
+ require.Equal(t, "Bearer ya29.test-token", httpStub.lastReq.Header.Get("Authorization"))
+ require.Empty(t, httpStub.lastReq.Header.Get("x-api-key"))
+ require.Empty(t, httpStub.lastReq.Header.Get("anthropic-version"))
+
+ var sent map[string]any
+ sentBody, err := io.ReadAll(httpStub.lastReq.Body)
+ require.NoError(t, err)
+ require.NoError(t, json.Unmarshal(sentBody, &sent))
+ require.Equal(t, "gemini-2.5-flash", sent["model"])
+ require.Equal(t, "project-1", sent["project"])
+ require.Contains(t, fmt.Sprint(sent["request"]), "hi")
+
+ var got map[string]any
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
+ require.Equal(t, "chat.completion", got["object"])
+ require.Equal(t, "gemini-2.5-flash", got["model"])
+ choices, ok := got["choices"].([]any)
+ require.True(t, ok)
+ require.NotEmpty(t, choices)
+ choice, ok := choices[0].(map[string]any)
+ require.True(t, ok)
+ message, ok := choice["message"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "assistant", message["role"])
+ require.Equal(t, "hello from gemini", message["content"])
+ usage, ok := got["usage"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, float64(7), usage["prompt_tokens"])
+ require.Equal(t, float64(3), usage["completion_tokens"])
+ require.Equal(t, float64(10), usage["total_tokens"])
+}
+
+func TestGeminiForwardAsChatCompletions_StreamsOpenAIChunksFromGeminiSSE(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ upstreamBody := `data: {"candidates":[{"content":{"parts":[{"text":"hel"}]}}],"usageMetadata":{"promptTokenCount":2,"candidatesTokenCount":1}}` + "\n\n" +
+ `data: {"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":2,"candidatesTokenCount":2}}` + "\n\n" +
+ "data: [DONE]\n\n"
+ httpStub := &geminiCompatHTTPUpstreamStub{
+ response: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}},
+ Body: io.NopCloser(strings.NewReader(upstreamBody)),
+ },
+ }
+ svc := &GeminiMessagesCompatService{
+ httpUpstream: httpStub,
+ cfg: &config.Config{},
+ }
+ account := &Account{
+ ID: 102,
+ Platform: PlatformGemini,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": "gemini-api-key",
+ },
+ Concurrency: 1,
+ }
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ body := []byte(`{"model":"gemini-2.5-flash","stream":true,"stream_options":{"include_usage":true},"messages":[{"role":"user","content":"hi"}]}`)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
+
+ result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body)
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.True(t, result.Stream)
+ require.Equal(t, 2, result.Usage.InputTokens)
+ require.Equal(t, 2, result.Usage.OutputTokens)
+
+ require.NotNil(t, httpStub.lastReq)
+ require.Contains(t, httpStub.lastReq.URL.String(), "/v1beta/models/gemini-2.5-flash:streamGenerateContent?alt=sse")
+ require.Equal(t, "gemini-api-key", httpStub.lastReq.Header.Get("x-goog-api-key"))
+
+ out := rec.Body.String()
+ require.Contains(t, out, `"object":"chat.completion.chunk"`)
+ require.Contains(t, out, `"role":"assistant"`)
+ require.Contains(t, out, `"content":"hel"`)
+ require.Contains(t, out, `"content":"lo"`)
+ require.Contains(t, out, `"usage":{"prompt_tokens":2,"completion_tokens":2,"total_tokens":4}`)
+ require.Contains(t, out, "data: [DONE]")
+}
+
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
tests := []struct {
diff --git a/backend/internal/service/image_billing_size.go b/backend/internal/service/image_billing_size.go
new file mode 100644
index 00000000..0ca69ac4
--- /dev/null
+++ b/backend/internal/service/image_billing_size.go
@@ -0,0 +1,260 @@
+package service
+
+import (
+ "sort"
+ "strconv"
+ "strings"
+)
+
+const (
+ ImageBillingSize1K = "1K"
+ ImageBillingSize2K = "2K"
+ ImageBillingSize4K = "4K"
+
+ ImageSizeSourceOutput = "output"
+ ImageSizeSourceInput = "input"
+ ImageSizeSourceDefault = "default"
+ ImageSizeSourceLegacy = "legacy"
+)
+
+type ImageBillingSizeResolution struct {
+ BillingSize string
+ InputSize string
+ OutputSize string
+ Source string
+ Breakdown map[string]int
+}
+
+func ClassifyImageBillingTier(size string) (string, bool) {
+ trimmed := strings.TrimSpace(size)
+ normalized := strings.ToLower(trimmed)
+ switch normalized {
+ case "", "auto":
+ return "", false
+ case "1k":
+ return ImageBillingSize1K, true
+ case "2k":
+ return ImageBillingSize2K, true
+ case "4k":
+ return ImageBillingSize4K, true
+ case "2048x2048", "2048x1152":
+ return ImageBillingSize2K, true
+ case "3840x2160", "2160x3840":
+ return ImageBillingSize4K, true
+ }
+
+ width, height, ok := parseImageBillingDimensions(trimmed)
+ if !ok {
+ return "", false
+ }
+ maxEdge := width
+ if height > maxEdge {
+ maxEdge = height
+ }
+ switch {
+ case maxEdge <= 1024:
+ return ImageBillingSize1K, true
+ case maxEdge <= 2048:
+ return ImageBillingSize2K, true
+ default:
+ return ImageBillingSize4K, true
+ }
+}
+
+func NormalizeImageBillingTierOrDefault(size string) string {
+ if tier, ok := ClassifyImageBillingTier(size); ok {
+ return tier
+ }
+ return ImageBillingSize2K
+}
+
+func ResolveImageBillingSize(inputSize string, outputSizes []string) ImageBillingSizeResolution {
+ inputSize = strings.TrimSpace(inputSize)
+ outputSizes = compactTrimmedStrings(outputSizes)
+
+ breakdown := map[string]int{}
+ outputSize := firstDisplayImageOutputSize(outputSizes)
+ outputTier := ""
+ for _, output := range outputSizes {
+ tier, ok := ClassifyImageBillingTier(output)
+ if !ok {
+ continue
+ }
+ breakdown[tier]++
+ if imageTierRank(tier) > imageTierRank(outputTier) {
+ outputTier = tier
+ }
+ }
+ if outputTier != "" {
+ return ImageBillingSizeResolution{
+ BillingSize: outputTier,
+ InputSize: inputSize,
+ OutputSize: outputSize,
+ Source: ImageSizeSourceOutput,
+ Breakdown: normalizeImageSizeBreakdown(breakdown),
+ }
+ }
+
+ if tier, ok := ClassifyImageBillingTier(inputSize); ok {
+ return ImageBillingSizeResolution{
+ BillingSize: tier,
+ InputSize: inputSize,
+ OutputSize: outputSize,
+ Source: ImageSizeSourceInput,
+ }
+ }
+
+ return ImageBillingSizeResolution{
+ BillingSize: ImageBillingSize2K,
+ InputSize: inputSize,
+ OutputSize: outputSize,
+ Source: ImageSizeSourceDefault,
+ }
+}
+
+func ApplyOpenAIImageBillingResolution(result *OpenAIForwardResult) {
+ if result == nil || result.ImageCount <= 0 {
+ return
+ }
+ inputSize := strings.TrimSpace(result.ImageInputSize)
+ if inputSize == "" && strings.TrimSpace(result.ImageSize) != ImageBillingSize2K {
+ inputSize = strings.TrimSpace(result.ImageSize)
+ }
+ outputSizes := result.ImageOutputSizes
+ if len(outputSizes) == 0 && strings.TrimSpace(result.ImageOutputSize) != "" {
+ outputSizes = []string{result.ImageOutputSize}
+ }
+ resolved := ResolveImageBillingSize(inputSize, outputSizes)
+ applyImageBillingResolution(
+ &result.ImageSize,
+ &result.ImageInputSize,
+ &result.ImageOutputSize,
+ &result.ImageSizeSource,
+ &result.ImageSizeBreakdown,
+ resolved,
+ )
+}
+
+func ApplyForwardImageBillingResolution(result *ForwardResult) {
+ if result == nil || result.ImageCount <= 0 {
+ return
+ }
+ inputSize := strings.TrimSpace(result.ImageInputSize)
+ if inputSize == "" && strings.TrimSpace(result.ImageSize) != ImageBillingSize2K {
+ inputSize = strings.TrimSpace(result.ImageSize)
+ }
+ outputSizes := result.ImageOutputSizes
+ if len(outputSizes) == 0 && strings.TrimSpace(result.ImageOutputSize) != "" {
+ outputSizes = []string{result.ImageOutputSize}
+ }
+ resolved := ResolveImageBillingSize(inputSize, outputSizes)
+ applyImageBillingResolution(
+ &result.ImageSize,
+ &result.ImageInputSize,
+ &result.ImageOutputSize,
+ &result.ImageSizeSource,
+ &result.ImageSizeBreakdown,
+ resolved,
+ )
+}
+
+func applyImageBillingResolution(
+ billingSize *string,
+ inputSize *string,
+ outputSize *string,
+ source *string,
+ breakdown *map[string]int,
+ resolved ImageBillingSizeResolution,
+) {
+ *billingSize = resolved.BillingSize
+ *inputSize = resolved.InputSize
+ *outputSize = resolved.OutputSize
+ *source = resolved.Source
+ *breakdown = resolved.Breakdown
+}
+
+func parseImageBillingDimensions(size string) (int, int, bool) {
+ parts := strings.Split(strings.ToLower(strings.TrimSpace(size)), "x")
+ if len(parts) != 2 {
+ return 0, 0, false
+ }
+ width, err := strconv.Atoi(strings.TrimSpace(parts[0]))
+ if err != nil {
+ return 0, 0, false
+ }
+ height, err := strconv.Atoi(strings.TrimSpace(parts[1]))
+ if err != nil {
+ return 0, 0, false
+ }
+ if width <= 0 || height <= 0 {
+ return 0, 0, false
+ }
+ return width, height, true
+}
+
+func compactTrimmedStrings(values []string) []string {
+ if len(values) == 0 {
+ return nil
+ }
+ out := make([]string, 0, len(values))
+ for _, value := range values {
+ trimmed := strings.TrimSpace(value)
+ if trimmed != "" {
+ out = append(out, trimmed)
+ }
+ }
+ return out
+}
+
+func firstDisplayImageOutputSize(outputSizes []string) string {
+ for _, output := range outputSizes {
+ if trimmed := strings.TrimSpace(output); trimmed != "" {
+ return trimmed
+ }
+ }
+ return ""
+}
+
+func imageTierRank(tier string) int {
+ switch strings.ToUpper(strings.TrimSpace(tier)) {
+ case ImageBillingSize1K:
+ return 1
+ case ImageBillingSize2K:
+ return 2
+ case ImageBillingSize4K:
+ return 3
+ default:
+ return 0
+ }
+}
+
+func normalizeImageSizeBreakdown(in map[string]int) map[string]int {
+ if len(in) == 0 {
+ return nil
+ }
+ out := make(map[string]int, len(in))
+ for _, tier := range []string{ImageBillingSize1K, ImageBillingSize2K, ImageBillingSize4K} {
+ if count := in[tier]; count > 0 {
+ out[tier] = count
+ }
+ }
+ if len(out) == 0 {
+ return nil
+ }
+ return out
+}
+
+func SortedImageBillingBreakdownKeys(breakdown map[string]int) []string {
+ keys := make([]string, 0, len(breakdown))
+ for key := range breakdown {
+ keys = append(keys, key)
+ }
+ sort.Slice(keys, func(i, j int) bool {
+ left, right := imageTierRank(keys[i]), imageTierRank(keys[j])
+ if left == right {
+ return keys[i] < keys[j]
+ }
+ return left < right
+ })
+ return keys
+}
diff --git a/backend/internal/service/image_billing_size_test.go b/backend/internal/service/image_billing_size_test.go
new file mode 100644
index 00000000..48c9ac34
--- /dev/null
+++ b/backend/internal/service/image_billing_size_test.go
@@ -0,0 +1,110 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestClassifyImageBillingTier(t *testing.T) {
+ tests := []struct {
+ name string
+ size string
+ wantTier string
+ wantOK bool
+ }{
+ {name: "explicit 2k square", size: "2048x2048", wantTier: "2K", wantOK: true},
+ {name: "explicit 2k landscape", size: "2048x1152", wantTier: "2K", wantOK: true},
+ {name: "explicit 4k landscape", size: "3840x2160", wantTier: "4K", wantOK: true},
+ {name: "explicit 4k portrait", size: "2160x3840", wantTier: "4K", wantOK: true},
+ {name: "long edge 1k", size: "1024X768", wantTier: "1K", wantOK: true},
+ {name: "long edge 2k", size: "1280x768", wantTier: "2K", wantOK: true},
+ {name: "long edge 4k", size: "2560x1600", wantTier: "4K", wantOK: true},
+ {name: "tier string 1k", size: "1k", wantTier: "1K", wantOK: true},
+ {name: "empty", size: "", wantOK: false},
+ {name: "auto", size: "auto", wantOK: false},
+ {name: "invalid", size: "not-a-size", wantOK: false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ gotTier, gotOK := ClassifyImageBillingTier(tt.size)
+ require.Equal(t, tt.wantOK, gotOK)
+ require.Equal(t, tt.wantTier, gotTier)
+ })
+ }
+}
+
+func TestResolveImageBillingSize(t *testing.T) {
+ tests := []struct {
+ name string
+ inputSize string
+ outputSizes []string
+ wantBilling string
+ wantOutput string
+ wantSource string
+ wantBreakdown map[string]int
+ }{
+ {
+ name: "output wins over input",
+ inputSize: "1024x1024",
+ outputSizes: []string{"3840x2160"},
+ wantBilling: "4K",
+ wantOutput: "3840x2160",
+ wantSource: ImageSizeSourceOutput,
+ wantBreakdown: map[string]int{"4K": 1},
+ },
+ {
+ name: "input fallback",
+ inputSize: "1024x1024",
+ wantBilling: "1K",
+ wantSource: ImageSizeSourceInput,
+ },
+ {
+ name: "auto defaults",
+ inputSize: "auto",
+ wantBilling: "2K",
+ wantSource: ImageSizeSourceDefault,
+ },
+ {
+ name: "empty defaults",
+ inputSize: "",
+ wantBilling: "2K",
+ wantSource: ImageSizeSourceDefault,
+ },
+ {
+ name: "invalid defaults",
+ inputSize: "largest",
+ wantBilling: "2K",
+ wantSource: ImageSizeSourceDefault,
+ },
+ {
+ name: "mixed output chooses highest tier",
+ inputSize: "1024x1024",
+ outputSizes: []string{"1024x1024", "3840x2160", "1280x720"},
+ wantBilling: "4K",
+ wantOutput: "1024x1024",
+ wantSource: ImageSizeSourceOutput,
+ wantBreakdown: map[string]int{"1K": 1, "2K": 1, "4K": 1},
+ },
+ {
+ name: "unparseable output falls back to parseable input",
+ inputSize: "2048x1152",
+ outputSizes: []string{"auto"},
+ wantBilling: "2K",
+ wantOutput: "auto",
+ wantSource: ImageSizeSourceInput,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := ResolveImageBillingSize(tt.inputSize, tt.outputSizes)
+ require.Equal(t, tt.wantBilling, got.BillingSize)
+ require.Equal(t, tt.inputSize, got.InputSize)
+ require.Equal(t, tt.wantOutput, got.OutputSize)
+ require.Equal(t, tt.wantSource, got.Source)
+ require.Equal(t, tt.wantBreakdown, got.Breakdown)
+ })
+ }
+}
diff --git a/backend/internal/service/image_generation_intent.go b/backend/internal/service/image_generation_intent.go
index b6ef1065..4aca1239 100644
--- a/backend/internal/service/image_generation_intent.go
+++ b/backend/internal/service/image_generation_intent.go
@@ -170,7 +170,13 @@ func cloneRequestMapForImageIntent(body []byte) map[string]any {
return out
}
-func resolveOpenAIResponsesImageBillingConfig(reqBody map[string]any, fallbackModel string) (string, string, error) {
+type OpenAIResponsesImageBillingConfig struct {
+ Model string
+ SizeTier string
+ InputSize string
+}
+
+func resolveOpenAIResponsesImageBillingConfigDetailed(reqBody map[string]any, fallbackModel string) (OpenAIResponsesImageBillingConfig, error) {
imageModel := ""
imageSize := ""
hasImageTool := false
@@ -203,12 +209,24 @@ func resolveOpenAIResponsesImageBillingConfig(reqBody map[string]any, fallbackMo
imageModel = strings.TrimSpace(fallbackModel)
}
sizeTier := normalizeOpenAIImageSizeTier(imageSize)
- return imageModel, sizeTier, nil
+ return OpenAIResponsesImageBillingConfig{
+ Model: imageModel,
+ SizeTier: sizeTier,
+ InputSize: imageSize,
+ }, nil
}
func resolveOpenAIResponsesImageBillingConfigFromBody(body []byte, fallbackModel string) (string, string, error) {
+ cfg, err := resolveOpenAIResponsesImageBillingConfigDetailedFromBody(body, fallbackModel)
+ if err != nil {
+ return "", "", err
+ }
+ return cfg.Model, cfg.SizeTier, nil
+}
+
+func resolveOpenAIResponsesImageBillingConfigDetailedFromBody(body []byte, fallbackModel string) (OpenAIResponsesImageBillingConfig, error) {
reqBody := cloneRequestMapForImageIntent(body)
- return resolveOpenAIResponsesImageBillingConfig(reqBody, fallbackModel)
+ return resolveOpenAIResponsesImageBillingConfigDetailed(reqBody, fallbackModel)
}
func isOpenAIImageBillingModelAlias(model string) bool {
diff --git a/backend/internal/service/image_generation_intent_test.go b/backend/internal/service/image_generation_intent_test.go
index 5e7bec79..4621e9d9 100644
--- a/backend/internal/service/image_generation_intent_test.go
+++ b/backend/internal/service/image_generation_intent_test.go
@@ -140,9 +140,10 @@ func TestResolveOpenAIResponsesImageBillingConfigDoesNotRejectUnknownSizes(t *te
func TestOpenAIImageOutputCounterDeduplicatesFinalImages(t *testing.T) {
counter := newOpenAIImageOutputCounter()
counter.AddSSEData([]byte(`{"type":"response.image_generation_call.partial_image","partial_image_b64":"abc"}`))
- counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_1","type":"image_generation_call","result":"final-a"}}`))
- counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_1","type":"image_generation_call","result":"final-a"},{"id":"ig_2","type":"image_generation_call","result":"final-b"}]}}`))
+ counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_1","type":"image_generation_call","result":"final-a","size":"1024x1024"}}`))
+ counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_1","type":"image_generation_call","result":"final-a"},{"id":"ig_2","type":"image_generation_call","result":"final-b","size":"3840x2160"}]}}`))
require.Equal(t, 2, counter.Count())
+ require.Equal(t, []string{"1024x1024", "3840x2160"}, counter.Sizes())
}
func TestOpenAIImageOutputCounterCountsImagesAPIStreamShapes(t *testing.T) {
@@ -182,3 +183,36 @@ func TestOpenAIImageOutputCounterFallsBackForInvalidMultilineSSEBody(t *testing.
)
require.Equal(t, 2, counter.Count())
}
+
+func TestCollectOpenAIResponseImageOutputSizesFromJSONBytes(t *testing.T) {
+ body := []byte(`{
+ "output": [
+ {"id":"ig_1","type":"image_generation_call","result":"final-a","size":"3840x2160"},
+ {"id":"ig_2","type":"image_generation_call","result":"final-b","size":"1024x1024"}
+ ]
+ }`)
+
+ require.Equal(t, 2, countOpenAIResponseImageOutputsFromJSONBytes(body))
+ require.Equal(t, []string{"3840x2160", "1024x1024"}, collectOpenAIResponseImageOutputSizesFromJSONBytes(body))
+}
+
+func TestCollectOpenAIResponseImageOutputSizesFromImagesAPIData(t *testing.T) {
+ body := []byte(`{
+ "data": [
+ {"b64_json":"final-a","size":"2048x1152"},
+ {"b64_json":"final-b","size":"2048x1152"}
+ ]
+ }`)
+
+ require.Equal(t, 2, countOpenAIResponseImageOutputsFromJSONBytes(body))
+ require.Equal(t, []string{"2048x1152", "2048x1152"}, collectOpenAIResponseImageOutputSizesFromJSONBytes(body))
+}
+
+func TestCollectOpenAIImageOutputSizesFromSSEBody(t *testing.T) {
+ body := "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_1\",\"type\":\"image_generation_call\",\"result\":\"final-a\",\"size\":\"3840x2160\"}}\n\n" +
+ "data: {\"type\":\"response.completed\",\"response\":{\"output\":[{\"id\":\"ig_1\",\"type\":\"image_generation_call\",\"result\":\"final-a\"},{\"id\":\"ig_2\",\"type\":\"image_generation_call\",\"result\":\"final-b\",\"size\":\"1024x1024\"}]}}\n\n" +
+ "data: [DONE]\n\n"
+
+ require.Equal(t, 2, countOpenAIImageOutputsFromSSEBody(body))
+ require.Equal(t, []string{"3840x2160", "1024x1024"}, collectOpenAIImageOutputSizesFromSSEBody(body))
+}
diff --git a/backend/internal/service/image_output_accounting.go b/backend/internal/service/image_output_accounting.go
index 219c0c59..2f2bd6ae 100644
--- a/backend/internal/service/image_output_accounting.go
+++ b/backend/internal/service/image_output_accounting.go
@@ -10,12 +10,18 @@ import (
type openAIImageOutputCounter struct {
seen map[string]struct{}
+ seenSizes map[string]string
+ seenOrder []string
+ dataSizes []string
count int
maxDataCount int
}
func newOpenAIImageOutputCounter() *openAIImageOutputCounter {
- return &openAIImageOutputCounter{seen: make(map[string]struct{})}
+ return &openAIImageOutputCounter{
+ seen: make(map[string]struct{}),
+ seenSizes: make(map[string]string),
+ }
}
func (c *openAIImageOutputCounter) Count() int {
@@ -28,6 +34,25 @@ func (c *openAIImageOutputCounter) Count() int {
return c.count
}
+func (c *openAIImageOutputCounter) Sizes() []string {
+ if c == nil {
+ return nil
+ }
+ sizes := make([]string, 0, len(c.seenOrder)+len(c.dataSizes))
+ for _, key := range c.seenOrder {
+ if size := strings.TrimSpace(c.seenSizes[key]); size != "" {
+ sizes = append(sizes, size)
+ }
+ }
+ if len(sizes) == 0 && len(c.dataSizes) > 0 {
+ sizes = append(sizes, c.dataSizes...)
+ }
+ if len(sizes) == 0 {
+ return nil
+ }
+ return sizes
+}
+
func (c *openAIImageOutputCounter) AddJSONResponse(body []byte) {
if c == nil || len(body) == 0 || !gjson.ValidBytes(body) {
return
@@ -73,10 +98,20 @@ func (c *openAIImageOutputCounter) addDataArray(data gjson.Result) {
if !data.IsArray() {
return
}
- count := len(data.Array())
+ items := data.Array()
+ count := len(items)
if count > c.maxDataCount {
c.maxDataCount = count
}
+ sizes := make([]string, 0, len(items))
+ for _, item := range items {
+ if size := strings.TrimSpace(item.Get("size").String()); size != "" {
+ sizes = append(sizes, size)
+ }
+ }
+ if len(sizes) > 0 {
+ c.dataSizes = sizes
+ }
}
func (c *openAIImageOutputCounter) addOutputArray(output gjson.Result) {
@@ -120,10 +155,18 @@ func (c *openAIImageOutputCounter) addImageOutputItem(item gjson.Result) {
if key == "" {
return
}
+ size := strings.TrimSpace(item.Get("size").String())
if _, exists := c.seen[key]; exists {
+ if size != "" && strings.TrimSpace(c.seenSizes[key]) == "" {
+ c.seenSizes[key] = size
+ }
return
}
c.seen[key] = struct{}{}
+ c.seenOrder = append(c.seenOrder, key)
+ if size != "" {
+ c.seenSizes[key] = size
+ }
c.count++
}
@@ -142,8 +185,20 @@ func countOpenAIResponseImageOutputsFromJSONBytes(body []byte) int {
return counter.Count()
}
+func collectOpenAIResponseImageOutputSizesFromJSONBytes(body []byte) []string {
+ counter := newOpenAIImageOutputCounter()
+ counter.AddJSONResponse(body)
+ return counter.Sizes()
+}
+
func countOpenAIImageOutputsFromSSEBody(body string) int {
counter := newOpenAIImageOutputCounter()
counter.AddSSEBody(body)
return counter.Count()
}
+
+func collectOpenAIImageOutputSizesFromSSEBody(body string) []string {
+ counter := newOpenAIImageOutputCounter()
+ counter.AddSSEBody(body)
+ return counter.Sizes()
+}
diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go
index a3b69dee..1c2e3cb3 100644
--- a/backend/internal/service/openai_codex_transform.go
+++ b/backend/internal/service/openai_codex_transform.go
@@ -8,6 +8,7 @@ import (
var codexModelMap = map[string]string{
"gpt-5.5": "gpt-5.5",
+ "codex-auto-review": "codex-auto-review",
"gpt-5.4": "gpt-5.4",
"gpt-5.4-mini": "gpt-5.4-mini",
"gpt-5.4-none": "gpt-5.4",
@@ -1030,7 +1031,7 @@ func filterCodexInputWithOptions(input []any, opts codexInputFilterOptions) []an
return id
}
if strings.HasPrefix(id, "call_") {
- return "fc" + strings.TrimPrefix(id, "call_")
+ return "fc_" + strings.TrimPrefix(id, "call_")
}
return "fc_" + id
}
diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go
index 9c72760a..4c182b8e 100644
--- a/backend/internal/service/openai_codex_transform_test.go
+++ b/backend/internal/service/openai_codex_transform_test.go
@@ -41,7 +41,7 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) {
second, ok := input[1].(map[string]any)
require.True(t, ok)
require.Equal(t, "o1", second["id"])
- require.Equal(t, "fc1", second["call_id"])
+ require.Equal(t, "fc_1", second["call_id"])
}
func TestApplyCodexOAuthTransform_MessagesBridgePromptCacheKeyIsHeaderOnly(t *testing.T) {
@@ -120,11 +120,11 @@ func TestApplyCodexOAuthTransform_ToolContinuationNormalizesToolReferenceIDsOnly
first, ok := input[0].(map[string]any)
require.True(t, ok)
- require.Equal(t, "fc1", first["id"])
+ require.Equal(t, "fc_1", first["id"])
second, ok := input[1].(map[string]any)
require.True(t, ok)
- require.Equal(t, "fc1", second["call_id"])
+ require.Equal(t, "fc_1", second["call_id"])
}
func TestApplyCodexOAuthTransform_ToolSearchOutputPreservesCallID(t *testing.T) {
@@ -144,7 +144,7 @@ func TestApplyCodexOAuthTransform_ToolSearchOutputPreservesCallID(t *testing.T)
first, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "tool_search_output", first["type"])
- require.Equal(t, "fc1", first["call_id"])
+ require.Equal(t, "fc_1", first["call_id"])
}
func TestApplyCodexOAuthTransform_CustomAndMCPToolOutputsPreserveCallID(t *testing.T) {
@@ -164,11 +164,11 @@ func TestApplyCodexOAuthTransform_CustomAndMCPToolOutputsPreserveCallID(t *testi
first, ok := input[0].(map[string]any)
require.True(t, ok)
- require.Equal(t, "fccustom", first["call_id"])
+ require.Equal(t, "fc_custom", first["call_id"])
second, ok := input[1].(map[string]any)
require.True(t, ok)
- require.Equal(t, "fcmcp", second["call_id"])
+ require.Equal(t, "fc_mcp", second["call_id"])
}
func TestApplyCodexOAuthTransform_ImageAndWebSearchCallsDoNotGainCallID(t *testing.T) {
@@ -221,7 +221,7 @@ func TestApplyCodexOAuthTransform_ConvertsToolRoleMessageToFunctionCallOutput(t
item, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "function_call_output", item["type"])
- require.Equal(t, "fc1", item["call_id"])
+ require.Equal(t, "fc_1", item["call_id"])
require.Equal(t, "ok", item["output"])
_, hasRole := item["role"]
require.False(t, hasRole)
@@ -340,7 +340,7 @@ func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testin
require.True(t, ok)
require.Equal(t, "function_call", item["type"])
require.Equal(t, "tool", item["name"])
- require.Equal(t, "fc1", item["call_id"])
+ require.Equal(t, "fc_1", item["call_id"])
}
func TestApplyCodexOAuthTransform_PreservesFunctionCallInputName(t *testing.T) {
@@ -359,7 +359,7 @@ func TestApplyCodexOAuthTransform_PreservesFunctionCallInputName(t *testing.T) {
item, ok := input[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "shell", item["name"])
- require.Equal(t, "fc1", item["call_id"])
+ require.Equal(t, "fc_1", item["call_id"])
}
func TestApplyCodexOAuthTransform_PreservesMCPToolCallIDAndName(t *testing.T) {
@@ -384,7 +384,7 @@ func TestApplyCodexOAuthTransform_PreservesMCPToolCallIDAndName(t *testing.T) {
require.True(t, ok)
require.Equal(t, "mcp_tool_call", item["type"])
require.Equal(t, "remote_tool", item["name"])
- require.Equal(t, "fcabc", item["call_id"])
+ require.Equal(t, "fc_abc", item["call_id"])
}
func TestCodexInputItemRequiresNameTypesAllowCallID(t *testing.T) {
@@ -839,6 +839,7 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) {
"gpt-5.4": "gpt-5.4",
"gpt5.5": "gpt-5.5",
"openai/gpt5.5": "gpt-5.5",
+ "codex-auto-review": "codex-auto-review",
"gpt5.4": "gpt-5.4",
"gpt-5.4-high": "gpt-5.4",
"gpt-5.4-chat-latest": "gpt-5.4",
diff --git a/backend/internal/service/openai_compat_model_test.go b/backend/internal/service/openai_compat_model_test.go
index e222b093..f8b9d360 100644
--- a/backend/internal/service/openai_compat_model_test.go
+++ b/backend/internal/service/openai_compat_model_test.go
@@ -183,6 +183,63 @@ func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T
t.Logf("response body: %s", rec.Body.String())
}
+func TestForwardAsAnthropic_MappedClaudeModelAcceptsChatUsageShape(t *testing.T) {
+ t.Parallel()
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ body := []byte(`{"model":"claude-opus-4-7","max_tokens":16,"messages":[{"role":"user","content":"compact this"}],"stream":true}`)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ upstreamBody := strings.Join([]string{
+ `data: {"type":"response.created","response":{"id":"resp_compact","model":"gpt-5.5","status":"in_progress","output":[]}}`,
+ "",
+ `data: {"type":"response.output_text.delta","delta":"ok"}`,
+ "",
+ `data: {"type":"response.completed","response":{"id":"resp_compact","object":"response","model":"gpt-5.5","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"prompt_tokens":31,"completion_tokens":9,"total_tokens":40,"prompt_tokens_details":{"cached_tokens":11}}}}`,
+ "",
+ "data: [DONE]",
+ "",
+ }, "\n")
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_compact_usage"}},
+ Body: io.NopCloser(strings.NewReader(upstreamBody)),
+ }}
+
+ svc := &OpenAIGatewayService{
+ httpUpstream: upstream,
+ cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
+ }
+ account := &Account{
+ ID: 1,
+ Name: "openai-apikey",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "base_url": "https://api.openai.com/v1",
+ "model_mapping": map[string]any{
+ "gpt-5.5": "gpt-5.5",
+ },
+ },
+ }
+
+ result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.5")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "claude-opus-4-7", result.Model)
+ require.Equal(t, "gpt-5.5", result.BillingModel)
+ require.Equal(t, "gpt-5.5", result.UpstreamModel)
+ require.Equal(t, 31, result.Usage.InputTokens)
+ require.Equal(t, 9, result.Usage.OutputTokens)
+ require.Equal(t, 11, result.Usage.CacheReadInputTokens)
+ require.Equal(t, "gpt-5.5", gjson.GetBytes(upstream.lastBody, "model").String())
+}
+
func TestForwardAsAnthropic_InjectsPromptCacheKeyForAPIKeyMessagesDispatch(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
@@ -1360,6 +1417,135 @@ func TestForwardAsAnthropic_TerminalUsageWithoutUpstreamCloseReturns(t *testing.
}
}
+func TestForwardAsAnthropic_EventNamedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
+ body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ upstreamBody := []byte(strings.Join([]string{
+ `event: response.completed`,
+ `data: {"response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}`,
+ ``,
+ ``,
+ }, "\n"))
+ upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
+ defer func() {
+ require.NoError(t, upstreamStream.Close())
+ }()
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_messages_event_named_terminal"}},
+ Body: upstreamStream,
+ }}
+
+ svc := &OpenAIGatewayService{httpUpstream: upstream}
+ account := &Account{
+ ID: 1,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "oauth-token",
+ "chatgpt_account_id": "chatgpt-acc",
+ },
+ }
+
+ type forwardResult struct {
+ result *OpenAIForwardResult
+ err error
+ }
+ resultCh := make(chan forwardResult, 1)
+ go func() {
+ result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
+ resultCh <- forwardResult{result: result, err: err}
+ }()
+
+ select {
+ case got := <-resultCh:
+ require.NoError(t, got.err)
+ require.NotNil(t, got.result)
+ require.Equal(t, 15, got.result.Usage.InputTokens)
+ require.Equal(t, 6, got.result.Usage.OutputTokens)
+ require.Equal(t, 5, got.result.Usage.CacheReadInputTokens)
+ case <-time.After(time.Second):
+ require.Fail(t, "ForwardAsAnthropic should use SSE event names when data payloads omit type")
+ }
+}
+
+func TestForwardAsAnthropic_EventNamedTerminalWithKeepaliveReturns(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Writer = &openAICompatFailingWriter{ResponseWriter: c.Writer, failAfter: 0}
+ body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":true}`)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ upstreamBody := []byte(strings.Join([]string{
+ `: upstream ping`,
+ ``,
+ `event: response.completed`,
+ `data: {"response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}`,
+ ``,
+ ``,
+ }, "\n"))
+ upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
+ defer func() {
+ require.NoError(t, upstreamStream.Close())
+ }()
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_messages_event_named_keepalive"}},
+ Body: upstreamStream,
+ }}
+
+ svc := &OpenAIGatewayService{
+ cfg: &config.Config{Gateway: config.GatewayConfig{
+ StreamKeepaliveInterval: 5,
+ }},
+ httpUpstream: upstream,
+ }
+ account := &Account{
+ ID: 1,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "oauth-token",
+ "chatgpt_account_id": "chatgpt-acc",
+ },
+ }
+
+ type forwardResult struct {
+ result *OpenAIForwardResult
+ err error
+ }
+ resultCh := make(chan forwardResult, 1)
+ go func() {
+ result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
+ resultCh <- forwardResult{result: result, err: err}
+ }()
+
+ select {
+ case got := <-resultCh:
+ require.NoError(t, got.err)
+ require.NotNil(t, got.result)
+ require.Equal(t, 15, got.result.Usage.InputTokens)
+ require.Equal(t, 6, got.result.Usage.OutputTokens)
+ require.Equal(t, 5, got.result.Usage.CacheReadInputTokens)
+ case <-time.After(time.Second):
+ require.Fail(t, "ForwardAsAnthropic keepalive path should use SSE event names when data payloads omit type")
+ }
+}
+
func TestForwardAsAnthropic_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -1416,6 +1602,67 @@ func TestForwardAsAnthropic_BufferedTerminalWithoutUpstreamCloseReturns(t *testi
}
}
+func TestForwardAsAnthropic_BufferedEventNamedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ body := []byte(`{"model":"gpt-5.4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"stream":false}`)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ upstreamBody := []byte(strings.Join([]string{
+ `event: response.completed`,
+ `data: {"response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":15,"output_tokens":6,"total_tokens":21,"input_tokens_details":{"cached_tokens":5}}}}`,
+ ``,
+ ``,
+ }, "\n"))
+ upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
+ defer func() {
+ require.NoError(t, upstreamStream.Close())
+ }()
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_messages_buffered_event_named"}},
+ Body: upstreamStream,
+ }}
+
+ svc := &OpenAIGatewayService{httpUpstream: upstream}
+ account := &Account{
+ ID: 1,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "oauth-token",
+ "chatgpt_account_id": "chatgpt-acc",
+ },
+ }
+
+ type forwardResult struct {
+ result *OpenAIForwardResult
+ err error
+ }
+ resultCh := make(chan forwardResult, 1)
+ go func() {
+ result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.1")
+ resultCh <- forwardResult{result: result, err: err}
+ }()
+
+ select {
+ case got := <-resultCh:
+ require.NoError(t, got.err)
+ require.NotNil(t, got.result)
+ require.Equal(t, 15, got.result.Usage.InputTokens)
+ require.Equal(t, 6, got.result.Usage.OutputTokens)
+ require.Equal(t, 5, got.result.Usage.CacheReadInputTokens)
+ require.Contains(t, rec.Body.String(), `"stop_reason":"end_turn"`)
+ case <-time.After(time.Second):
+ require.Fail(t, "ForwardAsAnthropic buffered response should use SSE event names when data payloads omit type")
+ }
+}
+
func TestForwardAsAnthropic_DoneSentinelWithoutTerminalReturnsError(t *testing.T) {
gin.SetMode(gin.TestMode)
diff --git a/backend/internal/service/openai_endpoint_url.go b/backend/internal/service/openai_endpoint_url.go
new file mode 100644
index 00000000..93ae9b95
--- /dev/null
+++ b/backend/internal/service/openai_endpoint_url.go
@@ -0,0 +1,78 @@
+package service
+
+import (
+ "net/url"
+ "strings"
+)
+
+func buildOpenAIEndpointURL(base string, endpoint string) string {
+ normalized := strings.TrimRight(strings.TrimSpace(base), "/")
+ endpoint = "/" + strings.TrimLeft(strings.TrimSpace(endpoint), "/")
+ relative := strings.TrimPrefix(endpoint, "/v1")
+ if strings.HasSuffix(normalized, endpoint) || strings.HasSuffix(normalized, relative) {
+ return normalized
+ }
+ if openAIBaseURLHasVersionSuffix(normalized) {
+ return normalized + relative
+ }
+ return normalized + endpoint
+}
+
+func openAIBaseURLHasVersionSuffix(raw string) bool {
+ trimmed := strings.TrimSpace(raw)
+ if trimmed == "" {
+ return false
+ }
+
+ pathValue := ""
+ if parsed, err := url.Parse(trimmed); err == nil && parsed.Scheme != "" && parsed.Host != "" {
+ pathValue = parsed.Path
+ } else if slash := strings.Index(trimmed, "/"); slash >= 0 {
+ pathValue = trimmed[slash:]
+ }
+
+ pathValue = strings.TrimRight(pathValue, "/")
+ if pathValue == "" {
+ return false
+ }
+ lastSlash := strings.LastIndex(pathValue, "/")
+ segment := pathValue
+ if lastSlash >= 0 {
+ segment = pathValue[lastSlash+1:]
+ }
+ return isOpenAIAPIVersionSegment(segment)
+}
+
+func isOpenAIAPIVersionSegment(segment string) bool {
+ s := strings.ToLower(strings.TrimSpace(segment))
+ if len(s) < 2 || s[0] != 'v' || !isASCIIDigit(s[1]) {
+ return false
+ }
+
+ i := 1
+ for i < len(s) && isASCIIDigit(s[i]) {
+ i++
+ }
+ if i == len(s) {
+ return true
+ }
+ if s[i] == '.' {
+ i++
+ if i == len(s) || !isASCIIDigit(s[i]) {
+ return false
+ }
+ for i < len(s) && isASCIIDigit(s[i]) {
+ i++
+ }
+ return i == len(s)
+ }
+
+ suffix := s[i:]
+ return strings.HasPrefix(suffix, "alpha") ||
+ strings.HasPrefix(suffix, "beta") ||
+ strings.HasPrefix(suffix, "preview")
+}
+
+func isASCIIDigit(b byte) bool {
+ return b >= '0' && b <= '9'
+}
diff --git a/backend/internal/service/openai_fast_policy_test.go b/backend/internal/service/openai_fast_policy_test.go
index b52da614..70fcaffa 100644
--- a/backend/internal/service/openai_fast_policy_test.go
+++ b/backend/internal/service/openai_fast_policy_test.go
@@ -8,6 +8,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
)
type openAIFastPolicyRepoStub struct {
@@ -62,25 +63,33 @@ func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolic
}
}
-func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) {
+func openAIFastFilterPriorityPolicy() *OpenAIFastPolicySettings {
+ return &OpenAIFastPolicySettings{
+ Rules: []OpenAIFastPolicyRule{{
+ ServiceTier: OpenAIFastTierPriority,
+ Action: BetaPolicyActionFilter,
+ Scope: BetaPolicyScopeAll,
+ ModelWhitelist: []string{},
+ FallbackAction: BetaPolicyActionPass,
+ }},
+ }
+}
+
+func TestEvaluateOpenAIFastPolicy_DefaultPassesKnownTiers(t *testing.T) {
+ require.Empty(t, DefaultOpenAIFastPolicySettings().Rules, "default policy must not rewrite service_tier unless admin configured rules")
+
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- // 默认策略对所有模型生效(whitelist 为空),因为 codex 的 service_tier=fast
- // 是用户级开关,与 model 正交。
- // gpt-5.5 + priority → filter
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority)
- require.Equal(t, BetaPolicyActionFilter, action)
+ require.Equal(t, BetaPolicyActionPass, action)
- // gpt-5.5-turbo → filter
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority)
- require.Equal(t, BetaPolicyActionFilter, action)
+ require.Equal(t, BetaPolicyActionPass, action)
- // gpt-4 + priority → filter(默认策略覆盖所有模型)
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority)
- require.Equal(t, BetaPolicyActionFilter, action)
+ require.Equal(t, BetaPolicyActionPass, action)
- // gpt-5.5 + flex → pass (tier doesn't match)
action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex)
require.Equal(t, BetaPolicyActionPass, action)
@@ -129,27 +138,24 @@ func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) {
require.Equal(t, BetaPolicyActionPass, action)
}
-func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) {
+func TestApplyOpenAIFastPolicyToBody_DefaultPassesPriorityAndFast(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- // gpt-5.5 fast → service_tier stripped
body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`)
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
- require.NotContains(t, string(updated), `"service_tier"`)
+ require.Equal(t, string(body), string(updated))
- // Client sending "fast" (alias for priority) also filtered
body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err)
- require.NotContains(t, string(updated), `"service_tier"`)
+ require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String())
- // gpt-4 priority → 默认策略对所有模型 filter,service_tier 被移除
body = []byte(`{"model":"gpt-4","service_tier":"priority"}`)
updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body)
require.NoError(t, err)
- require.NotContains(t, string(updated), `"service_tier"`)
+ require.Equal(t, string(body), string(updated))
// No service_tier → no-op
body = []byte(`{"model":"gpt-5.5"}`)
@@ -158,9 +164,23 @@ func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) {
require.Equal(t, string(body), string(updated))
}
-// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后
-// 客户端显式发送的 OpenAI 官方合法 tier(auto/default/scale)能透传到上游而不被
-// 静默剥离。默认策略只针对 priority,所以这些 tier 落在 fall-through pass 分支。
+func TestApplyOpenAIFastPolicyToBody_ExplicitFilterRemovesField(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, openAIFastFilterPriorityPolicy())
+ account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
+
+ body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`)
+ updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.NoError(t, err)
+ require.NotContains(t, string(updated), `"service_tier"`)
+
+ body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`)
+ updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
+ require.NoError(t, err)
+ require.NotContains(t, string(updated), `"service_tier"`)
+}
+
+// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证默认配置
+// 下客户端显式发送的 OpenAI 官方合法 tier 能透传到上游而不被静默剥离。
func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
@@ -170,10 +190,10 @@ func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T
updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body)
require.NoError(t, err, "tier %q should pass without error", tier)
require.Contains(t, string(updated), `"service_tier":"`+tier+`"`,
- "tier %q should be preserved in body under default rule", tier)
+ "tier %q should be preserved in body under default policy", tier)
}
- // evaluate 层也应判定为 pass(默认规则 ServiceTier=priority 与 auto/default/scale 不匹配)
+ // evaluate 层也应判定为 pass(默认配置没有内置规则)。
for _, tier := range []string{"auto", "default", "scale"} {
action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier)
require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier)
diff --git a/backend/internal/service/openai_fast_policy_ws_test.go b/backend/internal/service/openai_fast_policy_ws_test.go
index 7c8341b2..4624e7a5 100644
--- a/backend/internal/service/openai_fast_policy_ws_test.go
+++ b/backend/internal/service/openai_fast_policy_ws_test.go
@@ -22,7 +22,7 @@ import (
// --- Helper-level (unit) tests for applyOpenAIFastPolicyToWSResponseCreate ---
-func TestWSResponseCreate_FilterStripsServiceTier(t *testing.T) {
+func TestWSResponseCreate_DefaultPassesPriorityAndNormalizesFast(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
@@ -30,26 +30,37 @@ func TestWSResponseCreate_FilterStripsServiceTier(t *testing.T) {
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
- require.NotContains(t, string(updated), `"service_tier"`, "filter action should strip service_tier")
+ require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String(), "default policy should preserve priority tier")
// Other fields preserved.
require.Equal(t, "response.create", gjson.GetBytes(updated, "type").String())
require.Equal(t, "gpt-5.5", gjson.GetBytes(updated, "model").String())
require.Equal(t, "hi", gjson.GetBytes(updated, "input.0.text").String())
+
+ frame = []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"fast"}`)
+ updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String(), "fast alias should normalize before reaching upstream")
+
+ // Mixed-case + whitespace variant should also normalize.
+ frame = []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":" Fast "}`)
+ updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
+ require.NoError(t, err)
+ require.Nil(t, blocked)
+ require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String())
}
-func TestWSResponseCreate_FastNormalizedToPriorityThenFiltered(t *testing.T) {
- svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+func TestWSResponseCreate_ExplicitFilterStripsServiceTier(t *testing.T) {
+ svc := newOpenAIGatewayServiceWithSettings(t, openAIFastFilterPriorityPolicy())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- // Verbatim "fast" → normalized to "priority" → matches default rule → filter.
- frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"fast"}`)
+ frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority","input":[{"type":"input_text","text":"hi"}]}`)
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
- require.NotContains(t, string(updated), `"service_tier"`)
+ require.NotContains(t, string(updated), `"service_tier"`, "filter action should strip service_tier")
- // Mixed-case + whitespace variant should also normalize and filter.
- frame = []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":" Fast "}`)
+ frame = []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"fast"}`)
updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
require.Nil(t, blocked)
@@ -60,7 +71,7 @@ func TestWSResponseCreate_FlexPassThrough(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- // Default policy targets priority only; flex is left untouched.
+ // Default policy has no rules; flex is left untouched.
frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"flex"}`)
updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame)
require.NoError(t, err)
@@ -220,8 +231,8 @@ func (f *fakePassthroughFrameConn) Close() error {
}
// gpt55WhitelistFastPolicy 返回一份强制带 model whitelist 的策略,用于
-// 验证 capturedSessionModel fallback 的语义(默认策略 whitelist 为空时
-// fallback 路径无法被观察到)。
+// 验证 capturedSessionModel fallback 的语义(默认配置没有规则,fallback
+// 路径无法被观察到)。
func gpt55WhitelistFastPolicy() *OpenAIFastPolicySettings {
return &OpenAIFastPolicySettings{
Rules: []OpenAIFastPolicyRule{{
@@ -242,7 +253,7 @@ func gpt55WhitelistFastPolicy() *OpenAIFastPolicySettings {
// through to the upstream.
func TestPolicyEnforcingFrameConn_FollowupFrameWithoutModelUsesCapturedModel(t *testing.T) {
// 此处特意使用带 whitelist 的策略,以便观察 capturedSessionModel
- // fallback 是否生效(默认策略 whitelist 为空,fallback 与否结果一致,
+ // fallback 是否生效(默认配置没有规则,fallback 与否结果一致,
// 不能用来覆盖此回归)。
svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
@@ -310,13 +321,13 @@ func TestPolicyEnforcingFrameConn_WithoutCapturedFallbackPolicyMisses(t *testing
"sanity: without capturedSessionModel fallback the leak (D5) reproduces — confirms the fix is load-bearing")
}
-// --- Ingress end-to-end test (filter path) ---
+// --- Ingress end-to-end test (explicit filter path) ---
// TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream wires up the
// real ProxyResponsesWebSocketFromClient ingress session pipeline against a
// captureConn upstream and asserts that a client frame with service_tier=fast
-// is normalized + filtered out before being written upstream. This is the
-// integration flavour of TestWSResponseCreate_FilterStripsServiceTier.
+// is normalized + filtered out by an explicit admin policy before being
+// written upstream.
func TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -345,9 +356,9 @@ func TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream(t *testing.T)
pool.setClientDialerForTest(captureDialer)
repo := &openAIFastPolicyRepoStub{values: map[string]string{}}
- defaultJSON, err := json.Marshal(DefaultOpenAIFastPolicySettings())
+ filterPolicyJSON, err := json.Marshal(openAIFastFilterPriorityPolicy())
require.NoError(t, err)
- repo.values[SettingKeyOpenAIFastPolicySettings] = string(defaultJSON)
+ repo.values[SettingKeyOpenAIFastPolicySettings] = string(filterPolicyJSON)
svc := &OpenAIGatewayService{
cfg: cfg,
@@ -631,13 +642,13 @@ func TestApplyOpenAIFastPolicyToBody_BlockShortCircuitsUpstream(t *testing.T) {
require.Equal(t, string(body), string(updated), "block must not mutate body")
}
-// TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy verifies
-// the Anthropic-compat entrypoint chain: anthropic-beta: fast-mode → BetaFastMode
-// detection → ServiceTier="priority" injection (openai_gateway_messages.go:60)
-// → applyOpenAIFastPolicyToBody filter on default policy → upstream body has
-// no service_tier. We exercise the same internal pipeline (Anthropic→Responses
-// + BetaFastMode + policy) without spinning up a real upstream HTTP server.
-func TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy(t *testing.T) {
+// TestForwardAsAnthropicMessages_BetaFastModePassesOpenAIFastPolicyByDefault
+// verifies the Anthropic-compat entrypoint chain: anthropic-beta: fast-mode →
+// BetaFastMode detection → ServiceTier="priority" injection
+// (openai_gateway_messages.go:60) → default OpenAI fast policy pass. We
+// exercise the same internal pipeline (Anthropic→Responses + BetaFastMode +
+// policy) without spinning up a real upstream HTTP server.
+func TestForwardAsAnthropicMessages_BetaFastModePassesOpenAIFastPolicyByDefault(t *testing.T) {
svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
@@ -663,8 +674,9 @@ func TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy(t *test
upstreamBody, policyErr := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", responsesBody)
require.NoError(t, policyErr)
- // Step 4: assert that policy filtered the field before the upstream HTTP request.
- require.NotContains(t, string(upstreamBody), `"service_tier"`, "default policy 命中 gpt-5.5 priority 应当 filter 掉 service_tier")
+ // Step 4: default policy must preserve the explicit fast/priority request.
+ require.Equal(t, "priority", gjson.GetBytes(upstreamBody, "service_tier").String(),
+ "default policy should pass service_tier=priority through to upstream")
}
// --- Fix1: passthrough capturedSessionModel must follow session.update ---
@@ -808,7 +820,7 @@ func TestApplyOpenAIFastPolicyToBody_PassNormalizesFastAlias(t *testing.T) {
// tier) instead of the user-requested "priority". This test pins the
// contract those two helpers must uphold for the adapter's billing path.
func TestPassthroughBilling_PostFilterServiceTier(t *testing.T) {
- svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ svc := newOpenAIGatewayServiceWithSettings(t, openAIFastFilterPriorityPolicy())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
raw := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`)
@@ -821,7 +833,7 @@ func TestPassthroughBilling_PostFilterServiceTier(t *testing.T) {
require.Equal(t, "priority", *pre,
"sanity: raw first frame carries priority that pre-fix billing would have reported")
- // Apply policy filter (default rule: gpt-5.5 + priority → filter).
+ // Apply explicit policy filter (gpt-5.5 + priority → filter).
filtered, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", raw)
require.NoError(t, err)
require.Nil(t, blocked)
@@ -890,9 +902,9 @@ func TestApplyOpenAIFastPolicyToBody_NonStringServiceTier(t *testing.T) {
// atomic.Pointer[string] on every successful response.create frame.
//
// This test pins the four legs of the semantic contract:
-// - turn 1: service_tier=priority hits the default whitelist filter, so
+// - turn 1: service_tier=priority hits the explicit filter rule, so
// after filter the upstream sees no tier → billing is nil.
-// - turn 2: service_tier=flex passes (default rule targets priority only),
+// - turn 2: service_tier=flex passes (the filter rule targets priority only),
// billing should now reflect "flex".
// - turn 3: response.create without any service_tier — the upstream will
// treat it as default; we choose to mirror that and overwrite billing
@@ -900,7 +912,7 @@ func TestApplyOpenAIFastPolicyToBody_NonStringServiceTier(t *testing.T) {
// - non-response.create frame (response.cancel here) carrying a stray
// service_tier-shaped field must NOT clobber the billing pointer.
func TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames(t *testing.T) {
- svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings())
+ svc := newOpenAIGatewayServiceWithSettings(t, openAIFastFilterPriorityPolicy())
account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
// Mirror the production filter closure (openai_ws_v2_passthrough_adapter.go
diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go
index 84d85c74..f8b23a28 100644
--- a/backend/internal/service/openai_gateway_chat_completions.go
+++ b/backend/internal/service/openai_gateway_chat_completions.go
@@ -48,10 +48,10 @@ var cursorResponsesUnsupportedFields = []string{
// 正确的,但 sub2api 接入 DeepSeek/Kimi/GLM 等第三方 OpenAI 兼容上游后假设破裂:
// 这些上游普遍只支持 /v1/chat/completions,无 /v1/responses 端点。
//
-// 当前路由策略(基于账号探测标记,详见 openai_compat.ShouldUseResponsesAPI):
-// - APIKey 账号 + 探测确认不支持 Responses → 走 forwardAsRawChatCompletions
+// 当前路由策略(基于账号覆盖模式/探测标记,详见 openai_compat.ShouldUseResponsesAPI):
+// - APIKey 账号 + 强制或探测确认不支持 Responses → 走 forwardAsRawChatCompletions
// 直转上游 /v1/chat/completions,不做协议转换
-// - 其他所有情况(OAuth、APIKey 探测确认支持、未探测)→ 走原有 CC→Responses
+// - 其他所有情况(OAuth、APIKey 强制/探测确认支持、未探测)→ 走原有 CC→Responses
// 转换路径(保留旧行为,存量未探测账号零兼容破坏)
func (s *OpenAIGatewayService) ForwardAsChatCompletions(
ctx context.Context,
@@ -61,8 +61,8 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
promptCacheKey string,
defaultMappedModel string,
) (*OpenAIForwardResult, error) {
- // 入口分流:APIKey 账号 + 已探测且确认上游不支持 Responses,走 CC 直转。
- // 标记缺失(未探测)按"现状即证据"原则继续走下方原 Responses 转换路径。
+ // 入口分流:APIKey 账号 + 强制或已探测确认上游不支持 Responses,走 CC 直转。
+ // 自动模式下标记缺失(未探测)按"现状即证据"原则继续走下方原 Responses 转换路径。
if account.Type == AccountTypeAPIKey && !openai_compat.ShouldUseResponsesAPI(account.Extra) {
return s.forwardAsRawChatCompletions(ctx, c, account, body, defaultMappedModel)
}
@@ -247,6 +247,16 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
+ if account.Type == AccountTypeAPIKey &&
+ openai_compat.ResolveResponsesSupport(account.Extra) == openai_compat.ResponsesSupportUnknown &&
+ !isResponsesEndpointSupportedByStatus(resp.StatusCode) {
+ logger.L().Info("openai chat_completions: /responses unsupported, falling back to raw chat completions",
+ zap.Int64("account_id", account.ID),
+ zap.Int("upstream_status", resp.StatusCode),
+ zap.String("upstream_message", upstreamMsg),
+ )
+ return s.forwardAsRawChatCompletions(ctx, c, account, body, defaultMappedModel)
+ }
if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) {
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
@@ -282,7 +292,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions(
var result *OpenAIForwardResult
var handleErr error
if clientStream {
- result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, includeUsage, startTime)
+ result, handleErr = s.handleChatStreamingResponse(resp, c, account, originalModel, billingModel, upstreamModel, includeUsage, startTime, len(body))
} else {
result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime)
}
@@ -404,22 +414,31 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse(
func (s *OpenAIGatewayService) handleChatStreamingResponse(
resp *http.Response,
c *gin.Context,
+ account *Account,
originalModel string,
billingModel string,
upstreamModel string,
includeUsage bool,
startTime time.Time,
+ requestBodyLen int,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
- if s.responseHeaderFilter != nil {
- responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
+ headersWritten := false
+ writeStreamHeaders := func() {
+ if headersWritten {
+ return
+ }
+ headersWritten = true
+ if s.responseHeaderFilter != nil {
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
+ }
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Writer.WriteHeader(http.StatusOK)
}
- c.Writer.Header().Set("Content-Type", "text/event-stream")
- c.Writer.Header().Set("Cache-Control", "no-cache")
- c.Writer.Header().Set("Connection", "keep-alive")
- c.Writer.Header().Set("X-Accel-Buffering", "no")
- c.Writer.WriteHeader(http.StatusOK)
state := apicompat.NewResponsesEventToChatState()
state.Model = originalModel
@@ -429,6 +448,9 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
var firstTokenMs *int
firstChunk := true
clientDisconnected := false
+ clientOutputStarted := false
+ pendingSSE := make([]string, 0, 4)
+ refusalDetector := newOpenAIChatSilentRefusalDetector(requestBodyLen)
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
@@ -479,6 +501,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
)
return false
}
+ refusalDetector.ObservePayload([]byte(payload))
// 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。
isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type)
@@ -489,6 +512,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
chunks := apicompat.ResponsesEventToChatChunks(&event, state)
if !clientDisconnected {
for _, chunk := range chunks {
+ refusalDetector.ObserveChatChunk(chunk)
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
logger.L().Warn("openai chat_completions stream: failed to marshal chunk",
@@ -497,6 +521,27 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
)
continue
}
+ if !clientOutputStarted && !refusalDetector.ShouldReleaseClientOutput() {
+ pendingSSE = append(pendingSSE, sse)
+ continue
+ }
+ if !clientOutputStarted {
+ writeStreamHeaders()
+ for _, pending := range pendingSSE {
+ if _, err := fmt.Fprint(c.Writer, pending); err != nil {
+ clientDisconnected = true
+ logger.L().Info("openai chat_completions stream: client disconnected while flushing pending chunks",
+ zap.String("request_id", requestID),
+ )
+ break
+ }
+ }
+ pendingSSE = pendingSSE[:0]
+ clientOutputStarted = !clientDisconnected
+ if clientDisconnected {
+ break
+ }
+ }
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
clientDisconnected = true
logger.L().Info("openai chat_completions stream: client disconnected, continuing to drain upstream for billing",
@@ -506,7 +551,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
}
}
}
- if len(chunks) > 0 && !clientDisconnected {
+ if len(chunks) > 0 && !clientDisconnected && clientOutputStarted {
c.Writer.Flush()
}
return isTerminalEvent
@@ -515,10 +560,32 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
finalizeStream := func() (*OpenAIForwardResult, error) {
if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 && !clientDisconnected {
for _, chunk := range finalChunks {
+ refusalDetector.ObserveChatChunk(chunk)
sse, err := apicompat.ChatChunkToSSE(chunk)
if err != nil {
continue
}
+ if !clientOutputStarted && !refusalDetector.ShouldReleaseClientOutput() {
+ pendingSSE = append(pendingSSE, sse)
+ continue
+ }
+ if !clientOutputStarted {
+ writeStreamHeaders()
+ for _, pending := range pendingSSE {
+ if _, err := fmt.Fprint(c.Writer, pending); err != nil {
+ clientDisconnected = true
+ logger.L().Info("openai chat_completions stream: client disconnected during pending final flush",
+ zap.String("request_id", requestID),
+ )
+ break
+ }
+ }
+ pendingSSE = pendingSSE[:0]
+ clientOutputStarted = !clientDisconnected
+ if clientDisconnected {
+ break
+ }
+ }
if _, err := fmt.Fprint(c.Writer, sse); err != nil {
clientDisconnected = true
logger.L().Info("openai chat_completions stream: client disconnected during final flush",
@@ -528,14 +595,35 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
}
}
}
+ if !clientDisconnected && !clientOutputStarted {
+ if refusalDetector.IsSilentRefusal() {
+ return nil, newOpenAISilentRefusalFailoverError(c, account, requestID)
+ }
+ if len(pendingSSE) > 0 {
+ writeStreamHeaders()
+ for _, pending := range pendingSSE {
+ if _, err := fmt.Fprint(c.Writer, pending); err != nil {
+ clientDisconnected = true
+ logger.L().Info("openai chat_completions stream: client disconnected during final pending flush",
+ zap.String("request_id", requestID),
+ )
+ break
+ }
+ }
+ pendingSSE = pendingSSE[:0]
+ clientOutputStarted = !clientDisconnected
+ }
+ }
// Send [DONE] sentinel
if !clientDisconnected {
+ writeStreamHeaders()
if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil {
clientDisconnected = true
logger.L().Info("openai chat_completions stream: client disconnected during done flush",
zap.String("request_id", requestID),
)
}
+ clientOutputStarted = !clientDisconnected
}
if !clientDisconnected {
c.Writer.Flush()
@@ -554,6 +642,13 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
missingTerminalErr := func() (*OpenAIForwardResult, error) {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
+ processFrame := func(frame openAICompatSSEFrame) bool {
+ payload := openAICompatPayloadWithEventType(frame.Data, frame.EventType)
+ if strings.TrimSpace(payload) == "[DONE]" {
+ return false
+ }
+ return processDataLine(payload)
+ }
// Determine keepalive interval
keepaliveInterval := time.Duration(0)
@@ -563,16 +658,17 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
// No keepalive: fast synchronous path
if streamInterval <= 0 && keepaliveInterval <= 0 {
+ var parser openAICompatSSEFrameParser
for scanner.Scan() {
line := scanner.Text()
- payload, ok := extractOpenAISSEDataLine(line)
+ frame, ok := parser.AddLine(line)
if !ok {
continue
}
- if strings.TrimSpace(payload) == "[DONE]" {
+ if strings.TrimSpace(frame.Data) == "[DONE]" {
return missingTerminalErr()
}
- if processDataLine(payload) {
+ if processFrame(frame) {
return finalizeStream()
}
}
@@ -580,6 +676,14 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
handleScanErr(err)
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
}
+ if frame, ok := parser.Finish(); ok {
+ if strings.TrimSpace(frame.Data) == "[DONE]" {
+ return missingTerminalErr()
+ }
+ if processFrame(frame) {
+ return finalizeStream()
+ }
+ }
return missingTerminalErr()
}
@@ -624,11 +728,20 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
+ var parser openAICompatSSEFrameParser
for {
select {
case ev, ok := <-events:
if !ok {
+ if frame, ok := parser.Finish(); ok {
+ if strings.TrimSpace(frame.Data) == "[DONE]" {
+ return missingTerminalErr()
+ }
+ if processFrame(frame) {
+ return finalizeStream()
+ }
+ }
return missingTerminalErr()
}
if ev.err != nil {
@@ -637,14 +750,14 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
}
lastDataAt = time.Now()
line := ev.line
- payload, ok := extractOpenAISSEDataLine(line)
+ frame, ok := parser.AddLine(line)
if !ok {
continue
}
- if strings.TrimSpace(payload) == "[DONE]" {
+ if strings.TrimSpace(frame.Data) == "[DONE]" {
return missingTerminalErr()
}
- if processDataLine(payload) {
+ if processFrame(frame) {
return finalizeStream()
}
@@ -667,10 +780,14 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse(
if clientDisconnected {
continue
}
+ if refusalDetector.Enabled() && !clientOutputStarted {
+ continue
+ }
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
// Send SSE comment as keepalive
+ writeStreamHeaders()
if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil {
logger.L().Info("openai chat_completions stream: client disconnected during keepalive",
zap.String("request_id", requestID),
diff --git a/backend/internal/service/openai_gateway_chat_completions_raw.go b/backend/internal/service/openai_gateway_chat_completions_raw.go
index 3be765a2..c585290e 100644
--- a/backend/internal/service/openai_gateway_chat_completions_raw.go
+++ b/backend/internal/service/openai_gateway_chat_completions_raw.go
@@ -220,7 +220,7 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
// 8. Forward response
if clientStream {
- return s.streamRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
+ return s.streamRawChatCompletions(c, resp, account, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime, len(body))
}
return s.bufferRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime)
}
@@ -234,23 +234,32 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions(
func (s *OpenAIGatewayService) streamRawChatCompletions(
c *gin.Context,
resp *http.Response,
+ account *Account,
originalModel string,
billingModel string,
upstreamModel string,
reasoningEffort *string,
serviceTier *string,
startTime time.Time,
+ requestBodyLen int,
) (*OpenAIForwardResult, error) {
requestID := resp.Header.Get("x-request-id")
- if s.responseHeaderFilter != nil {
- responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
+ headersWritten := false
+ writeStreamHeaders := func() {
+ if headersWritten {
+ return
+ }
+ headersWritten = true
+ if s.responseHeaderFilter != nil {
+ responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
+ }
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
+ c.Writer.Header().Set("Cache-Control", "no-cache")
+ c.Writer.Header().Set("Connection", "keep-alive")
+ c.Writer.Header().Set("X-Accel-Buffering", "no")
+ c.Writer.WriteHeader(http.StatusOK)
}
- c.Writer.Header().Set("Content-Type", "text/event-stream")
- c.Writer.Header().Set("Cache-Control", "no-cache")
- c.Writer.Header().Set("Connection", "keep-alive")
- c.Writer.Header().Set("X-Accel-Buffering", "no")
- c.Writer.WriteHeader(http.StatusOK)
scanner := bufio.NewScanner(resp.Body)
maxLineSize := defaultMaxLineSize
@@ -262,9 +271,45 @@ func (s *OpenAIGatewayService) streamRawChatCompletions(
var usage OpenAIUsage
var firstTokenMs *int
clientDisconnected := false
+ clientOutputStarted := false
+ pendingLines := make([]string, 0, 8)
+ refusalDetector := newOpenAIChatSilentRefusalDetector(requestBodyLen)
+
+ writeLine := func(line string) {
+ if clientDisconnected {
+ return
+ }
+ if !clientOutputStarted && !refusalDetector.ShouldReleaseClientOutput() {
+ pendingLines = append(pendingLines, line)
+ return
+ }
+ if !clientOutputStarted {
+ writeStreamHeaders()
+ for _, pending := range pendingLines {
+ if _, werr := c.Writer.WriteString(pending + "\n"); werr != nil {
+ clientDisconnected = true
+ logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing",
+ zap.Error(werr),
+ zap.String("request_id", requestID),
+ )
+ return
+ }
+ }
+ pendingLines = pendingLines[:0]
+ clientOutputStarted = true
+ }
+ if _, werr := c.Writer.WriteString(line + "\n"); werr != nil {
+ clientDisconnected = true
+ logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing",
+ zap.Error(werr),
+ zap.String("request_id", requestID),
+ )
+ }
+ }
for scanner.Scan() {
line := scanner.Text()
+ refusalDetector.ObserveSSELine(line)
if payload, ok := extractOpenAISSEDataLine(line); ok {
trimmedPayload := strings.TrimSpace(payload)
if trimmedPayload != "[DONE]" {
@@ -279,22 +324,14 @@ func (s *OpenAIGatewayService) streamRawChatCompletions(
}
}
- if !clientDisconnected {
- if _, werr := c.Writer.WriteString(line + "\n"); werr != nil {
- clientDisconnected = true
- logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing",
- zap.Error(werr),
- zap.String("request_id", requestID),
- )
- }
- }
+ writeLine(line)
if line == "" {
- if !clientDisconnected {
+ if !clientDisconnected && clientOutputStarted {
c.Writer.Flush()
}
continue
}
- if !clientDisconnected {
+ if !clientDisconnected && clientOutputStarted {
c.Writer.Flush()
}
}
@@ -306,6 +343,27 @@ func (s *OpenAIGatewayService) streamRawChatCompletions(
zap.String("request_id", requestID),
)
}
+ } else if !clientDisconnected && !clientOutputStarted {
+ if refusalDetector.IsSilentRefusal() {
+ return nil, newOpenAISilentRefusalFailoverError(c, account, requestID)
+ }
+ if len(pendingLines) > 0 {
+ writeStreamHeaders()
+ for _, pending := range pendingLines {
+ if _, werr := c.Writer.WriteString(pending + "\n"); werr != nil {
+ clientDisconnected = true
+ logger.L().Debug("openai chat_completions raw: client disconnected during final flush",
+ zap.Error(werr),
+ zap.String("request_id", requestID),
+ )
+ break
+ }
+ }
+ if !clientDisconnected {
+ c.Writer.Flush()
+ clientOutputStarted = true
+ }
+ }
}
return &OpenAIForwardResult{
@@ -422,16 +480,10 @@ func (s *OpenAIGatewayService) bufferRawChatCompletions(
//
// - base 已是 /chat/completions:原样返回
// - base 以 /v1 结尾:追加 /chat/completions
+// - base 以其他版本段结尾(如 /v4):追加 /chat/completions
// - 其他情况:追加 /v1/chat/completions
//
// 与 buildOpenAIResponsesURL 是姐妹函数。
func buildOpenAIChatCompletionsURL(base string) string {
- normalized := strings.TrimRight(strings.TrimSpace(base), "/")
- if strings.HasSuffix(normalized, "/chat/completions") {
- return normalized
- }
- if strings.HasSuffix(normalized, "/v1") {
- return normalized + "/chat/completions"
- }
- return normalized + "/v1/chat/completions"
+ return buildOpenAIEndpointURL(base, "/v1/chat/completions")
}
diff --git a/backend/internal/service/openai_gateway_chat_completions_raw_test.go b/backend/internal/service/openai_gateway_chat_completions_raw_test.go
index 1be07fd7..64449636 100644
--- a/backend/internal/service/openai_gateway_chat_completions_raw_test.go
+++ b/backend/internal/service/openai_gateway_chat_completions_raw_test.go
@@ -5,6 +5,7 @@ package service
import (
"bytes"
"context"
+ "errors"
"io"
"net/http"
"net/http/httptest"
@@ -36,6 +37,7 @@ func TestBuildOpenAIChatCompletionsURL(t *testing.T) {
// 第三方上游常见形式
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/chat/completions"},
{"third-party with path prefix", "https://api.gptgod.online/api", "https://api.gptgod.online/api/v1/chat/completions"},
+ {"third-party versioned path", "https://open.bigmodel.cn/api/paas/v4", "https://open.bigmodel.cn/api/paas/v4/chat/completions"},
// 带空白字符
{"whitespace trimmed", " https://api.openai.com/v1 ", "https://api.openai.com/v1/chat/completions"},
}
@@ -64,6 +66,7 @@ func TestBuildOpenAIResponsesURL_ProbeURL(t *testing.T) {
{"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/responses"},
{"already /responses", "https://api.openai.com/v1/responses", "https://api.openai.com/v1/responses"},
{"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/responses"},
+ {"third-party versioned path", "https://open.bigmodel.cn/api/paas/v4", "https://open.bigmodel.cn/api/paas/v4/responses"},
{"only domain, no scheme", "api.gptgod.online", "api.gptgod.online/v1/responses"},
}
@@ -118,6 +121,259 @@ func TestForwardAsRawChatCompletions_ForcesStreamUsageUpstreamAndPassesUsageDown
require.Contains(t, rec.Body.String(), "data: [DONE]")
}
+func TestForwardAsRawChatCompletions_PreservesDeepSeekReasoningContentNonStreaming(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ body := []byte(`{"model":"deepseek-reasoner","messages":[{"role":"user","content":"hello"}],"stream":false}`)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ upstreamJSON := `{"id":"chatcmpl_reasoning","object":"chat.completion","model":"deepseek-reasoner","choices":[{"index":0,"message":{"role":"assistant","reasoning_content":"think first","content":"final answer"},"finish_reason":"stop"}],"usage":{"prompt_tokens":3,"completion_tokens":5,"total_tokens":8}}`
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_deepseek_reasoning_json"}},
+ Body: io.NopCloser(strings.NewReader(upstreamJSON)),
+ }}
+
+ svc := &OpenAIGatewayService{
+ cfg: rawChatCompletionsTestConfig(),
+ httpUpstream: upstream,
+ }
+ account := rawChatCompletionsTestAccount()
+
+ result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, 3, result.Usage.InputTokens)
+ require.Equal(t, 5, result.Usage.OutputTokens)
+ require.Equal(t, "think first", gjson.Get(rec.Body.String(), "choices.0.message.reasoning_content").String())
+ require.Equal(t, "final answer", gjson.Get(rec.Body.String(), "choices.0.message.content").String())
+}
+
+func TestForwardAsRawChatCompletions_PreservesDeepSeekReasoningContentStreaming(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ body := []byte(`{"model":"deepseek-reasoner","messages":[{"role":"user","content":"hello"}],"stream":true}`)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ upstreamBody := strings.Join([]string{
+ `data: {"id":"chatcmpl_reasoning","object":"chat.completion.chunk","model":"deepseek-reasoner","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}`,
+ "",
+ `data: {"id":"chatcmpl_reasoning","object":"chat.completion.chunk","model":"deepseek-reasoner","choices":[{"index":0,"delta":{"reasoning_content":"think first"},"finish_reason":null}]}`,
+ "",
+ `data: {"id":"chatcmpl_reasoning","object":"chat.completion.chunk","model":"deepseek-reasoner","choices":[{"index":0,"delta":{"content":"final answer"},"finish_reason":null}]}`,
+ "",
+ `data: {"id":"chatcmpl_reasoning","object":"chat.completion.chunk","model":"deepseek-reasoner","choices":[],"usage":{"prompt_tokens":3,"completion_tokens":5,"total_tokens":8}}`,
+ "",
+ "data: [DONE]",
+ "",
+ }, "\n")
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_deepseek_reasoning_stream"}},
+ Body: io.NopCloser(strings.NewReader(upstreamBody)),
+ }}
+
+ svc := &OpenAIGatewayService{
+ cfg: rawChatCompletionsTestConfig(),
+ httpUpstream: upstream,
+ }
+ account := rawChatCompletionsTestAccount()
+
+ result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, 3, result.Usage.InputTokens)
+ require.Equal(t, 5, result.Usage.OutputTokens)
+ require.Contains(t, rec.Body.String(), `"reasoning_content":"think first"`)
+ require.Contains(t, rec.Body.String(), `"content":"final answer"`)
+ require.Contains(t, rec.Body.String(), "data: [DONE]")
+}
+
+func TestForwardAsRawChatCompletions_PreservesDeepSeekReasoningContentInRequest(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ body := []byte(`{"model":"deepseek-v4-pro","messages":[{"role":"user","content":"weather"},{"role":"assistant","reasoning_content":"need tool","content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{}"}}]},{"role":"tool","tool_call_id":"call_1","content":"cloudy"}],"stream":false}`)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_deepseek_reasoning_request"}},
+ Body: io.NopCloser(strings.NewReader(`{"id":"chatcmpl_request","object":"chat.completion","model":"deepseek-v4-pro","choices":[{"index":0,"message":{"role":"assistant","content":"done"},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"completion_tokens":2,"total_tokens":6}}`)),
+ }}
+
+ svc := &OpenAIGatewayService{
+ cfg: rawChatCompletionsTestConfig(),
+ httpUpstream: upstream,
+ }
+ account := rawChatCompletionsTestAccount()
+
+ result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, "need tool", gjson.GetBytes(upstream.lastBody, "messages.1.reasoning_content").String())
+ require.Equal(t, "get_weather", gjson.GetBytes(upstream.lastBody, "messages.1.tool_calls.0.function.name").String())
+}
+
+func TestForwardAsRawChatCompletions_SilentRefusalTriggersFailover(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ body := largeRawChatCompletionsBody()
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ upstreamBody := strings.Join([]string{
+ `data: {"id":"chatcmpl_silent","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"role":"assistant"}}]}`,
+ "",
+ `data: {"id":"chatcmpl_silent","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}]}`,
+ "",
+ "data: [DONE]",
+ "",
+ }, "\n")
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_silent"}},
+ Body: io.NopCloser(strings.NewReader(upstreamBody)),
+ }}
+
+ svc := &OpenAIGatewayService{
+ cfg: rawChatCompletionsTestConfig(),
+ httpUpstream: upstream,
+ }
+
+ result, err := svc.forwardAsRawChatCompletions(context.Background(), c, rawChatCompletionsTestAccount(), body, "")
+ require.Nil(t, result)
+ var failoverErr *UpstreamFailoverError
+ require.True(t, errors.As(err, &failoverErr))
+ require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
+ require.True(t, IsOpenAISilentRefusalErrorBody(failoverErr.ResponseBody))
+ require.False(t, c.Writer.Written(), "silent refusal must not commit a 200 response before failover")
+ require.Empty(t, rec.Body.String())
+}
+
+func TestForwardAsRawChatCompletions_SilentRefusalToolCallsExempt(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ body := largeRawChatCompletionsBody()
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ upstreamBody := strings.Join([]string{
+ `data: {"id":"chatcmpl_tool","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"role":"assistant"}}]}`,
+ "",
+ `data: {"id":"chatcmpl_tool","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"lookup","arguments":""}}]}}]}`,
+ "",
+ `data: {"id":"chatcmpl_tool","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
+ "",
+ "data: [DONE]",
+ "",
+ }, "\n")
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_tool"}},
+ Body: io.NopCloser(strings.NewReader(upstreamBody)),
+ }}
+
+ svc := &OpenAIGatewayService{
+ cfg: rawChatCompletionsTestConfig(),
+ httpUpstream: upstream,
+ }
+
+ result, err := svc.forwardAsRawChatCompletions(context.Background(), c, rawChatCompletionsTestAccount(), body, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Contains(t, rec.Body.String(), `"tool_calls"`)
+ require.Contains(t, rec.Body.String(), `"finish_reason":"tool_calls"`)
+}
+
+func TestHandleChatStreamingResponse_SilentRefusalReasoningSummaryExempt(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
+
+ upstreamBody := strings.Join([]string{
+ `data: {"type":"response.created","response":{"id":"resp_reasoning","model":"gpt-5.5"}}`,
+ "",
+ `data: {"type":"response.reasoning_summary_text.delta","delta":"thinking only"}`,
+ "",
+ `data: {"type":"response.completed","response":{"id":"resp_reasoning","model":"gpt-5.5","status":"completed"}}`,
+ "",
+ }, "\n")
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_reasoning"}},
+ Body: io.NopCloser(strings.NewReader(upstreamBody)),
+ }
+ svc := &OpenAIGatewayService{cfg: rawChatCompletionsTestConfig()}
+
+ result, err := svc.handleChatStreamingResponse(
+ resp,
+ c,
+ rawChatCompletionsTestAccount(),
+ "gpt-5.5",
+ "gpt-5.5",
+ "gpt-5.5",
+ false,
+ time.Now(),
+ openAISilentRefusalMinRequestBodyBytes,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Contains(t, rec.Body.String(), `"reasoning_content":"thinking only"`)
+ require.Contains(t, rec.Body.String(), "data: [DONE]")
+}
+
+func TestForwardAsRawChatCompletions_SilentRefusalNormalContentExempt(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ body := largeRawChatCompletionsBody()
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ upstreamBody := strings.Join([]string{
+ `data: {"id":"chatcmpl_ok","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"role":"assistant"}}]}`,
+ "",
+ `data: {"id":"chatcmpl_ok","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"content":"ok"}}]}`,
+ "",
+ `data: {"id":"chatcmpl_ok","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}]}`,
+ "",
+ "data: [DONE]",
+ "",
+ }, "\n")
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_ok"}},
+ Body: io.NopCloser(strings.NewReader(upstreamBody)),
+ }}
+
+ svc := &OpenAIGatewayService{
+ cfg: rawChatCompletionsTestConfig(),
+ httpUpstream: upstream,
+ }
+
+ result, err := svc.forwardAsRawChatCompletions(context.Background(), c, rawChatCompletionsTestAccount(), body, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Contains(t, rec.Body.String(), `"content":"ok"`)
+ require.Contains(t, rec.Body.String(), "data: [DONE]")
+}
+
func TestForwardAsRawChatCompletions_ClientDisconnectDrainsUsage(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -193,6 +449,49 @@ func TestForwardAsRawChatCompletions_UpstreamRequestIgnoresClientCancel(t *testi
require.NoError(t, upstream.lastReq.Context().Err())
}
+func TestForwardAsChatCompletions_UnknownResponsesSupportFallbackUsesVersionedChatURL(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ body := []byte(`{"model":"glm-4.5-air","messages":[{"role":"user","content":"hello"}],"stream":false}`)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ upstream := &httpUpstreamRecorder{responses: []*http.Response{
+ {
+ StatusCode: http.StatusNotFound,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"error":{"message":"not found"}}`)),
+ },
+ {
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_raw_fallback"}},
+ Body: io.NopCloser(strings.NewReader(
+ `{"id":"chatcmpl_1","object":"chat.completion","model":"glm-4.5-air","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`,
+ )),
+ },
+ }}
+
+ svc := &OpenAIGatewayService{
+ cfg: rawChatCompletionsTestConfig(),
+ httpUpstream: upstream,
+ }
+ account := rawChatCompletionsTestAccount()
+ account.Credentials["base_url"] = "https://open.bigmodel.cn/api/paas/v4"
+
+ result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Equal(t, 1, result.Usage.InputTokens)
+ require.Equal(t, 2, result.Usage.OutputTokens)
+ require.Len(t, upstream.requests, 2)
+ require.Equal(t, "https://open.bigmodel.cn/api/paas/v4/responses", upstream.requests[0].URL.String())
+ require.Equal(t, "https://open.bigmodel.cn/api/paas/v4/chat/completions", upstream.requests[1].URL.String())
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.Contains(t, rec.Body.String(), `"content":"ok"`)
+}
+
func TestIsOpenAIChatUsageOnlyStreamChunk(t *testing.T) {
t.Parallel()
@@ -258,3 +557,9 @@ func rawChatCompletionsTestAccount() *Account {
},
}
}
+
+func largeRawChatCompletionsBody() []byte {
+ return []byte(`{"model":"gpt-5.5","messages":[{"role":"user","content":"` +
+ strings.Repeat("x", openAISilentRefusalMinRequestBodyBytes) +
+ `"}],"stream":true}`)
+}
diff --git a/backend/internal/service/openai_gateway_chat_completions_test.go b/backend/internal/service/openai_gateway_chat_completions_test.go
index b0d1fa31..a26091a3 100644
--- a/backend/internal/service/openai_gateway_chat_completions_test.go
+++ b/backend/internal/service/openai_gateway_chat_completions_test.go
@@ -236,6 +236,120 @@ func TestForwardAsChatCompletions_TerminalUsageWithoutUpstreamCloseReturns(t *te
}
}
+func TestForwardAsChatCompletions_EventNamedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ upstreamBody := []byte(strings.Join([]string{
+ `event: response.created`,
+ `data: {"response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
+ ``,
+ `event: response.output_text.delta`,
+ `data: {"delta":"ok"}`,
+ ``,
+ `event: response.completed`,
+ `data: {"response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}`,
+ ``,
+ ``,
+ }, "\n"))
+ upstreamStream := newOpenAICompatBlockingReadCloser(upstreamBody)
+ defer func() {
+ require.NoError(t, upstreamStream.Close())
+ }()
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_event_named_terminal"}},
+ Body: upstreamStream,
+ }}
+
+ svc := &OpenAIGatewayService{httpUpstream: upstream}
+ account := &Account{
+ ID: 1,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "oauth-token",
+ "chatgpt_account_id": "chatgpt-acc",
+ },
+ }
+
+ type forwardResult struct {
+ result *OpenAIForwardResult
+ err error
+ }
+ resultCh := make(chan forwardResult, 1)
+ go func() {
+ result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
+ resultCh <- forwardResult{result: result, err: err}
+ }()
+
+ select {
+ case got := <-resultCh:
+ require.NoError(t, got.err)
+ require.NotNil(t, got.result)
+ require.Equal(t, 17, got.result.Usage.InputTokens)
+ require.Equal(t, 8, got.result.Usage.OutputTokens)
+ require.Equal(t, 6, got.result.Usage.CacheReadInputTokens)
+ require.Contains(t, rec.Body.String(), `"content":"ok"`)
+ case <-time.After(time.Second):
+ require.Fail(t, "ForwardAsChatCompletions should use SSE event names when data payloads omit type")
+ }
+}
+
+func TestForwardAsChatCompletions_EventTypeDoesNotLeakAcrossFrames(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"stream":true}`)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ upstreamBody := strings.Join([]string{
+ `event: response.created`,
+ `data: {"response":{"id":"resp_1","model":"gpt-5.4","status":"in_progress","output":[]}}`,
+ ``,
+ `data: {"type":"response.output_text.delta","delta":"ok"}`,
+ ``,
+ `event: response.completed`,
+ `data: {"response":{"id":"resp_1","object":"response","model":"gpt-5.4","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":17,"output_tokens":8,"total_tokens":25,"input_tokens_details":{"cached_tokens":6}}}}`,
+ ``,
+ `data: [DONE]`,
+ ``,
+ }, "\n")
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_chat_event_boundary"}},
+ Body: io.NopCloser(strings.NewReader(upstreamBody)),
+ }}
+
+ svc := &OpenAIGatewayService{httpUpstream: upstream}
+ account := &Account{
+ ID: 1,
+ Name: "openai-oauth",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "access_token": "oauth-token",
+ "chatgpt_account_id": "chatgpt-acc",
+ },
+ }
+
+ result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "gpt-5.1")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.Contains(t, rec.Body.String(), `"content":"ok"`)
+ require.Contains(t, rec.Body.String(), `data: [DONE]`)
+}
+
func TestForwardAsChatCompletions_BufferedTerminalWithoutUpstreamCloseReturns(t *testing.T) {
gin.SetMode(gin.TestMode)
diff --git a/backend/internal/service/openai_gateway_messages.go b/backend/internal/service/openai_gateway_messages.go
index aefa8fd2..6d74f7dd 100644
--- a/backend/internal/service/openai_gateway_messages.go
+++ b/backend/internal/service/openai_gateway_messages.go
@@ -560,10 +560,24 @@ func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal(
}()
defer close(done)
+ var parser openAICompatSSEFrameParser
for {
select {
case ev, ok := <-events:
if !ok {
+ if frame, ok := parser.Finish(); ok {
+ payload := openAICompatPayloadWithEventType(frame.Data, frame.EventType)
+ var event apicompat.ResponsesStreamEvent
+ if err := json.Unmarshal([]byte(payload), &event); err == nil {
+ acc.ProcessEvent(&event)
+ if isOpenAICompatResponsesTerminalEvent(event.Type) && event.Response != nil {
+ if event.Response.Usage != nil {
+ usage = copyOpenAIUsageFromResponsesUsage(event.Response.Usage)
+ }
+ return event.Response, usage, acc, nil
+ }
+ }
+ }
return nil, usage, acc, nil
}
resetTimeout()
@@ -580,10 +594,11 @@ func (s *OpenAIGatewayService) readOpenAICompatBufferedTerminal(
if isOpenAICompatDoneSentinelLine(ev.line) {
return nil, usage, acc, nil
}
- payload, ok := extractOpenAISSEDataLine(ev.line)
- if !ok || payload == "" {
+ frame, ok := parser.AddLine(ev.line)
+ if !ok {
continue
}
+ payload := openAICompatPayloadWithEventType(frame.Data, frame.EventType)
var event apicompat.ResponsesStreamEvent
if err := json.Unmarshal([]byte(payload), &event); err != nil {
@@ -772,6 +787,10 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
missingTerminalErr := func() (*OpenAIForwardResult, error) {
return resultWithUsage(), fmt.Errorf("stream usage incomplete: missing terminal event")
}
+ processFrame := func(frame openAICompatSSEFrame) bool {
+ payload := openAICompatPayloadWithEventType(frame.Data, frame.EventType)
+ return processDataLine(payload)
+ }
// ── Determine keepalive interval ──
keepaliveInterval := time.Duration(0)
@@ -781,16 +800,17 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
// ── No keepalive: fast synchronous path (no goroutine overhead) ──
if streamInterval <= 0 && keepaliveInterval <= 0 {
+ var parser openAICompatSSEFrameParser
for scanner.Scan() {
line := scanner.Text()
if isOpenAICompatDoneSentinelLine(line) {
return missingTerminalErr()
}
- payload, ok := extractOpenAISSEDataLine(line)
+ frame, ok := parser.AddLine(line)
if !ok {
continue
}
- if processDataLine(payload) {
+ if processFrame(frame) {
return finalizeStream()
}
}
@@ -798,6 +818,14 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
handleScanErr(err)
return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err)
}
+ if frame, ok := parser.Finish(); ok {
+ if strings.TrimSpace(frame.Data) == "[DONE]" {
+ return missingTerminalErr()
+ }
+ if processFrame(frame) {
+ return finalizeStream()
+ }
+ }
return missingTerminalErr()
}
@@ -842,12 +870,21 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
keepaliveCh = keepaliveTicker.C
}
lastDataAt := time.Now()
+ var parser openAICompatSSEFrameParser
for {
select {
case ev, ok := <-events:
if !ok {
// Upstream closed
+ if frame, ok := parser.Finish(); ok {
+ if strings.TrimSpace(frame.Data) == "[DONE]" {
+ return missingTerminalErr()
+ }
+ if processFrame(frame) {
+ return finalizeStream()
+ }
+ }
return missingTerminalErr()
}
if ev.err != nil {
@@ -859,11 +896,11 @@ func (s *OpenAIGatewayService) handleAnthropicStreamingResponse(
if isOpenAICompatDoneSentinelLine(line) {
return missingTerminalErr()
}
- payload, ok := extractOpenAISSEDataLine(line)
+ frame, ok := parser.AddLine(line)
if !ok {
continue
}
- if processDataLine(payload) {
+ if processFrame(frame) {
return finalizeStream()
}
diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go
index cf909ec9..096f5b10 100644
--- a/backend/internal/service/openai_gateway_record_usage_test.go
+++ b/backend/internal/service/openai_gateway_record_usage_test.go
@@ -1320,6 +1320,93 @@ func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing.
require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
}
+func TestOpenAIGatewayServiceRecordUsage_EmptyImageSizeDefaultsBeforeBillingAndPersistence(t *testing.T) {
+ imagePrice2K := 0.31
+ groupID := int64(1201)
+ usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_image_default_size",
+ Model: "gpt-image-2",
+ ImageCount: 2,
+ ImageSize: "",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 11201,
+ GroupID: i64p(groupID),
+ Group: &Group{
+ ID: groupID,
+ RateMultiplier: 1.0,
+ ImagePrice2K: &imagePrice2K,
+ },
+ },
+ User: &User{ID: 21201},
+ Account: &Account{ID: 31201},
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, usageRepo.lastLog)
+ require.Equal(t, 2, usageRepo.lastLog.ImageCount)
+ require.NotNil(t, usageRepo.lastLog.ImageSize)
+ require.Equal(t, ImageBillingSize2K, *usageRepo.lastLog.ImageSize)
+ require.NotNil(t, usageRepo.lastLog.ImageSizeSource)
+ require.Equal(t, ImageSizeSourceDefault, *usageRepo.lastLog.ImageSizeSource)
+ require.Nil(t, usageRepo.lastLog.ImageInputSize)
+ require.Nil(t, usageRepo.lastLog.ImageOutputSize)
+ require.InDelta(t, 0.62, usageRepo.lastLog.TotalCost, 1e-12)
+ require.InDelta(t, 0.62, usageRepo.lastLog.ActualCost, 1e-12)
+ require.NotNil(t, usageRepo.lastLog.BillingMode)
+ require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_OutputImageSizeWinsBeforeBillingAndPersistence(t *testing.T) {
+ imagePrice1K := 0.11
+ imagePrice4K := 0.44
+ groupID := int64(1202)
+ usageRepo := &openAIRecordUsageLogRepoStub{inserted: true}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_image_output_size",
+ Model: "gpt-image-2",
+ ImageCount: 1,
+ ImageInputSize: "1024x1024",
+ ImageOutputSizes: []string{"3840x2160"},
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 11202,
+ GroupID: i64p(groupID),
+ Group: &Group{
+ ID: groupID,
+ RateMultiplier: 1.0,
+ ImagePrice1K: &imagePrice1K,
+ ImagePrice4K: &imagePrice4K,
+ },
+ },
+ User: &User{ID: 21202},
+ Account: &Account{ID: 31202},
+ })
+
+ require.NoError(t, err)
+ require.NotNil(t, usageRepo.lastLog)
+ require.NotNil(t, usageRepo.lastLog.ImageSize)
+ require.Equal(t, ImageBillingSize4K, *usageRepo.lastLog.ImageSize)
+ require.NotNil(t, usageRepo.lastLog.ImageInputSize)
+ require.Equal(t, "1024x1024", *usageRepo.lastLog.ImageInputSize)
+ require.NotNil(t, usageRepo.lastLog.ImageOutputSize)
+ require.Equal(t, "3840x2160", *usageRepo.lastLog.ImageOutputSize)
+ require.NotNil(t, usageRepo.lastLog.ImageSizeSource)
+ require.Equal(t, ImageSizeSourceOutput, *usageRepo.lastLog.ImageSizeSource)
+ require.Equal(t, map[string]int{ImageBillingSize4K: 1}, usageRepo.lastLog.ImageSizeBreakdown)
+ require.InDelta(t, 0.44, usageRepo.lastLog.TotalCost, 1e-12)
+ require.InDelta(t, 0.44, usageRepo.lastLog.ActualCost, 1e-12)
+}
+
func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTokens(t *testing.T) {
imagePrice := 0.02
groupID := int64(12)
@@ -1641,3 +1728,42 @@ func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingUsesSizeTier(
require.InDelta(t, 0.80, cost.TotalCost, 1e-12)
require.InDelta(t, 0.80, cost.ActualCost, 1e-12)
}
+
+func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingNormalizesMissingSizeTier(t *testing.T) {
+ groupID := int64(128)
+ defaultPrice := 0.10
+ price2K := 0.22
+ cache := newEmptyChannelCache()
+ cache.pricingByGroupModel[channelModelKey{groupID: groupID, model: "gemini-image"}] = &ChannelModelPricing{
+ BillingMode: BillingModeImage,
+ PerRequestPrice: &defaultPrice,
+ Intervals: []PricingInterval{{
+ TierLabel: "2K",
+ PerRequestPrice: &price2K,
+ }},
+ }
+ cache.channelByGroupID[groupID] = &Channel{ID: groupID, Status: StatusActive}
+ cache.loadedAt = time.Now()
+ channelService := &ChannelService{}
+ channelService.cache.Store(cache)
+
+ svc := &GatewayService{
+ billingService: NewBillingService(&config.Config{}, nil),
+ resolver: NewModelPricingResolver(channelService, NewBillingService(&config.Config{}, nil)),
+ }
+
+ cost := svc.calculateRecordUsageCost(
+ context.Background(),
+ &ForwardResult{Model: "gemini-image", ImageCount: 2, ImageSize: ""},
+ &APIKey{GroupID: i64p(groupID), Group: &Group{ID: groupID}},
+ "gemini-image",
+ 1.0,
+ 1.0,
+ nil,
+ )
+
+ require.NotNil(t, cost)
+ require.Equal(t, string(BillingModeImage), cost.BillingMode)
+ require.InDelta(t, 0.44, cost.TotalCost, 1e-12)
+ require.InDelta(t, 0.44, cost.ActualCost, 1e-12)
+}
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index e12b208e..cfaf5bff 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -228,14 +228,19 @@ type OpenAIForwardResult struct {
ServiceTier *string
// ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix.
// Stored for usage records display; nil means not provided / not applicable.
- ReasoningEffort *string
- Stream bool
- OpenAIWSMode bool
- ResponseHeaders http.Header
- Duration time.Duration
- FirstTokenMs *int
- ImageCount int
- ImageSize string
+ ReasoningEffort *string
+ Stream bool
+ OpenAIWSMode bool
+ ResponseHeaders http.Header
+ Duration time.Duration
+ FirstTokenMs *int
+ ImageCount int
+ ImageSize string
+ ImageInputSize string
+ ImageOutputSize string
+ ImageOutputSizes []string
+ ImageSizeSource string
+ ImageSizeBreakdown map[string]int
}
type OpenAIWSRetryMetricsSnapshot struct {
@@ -1113,6 +1118,9 @@ func isOpenAITransientProcessingError(upstreamStatusCode int, upstreamMsg string
if strings.Contains(lower, "an error occurred while processing your request") {
return true
}
+ if strings.Contains(lower, "selected model is at capacity") {
+ return true
+ }
return strings.Contains(lower, "you can retry your request") &&
strings.Contains(lower, "help.openai.com") &&
strings.Contains(lower, "request id")
@@ -2416,9 +2424,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
imageBillingModel := ""
imageSizeTier := ""
+ imageInputSize := ""
if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) {
var imageCfgErr error
- imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfig(reqBody, billingModel)
+ imageCfg, imageCfgErr := resolveOpenAIResponsesImageBillingConfigDetailed(reqBody, billingModel)
if imageCfgErr != nil {
setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "")
c.JSON(http.StatusBadRequest, gin.H{
@@ -2430,6 +2439,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
})
return nil, imageCfgErr
}
+ imageBillingModel = imageCfg.Model
+ imageSizeTier = imageCfg.SizeTier
+ imageInputSize = imageCfg.InputSize
}
// Re-serialize body only if modified
@@ -2461,9 +2473,6 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
return nil, err
}
- // Capture upstream request body for ops retry of this attempt.
- setOpsUpstreamRequestBody(c, body)
-
// 命中 WS 时仅走 WebSocket Mode;不再自动回退 HTTP。
if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 {
wsReqBody := reqBody
@@ -2671,6 +2680,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
wsResult.UpstreamModel = upstreamModel
if wsResult.ImageCount > 0 {
wsResult.ImageSize = imageSizeTier
+ wsResult.ImageInputSize = imageInputSize
wsResult.BillingModel = imageBillingModel
}
return wsResult, nil
@@ -2735,7 +2745,6 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
if err != nil {
return nil, fmt.Errorf("serialize invalid_encrypted_content retry body: %w", err)
}
- setOpsUpstreamRequestBody(c, body)
httpInvalidEncryptedContentRetryTried = true
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Retrying non-WSv2 request once after invalid_encrypted_content (account: %s)", account.Name)
continue
@@ -2773,10 +2782,15 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
defer func() { _ = resp.Body.Close() }()
+ reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
+ serviceTier := extractOpenAIServiceTier(reqBody)
+ releaseOpenAIParsedRequestBody(c)
+
// Handle normal response
var usage *OpenAIUsage
var firstTokenMs *int
imageCount := 0
+ var imageOutputSizes []string
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel)
if err != nil {
@@ -2785,6 +2799,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
imageCount = streamResult.imageCount
+ imageOutputSizes = streamResult.imageOutputSizes
} else {
nonStreamResult, err := s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel)
if err != nil {
@@ -2792,6 +2807,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
usage = nonStreamResult.usage
imageCount = nonStreamResult.imageCount
+ imageOutputSizes = nonStreamResult.imageOutputSizes
}
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
@@ -2805,9 +2821,6 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
usage = &OpenAIUsage{}
}
- reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel)
- serviceTier := extractOpenAIServiceTier(reqBody)
-
forwardResult := &OpenAIForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
@@ -2823,6 +2836,8 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
if imageCount > 0 {
forwardResult.ImageCount = imageCount
forwardResult.ImageSize = imageSizeTier
+ forwardResult.ImageInputSize = imageInputSize
+ forwardResult.ImageOutputSizes = imageOutputSizes
forwardResult.BillingModel = imageBillingModel
}
return forwardResult, nil
@@ -2927,9 +2942,10 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
}
imageBillingModel := ""
imageSizeTier := ""
+ imageInputSize := ""
if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) {
var imageCfgErr error
- imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(body, reqModel)
+ imageCfg, imageCfgErr := resolveOpenAIResponsesImageBillingConfigDetailedFromBody(body, reqModel)
if imageCfgErr != nil {
setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "")
c.JSON(http.StatusBadRequest, gin.H{
@@ -2941,6 +2957,9 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
})
return nil, imageCfgErr
}
+ imageBillingModel = imageCfg.Model
+ imageSizeTier = imageCfg.SizeTier
+ imageInputSize = imageCfg.InputSize
}
logger.LegacyPrintf("service.openai_gateway",
@@ -2984,7 +3003,6 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
proxyURL = account.Proxy.URL()
}
- setOpsUpstreamRequestBody(c, body)
if c != nil {
c.Set("openai_passthrough", true)
}
@@ -3023,9 +3041,12 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
return nil, s.handleErrorResponsePassthrough(ctx, resp, c, account, body)
}
+ serviceTier := extractOpenAIServiceTierFromBody(body)
+
var usage *OpenAIUsage
var firstTokenMs *int
imageCount := 0
+ var imageOutputSizes []string
if reqStream {
result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime, reqModel, upstreamPassthroughModel)
if err != nil {
@@ -3034,6 +3055,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
usage = result.usage
firstTokenMs = result.firstTokenMs
imageCount = result.imageCount
+ imageOutputSizes = result.imageOutputSizes
} else {
result, err := s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel)
if err != nil {
@@ -3041,6 +3063,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
}
usage = result.usage
imageCount = result.imageCount
+ imageOutputSizes = result.imageOutputSizes
}
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
@@ -3056,7 +3079,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
Usage: *usage,
Model: reqModel,
UpstreamModel: upstreamPassthroughModel,
- ServiceTier: extractOpenAIServiceTierFromBody(body),
+ ServiceTier: serviceTier,
ReasoningEffort: reasoningEffort,
Stream: reqStream,
OpenAIWSMode: false,
@@ -3066,6 +3089,8 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
if imageCount > 0 {
forwardResult.ImageCount = imageCount
forwardResult.ImageSize = imageSizeTier
+ forwardResult.ImageInputSize = imageInputSize
+ forwardResult.ImageOutputSizes = imageOutputSizes
forwardResult.BillingModel = imageBillingModel
}
return forwardResult, nil
@@ -3361,15 +3386,17 @@ func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string {
}
type openaiStreamingResultPassthrough struct {
- usage *OpenAIUsage
- firstTokenMs *int
- imageCount int
+ usage *OpenAIUsage
+ firstTokenMs *int
+ imageCount int
+ imageOutputSizes []string
}
type openaiNonStreamingResultPassthrough struct {
*OpenAIUsage
- usage *OpenAIUsage
- imageCount int
+ usage *OpenAIUsage
+ imageCount int
+ imageOutputSizes []string
}
func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool {
@@ -3400,6 +3427,9 @@ func openAIStreamDataStartsClientOutput(data, eventType string) bool {
}
func openAIStreamFailedEventShouldFailover(payload []byte, message string) bool {
+ if isOpenAITransientProcessingError(http.StatusBadRequest, message, payload) {
+ return true
+ }
code := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.code").String()))
if code == "" {
code = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.code").String()))
@@ -3539,7 +3569,12 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
needModelReplace := strings.TrimSpace(originalModel) != "" && strings.TrimSpace(mappedModel) != "" && strings.TrimSpace(originalModel) != strings.TrimSpace(mappedModel)
resultWithUsage := func() *openaiStreamingResultPassthrough {
- return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()}
+ return &openaiStreamingResultPassthrough{
+ usage: usage,
+ firstTokenMs: firstTokenMs,
+ imageCount: imageCounter.Count(),
+ imageOutputSizes: imageCounter.Sizes(),
+ }
}
for scanner.Scan() {
@@ -3696,9 +3731,10 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
}
c.Data(resp.StatusCode, contentType, body)
return &openaiNonStreamingResultPassthrough{
- OpenAIUsage: usage,
- usage: usage,
- imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body),
+ OpenAIUsage: usage,
+ usage: usage,
+ imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body),
+ imageOutputSizes: collectOpenAIResponseImageOutputSizesFromJSONBytes(body),
}, nil
}
@@ -3758,9 +3794,10 @@ func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c
c.Data(resp.StatusCode, contentType, body)
return &openaiNonStreamingResultPassthrough{
- OpenAIUsage: usage,
- usage: usage,
- imageCount: countOpenAIImageOutputsFromSSEBody(bodyText),
+ OpenAIUsage: usage,
+ usage: usage,
+ imageCount: countOpenAIImageOutputsFromSSEBody(bodyText),
+ imageOutputSizes: collectOpenAIImageOutputSizesFromSSEBody(bodyText),
}, nil
}
@@ -4182,15 +4219,17 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse(
// openaiStreamingResult streaming response result
type openaiStreamingResult struct {
- usage *OpenAIUsage
- firstTokenMs *int
- imageCount int
+ usage *OpenAIUsage
+ firstTokenMs *int
+ imageCount int
+ imageOutputSizes []string
}
type openaiNonStreamingResult struct {
*OpenAIUsage
- usage *OpenAIUsage
- imageCount int
+ usage *OpenAIUsage
+ imageCount int
+ imageOutputSizes []string
}
func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) {
@@ -4303,7 +4342,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
needModelReplace := originalModel != mappedModel
resultWithUsage := func() *openaiStreamingResult {
- return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()}
+ return &openaiStreamingResult{
+ usage: usage,
+ firstTokenMs: firstTokenMs,
+ imageCount: imageCounter.Count(),
+ imageOutputSizes: imageCounter.Sizes(),
+ }
}
finalizeStream := func() (*openaiStreamingResult, error) {
if !sawTerminalEvent {
@@ -4578,6 +4622,76 @@ func extractOpenAISSEDataLine(line string) (string, bool) {
return line[start:], true
}
+func extractOpenAISSEEventLine(line string) (string, bool) {
+ if !strings.HasPrefix(line, "event:") {
+ return "", false
+ }
+ start := len("event:")
+ for start < len(line) {
+ if line[start] != ' ' && line[start] != ' ' {
+ break
+ }
+ start++
+ }
+ return strings.TrimSpace(line[start:]), true
+}
+
+type openAICompatSSEFrame struct {
+ EventType string
+ Data string
+}
+
+type openAICompatSSEFrameParser struct {
+ eventType string
+ dataLines []string
+}
+
+func (p *openAICompatSSEFrameParser) AddLine(line string) (openAICompatSSEFrame, bool) {
+ if line == "" {
+ return p.dispatch()
+ }
+ if strings.HasPrefix(line, ":") {
+ return openAICompatSSEFrame{}, false
+ }
+ if eventType, ok := extractOpenAISSEEventLine(line); ok {
+ p.eventType = eventType
+ return openAICompatSSEFrame{}, false
+ }
+ if data, ok := extractOpenAISSEDataLine(line); ok {
+ p.dataLines = append(p.dataLines, data)
+ }
+ return openAICompatSSEFrame{}, false
+}
+
+func (p *openAICompatSSEFrameParser) Finish() (openAICompatSSEFrame, bool) {
+ return p.dispatch()
+}
+
+func (p *openAICompatSSEFrameParser) dispatch() (openAICompatSSEFrame, bool) {
+ frame := openAICompatSSEFrame{
+ EventType: p.eventType,
+ Data: strings.Join(p.dataLines, "\n"),
+ }
+ p.eventType = ""
+ p.dataLines = nil
+ return frame, frame.Data != ""
+}
+
+func openAICompatPayloadWithEventType(payload, eventType string) string {
+ eventType = strings.TrimSpace(eventType)
+ if eventType == "" || strings.TrimSpace(payload) == "" || strings.TrimSpace(payload) == "[DONE]" {
+ return payload
+ }
+ if gjson.Get(payload, "type").Exists() {
+ return payload
+ }
+ patched, err := sjson.Set(payload, "type", eventType)
+ if err != nil {
+ return payload
+ }
+ return patched
+}
+
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
data, ok := extractOpenAISSEDataLine(line)
if !ok {
@@ -4639,28 +4753,47 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag
return
}
- usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int())
- usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int())
- usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int())
- usage.ImageOutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens_details.image_tokens").Int())
+ if parsedUsage, ok := extractOpenAIUsageFromJSONBytes(data); ok {
+ *usage = parsedUsage
+ }
}
func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
if len(body) == 0 || !gjson.ValidBytes(body) {
return OpenAIUsage{}, false
}
- values := gjson.GetManyBytes(
- body,
- "usage.input_tokens",
- "usage.output_tokens",
- "usage.input_tokens_details.cached_tokens",
- "usage.output_tokens_details.image_tokens",
- )
+ if usage, ok := openAIUsageFromGJSON(gjson.GetBytes(body, "usage")); ok {
+ return usage, true
+ }
+ return openAIUsageFromGJSON(gjson.GetBytes(body, "response.usage"))
+}
+
+func openAIUsageFromGJSON(value gjson.Result) (OpenAIUsage, bool) {
+ if !value.Exists() || !value.IsObject() {
+ return OpenAIUsage{}, false
+ }
+ inputTokens := value.Get("input_tokens").Int()
+ if inputTokens == 0 {
+ inputTokens = value.Get("prompt_tokens").Int()
+ }
+ outputTokens := value.Get("output_tokens").Int()
+ if outputTokens == 0 {
+ outputTokens = value.Get("completion_tokens").Int()
+ }
+ cacheReadTokens := value.Get("input_tokens_details.cached_tokens").Int()
+ if cacheReadTokens == 0 {
+ cacheReadTokens = value.Get("prompt_tokens_details.cached_tokens").Int()
+ }
+ imageOutputTokens := value.Get("output_tokens_details.image_tokens").Int()
+ if imageOutputTokens == 0 {
+ imageOutputTokens = value.Get("completion_tokens_details.image_tokens").Int()
+ }
return OpenAIUsage{
- InputTokens: int(values[0].Int()),
- OutputTokens: int(values[1].Int()),
- CacheReadInputTokens: int(values[2].Int()),
- ImageOutputTokens: int(values[3].Int()),
+ InputTokens: int(inputTokens),
+ OutputTokens: int(outputTokens),
+ CacheCreationInputTokens: int(value.Get("cache_creation_input_tokens").Int()),
+ CacheReadInputTokens: int(cacheReadTokens),
+ ImageOutputTokens: int(imageOutputTokens),
}, true
}
@@ -4711,9 +4844,10 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
c.Data(resp.StatusCode, contentType, body)
return &openaiNonStreamingResult{
- OpenAIUsage: usage,
- usage: usage,
- imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body),
+ OpenAIUsage: usage,
+ usage: usage,
+ imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body),
+ imageOutputSizes: collectOpenAIResponseImageOutputSizesFromJSONBytes(body),
}, nil
}
@@ -4775,9 +4909,10 @@ func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Conte
c.Data(resp.StatusCode, contentType, body)
return &openaiNonStreamingResult{
- OpenAIUsage: usage,
- usage: usage,
- imageCount: countOpenAIImageOutputsFromSSEBody(bodyText),
+ OpenAIUsage: usage,
+ usage: usage,
+ imageCount: countOpenAIImageOutputsFromSSEBody(bodyText),
+ imageOutputSizes: collectOpenAIImageOutputSizesFromSSEBody(bodyText),
}, nil
}
@@ -4955,17 +5090,11 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro
// buildOpenAIResponsesURL 组装 OpenAI Responses 端点。
// - base 以 /v1 结尾:追加 /responses
+// - base 以其他版本段结尾(如 /v4):追加 /responses
// - base 已是 /responses:原样返回
// - 其他情况:追加 /v1/responses
func buildOpenAIResponsesURL(base string) string {
- normalized := strings.TrimRight(strings.TrimSpace(base), "/")
- if strings.HasSuffix(normalized, "/responses") {
- return normalized
- }
- if strings.HasSuffix(normalized, "/v1") {
- return normalized + "/responses"
- }
- return normalized + "/v1/responses"
+ return buildOpenAIEndpointURL(base, "/v1/responses")
}
func trimOpenAIEncryptedReasoningItems(reqBody map[string]any) bool {
@@ -5216,6 +5345,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
user := input.User
account := input.Account
subscription := input.Subscription
+ ApplyOpenAIImageBillingResolution(result)
// 计算实际的新输入token(减去缓存读取的token)
// 因为 input_tokens 包含了 cache_read_tokens,而缓存读取的token不应按输入价格计费
@@ -5325,6 +5455,10 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
ImageOutputTokens: result.Usage.ImageOutputTokens,
ImageCount: result.ImageCount,
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
+ ImageInputSize: optionalTrimmedStringPtr(result.ImageInputSize),
+ ImageOutputSize: optionalTrimmedStringPtr(result.ImageOutputSize),
+ ImageSizeSource: optionalTrimmedStringPtr(result.ImageSizeSource),
+ ImageSizeBreakdown: result.ImageSizeBreakdown,
}
if cost != nil {
usageLog.InputCost = cost.InputCost
@@ -5493,6 +5627,7 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost(
result *OpenAIForwardResult,
multiplier float64,
) *CostBreakdown {
+ sizeTier := NormalizeImageBillingTierOrDefault(result.ImageSize)
if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil &&
(resolved.Mode == BillingModePerRequest || resolved.Mode == BillingModeImage) {
gid := apiKey.Group.ID
@@ -5501,7 +5636,7 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost(
Model: billingModel,
GroupID: &gid,
RequestCount: result.ImageCount,
- SizeTier: result.ImageSize,
+ SizeTier: sizeTier,
RateMultiplier: multiplier,
Resolver: s.resolver,
Resolved: resolved,
@@ -5520,7 +5655,7 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost(
Price4K: apiKey.Group.ImagePrice4K,
}
}
- return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier)
+ return s.billingService.CalculateImageCost(billingModel, sizeTier, result.ImageCount, groupConfig, multiplier)
}
func (s *OpenAIGatewayService) resolveOpenAIChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing {
@@ -6098,7 +6233,7 @@ func writeOpenAIFastPolicyBlockedResponse(c *gin.Context, err *OpenAIFastBlocked
// applyOpenAIFastPolicyToBody contract but operates on a Realtime/Responses
// WS payload:
//
-// - pass: returns frame unchanged (newBytes == frame, blocked == nil)
+// - pass: keeps service_tier, normalizing aliases such as "fast" to "priority"
// - filter: returns a copy with top-level service_tier removed
// - block: returns (frame, *OpenAIFastBlockedError)
//
@@ -6162,7 +6297,14 @@ func (s *OpenAIGatewayService) applyOpenAIFastPolicyToWSResponseCreate(
}
return trimmed, nil, nil
default:
- return frame, nil, nil
+ if normTier == rawTier {
+ return frame, nil, nil
+ }
+ updated, err := sjson.SetBytes(frame, "service_tier", normTier)
+ if err != nil {
+ return frame, nil, fmt.Errorf("normalize service_tier in ws frame: %w", err)
+ }
+ return updated, nil, nil
}
}
@@ -6359,6 +6501,13 @@ func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error
return reqBody, nil
}
+func releaseOpenAIParsedRequestBody(c *gin.Context) {
+ if c == nil {
+ return
+ }
+ delete(c.Keys, OpenAIParsedRequestBodyKey)
+}
+
func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string {
if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present {
if value == "" {
diff --git a/backend/internal/service/openai_gateway_service_codex_cli_only_test.go b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go
index fe58e92f..951860cd 100644
--- a/backend/internal/service/openai_gateway_service_codex_cli_only_test.go
+++ b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go
@@ -218,6 +218,12 @@ func TestIsOpenAITransientProcessingError(t *testing.T) {
nil,
))
+ require.True(t, isOpenAITransientProcessingError(
+ http.StatusBadRequest,
+ "Selected model is at capacity. Please try a different model.",
+ []byte(`{"error":{"message":"Selected model is at capacity. Please try a different model.","type":"invalid_request_error"}}`),
+ ))
+
require.True(t, isOpenAITransientProcessingError(
http.StatusBadRequest,
"",
@@ -332,3 +338,55 @@ func TestOpenAIGatewayService_Forward_TransientProcessingErrorTriggersFailover(t
require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request")
require.False(t, c.Writer.Written(), "service 层应返回 failover 错误给上层换号,而不是直接向客户端写响应")
}
+
+func TestOpenAIGatewayService_Forward_ModelCapacityErrorTriggersFailoverAndSameAccountRetry(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
+ c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ upstream := &httpUpstreamRecorder{
+ resp: &http.Response{
+ StatusCode: http.StatusBadRequest,
+ Header: http.Header{
+ "Content-Type": []string{"application/json"},
+ "x-request-id": []string{"rid-capacity-400"},
+ },
+ Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Selected model is at capacity. Please try a different model.","type":"invalid_request_error"}}`)),
+ },
+ }
+ svc := &OpenAIGatewayService{
+ cfg: &config.Config{
+ Gateway: config.GatewayConfig{ForceCodexCLI: false},
+ },
+ httpUpstream: upstream,
+ }
+ account := &Account{
+ ID: 1001,
+ Name: "codex max套餐",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ "pool_mode": true,
+ },
+ Status: StatusActive,
+ Schedulable: true,
+ RateMultiplier: f64p(1),
+ }
+ body := []byte(`{"model":"gpt-5.4","stream":false,"input":[{"type":"text","text":"hello"}]}`)
+
+ _, err := svc.Forward(context.Background(), c, account, body)
+ require.Error(t, err)
+
+ var failoverErr *UpstreamFailoverError
+ require.ErrorAs(t, err, &failoverErr)
+ require.Equal(t, http.StatusBadRequest, failoverErr.StatusCode)
+ require.True(t, failoverErr.RetryableOnSameAccount)
+ require.Contains(t, string(failoverErr.ResponseBody), "Selected model is at capacity")
+ require.False(t, c.Writer.Written(), "service 层应返回 failover 错误给上层重试/换号,而不是直接向客户端写响应")
+}
diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go
index 24095f2b..0fac4508 100644
--- a/backend/internal/service/openai_gateway_service_test.go
+++ b/backend/internal/service/openai_gateway_service_test.go
@@ -1140,6 +1140,47 @@ func TestOpenAIStreamingResponseFailedBeforeOutputReturnsFailover(t *testing.T)
require.Empty(t, rec.Body.String())
}
+func TestOpenAIStreamingResponseFailedBeforeOutputCapacityErrorReturnsFailover(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ StreamDataIntervalTimeout: 0,
+ StreamKeepaliveInterval: 0,
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(strings.Join([]string{
+ "event: response.created",
+ `data: {"type":"response.created","response":{"id":"resp_1"}}`,
+ "",
+ "event: response.in_progress",
+ `data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
+ "",
+ "event: response.failed",
+ `data: {"type":"response.failed","error":{"message":"Selected model is at capacity. Please try a different model.","type":"invalid_request_error"}}`,
+ "",
+ }, "\n"))),
+ Header: http.Header{"X-Request-Id": []string{"rid-capacity-failed"}},
+ }
+
+ _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model")
+ require.Error(t, err)
+ var failoverErr *UpstreamFailoverError
+ require.ErrorAs(t, err, &failoverErr)
+ require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode)
+ require.Contains(t, string(failoverErr.ResponseBody), "Selected model is at capacity")
+ require.False(t, c.Writer.Written())
+ require.Empty(t, rec.Body.String())
+}
+
func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
@@ -2198,6 +2239,25 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
require.Equal(t, 13, usage.InputTokens)
require.Equal(t, 15, usage.OutputTokens)
require.Equal(t, 4, usage.CacheReadInputTokens)
+
+ svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"prompt_tokens":21,"completion_tokens":8,"prompt_tokens_details":{"cached_tokens":6}}}}`, usage)
+ require.Equal(t, 21, usage.InputTokens)
+ require.Equal(t, 8, usage.OutputTokens)
+ require.Equal(t, 6, usage.CacheReadInputTokens)
+}
+
+func TestExtractOpenAIUsageFromJSONBytes_AcceptsResponseAndChatUsageShapes(t *testing.T) {
+ usage, ok := extractOpenAIUsageFromJSONBytes([]byte(`{"id":"resp_1","usage":{"input_tokens":3,"output_tokens":5,"input_tokens_details":{"cached_tokens":2}}}`))
+ require.True(t, ok)
+ require.Equal(t, 3, usage.InputTokens)
+ require.Equal(t, 5, usage.OutputTokens)
+ require.Equal(t, 2, usage.CacheReadInputTokens)
+
+ usage, ok = extractOpenAIUsageFromJSONBytes([]byte(`{"type":"response.completed","response":{"usage":{"prompt_tokens":13,"completion_tokens":7,"prompt_tokens_details":{"cached_tokens":4}}}}`))
+ require.True(t, ok)
+ require.Equal(t, 13, usage.InputTokens)
+ require.Equal(t, 7, usage.OutputTokens)
+ require.Equal(t, 4, usage.CacheReadInputTokens)
}
func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) {
@@ -2317,3 +2377,29 @@ func TestHandleSSEToJSON_ResponseFailedReturnsProtocolError(t *testing.T) {
require.Contains(t, rec.Body.String(), "upstream rejected request")
require.Contains(t, rec.Header().Get("Content-Type"), "application/json")
}
+
+func TestOpenAICompatSSEFrameParserResetsEventTypeAtFrameBoundary(t *testing.T) {
+ var parser openAICompatSSEFrameParser
+
+ frame, ok := parser.AddLine("event: response.created")
+ require.False(t, ok)
+ require.Empty(t, frame)
+
+ frame, ok = parser.AddLine(`data: {"response":{"id":"resp_1"}}`)
+ require.False(t, ok)
+ require.Empty(t, frame)
+
+ frame, ok = parser.AddLine("")
+ require.True(t, ok)
+ require.Equal(t, "response.created", frame.EventType)
+ require.JSONEq(t, `{"response":{"id":"resp_1"}}`, frame.Data)
+
+ frame, ok = parser.AddLine(`data: {"delta":"ok"}`)
+ require.False(t, ok)
+ require.Empty(t, frame.EventType)
+
+ frame, ok = parser.AddLine("")
+ require.True(t, ok)
+ require.Empty(t, frame.EventType)
+ require.JSONEq(t, `{"delta":"ok"}`, frame.Data)
+}
diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go
index afa94156..95c054c9 100644
--- a/backend/internal/service/openai_images.go
+++ b/backend/internal/service/openai_images.go
@@ -532,54 +532,7 @@ func isOpenAINativeImageOption(name string) bool {
}
func normalizeOpenAIImageSizeTier(size string) string {
- trimmed := strings.TrimSpace(size)
- normalized := strings.ToLower(trimmed)
- switch normalized {
- case "", "auto":
- return "2K"
- case "1024x1024":
- return "1K"
- case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "2048x2048", "2048x1152", "1152x2048":
- return "2K"
- case "3840x2160", "2160x3840":
- return "4K"
- }
- width, height, ok := parseOpenAIImageSizeDimensions(trimmed)
- if !ok {
- return "2K"
- }
- return classifyUnknownOpenAIImageSizeTier(width, height)
-}
-
-const (
- openAIImage2KMaxPixels = 2560 * 1440
-)
-
-func parseOpenAIImageSizeDimensions(size string) (int, int, bool) {
- trimmed := strings.TrimSpace(size)
- parts := strings.Split(strings.ToLower(trimmed), "x")
- if len(parts) != 2 {
- return 0, 0, false
- }
- width, err := strconv.Atoi(strings.TrimSpace(parts[0]))
- if err != nil {
- return 0, 0, false
- }
- height, err := strconv.Atoi(strings.TrimSpace(parts[1]))
- if err != nil {
- return 0, 0, false
- }
- if width <= 0 || height <= 0 {
- return 0, 0, false
- }
- return width, height, true
-}
-
-func classifyUnknownOpenAIImageSizeTier(width int, height int) string {
- if height > 0 && width > openAIImage2KMaxPixels/height {
- return "4K"
- }
- return "2K"
+ return NormalizeImageBillingTierOrDefault(size)
}
func (s *OpenAIGatewayService) ForwardImages(
@@ -635,10 +588,6 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
if err != nil {
return nil, err
}
- if !parsed.Multipart {
- setOpsUpstreamRequestBody(c, forwardBody)
- }
-
upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream)
defer releaseUpstreamCtx()
@@ -704,29 +653,46 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
imageCount := parsed.N
var firstTokenMs *int
if parsed.Stream && isEventStreamResponse(resp.Header) {
- streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
+ streamUsage, streamCount, streamSizes, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime)
if err != nil {
if streamCount > 0 {
return &OpenAIForwardResult{
- RequestID: resp.Header.Get("x-request-id"),
- Usage: streamUsage,
- Model: requestModel,
- UpstreamModel: upstreamModel,
- Stream: parsed.Stream,
- ResponseHeaders: resp.Header.Clone(),
- Duration: time.Since(startTime),
- FirstTokenMs: ttft,
- ImageCount: streamCount,
- ImageSize: parsed.SizeTier,
+ RequestID: resp.Header.Get("x-request-id"),
+ Usage: streamUsage,
+ Model: requestModel,
+ UpstreamModel: upstreamModel,
+ Stream: parsed.Stream,
+ ResponseHeaders: resp.Header.Clone(),
+ Duration: time.Since(startTime),
+ FirstTokenMs: ttft,
+ ImageCount: streamCount,
+ ImageSize: parsed.SizeTier,
+ ImageInputSize: parsed.Size,
+ ImageOutputSizes: streamSizes,
}, err
}
return nil, err
}
usage = streamUsage
imageCount = streamCount
+ imageOutputSizes := streamSizes
firstTokenMs = ttft
+ return &OpenAIForwardResult{
+ RequestID: resp.Header.Get("x-request-id"),
+ Usage: usage,
+ Model: requestModel,
+ UpstreamModel: upstreamModel,
+ Stream: parsed.Stream,
+ ResponseHeaders: resp.Header.Clone(),
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ ImageCount: imageCount,
+ ImageSize: parsed.SizeTier,
+ ImageInputSize: parsed.Size,
+ ImageOutputSizes: imageOutputSizes,
+ }, nil
} else {
- nonStreamUsage, nonStreamCount, err := s.handleOpenAIImagesNonStreamingResponse(resp, c)
+ nonStreamUsage, nonStreamCount, nonStreamSizes, err := s.handleOpenAIImagesNonStreamingResponse(resp, c)
if err != nil {
return nil, err
}
@@ -734,19 +700,21 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey(
if nonStreamCount > 0 {
imageCount = nonStreamCount
}
+ return &OpenAIForwardResult{
+ RequestID: resp.Header.Get("x-request-id"),
+ Usage: usage,
+ Model: requestModel,
+ UpstreamModel: upstreamModel,
+ Stream: parsed.Stream,
+ ResponseHeaders: resp.Header.Clone(),
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ ImageCount: imageCount,
+ ImageSize: parsed.SizeTier,
+ ImageInputSize: parsed.Size,
+ ImageOutputSizes: nonStreamSizes,
+ }, nil
}
- return &OpenAIForwardResult{
- RequestID: resp.Header.Get("x-request-id"),
- Usage: usage,
- Model: requestModel,
- UpstreamModel: upstreamModel,
- Stream: parsed.Stream,
- ResponseHeaders: resp.Header.Clone(),
- Duration: time.Since(startTime),
- FirstTokenMs: firstTokenMs,
- ImageCount: imageCount,
- ImageSize: parsed.SizeTier,
- }, nil
}
func (s *OpenAIGatewayService) buildOpenAIImagesRequest(
@@ -795,15 +763,7 @@ func (s *OpenAIGatewayService) buildOpenAIImagesRequest(
}
func buildOpenAIImagesURL(base string, endpoint string) string {
- normalized := strings.TrimRight(strings.TrimSpace(base), "/")
- relative := strings.TrimPrefix(strings.TrimSpace(endpoint), "/v1")
- if strings.HasSuffix(normalized, endpoint) || strings.HasSuffix(normalized, relative) {
- return normalized
- }
- if strings.HasSuffix(normalized, "/v1") {
- return normalized + relative
- }
- return normalized + endpoint
+ return buildOpenAIEndpointURL(base, endpoint)
}
func rewriteOpenAIImagesModel(body []byte, contentType string, model string) ([]byte, string, error) {
@@ -892,10 +852,10 @@ func cloneMultipartHeader(src textproto.MIMEHeader) textproto.MIMEHeader {
return dst
}
-func (s *OpenAIGatewayService) handleOpenAIImagesNonStreamingResponse(resp *http.Response, c *gin.Context) (OpenAIUsage, int, error) {
+func (s *OpenAIGatewayService) handleOpenAIImagesNonStreamingResponse(resp *http.Response, c *gin.Context) (OpenAIUsage, int, []string, error) {
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
- return OpenAIUsage{}, 0, err
+ return OpenAIUsage{}, 0, nil, err
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
contentType := "application/json"
@@ -907,14 +867,14 @@ func (s *OpenAIGatewayService) handleOpenAIImagesNonStreamingResponse(resp *http
c.Data(resp.StatusCode, contentType, body)
usage, _ := extractOpenAIUsageFromJSONBytes(body)
- return usage, extractOpenAIImageCountFromJSONBytes(body), nil
+ return usage, extractOpenAIImageCountFromJSONBytes(body), collectOpenAIResponseImageOutputSizesFromJSONBytes(body), nil
}
func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
resp *http.Response,
c *gin.Context,
startTime time.Time,
-) (OpenAIUsage, int, *int, error) {
+) (OpenAIUsage, int, []string, *int, error) {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
contentType := strings.TrimSpace(resp.Header.Get("Content-Type"))
if contentType == "" {
@@ -925,7 +885,7 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
flusher, ok := c.Writer.(http.Flusher)
if !ok {
- return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer")
+ return OpenAIUsage{}, 0, nil, nil, fmt.Errorf("streaming is not supported by response writer")
}
usage := OpenAIUsage{}
@@ -1010,12 +970,12 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
}
if err != nil {
flushSSEEvent()
- return usage, imageCounter.Count(), firstTokenMs, err
+ return usage, imageCounter.Count(), imageCounter.Sizes(), firstTokenMs, err
}
}
flushSSEEvent()
finalizeFallbackBody()
- return usage, imageCounter.Count(), firstTokenMs, nil
+ return usage, imageCounter.Count(), imageCounter.Sizes(), firstTokenMs, nil
}
type readEvent struct {
@@ -1082,11 +1042,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
if !ok {
flushSSEEvent()
finalizeFallbackBody()
- return usage, imageCounter.Count(), firstTokenMs, nil
+ return usage, imageCounter.Count(), imageCounter.Sizes(), firstTokenMs, nil
}
if ev.err != nil {
flushSSEEvent()
- return usage, imageCounter.Count(), firstTokenMs, ev.err
+ return usage, imageCounter.Count(), imageCounter.Sizes(), firstTokenMs, ev.err
}
processLine(ev.line)
case <-intervalCh:
@@ -1095,11 +1055,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse(
continue
}
if clientDisconnected {
- return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream incomplete after timeout")
+ return usage, imageCounter.Count(), imageCounter.Sizes(), firstTokenMs, fmt.Errorf("image stream incomplete after timeout")
}
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream data interval timeout: interval=%s", streamInterval)
_ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval)))
- return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream data interval timeout")
+ return usage, imageCounter.Count(), imageCounter.Sizes(), firstTokenMs, fmt.Errorf("image stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval {
continue
diff --git a/backend/internal/service/openai_images_responses.go b/backend/internal/service/openai_images_responses.go
index 25cd8228..c89c2aaf 100644
--- a/backend/internal/service/openai_images_responses.go
+++ b/backend/internal/service/openai_images_responses.go
@@ -72,6 +72,22 @@ func mergeOpenAIResponsesImageMeta(dst *openAIResponsesImageResult, src openAIRe
}
}
+func openAIResponsesImageResultSizes(results []openAIResponsesImageResult) []string {
+ if len(results) == 0 {
+ return nil
+ }
+ sizes := make([]string, 0, len(results))
+ for _, result := range results {
+ if size := strings.TrimSpace(result.Size); size != "" {
+ sizes = append(sizes, size)
+ }
+ }
+ if len(sizes) == 0 {
+ return nil
+ }
+ return sizes
+}
+
func extractOpenAIResponsesImageMetaFromLifecycleEvent(payload []byte) (openAIResponsesImageResult, int64, bool) {
switch gjson.GetBytes(payload, "type").String() {
case "response.created", "response.in_progress", "response.completed":
@@ -547,10 +563,10 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
c *gin.Context,
responseFormat string,
fallbackModel string,
-) (OpenAIUsage, int, error) {
+) (OpenAIUsage, int, []string, error) {
body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError)
if err != nil {
- return OpenAIUsage{}, 0, err
+ return OpenAIUsage{}, 0, nil, err
}
var usage OpenAIUsage
@@ -559,10 +575,10 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
})
results, createdAt, usageRaw, firstMeta, _, err := collectOpenAIImagesFromResponsesBody(body)
if err != nil {
- return OpenAIUsage{}, 0, err
+ return OpenAIUsage{}, 0, nil, err
}
if len(results) == 0 {
- return OpenAIUsage{}, 0, fmt.Errorf("upstream did not return image output")
+ return OpenAIUsage{}, 0, nil, fmt.Errorf("upstream did not return image output")
}
if strings.TrimSpace(firstMeta.Model) == "" {
firstMeta.Model = strings.TrimSpace(fallbackModel)
@@ -570,11 +586,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse(
responseBody, err := buildOpenAIImagesAPIResponse(results, createdAt, usageRaw, firstMeta, responseFormat)
if err != nil {
- return OpenAIUsage{}, 0, err
+ return OpenAIUsage{}, 0, nil, err
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
c.Data(resp.StatusCode, "application/json; charset=utf-8", responseBody)
- return usage, len(results), nil
+ return usage, len(results), openAIResponsesImageResultSizes(results), nil
}
func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
@@ -584,7 +600,7 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
responseFormat string,
streamPrefix string,
fallbackModel string,
-) (OpenAIUsage, int, *int, error) {
+) (OpenAIUsage, int, []string, *int, error) {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
@@ -593,7 +609,7 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
flusher, ok := c.Writer.(http.Flusher)
if !ok {
- return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer")
+ return OpenAIUsage{}, 0, nil, nil, fmt.Errorf("streaming is not supported by response writer")
}
format := strings.ToLower(strings.TrimSpace(responseFormat))
@@ -603,6 +619,7 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
usage := OpenAIUsage{}
imageCount := 0
+ var imageOutputSizes []string
var firstTokenMs *int
emitted := make(map[string]struct{})
pendingResults := make([]openAIResponsesImageResult, 0, 1)
@@ -713,6 +730,7 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload)
}
imageCount = len(emitted)
+ imageOutputSizes = openAIResponsesImageResultSizes(finalResults)
processDataDone = true
}
}
@@ -753,6 +771,7 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload)
}
imageCount = len(emitted)
+ imageOutputSizes = openAIResponsesImageResultSizes(pendingResults)
return nil
}
@@ -769,33 +788,33 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
line, err := reader.ReadBytes('\n')
done, processErr := processLine(line)
if processErr != nil {
- return usage, imageCount, firstTokenMs, processErr
+ return usage, imageCount, imageOutputSizes, firstTokenMs, processErr
}
if done {
- return usage, imageCount, firstTokenMs, nil
+ return usage, imageCount, imageOutputSizes, firstTokenMs, nil
}
if err == io.EOF {
break
}
if err != nil {
if done, processErr := flushData(); processErr != nil {
- return usage, imageCount, firstTokenMs, processErr
+ return usage, imageCount, imageOutputSizes, firstTokenMs, processErr
} else if done {
- return usage, imageCount, firstTokenMs, nil
+ return usage, imageCount, imageOutputSizes, firstTokenMs, nil
}
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(err.Error()))
- return usage, imageCount, firstTokenMs, err
+ return usage, imageCount, imageOutputSizes, firstTokenMs, err
}
}
if done, processErr := flushData(); processErr != nil {
- return usage, imageCount, firstTokenMs, processErr
+ return usage, imageCount, imageOutputSizes, firstTokenMs, processErr
} else if done {
- return usage, imageCount, firstTokenMs, nil
+ return usage, imageCount, imageOutputSizes, firstTokenMs, nil
}
if err := finalizePending(); err != nil {
- return usage, imageCount, firstTokenMs, err
+ return usage, imageCount, imageOutputSizes, firstTokenMs, err
}
- return usage, imageCount, firstTokenMs, nil
+ return usage, imageCount, imageOutputSizes, firstTokenMs, nil
}
type readEvent struct {
@@ -861,30 +880,30 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
case ev, ok := <-events:
if !ok {
if done, processErr := flushData(); processErr != nil {
- return usage, imageCount, firstTokenMs, processErr
+ return usage, imageCount, imageOutputSizes, firstTokenMs, processErr
} else if done {
- return usage, imageCount, firstTokenMs, nil
+ return usage, imageCount, imageOutputSizes, firstTokenMs, nil
}
if err := finalizePending(); err != nil {
- return usage, imageCount, firstTokenMs, err
+ return usage, imageCount, imageOutputSizes, firstTokenMs, err
}
- return usage, imageCount, firstTokenMs, nil
+ return usage, imageCount, imageOutputSizes, firstTokenMs, nil
}
if ev.err != nil {
if done, processErr := flushData(); processErr != nil {
- return usage, imageCount, firstTokenMs, processErr
+ return usage, imageCount, imageOutputSizes, firstTokenMs, processErr
} else if done {
- return usage, imageCount, firstTokenMs, nil
+ return usage, imageCount, imageOutputSizes, firstTokenMs, nil
}
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(ev.err.Error()))
- return usage, imageCount, firstTokenMs, ev.err
+ return usage, imageCount, imageOutputSizes, firstTokenMs, ev.err
}
done, processErr := processLine(ev.line)
if processErr != nil {
- return usage, imageCount, firstTokenMs, processErr
+ return usage, imageCount, imageOutputSizes, firstTokenMs, processErr
}
if done {
- return usage, imageCount, firstTokenMs, nil
+ return usage, imageCount, imageOutputSizes, firstTokenMs, nil
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
@@ -892,11 +911,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse(
continue
}
if clientDisconnected {
- return usage, imageCount, firstTokenMs, fmt.Errorf("image stream incomplete after timeout")
+ return usage, imageCount, imageOutputSizes, firstTokenMs, fmt.Errorf("image stream incomplete after timeout")
}
logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images responses stream data interval timeout: interval=%s", streamInterval)
s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval)))
- return usage, imageCount, firstTokenMs, fmt.Errorf("image stream data interval timeout")
+ return usage, imageCount, imageOutputSizes, firstTokenMs, fmt.Errorf("image stream data interval timeout")
case <-keepaliveCh:
if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval {
continue
@@ -948,7 +967,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
)
}
- upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream)
+ upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx)
defer releaseUpstreamCtx()
token, _, err := s.GetAccessToken(upstreamCtx, account)
@@ -960,8 +979,6 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
if err != nil {
return nil, err
}
- setOpsUpstreamRequestBody(c, responsesBody)
-
upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false)
if err != nil {
return nil, err
@@ -1019,31 +1036,34 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
defer func() { _ = resp.Body.Close() }()
var (
- usage OpenAIUsage
- imageCount int
- firstTokenMs *int
+ usage OpenAIUsage
+ imageCount int
+ imageOutputSizes []string
+ firstTokenMs *int
)
if parsed.Stream {
- usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel)
+ usage, imageCount, imageOutputSizes, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel)
if err != nil {
if imageCount > 0 {
return &OpenAIForwardResult{
- RequestID: resp.Header.Get("x-request-id"),
- Usage: usage,
- Model: requestModel,
- UpstreamModel: requestModel,
- Stream: parsed.Stream,
- ResponseHeaders: resp.Header.Clone(),
- Duration: time.Since(startTime),
- FirstTokenMs: firstTokenMs,
- ImageCount: imageCount,
- ImageSize: parsed.SizeTier,
+ RequestID: resp.Header.Get("x-request-id"),
+ Usage: usage,
+ Model: requestModel,
+ UpstreamModel: requestModel,
+ Stream: parsed.Stream,
+ ResponseHeaders: resp.Header.Clone(),
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ ImageCount: imageCount,
+ ImageSize: parsed.SizeTier,
+ ImageInputSize: parsed.Size,
+ ImageOutputSizes: imageOutputSizes,
}, err
}
return nil, err
}
} else {
- usage, imageCount, err = s.handleOpenAIImagesOAuthNonStreamingResponse(resp, c, parsed.ResponseFormat, requestModel)
+ usage, imageCount, imageOutputSizes, err = s.handleOpenAIImagesOAuthNonStreamingResponse(resp, c, parsed.ResponseFormat, requestModel)
if err != nil {
return nil, err
}
@@ -1052,15 +1072,17 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth(
imageCount = parsed.N
}
return &OpenAIForwardResult{
- RequestID: resp.Header.Get("x-request-id"),
- Usage: usage,
- Model: requestModel,
- UpstreamModel: requestModel,
- Stream: parsed.Stream,
- ResponseHeaders: resp.Header.Clone(),
- Duration: time.Since(startTime),
- FirstTokenMs: firstTokenMs,
- ImageCount: imageCount,
- ImageSize: parsed.SizeTier,
+ RequestID: resp.Header.Get("x-request-id"),
+ Usage: usage,
+ Model: requestModel,
+ UpstreamModel: requestModel,
+ Stream: parsed.Stream,
+ ResponseHeaders: resp.Header.Clone(),
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ ImageCount: imageCount,
+ ImageSize: parsed.SizeTier,
+ ImageInputSize: parsed.Size,
+ ImageOutputSizes: imageOutputSizes,
}, nil
}
diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go
index 45fb24e9..35789d21 100644
--- a/backend/internal/service/openai_images_test.go
+++ b/backend/internal/service/openai_images_test.go
@@ -149,9 +149,9 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_NormalizesOfficialAndCusto
{size: "2048x1152", wantTier: "2K"},
{size: "3840x2160", wantTier: "4K"},
{size: "2160x3840", wantTier: "4K"},
- {size: "1024X768", wantTier: "2K"},
+ {size: "1024X768", wantTier: "1K"},
{size: "1280x768", wantTier: "2K"},
- {size: "2560x1440", wantTier: "2K"},
+ {size: "2560x1440", wantTier: "4K"},
{size: "2560x1600", wantTier: "4K"},
{size: "auto", wantTier: "2K"},
}
@@ -186,7 +186,7 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_UnknownSizesDoNotBlockPass
{size: "2048x1153", wantTier: "2K"},
{size: "4096x1024", wantTier: "4K"},
{size: "3840x1024", wantTier: "4K"},
- {size: "512x512", wantTier: "2K"},
+ {size: "512x512", wantTier: "1K"},
{size: "invalid", wantTier: "2K"},
{size: "999999999999999999999999999x2", wantTier: "2K"},
}
@@ -418,6 +418,10 @@ func TestBuildOpenAIImagesURL_HandlesVersionedBaseURL(t *testing.T) {
"https://image-upstream.example/v1/images/generations",
buildOpenAIImagesURL("https://image-upstream.example/v1", openAIImagesGenerationsEndpoint),
)
+ require.Equal(t,
+ "https://open.bigmodel.cn/api/paas/v4/images/generations",
+ buildOpenAIImagesURL("https://open.bigmodel.cn/api/paas/v4", openAIImagesGenerationsEndpoint),
+ )
require.Equal(t,
"https://image-upstream.example/v1/images/edits",
buildOpenAIImagesURL("https://image-upstream.example/v1/", openAIImagesEditsEndpoint),
diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go
index f087ac32..020e8875 100644
--- a/backend/internal/service/openai_model_mapping_test.go
+++ b/backend/internal/service/openai_model_mapping_test.go
@@ -261,6 +261,12 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) {
model: "gpt-5.4-high",
want: "gpt-5.4",
},
+ {
+ name: "oauth preserves codex auto review model",
+ account: &Account{Type: AccountTypeOAuth},
+ model: "codex-auto-review",
+ want: "codex-auto-review",
+ },
{
name: "apikey preserves custom compatible model",
account: &Account{Type: AccountTypeAPIKey},
@@ -283,3 +289,17 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) {
})
}
}
+
+func TestUsageBillingModelCandidatesPreserveCodexAutoReviewModel(t *testing.T) {
+ candidates := usageBillingModelCandidates("codex-auto-review")
+
+ expected := []string{"codex-auto-review"}
+ if len(candidates) != len(expected) {
+ t.Fatalf("usageBillingModelCandidates(codex-auto-review) = %#v, want %#v", candidates, expected)
+ }
+ for i := range expected {
+ if candidates[i] != expected[i] {
+ t.Fatalf("usageBillingModelCandidates(codex-auto-review) = %#v, want %#v", candidates, expected)
+ }
+ }
+}
diff --git a/backend/internal/service/openai_silent_refusal.go b/backend/internal/service/openai_silent_refusal.go
new file mode 100644
index 00000000..27b71b75
--- /dev/null
+++ b/backend/internal/service/openai_silent_refusal.go
@@ -0,0 +1,293 @@
+package service
+
+import (
+ "bytes"
+ "encoding/json"
+ "net/http"
+ "strings"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat"
+ "github.com/gin-gonic/gin"
+ "github.com/tidwall/gjson"
+)
+
+const (
+ openAISilentRefusalMinRequestBodyBytes = 64 * 1024
+ openAISilentRefusalErrorCode = "openai_silent_refusal"
+ openAISilentRefusalUpstreamMessage = "OpenAI upstream returned an empty completion stream with finish_reason=stop and no usage"
+ openAISilentRefusalClientMessage = "Upstream returned an empty completion without usage; no fallback account was available"
+)
+
+type openAIChatSilentRefusalDetector struct {
+ enabled bool
+ sawContent bool
+ sawToolCall bool
+ sawFunctionCall bool
+ sawUsage bool
+ sawError bool
+ sawReasoning bool
+ sawFinish bool
+ finishReason string
+}
+
+func newOpenAIChatSilentRefusalDetector(requestBodyLen int) *openAIChatSilentRefusalDetector {
+ return &openAIChatSilentRefusalDetector{
+ enabled: requestBodyLen >= openAISilentRefusalMinRequestBodyBytes,
+ }
+}
+
+func (d *openAIChatSilentRefusalDetector) Enabled() bool {
+ return d != nil && d.enabled
+}
+
+func (d *openAIChatSilentRefusalDetector) ObserveSSELine(line string) {
+ if d == nil || !d.enabled {
+ return
+ }
+ if eventType, ok := extractOpenAISSEEventLine(line); ok {
+ d.observeEventType(eventType)
+ return
+ }
+ if payload, ok := extractOpenAISSEDataLine(line); ok {
+ d.ObservePayload([]byte(payload))
+ }
+}
+
+func (d *openAIChatSilentRefusalDetector) ObservePayload(payload []byte) {
+ if d == nil || !d.enabled {
+ return
+ }
+ payload = bytes.TrimSpace(payload)
+ if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
+ return
+ }
+ if !gjson.ValidBytes(payload) {
+ return
+ }
+
+ eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
+ d.observeEventType(eventType)
+
+ if gjson.GetBytes(payload, "error").Exists() {
+ d.sawError = true
+ }
+ if usage := gjson.GetBytes(payload, "usage"); usage.Exists() && usage.IsObject() {
+ d.sawUsage = true
+ }
+ if usage := gjson.GetBytes(payload, "response.usage"); usage.Exists() && usage.IsObject() {
+ d.sawUsage = true
+ }
+
+ d.observeChatChoicesPayload(payload)
+ d.observeResponsesPayload(payload, eventType)
+}
+
+func (d *openAIChatSilentRefusalDetector) ObserveChatChunk(chunk apicompat.ChatCompletionsChunk) {
+ if d == nil || !d.enabled {
+ return
+ }
+ if chunk.Usage != nil {
+ d.sawUsage = true
+ }
+ for _, choice := range chunk.Choices {
+ if choice.FinishReason != nil {
+ d.observeFinishReason(*choice.FinishReason)
+ }
+ delta := choice.Delta
+ if delta.Content != nil && *delta.Content != "" {
+ d.sawContent = true
+ }
+ if delta.ReasoningContent != nil {
+ d.sawReasoning = true
+ }
+ if len(delta.ToolCalls) > 0 {
+ d.sawToolCall = true
+ }
+ }
+}
+
+func (d *openAIChatSilentRefusalDetector) ShouldReleaseClientOutput() bool {
+ if d == nil || !d.enabled {
+ return true
+ }
+ if d.sawContent || d.sawToolCall || d.sawFunctionCall || d.sawUsage || d.sawError || d.sawReasoning {
+ return true
+ }
+ return d.sawFinish && d.finishReason != "" && d.finishReason != "stop"
+}
+
+func (d *openAIChatSilentRefusalDetector) IsSilentRefusal() bool {
+ if d == nil || !d.enabled {
+ return false
+ }
+ return !d.sawContent &&
+ !d.sawToolCall &&
+ !d.sawFunctionCall &&
+ !d.sawUsage &&
+ !d.sawError &&
+ !d.sawReasoning &&
+ d.sawFinish &&
+ d.finishReason == "stop"
+}
+
+func (d *openAIChatSilentRefusalDetector) observeEventType(eventType string) {
+ eventType = strings.TrimSpace(eventType)
+ if eventType == "" {
+ return
+ }
+ if eventType == "error" || eventType == "response.failed" {
+ d.sawError = true
+ }
+ if strings.Contains(eventType, "reasoning") || strings.Contains(eventType, "reasoning_summary") {
+ d.sawReasoning = true
+ }
+}
+
+func (d *openAIChatSilentRefusalDetector) observeFinishReason(reason string) {
+ reason = strings.TrimSpace(reason)
+ if reason == "" {
+ return
+ }
+ d.sawFinish = true
+ d.finishReason = reason
+}
+
+func (d *openAIChatSilentRefusalDetector) observeChatChoicesPayload(payload []byte) {
+ choices := gjson.GetBytes(payload, "choices")
+ if !choices.Exists() || !choices.IsArray() {
+ return
+ }
+ for _, choice := range choices.Array() {
+ if finish := choice.Get("finish_reason"); finish.Exists() {
+ d.observeFinishReason(finish.String())
+ }
+ delta := choice.Get("delta")
+ if !delta.Exists() {
+ continue
+ }
+ if content := delta.Get("content"); content.Exists() && content.String() != "" {
+ d.sawContent = true
+ }
+ if delta.Get("tool_calls").Exists() {
+ d.sawToolCall = true
+ }
+ if delta.Get("function_call").Exists() {
+ d.sawFunctionCall = true
+ }
+ if delta.Get("reasoning").Exists() ||
+ delta.Get("reasoning_content").Exists() ||
+ delta.Get("reasoning_summary").Exists() {
+ d.sawReasoning = true
+ }
+ }
+}
+
+func (d *openAIChatSilentRefusalDetector) observeResponsesPayload(payload []byte, eventType string) {
+ switch eventType {
+ case "response.output_text.delta":
+ if gjson.GetBytes(payload, "delta").String() != "" {
+ d.sawContent = true
+ }
+ case "response.output_item.added":
+ switch strings.TrimSpace(gjson.GetBytes(payload, "item.type").String()) {
+ case "function_call":
+ d.sawToolCall = true
+ case "reasoning":
+ d.sawReasoning = true
+ }
+ case "response.function_call_arguments.delta":
+ d.sawToolCall = true
+ case "response.reasoning_summary_text.delta", "response.reasoning_summary_text.done":
+ d.sawReasoning = true
+ case "response.completed", "response.done":
+ d.observeFinishReason("stop")
+ case "response.incomplete":
+ d.observeFinishReason("length")
+ case "response.failed":
+ d.sawError = true
+ }
+
+ if output := gjson.GetBytes(payload, "response.output"); output.Exists() && output.IsArray() {
+ for _, item := range output.Array() {
+ switch strings.TrimSpace(item.Get("type").String()) {
+ case "function_call":
+ d.sawToolCall = true
+ case "reasoning":
+ d.sawReasoning = true
+ case "message":
+ d.observeResponseMessageItem(item)
+ }
+ }
+ }
+}
+
+func (d *openAIChatSilentRefusalDetector) observeResponseMessageItem(item gjson.Result) {
+ content := item.Get("content")
+ if !content.Exists() || !content.IsArray() {
+ return
+ }
+ for _, part := range content.Array() {
+ if part.Get("text").String() != "" {
+ d.sawContent = true
+ return
+ }
+ }
+}
+
+func newOpenAISilentRefusalFailoverError(c *gin.Context, account *Account, upstreamRequestID string) *UpstreamFailoverError {
+ accountID := int64(0)
+ accountName := ""
+ platform := PlatformOpenAI
+ if account != nil {
+ accountID = account.ID
+ accountName = account.Name
+ platform = account.Platform
+ }
+
+ setOpsUpstreamError(c, http.StatusBadGateway, openAISilentRefusalUpstreamMessage, "")
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: platform,
+ AccountID: accountID,
+ AccountName: accountName,
+ UpstreamStatusCode: http.StatusBadGateway,
+ UpstreamRequestID: upstreamRequestID,
+ Kind: "failover",
+ Message: openAISilentRefusalUpstreamMessage,
+ })
+
+ headers := http.Header{}
+ if strings.TrimSpace(upstreamRequestID) != "" {
+ headers.Set("x-request-id", strings.TrimSpace(upstreamRequestID))
+ }
+ return &UpstreamFailoverError{
+ StatusCode: http.StatusBadGateway,
+ ResponseBody: openAISilentRefusalErrorBody(),
+ ResponseHeaders: headers,
+ }
+}
+
+func openAISilentRefusalErrorBody() []byte {
+ body, err := json.Marshal(map[string]any{
+ "error": map[string]any{
+ "type": "upstream_error",
+ "code": openAISilentRefusalErrorCode,
+ "message": openAISilentRefusalUpstreamMessage,
+ },
+ })
+ if err != nil {
+ return []byte(`{"error":{"type":"upstream_error","code":"openai_silent_refusal","message":"OpenAI upstream returned an empty completion stream with finish_reason=stop and no usage"}}`)
+ }
+ return body
+}
+
+// IsOpenAISilentRefusalErrorBody reports whether a failover body was produced
+// by the OpenAI silent-refusal detector.
+func IsOpenAISilentRefusalErrorBody(body []byte) bool {
+ return strings.TrimSpace(gjson.GetBytes(body, "error.code").String()) == openAISilentRefusalErrorCode
+}
+
+// OpenAISilentRefusalClientMessage returns the exhausted-failover client message
+// for OpenAI silent refusals.
+func OpenAISilentRefusalClientMessage() string {
+ return openAISilentRefusalClientMessage
+}
diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go
index a680d451..5b55d200 100644
--- a/backend/internal/service/openai_token_provider.go
+++ b/backend/internal/service/openai_token_provider.go
@@ -154,7 +154,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
if needsRefresh && strings.TrimSpace(account.GetOpenAIRefreshToken()) == "" {
if expiresAt != nil && !time.Now().Before(*expiresAt) {
- return "", errors.New("openai access_token expired and refresh_token is missing")
+ const reason = "openai access_token expired and refresh_token is missing"
+ // 永久故障:缺失 refresh_token 时账号无法自愈,必须立即从调度池剔除,
+ // 否则会被反复选中、每次都在 token 阶段直接返回错误,对用户呈现持续 502。
+ p.disableAccountMissingRefreshToken(account, reason)
+ return "", errors.New(reason)
}
needsRefresh = false
}
@@ -261,6 +265,39 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
+// disableAccountMissingRefreshToken 在请求路径上发现 OpenAI OAuth 账号
+// 凭证已过期且 refresh_token 缺失时,将账号标记为 error 状态。
+// 这是一种永久性故障:仅靠后续请求或 TokenRefreshService 不会自愈
+// (NeedsRefresh 也会因 refresh_token 为空直接跳过),
+// 必须主动剔除以避免账号被持续选中导致用户端反复 502。
+// 使用 background context 是因为请求 context 可能很快结束。
+func (p *OpenAITokenProvider) disableAccountMissingRefreshToken(account *Account, reason string) {
+ if p == nil || p.accountRepo == nil || account == nil {
+ return
+ }
+ bgCtx := context.Background()
+ if err := p.accountRepo.SetError(bgCtx, account.ID, reason); err != nil {
+ slog.Warn("openai_token_provider.set_error_failed",
+ "account_id", account.ID,
+ "error", err,
+ )
+ return
+ }
+ if p.tokenCache != nil {
+ cacheKey := OpenAITokenCacheKey(account)
+ if err := p.tokenCache.DeleteAccessToken(bgCtx, cacheKey); err != nil {
+ slog.Warn("openai_token_provider.cache_delete_failed",
+ "account_id", account.ID,
+ "error", err,
+ )
+ }
+ }
+ slog.Warn("openai_token_provider.account_disabled_missing_refresh_token",
+ "account_id", account.ID,
+ "reason", reason,
+ )
+}
+
func (p *OpenAITokenProvider) waitForTokenAfterLockRace(ctx context.Context, cacheKey string) (string, error) {
wait := openAILockInitialWait
totalWaitMs := int64(0)
diff --git a/backend/internal/service/openai_token_provider_test.go b/backend/internal/service/openai_token_provider_test.go
index 4b69db8a..df2f0f3e 100644
--- a/backend/internal/service/openai_token_provider_test.go
+++ b/backend/internal/service/openai_token_provider_test.go
@@ -930,3 +930,34 @@ func TestOpenAITokenProvider_RuntimeMetrics_LockAcquireFailure(t *testing.T) {
require.GreaterOrEqual(t, metrics.LockAcquireFailure, int64(1))
require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1))
}
+
+func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+ repo := &rateLimitAccountRepoStub{}
+
+ expiresAt := time.Now().Add(-time.Minute).UTC().Format(time.RFC3339)
+ account := &Account{
+ ID: 2881,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "expired-access-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ cacheKey := OpenAITokenCacheKey(account)
+ cache.tokens[cacheKey] = "stale-cached-token"
+ // Force the provider past the cache hit branch.
+ cache.getErr = errors.New("simulated cache miss")
+
+ provider := NewOpenAITokenProvider(repo, cache, nil)
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.Error(t, err)
+ require.Empty(t, token)
+ require.Contains(t, err.Error(), "refresh_token is missing")
+
+ require.Equal(t, 1, repo.setErrorCalls, "account should be disabled via SetError exactly once")
+ require.Contains(t, repo.lastErrorMsg, "refresh_token is missing")
+}
diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go
index 77cf7d95..920a2239 100644
--- a/backend/internal/service/openai_ws_forwarder.go
+++ b/backend/internal/service/openai_ws_forwarder.go
@@ -399,15 +399,9 @@ func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIU
if usage == nil || len(message) == 0 {
return
}
- values := gjson.GetManyBytes(
- message,
- "response.usage.input_tokens",
- "response.usage.output_tokens",
- "response.usage.input_tokens_details.cached_tokens",
- )
- usage.InputTokens = int(values[0].Int())
- usage.OutputTokens = int(values[1].Int())
- usage.CacheReadInputTokens = int(values[2].Int())
+ if parsedUsage, ok := extractOpenAIUsageFromJSONBytes(message); ok {
+ *usage = parsedUsage
+ }
}
func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) {
@@ -2351,18 +2345,19 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
)
return &OpenAIForwardResult{
- RequestID: responseID,
- Usage: *usage,
- Model: originalModel,
- UpstreamModel: mappedModel,
- ImageCount: imageCounter.Count(),
- ServiceTier: extractOpenAIServiceTier(reqBody),
- ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
- Stream: reqStream,
- OpenAIWSMode: true,
- ResponseHeaders: lease.HandshakeHeaders(),
- Duration: time.Since(startTime),
- FirstTokenMs: firstTokenMs,
+ RequestID: responseID,
+ Usage: *usage,
+ Model: originalModel,
+ UpstreamModel: mappedModel,
+ ImageCount: imageCounter.Count(),
+ ImageOutputSizes: imageCounter.Sizes(),
+ ServiceTier: extractOpenAIServiceTier(reqBody),
+ ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
+ Stream: reqStream,
+ OpenAIWSMode: true,
+ ResponseHeaders: lease.HandshakeHeaders(),
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
}, nil
}
@@ -2464,6 +2459,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
originalModel string
imageBillingModel string
imageSizeTier string
+ imageInputSize string
payloadBytes int
}
@@ -2567,12 +2563,16 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
imageBillingModel := ""
imageSizeTier := ""
+ imageInputSize := ""
if imageIntent {
var imageCfgErr error
- imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(normalized, originalModel)
+ imageCfg, imageCfgErr := resolveOpenAIResponsesImageBillingConfigDetailedFromBody(normalized, originalModel)
if imageCfgErr != nil {
return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, imageCfgErr.Error(), imageCfgErr)
}
+ imageBillingModel = imageCfg.Model
+ imageSizeTier = imageCfg.SizeTier
+ imageInputSize = imageCfg.InputSize
}
// Apply OpenAI Fast Policy on the response.create frame using the same
@@ -2621,6 +2621,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
originalModel: originalModel,
imageBillingModel: imageBillingModel,
imageSizeTier: imageSizeTier,
+ imageInputSize: imageInputSize,
payloadBytes: len(normalized),
}, nil
}
@@ -2822,7 +2823,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
return payload, nil
}
- sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string, imageBillingModel string, imageSizeTier string) (*OpenAIForwardResult, error) {
+ sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string, imageBillingModel string, imageSizeTier string, imageInputSize string) (*OpenAIForwardResult, error) {
if lease == nil {
return nil, errors.New("upstream websocket lease is nil")
}
@@ -3046,6 +3047,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
if imageCount > 0 {
result.ImageCount = imageCount
result.ImageSize = imageSizeTier
+ result.ImageInputSize = imageInputSize
+ result.ImageOutputSizes = imageCounter.Sizes()
result.BillingModel = imageBillingModel
}
return result, nil
@@ -3057,6 +3060,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
currentOriginalModel := firstPayload.originalModel
currentImageBillingModel := firstPayload.imageBillingModel
currentImageSizeTier := firstPayload.imageSizeTier
+ currentImageInputSize := firstPayload.imageInputSize
currentPayloadBytes := firstPayload.payloadBytes
isStrictAffinityTurn := func(payload []byte) bool {
if !storeDisabled {
@@ -3534,7 +3538,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
)
}
- result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel, currentImageBillingModel, currentImageSizeTier)
+ result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel, currentImageBillingModel, currentImageSizeTier, currentImageInputSize)
if relayErr != nil {
lastTurnClean = false
if recoverIngressPrevResponseNotFound(relayErr, turn, connID) {
@@ -3658,6 +3662,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
currentOriginalModel = nextPayload.originalModel
currentImageBillingModel = nextPayload.imageBillingModel
currentImageSizeTier = nextPayload.imageSizeTier
+ currentImageInputSize = nextPayload.imageInputSize
currentPayloadBytes = nextPayload.payloadBytes
storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account)
if !storeDisabled {
diff --git a/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go
index 76167603..0350bde9 100644
--- a/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go
+++ b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go
@@ -29,6 +29,14 @@ func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) {
require.Equal(t, 11, usage.InputTokens)
require.Equal(t, 7, usage.OutputTokens)
require.Equal(t, 3, usage.CacheReadInputTokens)
+
+ parseOpenAIWSResponseUsageFromCompletedEvent(
+ []byte(`{"type":"response.completed","response":{"usage":{"prompt_tokens":19,"completion_tokens":5,"prompt_tokens_details":{"cached_tokens":4}}}}`),
+ usage,
+ )
+ require.Equal(t, 19, usage.InputTokens)
+ require.Equal(t, 5, usage.OutputTokens)
+ require.Equal(t, 4, usage.CacheReadInputTokens)
}
func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) {
diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay.go b/backend/internal/service/openai_ws_v2/passthrough_relay.go
index af8ee195..2b7e2add 100644
--- a/backend/internal/service/openai_ws_v2/passthrough_relay.go
+++ b/backend/internal/service/openai_ws_v2/passthrough_relay.go
@@ -82,6 +82,7 @@ type relayState struct {
terminalEventType string
firstTokenMs *int
turnTimingByID map[string]*relayTurnTiming
+ activeTurn *relayTurnTiming
}
type relayExitSignal struct {
@@ -550,6 +551,12 @@ func observeUpstreamMessage(
if ms >= 0 {
state.firstTokenMs = &ms
}
+ if state.activeTurn != nil && state.activeTurn.firstTokenMs == nil {
+ tms := int(now.Sub(state.activeTurn.startAt).Milliseconds())
+ if tms >= 0 {
+ state.activeTurn.firstTokenMs = &tms
+ }
+ }
}
parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure)
observed := observedUpstreamEvent{
@@ -622,6 +629,7 @@ func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now
if !ok || timing == nil || timing.startAt.IsZero() {
timing = &relayTurnTiming{startAt: now}
state.turnTimingByID[responseID] = timing
+ state.activeTurn = timing
return timing
}
return timing
@@ -636,6 +644,9 @@ func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayT
return relayTurnTiming{}, false
}
delete(state.turnTimingByID, responseID)
+ if state.activeTurn == timing {
+ state.activeTurn = nil
+ }
return *timing, true
}
diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay_test.go b/backend/internal/service/openai_ws_v2/passthrough_relay_test.go
index ff9b7311..cdd41a05 100644
--- a/backend/internal/service/openai_ws_v2/passthrough_relay_test.go
+++ b/backend/internal/service/openai_ws_v2/passthrough_relay_test.go
@@ -750,3 +750,67 @@ func (c *errorOnWriteFrameConn) WriteFrame(_ context.Context, _ coderws.MessageT
func (c *errorOnWriteFrameConn) Close() error {
return nil
}
+
+func TestRelay_OnTurnComplete_RealOpenAIStream_FirstTokenMs(t *testing.T) {
+ t.Parallel()
+
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.created","response":{"id":"resp_real"}}`),
+ },
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.output_text.delta","delta":"He"}`),
+ },
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.output_text.delta","delta":"llo"}`),
+ },
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.output_text.delta","delta":" world"}`),
+ },
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_real","usage":{"input_tokens":2,"output_tokens":3}}}`),
+ },
+ }, true)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ base := time.Unix(0, 0)
+ var nowTick atomic.Int64
+ nowFn := func() time.Time {
+ step := nowTick.Add(1)
+ return base.Add(time.Duration(step) * 10 * time.Millisecond)
+ }
+
+ var turn RelayTurnResult
+ result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
+ Now: nowFn,
+ OnTurnComplete: func(current RelayTurnResult) {
+ turn = current
+ },
+ })
+ require.Nil(t, relayExit)
+ require.Equal(t, "resp_real", turn.RequestID)
+ require.Equal(t, "response.completed", turn.TerminalEventType)
+
+ require.NotNil(t, turn.FirstTokenMs, "per-turn FirstTokenMs must be captured for real OpenAI streams")
+ require.Greater(t, turn.Duration.Milliseconds(), int64(0))
+
+ require.Less(t,
+ int64(*turn.FirstTokenMs),
+ turn.Duration.Milliseconds(),
+ "per-turn FirstTokenMs (%dms) should be strictly less than Duration (%dms); "+
+ "equality indicates the bug where first_token is mistakenly stamped on the terminal event",
+ *turn.FirstTokenMs, turn.Duration.Milliseconds(),
+ )
+
+ require.NotNil(t, result.FirstTokenMs)
+ require.Greater(t, *result.FirstTokenMs, 0)
+}
diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go
index e2760725..0a89e2dd 100644
--- a/backend/internal/service/openai_ws_v2_passthrough_adapter.go
+++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go
@@ -267,9 +267,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
// omits "model" — Realtime clients are allowed to send response.create
// without re-stating the model, in which case the upstream uses the model
// negotiated at session.update time. Without this fallback, an empty
- // model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be
- // silently passed through, defeating the policy on every frame after
- // the first.
+ // model would miss any admin-configured model whitelist and be silently
+ // passed through, defeating that policy on every frame after the first.
capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage)
initialRequestModel := ""
if hooks != nil {
diff --git a/backend/internal/service/ops_cleanup_executor.go b/backend/internal/service/ops_cleanup_executor.go
index 63a7367f..d51863c4 100644
--- a/backend/internal/service/ops_cleanup_executor.go
+++ b/backend/internal/service/ops_cleanup_executor.go
@@ -26,7 +26,6 @@ type opsCleanupTarget struct {
type opsCleanupDeletedCounts struct {
errorLogs int64
- retryAttempts int64
alertEvents int64
systemLogs int64
logAudits int64
@@ -37,9 +36,8 @@ type opsCleanupDeletedCounts struct {
func (c opsCleanupDeletedCounts) String() string {
return fmt.Sprintf(
- "error_logs=%d retry_attempts=%d alert_events=%d system_logs=%d log_audits=%d system_metrics=%d hourly_preagg=%d daily_preagg=%d",
+ "error_logs=%d alert_events=%d system_logs=%d log_audits=%d system_metrics=%d hourly_preagg=%d daily_preagg=%d",
c.errorLogs,
- c.retryAttempts,
c.alertEvents,
c.systemLogs,
c.logAudits,
diff --git a/backend/internal/service/ops_cleanup_service.go b/backend/internal/service/ops_cleanup_service.go
index 60a690f3..f812c290 100644
--- a/backend/internal/service/ops_cleanup_service.go
+++ b/backend/internal/service/ops_cleanup_service.go
@@ -299,7 +299,6 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet
targets := []opsCleanupTarget{
{effective.ErrorLogRetentionDays, "ops_error_logs", "created_at", false, &out.errorLogs},
- {effective.ErrorLogRetentionDays, "ops_retry_attempts", "created_at", false, &out.retryAttempts},
{effective.ErrorLogRetentionDays, "ops_alert_events", "created_at", false, &out.alertEvents},
{effective.ErrorLogRetentionDays, "ops_system_logs", "created_at", false, &out.systemLogs},
{effective.ErrorLogRetentionDays, "ops_system_log_cleanup_audits", "created_at", false, &out.logAudits},
diff --git a/backend/internal/service/ops_models.go b/backend/internal/service/ops_models.go
index 5fefb74f..ba735346 100644
--- a/backend/internal/service/ops_models.go
+++ b/backend/internal/service/ops_models.go
@@ -37,14 +37,10 @@ type OpsErrorLog struct {
Platform string `json:"platform"`
Model string `json:"model"`
- IsRetryable bool `json:"is_retryable"`
- RetryCount int `json:"retry_count"`
-
Resolved bool `json:"resolved"`
ResolvedAt *time.Time `json:"resolved_at"`
ResolvedByUserID *int64 `json:"resolved_by_user_id"`
ResolvedByUserName string `json:"resolved_by_user_name"`
- ResolvedRetryID *int64 `json:"resolved_retry_id"`
ResolvedStatusRaw string `json:"-"`
ClientRequestID string `json:"client_request_id"`
@@ -89,12 +85,6 @@ type OpsErrorLogDetail struct {
ResponseLatencyMs *int64 `json:"response_latency_ms"`
TimeToFirstTokenMs *int64 `json:"time_to_first_token_ms"`
- // Retry context
- RequestBody string `json:"request_body"`
- RequestBodyTruncated bool `json:"request_body_truncated"`
- RequestBodyBytes *int `json:"request_body_bytes"`
- RequestHeaders string `json:"request_headers,omitempty"`
-
// vNext metric semantics
IsBusinessLimited bool `json:"is_business_limited"`
}
@@ -136,55 +126,3 @@ type OpsErrorLogList struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
}
-
-type OpsRetryAttempt struct {
- ID int64 `json:"id"`
- CreatedAt time.Time `json:"created_at"`
-
- RequestedByUserID int64 `json:"requested_by_user_id"`
- SourceErrorID int64 `json:"source_error_id"`
- Mode string `json:"mode"`
- PinnedAccountID *int64 `json:"pinned_account_id"`
- PinnedAccountName string `json:"pinned_account_name"`
-
- Status string `json:"status"`
- StartedAt *time.Time `json:"started_at"`
- FinishedAt *time.Time `json:"finished_at"`
- DurationMs *int64 `json:"duration_ms"`
-
- // Persisted execution results (best-effort)
- Success *bool `json:"success"`
- HTTPStatusCode *int `json:"http_status_code"`
- UpstreamRequestID *string `json:"upstream_request_id"`
- UsedAccountID *int64 `json:"used_account_id"`
- UsedAccountName string `json:"used_account_name"`
- ResponsePreview *string `json:"response_preview"`
- ResponseTruncated *bool `json:"response_truncated"`
-
- // Optional correlation
- ResultRequestID *string `json:"result_request_id"`
- ResultErrorID *int64 `json:"result_error_id"`
-
- ErrorMessage *string `json:"error_message"`
-}
-
-type OpsRetryResult struct {
- AttemptID int64 `json:"attempt_id"`
- Mode string `json:"mode"`
- Status string `json:"status"`
-
- PinnedAccountID *int64 `json:"pinned_account_id"`
- UsedAccountID *int64 `json:"used_account_id"`
-
- HTTPStatusCode int `json:"http_status_code"`
- UpstreamRequestID string `json:"upstream_request_id"`
-
- ResponsePreview string `json:"response_preview"`
- ResponseTruncated bool `json:"response_truncated"`
-
- ErrorMessage string `json:"error_message"`
-
- StartedAt time.Time `json:"started_at"`
- FinishedAt time.Time `json:"finished_at"`
- DurationMs int64 `json:"duration_ms"`
-}
diff --git a/backend/internal/service/ops_port.go b/backend/internal/service/ops_port.go
index 04bf91c8..30145ed3 100644
--- a/backend/internal/service/ops_port.go
+++ b/backend/internal/service/ops_port.go
@@ -16,11 +16,7 @@ type OpsRepository interface {
DeleteSystemLogs(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error)
InsertSystemLogCleanupAudit(ctx context.Context, input *OpsSystemLogCleanupAudit) error
- InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error)
- UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error
- GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error)
- ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*OpsRetryAttempt, error)
- UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error
+ UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedAt *time.Time) error
// Lightweight window stats (for realtime WS / quick sampling).
GetWindowStats(ctx context.Context, filter *OpsDashboardFilter) (*OpsWindowStats, error)
@@ -121,51 +117,9 @@ type OpsInsertErrorLogInput struct {
ResponseLatencyMs *int64
TimeToFirstTokenMs *int64
- RequestBodyJSON *string // sanitized json string (not raw bytes)
- RequestBodyTruncated bool
- RequestBodyBytes *int
- RequestHeadersJSON *string // optional json string
-
- IsRetryable bool
- RetryCount int
-
CreatedAt time.Time
}
-type OpsInsertRetryAttemptInput struct {
- RequestedByUserID int64
- SourceErrorID int64
- Mode string
- PinnedAccountID *int64
-
- // running|queued etc.
- Status string
- StartedAt time.Time
-}
-
-type OpsUpdateRetryAttemptInput struct {
- ID int64
-
- // succeeded|failed
- Status string
- FinishedAt time.Time
- DurationMs int64
-
- // Persisted execution results (best-effort)
- Success *bool
- HTTPStatusCode *int
- UpstreamRequestID *string
- UsedAccountID *int64
- ResponsePreview *string
- ResponseTruncated *bool
-
- // Optional correlation (legacy fields kept)
- ResultRequestID *string
- ResultErrorID *int64
-
- ErrorMessage *string
-}
-
type OpsInsertSystemMetricsInput struct {
CreatedAt time.Time
WindowMinutes int
diff --git a/backend/internal/service/ops_repo_mock_test.go b/backend/internal/service/ops_repo_mock_test.go
index c8c66ec6..4138ea77 100644
--- a/backend/internal/service/ops_repo_mock_test.go
+++ b/backend/internal/service/ops_repo_mock_test.go
@@ -69,23 +69,7 @@ func (m *opsRepoMock) InsertSystemLogCleanupAudit(ctx context.Context, input *Op
return nil
}
-func (m *opsRepoMock) InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error) {
- return 0, nil
-}
-
-func (m *opsRepoMock) UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error {
- return nil
-}
-
-func (m *opsRepoMock) GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error) {
- return nil, nil
-}
-
-func (m *opsRepoMock) ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*OpsRetryAttempt, error) {
- return []*OpsRetryAttempt{}, nil
-}
-
-func (m *opsRepoMock) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error {
+func (m *opsRepoMock) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedAt *time.Time) error {
return nil
}
diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go
deleted file mode 100644
index bd40d389..00000000
--- a/backend/internal/service/ops_retry.go
+++ /dev/null
@@ -1,726 +0,0 @@
-package service
-
-import (
- "bytes"
- "context"
- "database/sql"
- "encoding/json"
- "errors"
- "fmt"
- "log"
- "net/http"
- "strings"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/domain"
- infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
- "github.com/gin-gonic/gin"
- "github.com/lib/pq"
-)
-
-const (
- OpsRetryModeClient = "client"
- OpsRetryModeUpstream = "upstream"
-)
-
-const (
- opsRetryStatusRunning = "running"
- opsRetryStatusSucceeded = "succeeded"
- opsRetryStatusFailed = "failed"
-)
-
-const (
- opsRetryTimeout = 60 * time.Second
- opsRetryCaptureBytesLimit = 64 * 1024
- opsRetryResponsePreviewMax = 8 * 1024
- opsRetryMinIntervalPerError = 10 * time.Second
- opsRetryMaxAccountSwitches = 3
-)
-
-var opsRetryRequestHeaderAllowlist = map[string]bool{
- "anthropic-beta": true,
- "anthropic-version": true,
-}
-
-type opsRetryRequestType string
-
-const (
- opsRetryTypeMessages opsRetryRequestType = "messages"
- opsRetryTypeOpenAI opsRetryRequestType = "openai_responses"
- opsRetryTypeGeminiV1B opsRetryRequestType = "gemini_v1beta"
-)
-
-type limitedResponseWriter struct {
- header http.Header
- wroteHeader bool
-
- limit int
- totalWritten int64
- buf bytes.Buffer
-}
-
-func newLimitedResponseWriter(limit int) *limitedResponseWriter {
- if limit <= 0 {
- limit = 1
- }
- return &limitedResponseWriter{
- header: make(http.Header),
- limit: limit,
- }
-}
-
-func (w *limitedResponseWriter) Header() http.Header {
- return w.header
-}
-
-func (w *limitedResponseWriter) WriteHeader(statusCode int) {
- if w.wroteHeader {
- return
- }
- w.wroteHeader = true
-}
-
-func (w *limitedResponseWriter) Write(p []byte) (int, error) {
- if !w.wroteHeader {
- w.WriteHeader(http.StatusOK)
- }
- w.totalWritten += int64(len(p))
-
- if w.buf.Len() < w.limit {
- remaining := w.limit - w.buf.Len()
- if len(p) > remaining {
- _, _ = w.buf.Write(p[:remaining])
- } else {
- _, _ = w.buf.Write(p)
- }
- }
-
- // Pretend we wrote everything to avoid upstream/client code treating it as an error.
- return len(p), nil
-}
-
-func (w *limitedResponseWriter) Flush() {}
-
-func (w *limitedResponseWriter) bodyBytes() []byte {
- return w.buf.Bytes()
-}
-
-func (w *limitedResponseWriter) truncated() bool {
- return w.totalWritten > int64(w.limit)
-}
-
-const (
- OpsRetryModeUpstreamEvent = "upstream_event"
-)
-
-func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, errorID int64, mode string, pinnedAccountID *int64) (*OpsRetryResult, error) {
- if err := s.RequireMonitoringEnabled(ctx); err != nil {
- return nil, err
- }
- if s.opsRepo == nil {
- return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
- }
-
- mode = strings.ToLower(strings.TrimSpace(mode))
- switch mode {
- case OpsRetryModeClient, OpsRetryModeUpstream:
- default:
- return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_MODE", "mode must be client or upstream")
- }
-
- errorLog, err := s.GetErrorLogByID(ctx, errorID)
- if err != nil {
- return nil, err
- }
- if errorLog == nil {
- return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
- }
- if strings.TrimSpace(errorLog.RequestBody) == "" {
- return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
- }
-
- var pinned *int64
- if mode == OpsRetryModeUpstream {
- if pinnedAccountID != nil && *pinnedAccountID > 0 {
- pinned = pinnedAccountID
- } else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
- pinned = errorLog.AccountID
- } else {
- return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "pinned_account_id is required for upstream retry")
- }
- }
-
- return s.retryWithErrorLog(ctx, requestedByUserID, errorID, mode, mode, pinned, errorLog)
-}
-
-// RetryUpstreamEvent retries a specific upstream attempt captured inside ops_error_logs.upstream_errors.
-// idx is 0-based. It always pins the original event account_id.
-func (s *OpsService) RetryUpstreamEvent(ctx context.Context, requestedByUserID int64, errorID int64, idx int) (*OpsRetryResult, error) {
- if err := s.RequireMonitoringEnabled(ctx); err != nil {
- return nil, err
- }
- if s.opsRepo == nil {
- return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
- }
- if idx < 0 {
- return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_UPSTREAM_IDX", "invalid upstream idx")
- }
-
- errorLog, err := s.GetErrorLogByID(ctx, errorID)
- if err != nil {
- return nil, err
- }
- if errorLog == nil {
- return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
- }
-
- events, err := ParseOpsUpstreamErrors(errorLog.UpstreamErrors)
- if err != nil {
- return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENTS_INVALID", "invalid upstream_errors")
- }
- if idx >= len(events) {
- return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_IDX_OOB", "upstream idx out of range")
- }
- ev := events[idx]
- if ev == nil {
- return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENT_MISSING", "upstream event missing")
- }
- if ev.AccountID <= 0 {
- return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
- }
-
- upstreamBody := strings.TrimSpace(ev.UpstreamRequestBody)
- if upstreamBody == "" {
- return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_NO_REQUEST_BODY", "No upstream request body found to retry")
- }
-
- override := *errorLog
- override.RequestBody = upstreamBody
- pinned := ev.AccountID
-
- // Persist as upstream_event, execute as upstream pinned retry.
- return s.retryWithErrorLog(ctx, requestedByUserID, errorID, OpsRetryModeUpstreamEvent, OpsRetryModeUpstream, &pinned, &override)
-}
-
-func (s *OpsService) retryWithErrorLog(ctx context.Context, requestedByUserID int64, errorID int64, mode string, execMode string, pinnedAccountID *int64, errorLog *OpsErrorLogDetail) (*OpsRetryResult, error) {
- latest, err := s.opsRepo.GetLatestRetryAttemptForError(ctx, errorID)
- if err != nil && !errors.Is(err, sql.ErrNoRows) {
- return nil, infraerrors.InternalServer("OPS_RETRY_LOAD_LATEST_FAILED", "Failed to check retry status").WithCause(err)
- }
- if latest != nil {
- if strings.EqualFold(latest.Status, opsRetryStatusRunning) || strings.EqualFold(latest.Status, "queued") {
- return nil, infraerrors.Conflict("OPS_RETRY_IN_PROGRESS", "A retry is already in progress for this error")
- }
-
- lastAttemptAt := latest.CreatedAt
- if latest.FinishedAt != nil && !latest.FinishedAt.IsZero() {
- lastAttemptAt = *latest.FinishedAt
- } else if latest.StartedAt != nil && !latest.StartedAt.IsZero() {
- lastAttemptAt = *latest.StartedAt
- }
-
- if time.Since(lastAttemptAt) < opsRetryMinIntervalPerError {
- return nil, infraerrors.Conflict("OPS_RETRY_TOO_FREQUENT", "Please wait before retrying this error again")
- }
- }
-
- if errorLog == nil || strings.TrimSpace(errorLog.RequestBody) == "" {
- return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
- }
-
- var pinned *int64
- if execMode == OpsRetryModeUpstream {
- if pinnedAccountID != nil && *pinnedAccountID > 0 {
- pinned = pinnedAccountID
- } else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
- pinned = errorLog.AccountID
- } else {
- return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
- }
- }
-
- startedAt := time.Now()
- attemptID, err := s.opsRepo.InsertRetryAttempt(ctx, &OpsInsertRetryAttemptInput{
- RequestedByUserID: requestedByUserID,
- SourceErrorID: errorID,
- Mode: mode,
- PinnedAccountID: pinned,
- Status: opsRetryStatusRunning,
- StartedAt: startedAt,
- })
- if err != nil {
- var pqErr *pq.Error
- if errors.As(err, &pqErr) && string(pqErr.Code) == "23505" {
- return nil, infraerrors.Conflict("OPS_RETRY_IN_PROGRESS", "A retry is already in progress for this error")
- }
- return nil, infraerrors.InternalServer("OPS_RETRY_CREATE_ATTEMPT_FAILED", "Failed to create retry attempt").WithCause(err)
- }
-
- result := &OpsRetryResult{
- AttemptID: attemptID,
- Mode: mode,
- Status: opsRetryStatusFailed,
- PinnedAccountID: pinned,
- HTTPStatusCode: 0,
- UpstreamRequestID: "",
- ResponsePreview: "",
- ResponseTruncated: false,
- ErrorMessage: "",
- StartedAt: startedAt,
- }
-
- execCtx, cancel := context.WithTimeout(ctx, opsRetryTimeout)
- defer cancel()
-
- execRes := s.executeRetry(execCtx, errorLog, execMode, pinned)
-
- finishedAt := time.Now()
- result.FinishedAt = finishedAt
- result.DurationMs = finishedAt.Sub(startedAt).Milliseconds()
-
- if execRes != nil {
- result.Status = execRes.status
- result.UsedAccountID = execRes.usedAccountID
- result.HTTPStatusCode = execRes.httpStatusCode
- result.UpstreamRequestID = execRes.upstreamRequestID
- result.ResponsePreview = execRes.responsePreview
- result.ResponseTruncated = execRes.responseTruncated
- result.ErrorMessage = execRes.errorMessage
- }
-
- updateCtx, updateCancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer updateCancel()
-
- var updateErrMsg *string
- if strings.TrimSpace(result.ErrorMessage) != "" {
- msg := result.ErrorMessage
- updateErrMsg = &msg
- }
- // Keep legacy result_request_id empty; use upstream_request_id instead.
- var resultRequestID *string
-
- finalStatus := result.Status
- if strings.TrimSpace(finalStatus) == "" {
- finalStatus = opsRetryStatusFailed
- }
-
- success := strings.EqualFold(finalStatus, opsRetryStatusSucceeded)
- httpStatus := result.HTTPStatusCode
- upstreamReqID := result.UpstreamRequestID
- usedAccountID := result.UsedAccountID
- preview := result.ResponsePreview
- truncated := result.ResponseTruncated
-
- if err := s.opsRepo.UpdateRetryAttempt(updateCtx, &OpsUpdateRetryAttemptInput{
- ID: attemptID,
- Status: finalStatus,
- FinishedAt: finishedAt,
- DurationMs: result.DurationMs,
- Success: &success,
- HTTPStatusCode: &httpStatus,
- UpstreamRequestID: &upstreamReqID,
- UsedAccountID: usedAccountID,
- ResponsePreview: &preview,
- ResponseTruncated: &truncated,
- ResultRequestID: resultRequestID,
- ErrorMessage: updateErrMsg,
- }); err != nil {
- log.Printf("[Ops] UpdateRetryAttempt failed: %v", err)
- } else if success {
- if err := s.opsRepo.UpdateErrorResolution(updateCtx, errorID, true, &requestedByUserID, &attemptID, &finishedAt); err != nil {
- log.Printf("[Ops] UpdateErrorResolution failed: %v", err)
- }
- }
-
- return result, nil
-}
-
-type opsRetryExecution struct {
- status string
-
- usedAccountID *int64
- httpStatusCode int
- upstreamRequestID string
-
- responsePreview string
- responseTruncated bool
-
- errorMessage string
-}
-
-func (s *OpsService) executeRetry(ctx context.Context, errorLog *OpsErrorLogDetail, mode string, pinnedAccountID *int64) *opsRetryExecution {
- if errorLog == nil {
- return &opsRetryExecution{
- status: opsRetryStatusFailed,
- errorMessage: "missing error log",
- }
- }
-
- reqType := detectOpsRetryType(errorLog.RequestPath)
- bodyBytes := []byte(errorLog.RequestBody)
-
- switch reqType {
- case opsRetryTypeMessages:
- bodyBytes = FilterThinkingBlocksForRetry(bodyBytes)
- case opsRetryTypeOpenAI, opsRetryTypeGeminiV1B:
- // No-op
- }
-
- switch strings.ToLower(strings.TrimSpace(mode)) {
- case OpsRetryModeUpstream:
- if pinnedAccountID == nil || *pinnedAccountID <= 0 {
- return &opsRetryExecution{
- status: opsRetryStatusFailed,
- errorMessage: "pinned_account_id required for upstream retry",
- }
- }
- return s.executePinnedRetry(ctx, reqType, errorLog, bodyBytes, *pinnedAccountID)
- case OpsRetryModeClient:
- return s.executeClientRetry(ctx, reqType, errorLog, bodyBytes)
- default:
- return &opsRetryExecution{
- status: opsRetryStatusFailed,
- errorMessage: "invalid retry mode",
- }
- }
-}
-
-func detectOpsRetryType(path string) opsRetryRequestType {
- p := strings.ToLower(strings.TrimSpace(path))
- switch {
- case strings.Contains(p, "/responses"), strings.Contains(p, "/images/"):
- return opsRetryTypeOpenAI
- case strings.Contains(p, "/v1beta/"):
- return opsRetryTypeGeminiV1B
- default:
- return opsRetryTypeMessages
- }
-}
-
-func (s *OpsService) executePinnedRetry(ctx context.Context, reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte, pinnedAccountID int64) *opsRetryExecution {
- if s.accountRepo == nil {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "account repository not available"}
- }
-
- account, err := s.accountRepo.GetByID(ctx, pinnedAccountID)
- if err != nil {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: fmt.Sprintf("account not found: %v", err)}
- }
- if account == nil {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "account not found"}
- }
- if !account.IsSchedulable() {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "account is not schedulable"}
- }
- if errorLog.GroupID != nil && *errorLog.GroupID > 0 {
- if !containsInt64(account.GroupIDs, *errorLog.GroupID) {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "pinned account is not in the same group as the original request"}
- }
- }
-
- var release func()
- if s.concurrencyService != nil {
- acq, err := s.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
- if err != nil {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: fmt.Sprintf("acquire account slot failed: %v", err)}
- }
- if acq == nil || !acq.Acquired {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "account concurrency limit reached"}
- }
- release = acq.ReleaseFunc
- }
- if release != nil {
- defer release()
- }
-
- usedID := account.ID
- exec := s.executeWithAccount(ctx, reqType, errorLog, body, account)
- exec.usedAccountID = &usedID
- if exec.status == "" {
- exec.status = opsRetryStatusFailed
- }
- return exec
-}
-
-func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte) *opsRetryExecution {
- groupID := errorLog.GroupID
- if groupID == nil || *groupID <= 0 {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "group_id missing; cannot reselect account"}
- }
-
- model, stream, parsedErr := extractRetryModelAndStream(reqType, errorLog, body)
- if parsedErr != nil {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: parsedErr.Error()}
- }
- _ = stream
-
- excluded := make(map[int64]struct{})
- switches := 0
-
- for {
- if switches >= opsRetryMaxAccountSwitches {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "retry failed after exhausting account failovers"}
- }
-
- selection, selErr := s.selectAccountForRetry(ctx, reqType, groupID, model, excluded)
- if selErr != nil {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: selErr.Error()}
- }
- if selection == nil || selection.Account == nil {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: ErrNoAvailableAccounts.Error()}
- }
-
- account := selection.Account
- if !selection.Acquired || selection.ReleaseFunc == nil {
- excluded[account.ID] = struct{}{}
- switches++
- continue
- }
-
- attemptCtx := ctx
- if switches > 0 {
- attemptCtx = WithAccountSwitchCount(attemptCtx, switches, false)
- }
- exec := func() *opsRetryExecution {
- defer selection.ReleaseFunc()
- return s.executeWithAccount(attemptCtx, reqType, errorLog, body, account)
- }()
-
- if exec != nil {
- if exec.status == opsRetryStatusSucceeded {
- usedID := account.ID
- exec.usedAccountID = &usedID
- return exec
- }
- // If the gateway services ask for failover, try another account.
- if s.isFailoverError(exec.errorMessage) {
- excluded[account.ID] = struct{}{}
- switches++
- continue
- }
- usedID := account.ID
- exec.usedAccountID = &usedID
- return exec
- }
-
- excluded[account.ID] = struct{}{}
- switches++
- }
-}
-
-func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetryRequestType, groupID *int64, model string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
- switch reqType {
- case opsRetryTypeOpenAI:
- if s.openAIGatewayService == nil {
- return nil, fmt.Errorf("openai gateway service not available")
- }
- return s.openAIGatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs)
- case opsRetryTypeGeminiV1B, opsRetryTypeMessages:
- if s.gatewayService == nil {
- return nil, fmt.Errorf("gateway service not available")
- }
- return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "", int64(0)) // 重试不使用会话限制
- default:
- return nil, fmt.Errorf("unsupported retry type: %s", reqType)
- }
-}
-
-func extractRetryModelAndStream(reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte) (model string, stream bool, err error) {
- switch reqType {
- case opsRetryTypeMessages:
- parsed, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic)
- if parseErr != nil {
- return "", false, fmt.Errorf("failed to parse messages request body: %w", parseErr)
- }
- return parsed.Model, parsed.Stream, nil
- case opsRetryTypeOpenAI:
- var v struct {
- Model string `json:"model"`
- Stream bool `json:"stream"`
- }
- if err := json.Unmarshal(body, &v); err != nil {
- return "", false, fmt.Errorf("failed to parse openai request body: %w", err)
- }
- return strings.TrimSpace(v.Model), v.Stream, nil
- case opsRetryTypeGeminiV1B:
- if strings.TrimSpace(errorLog.Model) == "" {
- return "", false, fmt.Errorf("missing model for gemini v1beta retry")
- }
- return strings.TrimSpace(errorLog.Model), errorLog.Stream, nil
- default:
- return "", false, fmt.Errorf("unsupported retry type: %s", reqType)
- }
-}
-
-func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte, account *Account) *opsRetryExecution {
- if account == nil {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "missing account"}
- }
-
- c, w := newOpsRetryContext(ctx, errorLog)
-
- var err error
- switch reqType {
- case opsRetryTypeOpenAI:
- if s.openAIGatewayService == nil {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "openai gateway service not available"}
- }
- _, err = s.openAIGatewayService.Forward(ctx, c, account, body)
- case opsRetryTypeGeminiV1B:
- if s.geminiCompatService == nil || s.antigravityGatewayService == nil {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gemini services not available"}
- }
- modelName := strings.TrimSpace(errorLog.Model)
- action := "generateContent"
- if errorLog.Stream {
- action = "streamGenerateContent"
- }
- if account.Platform == PlatformAntigravity {
- _, err = s.antigravityGatewayService.ForwardGemini(ctx, c, account, modelName, action, errorLog.Stream, body, false)
- } else {
- _, err = s.geminiCompatService.ForwardNative(ctx, c, account, modelName, action, errorLog.Stream, body)
- }
- case opsRetryTypeMessages:
- switch account.Platform {
- case PlatformAntigravity:
- if s.antigravityGatewayService == nil {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "antigravity gateway service not available"}
- }
- _, err = s.antigravityGatewayService.Forward(ctx, c, account, body, false)
- case PlatformGemini:
- if s.geminiCompatService == nil {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gemini gateway service not available"}
- }
- _, err = s.geminiCompatService.Forward(ctx, c, account, body)
- default:
- if s.gatewayService == nil {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gateway service not available"}
- }
- parsedReq, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic)
- if parseErr != nil {
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "failed to parse request body"}
- }
- _, err = s.gatewayService.Forward(ctx, c, account, parsedReq)
- }
- default:
- return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "unsupported retry type"}
- }
-
- statusCode := http.StatusOK
- if c != nil && c.Writer != nil {
- statusCode = c.Writer.Status()
- }
-
- upstreamReqID := extractUpstreamRequestID(c)
- preview, truncated := extractResponsePreview(w)
-
- exec := &opsRetryExecution{
- status: opsRetryStatusFailed,
- httpStatusCode: statusCode,
- upstreamRequestID: upstreamReqID,
- responsePreview: preview,
- responseTruncated: truncated,
- errorMessage: "",
- }
-
- if err == nil && statusCode < 400 {
- exec.status = opsRetryStatusSucceeded
- return exec
- }
-
- if err != nil {
- exec.errorMessage = err.Error()
- } else {
- exec.errorMessage = fmt.Sprintf("upstream returned status %d", statusCode)
- }
-
- return exec
-}
-
-func newOpsRetryContext(ctx context.Context, errorLog *OpsErrorLogDetail) (*gin.Context, *limitedResponseWriter) {
- w := newLimitedResponseWriter(opsRetryCaptureBytesLimit)
- c, _ := gin.CreateTestContext(w)
-
- path := "/"
- if errorLog != nil && strings.TrimSpace(errorLog.RequestPath) != "" {
- path = errorLog.RequestPath
- }
-
- req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "http://localhost"+path, bytes.NewReader(nil))
- req.Header.Set("content-type", "application/json")
- if errorLog != nil && strings.TrimSpace(errorLog.UserAgent) != "" {
- req.Header.Set("user-agent", errorLog.UserAgent)
- }
- // Restore a minimal, whitelisted subset of request headers to improve retry fidelity
- // (e.g. anthropic-beta / anthropic-version). Never replay auth credentials.
- if errorLog != nil && strings.TrimSpace(errorLog.RequestHeaders) != "" {
- var stored map[string]string
- if err := json.Unmarshal([]byte(errorLog.RequestHeaders), &stored); err == nil {
- for k, v := range stored {
- key := strings.TrimSpace(k)
- if key == "" {
- continue
- }
- if !opsRetryRequestHeaderAllowlist[strings.ToLower(key)] {
- continue
- }
- val := strings.TrimSpace(v)
- if val == "" {
- continue
- }
- req.Header.Set(key, val)
- }
- }
- }
-
- c.Request = req
- SetOpenAIClientTransport(c, OpenAIClientTransportHTTP)
- return c, w
-}
-
-func extractUpstreamRequestID(c *gin.Context) string {
- if c == nil || c.Writer == nil {
- return ""
- }
- h := c.Writer.Header()
- if h == nil {
- return ""
- }
- for _, key := range []string{"x-request-id", "X-Request-Id", "X-Request-ID"} {
- if v := strings.TrimSpace(h.Get(key)); v != "" {
- return v
- }
- }
- return ""
-}
-
-func extractResponsePreview(w *limitedResponseWriter) (preview string, truncated bool) {
- if w == nil {
- return "", false
- }
- b := bytes.TrimSpace(w.bodyBytes())
- if len(b) == 0 {
- return "", w.truncated()
- }
- if len(b) > opsRetryResponsePreviewMax {
- return string(b[:opsRetryResponsePreviewMax]), true
- }
- return string(b), w.truncated()
-}
-
-func containsInt64(items []int64, needle int64) bool {
- for _, v := range items {
- if v == needle {
- return true
- }
- }
- return false
-}
-
-func (s *OpsService) isFailoverError(message string) bool {
- msg := strings.ToLower(strings.TrimSpace(message))
- if msg == "" {
- return false
- }
- return strings.Contains(msg, "upstream error:") && strings.Contains(msg, "failover")
-}
diff --git a/backend/internal/service/ops_retry_context_test.go b/backend/internal/service/ops_retry_context_test.go
deleted file mode 100644
index a8c26ee4..00000000
--- a/backend/internal/service/ops_retry_context_test.go
+++ /dev/null
@@ -1,47 +0,0 @@
-package service
-
-import (
- "context"
- "testing"
-
- "github.com/stretchr/testify/require"
-)
-
-func TestNewOpsRetryContext_SetsHTTPTransportAndRequestHeaders(t *testing.T) {
- errorLog := &OpsErrorLogDetail{
- OpsErrorLog: OpsErrorLog{
- RequestPath: "/openai/v1/responses",
- },
- UserAgent: "ops-retry-agent/1.0",
- RequestHeaders: `{
- "anthropic-beta":"beta-v1",
- "ANTHROPIC-VERSION":"2023-06-01",
- "authorization":"Bearer should-not-forward"
- }`,
- }
-
- c, w := newOpsRetryContext(context.Background(), errorLog)
- require.NotNil(t, c)
- require.NotNil(t, w)
- require.NotNil(t, c.Request)
-
- require.Equal(t, "/openai/v1/responses", c.Request.URL.Path)
- require.Equal(t, "application/json", c.Request.Header.Get("Content-Type"))
- require.Equal(t, "ops-retry-agent/1.0", c.Request.Header.Get("User-Agent"))
- require.Equal(t, "beta-v1", c.Request.Header.Get("anthropic-beta"))
- require.Equal(t, "2023-06-01", c.Request.Header.Get("anthropic-version"))
- require.Empty(t, c.Request.Header.Get("authorization"), "未在白名单内的敏感头不应被重放")
- require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c))
-}
-
-func TestNewOpsRetryContext_InvalidHeadersJSONStillSetsHTTPTransport(t *testing.T) {
- errorLog := &OpsErrorLogDetail{
- RequestHeaders: "{invalid-json",
- }
-
- c, _ := newOpsRetryContext(context.Background(), errorLog)
- require.NotNil(t, c)
- require.NotNil(t, c.Request)
- require.Equal(t, "/", c.Request.URL.Path)
- require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c))
-}
diff --git a/backend/internal/service/ops_service.go b/backend/internal/service/ops_service.go
index 11afc6f9..1cea72fa 100644
--- a/backend/internal/service/ops_service.go
+++ b/backend/internal/service/ops_service.go
@@ -16,26 +16,9 @@ import (
var ErrOpsDisabled = infraerrors.NotFound("OPS_DISABLED", "Ops monitoring is disabled")
const (
- opsMaxStoredRequestBodyBytes = 256 * 1024
- opsMaxStoredErrorBodyBytes = 20 * 1024
+ opsMaxStoredErrorBodyBytes = 20 * 1024
)
-// PrepareOpsRequestBodyForQueue 在入队前对请求体执行脱敏与裁剪,返回可直接写入 OpsInsertErrorLogInput 的字段。
-// 该方法用于避免异步队列持有大块原始请求体,减少错误风暴下的内存放大风险。
-func PrepareOpsRequestBodyForQueue(raw []byte) (requestBodyJSON *string, truncated bool, requestBodyBytes *int) {
- if len(raw) == 0 {
- return nil, false, nil
- }
- sanitized, truncated, bytesLen := sanitizeAndTrimRequestBody(raw, opsMaxStoredRequestBodyBytes)
- if sanitized != "" {
- out := sanitized
- requestBodyJSON = &out
- }
- n := bytesLen
- requestBodyBytes = &n
- return requestBodyJSON, truncated, requestBodyBytes
-}
-
// OpsService provides ingestion and query APIs for the Ops monitoring module.
type OpsService struct {
opsRepo OpsRepository
@@ -138,8 +121,8 @@ func (s *OpsService) IsMonitoringEnabled(ctx context.Context) bool {
}
}
-func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) error {
- prepared, ok, err := s.prepareErrorLogInput(ctx, entry, rawRequestBody)
+func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogInput) error {
+ prepared, ok, err := s.prepareErrorLogInput(ctx, entry)
if err != nil {
log.Printf("[Ops] RecordError prepare failed: %v", err)
return err
@@ -162,7 +145,7 @@ func (s *OpsService) RecordErrorBatch(ctx context.Context, entries []*OpsInsertE
}
prepared := make([]*OpsInsertErrorLogInput, 0, len(entries))
for _, entry := range entries {
- item, ok, err := s.prepareErrorLogInput(ctx, entry, nil)
+ item, ok, err := s.prepareErrorLogInput(ctx, entry)
if err != nil {
log.Printf("[Ops] RecordErrorBatch prepare failed: %v", err)
continue
@@ -198,7 +181,7 @@ func (s *OpsService) RecordErrorBatch(ctx context.Context, entries []*OpsInsertE
return nil
}
-func (s *OpsService) prepareErrorLogInput(ctx context.Context, entry *OpsInsertErrorLogInput, rawRequestBody []byte) (*OpsInsertErrorLogInput, bool, error) {
+func (s *OpsService) prepareErrorLogInput(ctx context.Context, entry *OpsInsertErrorLogInput) (*OpsInsertErrorLogInput, bool, error) {
if entry == nil {
return nil, false, nil
}
@@ -224,11 +207,6 @@ func (s *OpsService) prepareErrorLogInput(ctx context.Context, entry *OpsInsertE
entry.ErrorType = "api_error"
}
- // Sanitize + trim request body (errors only).
- if len(rawRequestBody) > 0 {
- entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = PrepareOpsRequestBodyForQueue(rawRequestBody)
- }
-
// Sanitize + truncate error_body to avoid storing sensitive data.
if strings.TrimSpace(entry.ErrorBody) != "" {
sanitized, _ := sanitizeErrorBodyForStorage(entry.ErrorBody, opsMaxStoredErrorBodyBytes)
@@ -315,25 +293,6 @@ func sanitizeOpsUpstreamErrors(entry *OpsInsertErrorLogInput) error {
out.Detail = ""
}
- out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody)
- if out.UpstreamRequestBody != "" {
- // Reuse the same sanitization/trimming strategy as request body storage.
- // Keep it small so it is safe to persist in ops_error_logs JSON.
- sanitizedBody, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024)
- if sanitizedBody != "" {
- out.UpstreamRequestBody = sanitizedBody
- if truncated {
- out.Kind = strings.TrimSpace(out.Kind)
- if out.Kind == "" {
- out.Kind = "upstream"
- }
- out.Kind = out.Kind + ":request_body_truncated"
- }
- } else {
- out.UpstreamRequestBody = ""
- }
- }
-
// Drop fully-empty events (can happen if only status code was known).
if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" {
continue
@@ -381,27 +340,7 @@ func (s *OpsService) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLo
return detail, nil
}
-func (s *OpsService) ListRetryAttemptsByErrorID(ctx context.Context, errorID int64, limit int) ([]*OpsRetryAttempt, error) {
- if err := s.RequireMonitoringEnabled(ctx); err != nil {
- return nil, err
- }
- if s.opsRepo == nil {
- return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
- }
- if errorID <= 0 {
- return nil, infraerrors.BadRequest("OPS_ERROR_INVALID_ID", "invalid error id")
- }
- items, err := s.opsRepo.ListRetryAttemptsByErrorID(ctx, errorID, limit)
- if err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return []*OpsRetryAttempt{}, nil
- }
- return nil, infraerrors.InternalServer("OPS_RETRY_LIST_FAILED", "Failed to list retry attempts").WithCause(err)
- }
- return items, nil
-}
-
-func (s *OpsService) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64) error {
+func (s *OpsService) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64) error {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return err
}
@@ -418,10 +357,10 @@ func (s *OpsService) UpdateErrorResolution(ctx context.Context, errorID int64, r
}
return infraerrors.InternalServer("OPS_ERROR_LOAD_FAILED", "Failed to load ops error log").WithCause(err)
}
- return s.opsRepo.UpdateErrorResolution(ctx, errorID, resolved, resolvedByUserID, resolvedRetryID, nil)
+ return s.opsRepo.UpdateErrorResolution(ctx, errorID, resolved, resolvedByUserID, nil)
}
-func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, truncated bool, bytesLen int) {
+func sanitizeAndTrimJSONPayload(raw []byte, maxBytes int) (jsonString string, truncated bool, bytesLen int) {
bytesLen = len(raw)
if len(raw) == 0 {
return "", false, 0
@@ -429,7 +368,7 @@ func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, tr
var decoded any
if err := json.Unmarshal(raw, &decoded); err != nil {
- // If it's not valid JSON, don't store (retry would not be reliable anyway).
+ // If it is not valid JSON, fall back to the caller's non-JSON handling.
return "", false, bytesLen
}
@@ -465,7 +404,7 @@ func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, tr
// This avoids downstream code that expects certain top-level keys from crashing.
if root, ok := decoded.(map[string]any); ok {
placeholder := shallowCopyMap(root)
- placeholder["request_body_truncated"] = true
+ placeholder["payload_truncated"] = true
// Replace potentially huge arrays/strings, but keep the keys present.
for _, k := range []string{"messages", "contents", "input", "prompt"} {
@@ -488,7 +427,7 @@ func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, tr
}
// Final fallback: minimal valid JSON.
- encoded4, err4 := json.Marshal(map[string]any{"request_body_truncated": true})
+ encoded4, err4 := json.Marshal(map[string]any{"payload_truncated": true})
if err4 != nil {
return "", true, bytesLen
}
@@ -732,7 +671,7 @@ func sanitizeErrorBodyForStorage(raw string, maxBytes int) (sanitized string, tr
}
// Prefer JSON-safe sanitization when possible.
- if out, trunc, _ := sanitizeAndTrimRequestBody([]byte(raw), maxBytes); out != "" {
+ if out, trunc, _ := sanitizeAndTrimJSONPayload([]byte(raw), maxBytes); out != "" {
return out, trunc
}
diff --git a/backend/internal/service/ops_service_batch_test.go b/backend/internal/service/ops_service_batch_test.go
index f3a14d7f..a9419ad7 100644
--- a/backend/internal/service/ops_service_batch_test.go
+++ b/backend/internal/service/ops_service_batch_test.go
@@ -31,11 +31,10 @@ func TestOpsServiceRecordErrorBatch_SanitizesAndBatches(t *testing.T) {
UpstreamErrorDetail: strPtr(detail),
UpstreamErrors: []*OpsUpstreamErrorEvent{
{
- AccountID: -2,
- UpstreamStatusCode: 429,
- Message: " token leaked ",
- Detail: `{"refresh_token":"secret"}`,
- UpstreamRequestBody: `{"api_key":"secret","messages":[{"role":"user","content":"hello"}]}`,
+ AccountID: -2,
+ UpstreamStatusCode: 429,
+ Message: " token leaked ",
+ Detail: `{"refresh_token":"secret"}`,
},
},
},
diff --git a/backend/internal/service/ops_service_prepare_queue_test.go b/backend/internal/service/ops_service_prepare_queue_test.go
deleted file mode 100644
index d6f32c2d..00000000
--- a/backend/internal/service/ops_service_prepare_queue_test.go
+++ /dev/null
@@ -1,60 +0,0 @@
-package service
-
-import (
- "encoding/json"
- "strings"
- "testing"
-
- "github.com/stretchr/testify/require"
-)
-
-func TestPrepareOpsRequestBodyForQueue_EmptyBody(t *testing.T) {
- requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(nil)
- require.Nil(t, requestBodyJSON)
- require.False(t, truncated)
- require.Nil(t, requestBodyBytes)
-}
-
-func TestPrepareOpsRequestBodyForQueue_InvalidJSON(t *testing.T) {
- raw := []byte("{invalid-json")
- requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw)
- require.Nil(t, requestBodyJSON)
- require.False(t, truncated)
- require.NotNil(t, requestBodyBytes)
- require.Equal(t, len(raw), *requestBodyBytes)
-}
-
-func TestPrepareOpsRequestBodyForQueue_RedactSensitiveFields(t *testing.T) {
- raw := []byte(`{
- "model":"claude-3-5-sonnet-20241022",
- "api_key":"sk-test-123",
- "headers":{"authorization":"Bearer secret-token"},
- "messages":[{"role":"user","content":"hello"}]
- }`)
-
- requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw)
- require.NotNil(t, requestBodyJSON)
- require.NotNil(t, requestBodyBytes)
- require.False(t, truncated)
- require.Equal(t, len(raw), *requestBodyBytes)
-
- var body map[string]any
- require.NoError(t, json.Unmarshal([]byte(*requestBodyJSON), &body))
- require.Equal(t, "[REDACTED]", body["api_key"])
- headers, ok := body["headers"].(map[string]any)
- require.True(t, ok)
- require.Equal(t, "[REDACTED]", headers["authorization"])
-}
-
-func TestPrepareOpsRequestBodyForQueue_LargeBodyTruncated(t *testing.T) {
- largeMsg := strings.Repeat("x", opsMaxStoredRequestBodyBytes*2)
- raw := []byte(`{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":"` + largeMsg + `"}]}`)
-
- requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw)
- require.NotNil(t, requestBodyJSON)
- require.NotNil(t, requestBodyBytes)
- require.True(t, truncated)
- require.Equal(t, len(raw), *requestBodyBytes)
- require.LessOrEqual(t, len(*requestBodyJSON), opsMaxStoredRequestBodyBytes)
- require.Contains(t, *requestBodyJSON, "request_body_truncated")
-}
diff --git a/backend/internal/service/ops_service_redaction_test.go b/backend/internal/service/ops_service_redaction_test.go
index e0aeafa5..72b85ff0 100644
--- a/backend/internal/service/ops_service_redaction_test.go
+++ b/backend/internal/service/ops_service_redaction_test.go
@@ -45,11 +45,11 @@ func TestIsSensitiveKey_TokenBudgetKeysNotRedacted(t *testing.T) {
}
}
-func TestSanitizeAndTrimRequestBody_PreservesTokenBudgetFields(t *testing.T) {
+func TestSanitizeAndTrimJSONPayload_PreservesTokenBudgetFields(t *testing.T) {
t.Parallel()
raw := []byte(`{"model":"claude-3","max_tokens":123,"thinking":{"type":"enabled","budget_tokens":456},"access_token":"abc","messages":[{"role":"user","content":"hi"}]}`)
- out, _, _ := sanitizeAndTrimRequestBody(raw, 10*1024)
+ out, _, _ := sanitizeAndTrimJSONPayload(raw, 10*1024)
if out == "" {
t.Fatalf("expected non-empty sanitized output")
}
diff --git a/backend/internal/service/ops_upstream_context.go b/backend/internal/service/ops_upstream_context.go
index 05d444e1..b4ff0e74 100644
--- a/backend/internal/service/ops_upstream_context.go
+++ b/backend/internal/service/ops_upstream_context.go
@@ -16,11 +16,6 @@ const (
OpsUpstreamErrorDetailKey = "ops_upstream_error_detail"
OpsUpstreamErrorsKey = "ops_upstream_errors"
- // Best-effort capture of the current upstream request body so ops can
- // retry the specific upstream attempt (not just the client request).
- // This value is sanitized+trimmed before being persisted.
- OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
-
// Optional stage latencies (milliseconds) for troubleshooting and alerting.
OpsAuthLatencyMsKey = "ops_auth_latency_ms"
OpsRoutingLatencyMsKey = "ops_routing_latency_ms"
@@ -36,15 +31,13 @@ const (
// OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。
// ops_error_logger 中间件检查此 key,为 true 时跳过错误记录。
OpsSkipPassthroughKey = "ops_skip_passthrough"
-)
-func setOpsUpstreamRequestBody(c *gin.Context, body []byte) {
- if c == nil || len(body) == 0 {
- return
- }
- // 热路径避免 string(body) 额外分配,按需在落库前再转换。
- c.Set(OpsUpstreamRequestBodyKey, body)
-}
+ // Client-side configuration denials should remain visible in ops_error_logs,
+ // but should be excluded from SLA/error-rate calculations.
+ OpsClientBusinessLimitedKey = "ops_client_business_limited"
+ OpsClientBusinessLimitedReasonKey = "ops_client_business_limited_reason"
+ OpsClientBusinessLimitedReasonIPRestriction = "api_key_ip_restriction"
+)
func SetOpsLatencyMs(c *gin.Context, key string, value int64) {
if c == nil || strings.TrimSpace(key) == "" || value < 0 {
@@ -53,6 +46,28 @@ func SetOpsLatencyMs(c *gin.Context, key string, value int64) {
c.Set(key, value)
}
+func MarkOpsClientBusinessLimited(c *gin.Context, reason string) {
+ if c == nil {
+ return
+ }
+ c.Set(OpsClientBusinessLimitedKey, true)
+ if reason = strings.TrimSpace(reason); reason != "" {
+ c.Set(OpsClientBusinessLimitedReasonKey, reason)
+ }
+}
+
+func HasOpsClientBusinessLimited(c *gin.Context) bool {
+ if c == nil {
+ return false
+ }
+ v, ok := c.Get(OpsClientBusinessLimitedKey)
+ if !ok {
+ return false
+ }
+ marked, _ := v.(bool)
+ return marked
+}
+
// SetOpsUpstreamError is the exported wrapper for setOpsUpstreamError, used by
// handler-layer code (e.g. failover-exhausted paths) that needs to record the
// original upstream status code before mapping it to a client-facing code.
@@ -97,10 +112,6 @@ type OpsUpstreamErrorEvent struct {
// Helps debug 404/routing errors by showing which endpoint was targeted.
UpstreamURL string `json:"upstream_url,omitempty"`
- // Best-effort upstream request capture (sanitized+trimmed).
- // Required for retrying a specific upstream attempt.
- UpstreamRequestBody string `json:"upstream_request_body,omitempty"`
-
// Best-effort upstream response capture (sanitized+trimmed).
UpstreamResponseBody string `json:"upstream_response_body,omitempty"`
@@ -120,7 +131,6 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
}
ev.Platform = strings.TrimSpace(ev.Platform)
ev.UpstreamRequestID = strings.TrimSpace(ev.UpstreamRequestID)
- ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody)
ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody)
ev.Kind = strings.TrimSpace(ev.Kind)
ev.UpstreamURL = strings.TrimSpace(ev.UpstreamURL)
@@ -130,19 +140,6 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
ev.Message = sanitizeUpstreamErrorMessage(ev.Message)
}
- // If the caller didn't explicitly pass upstream request body but the gateway
- // stored it on the context, attach it so ops can retry this specific attempt.
- if ev.UpstreamRequestBody == "" {
- if v, ok := c.Get(OpsUpstreamRequestBodyKey); ok {
- switch raw := v.(type) {
- case string:
- ev.UpstreamRequestBody = strings.TrimSpace(raw)
- case []byte:
- ev.UpstreamRequestBody = strings.TrimSpace(string(raw))
- }
- }
- }
-
var existing []*OpsUpstreamErrorEvent
if v, ok := c.Get(OpsUpstreamErrorsKey); ok {
if arr, ok := v.([]*OpsUpstreamErrorEvent); ok {
diff --git a/backend/internal/service/ops_upstream_context_test.go b/backend/internal/service/ops_upstream_context_test.go
index fa6d1085..711223f4 100644
--- a/backend/internal/service/ops_upstream_context_test.go
+++ b/backend/internal/service/ops_upstream_context_test.go
@@ -1,10 +1,8 @@
package service
import (
- "net/http/httptest"
"testing"
- "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -28,41 +26,3 @@ func TestSafeUpstreamURL(t *testing.T) {
})
}
}
-
-func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) {
- gin.SetMode(gin.TestMode)
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
-
- setOpsUpstreamRequestBody(c, []byte(`{"model":"gpt-5"}`))
- appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
- Kind: "http_error",
- Message: "upstream failed",
- })
-
- v, ok := c.Get(OpsUpstreamErrorsKey)
- require.True(t, ok)
- events, ok := v.([]*OpsUpstreamErrorEvent)
- require.True(t, ok)
- require.Len(t, events, 1)
- require.Equal(t, `{"model":"gpt-5"}`, events[0].UpstreamRequestBody)
-}
-
-func TestAppendOpsUpstreamError_UsesRequestBodyStringFromContext(t *testing.T) {
- gin.SetMode(gin.TestMode)
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
-
- c.Set(OpsUpstreamRequestBodyKey, `{"model":"gpt-4"}`)
- appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
- Kind: "request_error",
- Message: "dial timeout",
- })
-
- v, ok := c.Get(OpsUpstreamErrorsKey)
- require.True(t, ok)
- events, ok := v.([]*OpsUpstreamErrorEvent)
- require.True(t, ok)
- require.Len(t, events, 1)
- require.Equal(t, `{"model":"gpt-4"}`, events[0].UpstreamRequestBody)
-}
diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go
index f57ac614..022b1b01 100644
--- a/backend/internal/service/payment_config_service.go
+++ b/backend/internal/service/payment_config_service.go
@@ -34,6 +34,7 @@ const (
SettingCancelWindowSize = "CANCEL_RATE_LIMIT_WINDOW"
SettingCancelWindowUnit = "CANCEL_RATE_LIMIT_UNIT"
SettingCancelWindowMode = "CANCEL_RATE_LIMIT_WINDOW_MODE"
+ SettingAlipayForceQRCode = "ALIPAY_FORCE_QRCODE"
)
// Default values for payment configuration settings.
@@ -67,6 +68,9 @@ type PaymentConfig struct {
CancelRateLimitWindow int `json:"cancel_rate_limit_window"`
CancelRateLimitUnit string `json:"cancel_rate_limit_unit"`
CancelRateLimitMode string `json:"cancel_rate_limit_window_mode"`
+
+ // Force Alipay mobile users to use QR code instead of mobile redirect
+ AlipayForceQRCode bool `json:"alipay_force_qrcode"`
}
// UpdatePaymentConfigRequest contains fields to update payment configuration.
@@ -94,6 +98,9 @@ type UpdatePaymentConfigRequest struct {
CancelRateLimitUnit *string `json:"cancel_rate_limit_unit"`
CancelRateLimitMode *string `json:"cancel_rate_limit_window_mode"`
+ // Force Alipay mobile users to use QR code instead of mobile redirect
+ AlipayForceQRCode *bool `json:"alipay_force_qrcode"`
+
VisibleMethodAlipaySource *string `json:"payment_visible_method_alipay_source"`
VisibleMethodWxpaySource *string `json:"payment_visible_method_wxpay_source"`
VisibleMethodAlipayEnabled *bool `json:"payment_visible_method_alipay_enabled"`
@@ -202,6 +209,7 @@ func (s *PaymentConfigService) GetPaymentConfig(ctx context.Context) (*PaymentCo
SettingHelpImageURL, SettingHelpText,
SettingCancelRateLimitOn, SettingCancelRateLimitMax,
SettingCancelWindowSize, SettingCancelWindowUnit, SettingCancelWindowMode,
+ SettingAlipayForceQRCode,
SettingPaymentVisibleMethodAlipayEnabled, SettingPaymentVisibleMethodAlipaySource,
SettingPaymentVisibleMethodWxpayEnabled, SettingPaymentVisibleMethodWxpaySource,
}
@@ -237,6 +245,8 @@ func (s *PaymentConfigService) parsePaymentConfig(vals map[string]string) *Payme
CancelRateLimitWindow: pcParseInt(vals[SettingCancelWindowSize], 1),
CancelRateLimitUnit: vals[SettingCancelWindowUnit],
CancelRateLimitMode: vals[SettingCancelWindowMode],
+
+ AlipayForceQRCode: vals[SettingAlipayForceQRCode] == "true",
}
if cfg.LoadBalanceStrategy == "" {
cfg.LoadBalanceStrategy = payment.DefaultLoadBalanceStrategy
@@ -314,6 +324,7 @@ func (s *PaymentConfigService) UpdatePaymentConfig(ctx context.Context, req Upda
SettingCancelWindowSize: formatPositiveInt(req.CancelRateLimitWindow),
SettingCancelWindowUnit: derefStr(req.CancelRateLimitUnit),
SettingCancelWindowMode: derefStr(req.CancelRateLimitMode),
+ SettingAlipayForceQRCode: formatBoolOrEmpty(req.AlipayForceQRCode),
SettingPaymentVisibleMethodAlipaySource: derefStr(req.VisibleMethodAlipaySource),
SettingPaymentVisibleMethodWxpaySource: derefStr(req.VisibleMethodWxpaySource),
SettingPaymentVisibleMethodAlipayEnabled: formatBoolOrEmpty(req.VisibleMethodAlipayEnabled),
diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go
index 056967f0..e6cc4b3c 100644
--- a/backend/internal/service/payment_order.go
+++ b/backend/internal/service/payment_order.go
@@ -499,24 +499,41 @@ func selectedInstanceSupportedTypes(sel *payment.InstanceSelection) string {
func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limitAmount float64, cfg *PaymentConfig, sel *payment.InstanceSelection) string {
if plan != nil {
- if plan.ProductName != "" {
- return plan.ProductName
+ productName := plan.ProductName
+ if productName == "" {
+ productName = "Sub2API Subscription " + plan.Name
}
- return "Sub2API Subscription " + plan.Name
+ return applyPaymentProductNameAffix(productName, cfg)
}
currency := payment.DefaultPaymentCurrency
if sel != nil {
currency = paymentProviderConfigCurrency(sel.ProviderKey, sel.Config)
}
amountStr := payment.FormatAmountForCurrency(limitAmount, currency)
- pf := strings.TrimSpace(cfg.ProductNamePrefix)
- sf := strings.TrimSpace(cfg.ProductNameSuffix)
- if pf != "" || sf != "" {
- return strings.TrimSpace(pf + " " + amountStr + " " + sf)
+ if hasPaymentProductNameAffix(cfg) {
+ return applyPaymentProductNameAffix(amountStr, cfg)
}
return "Sub2API " + amountStr + " " + currency
}
+func hasPaymentProductNameAffix(cfg *PaymentConfig) bool {
+ if cfg == nil {
+ return false
+ }
+ pf := strings.TrimSpace(cfg.ProductNamePrefix)
+ sf := strings.TrimSpace(cfg.ProductNameSuffix)
+ return pf != "" || sf != ""
+}
+
+func applyPaymentProductNameAffix(productName string, cfg *PaymentConfig) string {
+ if !hasPaymentProductNameAffix(cfg) {
+ return productName
+ }
+ pf := strings.TrimSpace(cfg.ProductNamePrefix)
+ sf := strings.TrimSpace(cfg.ProductNameSuffix)
+ return strings.TrimSpace(pf + " " + productName + " " + sf)
+}
+
func (s *PaymentService) maybeBuildWeChatOAuthRequiredResponse(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64) (*CreateOrderResponse, error) {
return s.maybeBuildWeChatOAuthRequiredResponseForSelection(ctx, req, amount, payAmount, feeRate, nil)
}
diff --git a/backend/internal/service/payment_order_expiry_service.go b/backend/internal/service/payment_order_expiry_service.go
index b0cda3e5..32e51d7f 100644
--- a/backend/internal/service/payment_order_expiry_service.go
+++ b/backend/internal/service/payment_order_expiry_service.go
@@ -59,10 +59,18 @@ func (s *PaymentOrderExpiryService) Stop() {
}
func (s *PaymentOrderExpiryService) runOnce() {
- ctx, cancel := context.WithTimeout(context.Background(), expiryCheckTimeout)
- defer cancel()
+ reconcileCtx, cancel := context.WithTimeout(context.Background(), expiryCheckTimeout)
+ recovered, err := s.paymentSvc.ReconcilePendingWxpayOrders(reconcileCtx)
+ cancel()
+ if err != nil {
+ slog.Warn("[PaymentOrderExpiry] failed to reconcile pending wxpay orders", "error", err)
+ } else if recovered > 0 {
+ slog.Info("[PaymentOrderExpiry] reconciled paid wxpay orders", "count", recovered)
+ }
- expired, err := s.paymentSvc.ExpireTimedOutOrders(ctx)
+ expireCtx, cancel := context.WithTimeout(context.Background(), expiryCheckTimeout)
+ defer cancel()
+ expired, err := s.paymentSvc.ExpireTimedOutOrders(expireCtx)
if err != nil {
slog.Error("[PaymentOrderExpiry] failed to expire orders", "error", err)
return
diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go
index b627ced4..ffe120d0 100644
--- a/backend/internal/service/payment_order_lifecycle.go
+++ b/backend/internal/service/payment_order_lifecycle.go
@@ -26,8 +26,14 @@ const (
rateLimitModeFixed = "fixed"
checkPaidResultAlreadyPaid = "already_paid"
checkPaidResultCancelled = "cancelled"
+
+ pendingWxpayReconcileLimit = 20
)
+type checkPaidOptions struct {
+ cancelIfUnpaid bool
+}
+
func (s *PaymentService) checkCancelRateLimit(ctx context.Context, userID int64, cfg *PaymentConfig) error {
if !cfg.CancelRateLimitEnabled || cfg.CancelRateLimitMax <= 0 {
return nil
@@ -136,6 +142,14 @@ func (s *PaymentService) cancelCore(ctx context.Context, o *dbent.PaymentOrder,
}
func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) string {
+ return s.checkPaidWithOptions(ctx, o, checkPaidOptions{cancelIfUnpaid: true})
+}
+
+func (s *PaymentService) reconcilePaid(ctx context.Context, o *dbent.PaymentOrder) string {
+ return s.checkPaidWithOptions(ctx, o, checkPaidOptions{})
+}
+
+func (s *PaymentService) checkPaidWithOptions(ctx context.Context, o *dbent.PaymentOrder, opts checkPaidOptions) string {
prov, err := s.getOrderProvider(ctx, o)
if err != nil {
return ""
@@ -182,6 +196,9 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s
}
return checkPaidResultAlreadyPaid
}
+ if !opts.cancelIfUnpaid {
+ return ""
+ }
if cp, ok := prov.(payment.CancelableProvider); ok {
_ = cp.CancelPayment(ctx, queryRef)
}
@@ -268,7 +285,7 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
}
// Only verify orders that are still pending or recently expired
if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
- result := s.checkPaid(ctx, o)
+ result := s.reconcilePaid(ctx, o)
if result == checkPaidResultAlreadyPaid {
// Reload order to get updated status
o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
@@ -280,6 +297,37 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo
return o, nil
}
+// ReconcilePendingWxpayOrders actively checks recent pending WeChat orders so
+// missed provider notifications do not wait until order expiry to fulfill.
+func (s *PaymentService) ReconcilePendingWxpayOrders(ctx context.Context) (int, error) {
+ now := time.Now()
+ orders, err := s.entClient.PaymentOrder.Query().
+ Where(
+ paymentorder.StatusEQ(OrderStatusPending),
+ paymentorder.ExpiresAtGT(now),
+ paymentorder.Or(
+ paymentorder.PaymentTypeEQ(payment.TypeWxpay),
+ paymentorder.PaymentTypeHasPrefix(payment.TypeWxpay+"_"),
+ paymentorder.ProviderKeyEQ(payment.TypeWxpay),
+ paymentorder.ProviderKeyHasPrefix(payment.TypeWxpay+"_"),
+ ),
+ ).
+ Order(dbent.Asc(paymentorder.FieldCreatedAt)).
+ Limit(pendingWxpayReconcileLimit).
+ All(ctx)
+ if err != nil {
+ return 0, fmt.Errorf("query pending wxpay orders: %w", err)
+ }
+
+ recovered := 0
+ for _, order := range orders {
+ if s.reconcilePaid(ctx, order) == checkPaidResultAlreadyPaid {
+ recovered++
+ }
+ }
+ return recovered, nil
+}
+
// VerifyOrderPublic returns the currently persisted public order state without
// triggering any upstream reconciliation. Signed resume-token recovery is the
// only public recovery path allowed to query upstream state.
diff --git a/backend/internal/service/payment_order_lifecycle_test.go b/backend/internal/service/payment_order_lifecycle_test.go
index d8595715..1964cdf6 100644
--- a/backend/internal/service/payment_order_lifecycle_test.go
+++ b/backend/internal/service/payment_order_lifecycle_test.go
@@ -20,10 +20,13 @@ import (
)
type paymentOrderLifecycleQueryProvider struct {
- lastQueryTradeNo string
- queryCalls int
- responses []*payment.QueryOrderResponse
- resp *payment.QueryOrderResponse
+ key string
+ lastQueryTradeNo string
+ lastCancelTradeNo string
+ queryCalls int
+ cancelCalls int
+ responses []*payment.QueryOrderResponse
+ resp *payment.QueryOrderResponse
}
type paymentOrderLifecycleRedeemRepo struct {
@@ -38,10 +41,15 @@ func (p *paymentOrderLifecycleQueryProvider) Name() string {
return "payment-order-lifecycle-query-provider"
}
-func (p *paymentOrderLifecycleQueryProvider) ProviderKey() string { return payment.TypeAlipay }
+func (p *paymentOrderLifecycleQueryProvider) ProviderKey() string {
+ if p.key != "" {
+ return p.key
+ }
+ return payment.TypeAlipay
+}
func (p *paymentOrderLifecycleQueryProvider) SupportedTypes() []payment.PaymentType {
- return []payment.PaymentType{payment.TypeAlipay}
+ return []payment.PaymentType{p.ProviderKey()}
}
func (p *paymentOrderLifecycleQueryProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
@@ -69,6 +77,12 @@ func (p *paymentOrderLifecycleQueryProvider) Refund(context.Context, payment.Ref
panic("unexpected call")
}
+func (p *paymentOrderLifecycleQueryProvider) CancelPayment(_ context.Context, tradeNo string) error {
+ p.lastCancelTradeNo = tradeNo
+ p.cancelCalls++
+ return nil
+}
+
func (r *paymentOrderLifecycleRedeemRepo) Create(context.Context, *RedeemCode) error {
panic("unexpected call")
}
@@ -435,6 +449,222 @@ func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) {
require.Empty(t, redeemRepo.useCalls)
}
+func TestVerifyOrderByOutTradeNoDoesNotCancelUnpaidUpstreamOrder(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentOrderLifecycleTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("checkpaid-pending@example.com").
+ SetPasswordHash("hash").
+ SetUsername("checkpaid-pending-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("CHECKPAID-PENDING").
+ SetOutTradeNo("sub2_checkpaid_pending").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ registry := payment.NewRegistry()
+ provider := &paymentOrderLifecycleQueryProvider{
+ resp: &payment.QueryOrderResponse{
+ TradeNo: order.OutTradeNo,
+ Status: payment.ProviderStatusPending,
+ Amount: 0,
+ },
+ }
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ providersLoaded: true,
+ }
+
+ got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, OrderStatusPending, got.Status)
+ require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
+ require.Zero(t, provider.cancelCalls)
+
+ reloaded, err := client.PaymentOrder.Get(ctx, order.ID)
+ require.NoError(t, err)
+ require.Equal(t, OrderStatusPending, reloaded.Status)
+}
+
+func TestCancelOrderStillClosesUnpaidUpstreamOrder(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentOrderLifecycleTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("cancel-pending@example.com").
+ SetPasswordHash("hash").
+ SetUsername("cancel-pending-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(88).
+ SetPayAmount(88).
+ SetFeeRate(0).
+ SetRechargeCode("CANCEL-PENDING").
+ SetOutTradeNo("sub2_cancel_pending").
+ SetPaymentType(payment.TypeAlipay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ registry := payment.NewRegistry()
+ provider := &paymentOrderLifecycleQueryProvider{
+ resp: &payment.QueryOrderResponse{
+ TradeNo: order.OutTradeNo,
+ Status: payment.ProviderStatusPending,
+ Amount: 0,
+ },
+ }
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ providersLoaded: true,
+ }
+
+ outcome, err := svc.CancelOrder(ctx, order.ID, user.ID)
+ require.NoError(t, err)
+ require.Equal(t, checkPaidResultCancelled, outcome)
+ require.Equal(t, order.OutTradeNo, provider.lastCancelTradeNo)
+ require.Equal(t, 1, provider.cancelCalls)
+
+ reloaded, err := client.PaymentOrder.Get(ctx, order.ID)
+ require.NoError(t, err)
+ require.Equal(t, OrderStatusCancelled, reloaded.Status)
+}
+
+func TestReconcilePendingWxpayOrdersBackfillsPaidOrder(t *testing.T) {
+ ctx := context.Background()
+ client := newPaymentOrderLifecycleTestClient(t)
+
+ user, err := client.User.Create().
+ SetEmail("wxpay-reconcile@example.com").
+ SetPasswordHash("hash").
+ SetUsername("wxpay-reconcile-user").
+ Save(ctx)
+ require.NoError(t, err)
+
+ order, err := client.PaymentOrder.Create().
+ SetUserID(user.ID).
+ SetUserEmail(user.Email).
+ SetUserName(user.Username).
+ SetAmount(50).
+ SetPayAmount(50).
+ SetFeeRate(0).
+ SetRechargeCode("WXPAY-RECONCILE").
+ SetOutTradeNo("sub2_wxpay_reconcile").
+ SetPaymentType(payment.TypeWxpay).
+ SetPaymentTradeNo("").
+ SetOrderType(payment.OrderTypeBalance).
+ SetStatus(OrderStatusPending).
+ SetExpiresAt(time.Now().Add(time.Hour)).
+ SetClientIP("127.0.0.1").
+ SetSrcHost("api.example.com").
+ Save(ctx)
+ require.NoError(t, err)
+
+ userRepo := &mockUserRepo{
+ getByIDUser: &User{
+ ID: user.ID,
+ Email: user.Email,
+ Username: user.Username,
+ Balance: 0,
+ },
+ }
+ userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error {
+ require.Equal(t, user.ID, id)
+ if userRepo.getByIDUser != nil {
+ userRepo.getByIDUser.Balance += amount
+ }
+ return nil
+ }
+ redeemRepo := &paymentOrderLifecycleRedeemRepo{
+ codesByCode: map[string]*RedeemCode{
+ order.RechargeCode: {
+ ID: 1,
+ Code: order.RechargeCode,
+ Type: RedeemTypeBalance,
+ Value: order.Amount,
+ Status: StatusUnused,
+ },
+ },
+ }
+ redeemService := NewRedeemService(
+ redeemRepo,
+ userRepo,
+ nil,
+ nil,
+ nil,
+ client,
+ nil,
+ nil,
+ )
+ registry := payment.NewRegistry()
+ provider := &paymentOrderLifecycleQueryProvider{
+ key: payment.TypeWxpay,
+ resp: &payment.QueryOrderResponse{
+ TradeNo: "wxpay-upstream-trade-123",
+ Status: payment.ProviderStatusPaid,
+ Amount: 50,
+ Metadata: map[string]string{
+ "trade_state": "SUCCESS",
+ },
+ },
+ }
+ registry.Register(provider)
+
+ svc := &PaymentService{
+ entClient: client,
+ registry: registry,
+ redeemService: redeemService,
+ userRepo: userRepo,
+ providersLoaded: true,
+ }
+
+ recovered, err := svc.ReconcilePendingWxpayOrders(ctx)
+ require.NoError(t, err)
+ require.Equal(t, 1, recovered)
+ require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo)
+ require.Zero(t, provider.cancelCalls)
+
+ reloaded, err := client.PaymentOrder.Get(ctx, order.ID)
+ require.NoError(t, err)
+ require.Equal(t, OrderStatusCompleted, reloaded.Status)
+ require.Equal(t, "wxpay-upstream-trade-123", reloaded.PaymentTradeNo)
+ require.Equal(t, 50.0, userRepo.getByIDUser.Balance)
+ require.Len(t, redeemRepo.useCalls, 1)
+}
+
func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay(t *testing.T) {
ctx := context.Background()
client := newPaymentOrderLifecycleTestClient(t)
diff --git a/backend/internal/service/payment_order_result_test.go b/backend/internal/service/payment_order_result_test.go
index f78d6b37..bfe27548 100644
--- a/backend/internal/service/payment_order_result_test.go
+++ b/backend/internal/service/payment_order_result_test.go
@@ -138,6 +138,41 @@ func TestCalculateCreateOrderPayAmountRejectsFractionalZeroDecimal(t *testing.T)
}
}
+func TestBuildPaymentSubjectAppliesAffixToSubscriptionPlanProductName(t *testing.T) {
+ t.Parallel()
+
+ svc := &PaymentService{}
+ cfg := &PaymentConfig{
+ ProductNamePrefix: "PRE",
+ ProductNameSuffix: "SUF",
+ }
+ plan := &dbent.SubscriptionPlan{
+ Name: "Pro Monthly",
+ ProductName: "Claude Pro",
+ }
+
+ got := svc.buildPaymentSubject(plan, 0, cfg, nil)
+ if got != "PRE Claude Pro SUF" {
+ t.Fatalf("buildPaymentSubject() = %q, want %q", got, "PRE Claude Pro SUF")
+ }
+}
+
+func TestBuildPaymentSubjectAppliesAffixToSubscriptionPlanDefaultName(t *testing.T) {
+ t.Parallel()
+
+ svc := &PaymentService{}
+ cfg := &PaymentConfig{
+ ProductNamePrefix: "PRE",
+ ProductNameSuffix: "SUF",
+ }
+ plan := &dbent.SubscriptionPlan{Name: "Team Monthly"}
+
+ got := svc.buildPaymentSubject(plan, 0, cfg, nil)
+ if got != "PRE Sub2API Subscription Team Monthly SUF" {
+ t.Fatalf("buildPaymentSubject() = %q, want %q", got, "PRE Sub2API Subscription Team Monthly SUF")
+ }
+}
+
func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) {
t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef")
diff --git a/backend/internal/service/payment_resume_lookup.go b/backend/internal/service/payment_resume_lookup.go
index 1ff061e8..fb41ced4 100644
--- a/backend/internal/service/payment_resume_lookup.go
+++ b/backend/internal/service/payment_resume_lookup.go
@@ -46,7 +46,7 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token
return nil, invalidResumeTokenMatchError()
}
if order.Status == OrderStatusPending || order.Status == OrderStatusExpired {
- result := s.checkPaid(ctx, order)
+ result := s.reconcilePaid(ctx, order)
if result == checkPaidResultAlreadyPaid {
order, err = s.entClient.PaymentOrder.Get(ctx, order.ID)
if err != nil {
diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go
index 8a033710..bd0c30df 100644
--- a/backend/internal/service/pricing_service.go
+++ b/backend/internal/service/pricing_service.go
@@ -9,6 +9,7 @@ import (
"os"
"path/filepath"
"regexp"
+ "sort"
"strings"
"sync"
"time"
@@ -903,6 +904,24 @@ func (s *PricingService) getHashFilePath() string {
return filepath.Join(s.cfg.Pricing.DataDir, "model_pricing.sha256")
}
+// ListModelNamesByProvider returns all model names in the catalog whose
+// LiteLLMProvider matches the given provider string (case-insensitive).
+// The returned slice is sorted alphabetically.
+func (s *PricingService) ListModelNamesByProvider(provider string) []string {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ provider = strings.ToLower(strings.TrimSpace(provider))
+ names := make([]string, 0)
+ for name, p := range s.pricingData {
+ if strings.ToLower(p.LiteLLMProvider) == provider {
+ names = append(names, name)
+ }
+ }
+ sort.Strings(names)
+ return names
+}
+
// isNumeric 检查字符串是否为纯数字
func isNumeric(s string) bool {
for _, c := range s {
diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go
index 3c3e2c5b..cc8b120a 100644
--- a/backend/internal/service/pricing_service_test.go
+++ b/backend/internal/service/pricing_service_test.go
@@ -2,6 +2,8 @@ package service
import (
"encoding/json"
+ "os"
+ "path/filepath"
"testing"
"github.com/stretchr/testify/require"
@@ -111,6 +113,22 @@ func TestGetModelPricing_OpenAICompactAliasUsesStaticFallback(t *testing.T) {
require.InDelta(t, 1.5e-5, got.OutputCostPerToken, 1e-12)
}
+func TestDefaultPricingIncludesCodexAutoReview(t *testing.T) {
+ data, err := os.ReadFile(filepath.Join("..", "..", "resources", "model-pricing", "model_prices_and_context_window.json"))
+ require.NoError(t, err)
+
+ svc := &PricingService{}
+ pricingData, err := svc.parsePricingData(data)
+ require.NoError(t, err)
+ svc.pricingData = pricingData
+
+ got := svc.GetModelPricing("codex-auto-review")
+ require.NotNil(t, got)
+ require.InDelta(t, 2.5e-6, got.InputCostPerToken, 1e-12)
+ require.InDelta(t, 1.5e-5, got.OutputCostPerToken, 1e-12)
+ require.InDelta(t, 2.5e-7, got.CacheReadInputTokenCost, 1e-12)
+}
+
func TestGetModelPricing_Gpt54MiniUsesDedicatedStaticFallbackWhenRemoteMissing(t *testing.T) {
svc := &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
@@ -216,3 +234,60 @@ func TestParsePricingData_PreservesServiceTierPriorityFields(t *testing.T) {
require.InDelta(t, 0.0000005, pricing.CacheReadInputTokenCostPriority, 1e-12)
require.True(t, pricing.SupportsServiceTier)
}
+
+// ---------------------------------------------------------------------------
+// ListModelNamesByProvider
+// ---------------------------------------------------------------------------
+
+func TestListModelNamesByProvider_ReturnsMatchingModels(t *testing.T) {
+ svc := &PricingService{
+ pricingData: map[string]*LiteLLMModelPricing{
+ "claude-opus-4-5-20251101": {LiteLLMProvider: "anthropic", InputCostPerToken: 1.5e-5},
+ "claude-sonnet-4-5": {LiteLLMProvider: "anthropic", InputCostPerToken: 3e-6},
+ "gpt-4o": {LiteLLMProvider: "openai", InputCostPerToken: 5e-6},
+ "gemini-2.5-pro": {LiteLLMProvider: "google", InputCostPerToken: 1.25e-6},
+ },
+ }
+
+ got := svc.ListModelNamesByProvider("anthropic")
+ require.ElementsMatch(t, []string{"claude-opus-4-5-20251101", "claude-sonnet-4-5"}, got)
+ // Must be sorted
+ require.Equal(t, "claude-opus-4-5-20251101", got[0])
+ require.Equal(t, "claude-sonnet-4-5", got[1])
+}
+
+func TestListModelNamesByProvider_CaseInsensitive(t *testing.T) {
+ svc := &PricingService{
+ pricingData: map[string]*LiteLLMModelPricing{
+ "gpt-4o": {LiteLLMProvider: "OpenAI", InputCostPerToken: 5e-6},
+ },
+ }
+
+ got := svc.ListModelNamesByProvider("openai")
+ require.Equal(t, []string{"gpt-4o"}, got)
+
+ got2 := svc.ListModelNamesByProvider("OPENAI")
+ require.Equal(t, []string{"gpt-4o"}, got2)
+}
+
+func TestListModelNamesByProvider_NoMatch(t *testing.T) {
+ svc := &PricingService{
+ pricingData: map[string]*LiteLLMModelPricing{
+ "gpt-4o": {LiteLLMProvider: "openai", InputCostPerToken: 5e-6},
+ },
+ }
+
+ got := svc.ListModelNamesByProvider("anthropic")
+ require.NotNil(t, got)
+ require.Empty(t, got)
+}
+
+func TestListModelNamesByProvider_EmptyCatalog(t *testing.T) {
+ svc := &PricingService{
+ pricingData: map[string]*LiteLLMModelPricing{},
+ }
+
+ got := svc.ListModelNamesByProvider("openai")
+ require.NotNil(t, got)
+ require.Empty(t, got)
+}
diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go
index 19c45a5a..892d9aca 100644
--- a/backend/internal/service/ratelimit_service.go
+++ b/backend/internal/service/ratelimit_service.go
@@ -209,6 +209,17 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
}
}
+ // 缺少 refresh_token 的 OAuth 账号无法在冷却期内自愈(后台刷新服务也会跳过),
+ // 直接走 SetError 永久禁用,避免冷却结束后再被选中产生一发无意义的 502。
+ if strings.TrimSpace(account.GetCredential("refresh_token")) == "" {
+ msg := "Authentication failed (401): refresh_token missing, cannot recover"
+ if upstreamMsg != "" {
+ msg = "OAuth 401 (no refresh_token): " + upstreamMsg
+ }
+ s.handleAuthError(ctx, account, msg)
+ shouldDisable = true
+ break
+ }
// 2. 设置 expires_at 为当前时间,强制下次请求刷新 token
if account.Credentials == nil {
account.Credentials = make(map[string]any)
diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go
index 73b7849f..a964775e 100644
--- a/backend/internal/service/ratelimit_service_401_test.go
+++ b/backend/internal/service/ratelimit_service_401_test.go
@@ -85,6 +85,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *t
Platform: PlatformGemini,
Type: AccountTypeOAuth,
Credentials: map[string]any{
+ "refresh_token": "rt-100",
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
@@ -138,6 +139,9 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin
ID: 101,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "refresh_token": "rt-101",
+ },
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
@@ -175,7 +179,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
- "access_token": "token",
+ "access_token": "token",
+ "refresh_token": "rt-103",
},
}
@@ -185,3 +190,52 @@ func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *
require.Equal(t, 1, repo.updateCredentialsCalls)
require.NotEmpty(t, repo.lastCredentials["expires_at"])
}
+
+// 缺少 refresh_token 的 OAuth 账号 401 应直接 SetError 永久禁用,
+// 不再走 10 分钟冷却(冷却期内无人能刷新它,结束后还会被选中再 502 一次)。
+func TestRateLimitService_HandleUpstreamError_OAuth401NoRefreshTokenSetsError(t *testing.T) {
+ t.Run("openai_no_refresh_token", func(t *testing.T) {
+ repo := &rateLimitAccountRepoStub{}
+ invalidator := &tokenCacheInvalidatorRecorder{}
+ service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
+ service.SetTokenCacheInvalidator(invalidator)
+ account := &Account{
+ ID: 2881,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "expired-at",
+ // no refresh_token
+ },
+ }
+
+ shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
+
+ require.True(t, shouldDisable)
+ require.Equal(t, 1, repo.setErrorCalls, "AT-only OAuth 401 must SetError")
+ require.Equal(t, 0, repo.tempCalls, "AT-only OAuth 401 must NOT temp-unschedule")
+ require.Equal(t, 0, repo.updateCredentialsCalls, "no point forcing expires_at when refresh is impossible")
+ require.Contains(t, repo.lastErrorMsg, "refresh_token missing")
+ require.Len(t, invalidator.accounts, 1, "cache should still be invalidated")
+ })
+
+ t.Run("openai_blank_refresh_token_treated_as_missing", func(t *testing.T) {
+ repo := &rateLimitAccountRepoStub{}
+ service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
+ account := &Account{
+ ID: 2882,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "expired-at",
+ "refresh_token": " ",
+ },
+ }
+
+ shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
+
+ require.True(t, shouldDisable)
+ require.Equal(t, 1, repo.setErrorCalls)
+ require.Equal(t, 0, repo.tempCalls)
+ })
+}
diff --git a/backend/internal/service/redeem_code.go b/backend/internal/service/redeem_code.go
index a66b53ba..55abcfb3 100644
--- a/backend/internal/service/redeem_code.go
+++ b/backend/internal/service/redeem_code.go
@@ -16,6 +16,7 @@ type RedeemCode struct {
UsedAt *time.Time
Notes string
CreatedAt time.Time
+ ExpiresAt *time.Time
GroupID *int64
ValidityDays int
@@ -28,8 +29,22 @@ func (r *RedeemCode) IsUsed() bool {
return r.Status == StatusUsed
}
+func (r *RedeemCode) IsExpired() bool {
+ return r.IsExpiredAt(time.Now())
+}
+
+func (r *RedeemCode) IsExpiredAt(now time.Time) bool {
+ if r == nil {
+ return false
+ }
+ if r.Status == StatusExpired {
+ return true
+ }
+ return r.Status == StatusUnused && r.ExpiresAt != nil && !r.ExpiresAt.After(now)
+}
+
func (r *RedeemCode) CanUse() bool {
- return r.Status == StatusUnused
+ return r.Status == StatusUnused && !r.IsExpired()
}
func GenerateRedeemCode() (string, error) {
diff --git a/backend/internal/service/redeem_code_test.go b/backend/internal/service/redeem_code_test.go
new file mode 100644
index 00000000..ba5c7e7c
--- /dev/null
+++ b/backend/internal/service/redeem_code_test.go
@@ -0,0 +1,59 @@
+package service
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestRedeemCodeExpiry(t *testing.T) {
+ now := time.Now().UTC()
+ past := now.Add(-time.Hour)
+ future := now.Add(time.Hour)
+
+ tests := []struct {
+ name string
+ code RedeemCode
+ wantExpired bool
+ wantCanUse bool
+ }{
+ {
+ name: "unused without expiry can be used",
+ code: RedeemCode{Status: StatusUnused},
+ wantExpired: false,
+ wantCanUse: true,
+ },
+ {
+ name: "unused before expiry can be used",
+ code: RedeemCode{Status: StatusUnused, ExpiresAt: &future},
+ wantExpired: false,
+ wantCanUse: true,
+ },
+ {
+ name: "unused after expiry cannot be used",
+ code: RedeemCode{Status: StatusUnused, ExpiresAt: &past},
+ wantExpired: true,
+ wantCanUse: false,
+ },
+ {
+ name: "explicit expired status is expired",
+ code: RedeemCode{Status: StatusExpired},
+ wantExpired: true,
+ wantCanUse: false,
+ },
+ {
+ name: "used code remains used even after expiry time",
+ code: RedeemCode{Status: StatusUsed, ExpiresAt: &past},
+ wantExpired: false,
+ wantCanUse: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ require.Equal(t, tt.wantExpired, tt.code.IsExpiredAt(now))
+ require.Equal(t, tt.wantCanUse, tt.code.CanUse())
+ })
+ }
+}
diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go
index dcf293c5..73aa02b1 100644
--- a/backend/internal/service/redeem_service.go
+++ b/backend/internal/service/redeem_service.go
@@ -18,6 +18,7 @@ import (
var (
ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found")
ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used")
+ ErrRedeemCodeExpired = infraerrors.Conflict("REDEEM_CODE_EXPIRED", "redeem code expired")
ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance")
ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later")
ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again")
@@ -207,6 +208,9 @@ func (s *RedeemService) CreateCode(ctx context.Context, code *RedeemCode) error
if code.Status == "" {
code.Status = StatusUnused
}
+ if code.IsExpired() {
+ return ErrRedeemCodeExpired
+ }
if err := s.redeemRepo.Create(ctx, code); err != nil {
return fmt.Errorf("create redeem code: %w", err)
@@ -289,7 +293,11 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) (
return nil, fmt.Errorf("get redeem code: %w", err)
}
- // 检查兑换码状态
+ // 检查兑换码状态和码本身的过期时间
+ if redeemCode.IsExpired() {
+ s.incrementRedeemErrorCount(ctx, userID)
+ return nil, ErrRedeemCodeExpired
+ }
if !redeemCode.CanUse() {
s.incrementRedeemErrorCount(ctx, userID)
return nil, ErrRedeemCodeUsed
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 86978eec..a5c16b1f 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -24,6 +24,25 @@ import (
"golang.org/x/sync/singleflight"
)
+// CoerceDingTalkCorpPolicyForWrite 是 coerceDeprecatedDingTalkCorpPolicy 的导出版本,
+// 用于 admin handler 在写入路径上对客户端直传的入参做防御性 coerce(前端 UI 虽已无 whitelist 选项,
+// 但 API 可被直接调用)。
+func CoerceDingTalkCorpPolicyForWrite(policy string) string {
+ return coerceDeprecatedDingTalkCorpPolicy(policy)
+}
+
+// coerceDeprecatedDingTalkCorpPolicy 把已废弃的 corp_restriction_policy 值替换成安全的等价值。
+// 升级前残留在 DB 中的 "whitelist" 会导致 callback 链路在 default case 静默 fail-closed
+// (所有钉钉登录被拒)。这里统一退化为 "none" 让服务保持可用,并 warn 日志提醒 admin 重新保存设置。
+func coerceDeprecatedDingTalkCorpPolicy(policy string) string {
+ if policy == "whitelist" {
+ slog.Warn("dingtalk: corp_restriction_policy=whitelist is deprecated and unsupported, coercing to none",
+ "hint", "re-save DingTalk settings in admin UI to clear this warning")
+ return "none"
+ }
+ return policy
+}
+
var (
ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
@@ -146,6 +165,7 @@ type AuthSourceDefaultSettings struct {
WeChat ProviderDefaultGrantSettings
GitHub ProviderDefaultGrantSettings
Google ProviderDefaultGrantSettings
+ DingTalk ProviderDefaultGrantSettings
ForceEmailOnThirdPartySignup bool
}
@@ -200,6 +220,13 @@ var (
grantOnSignup: SettingKeyAuthSourceDefaultGoogleGrantOnSignup,
grantOnFirstBind: SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind,
}
+ dingTalkAuthSourceDefaultKeys = authSourceDefaultKeySet{
+ balance: SettingKeyAuthSourceDefaultDingTalkBalance,
+ concurrency: SettingKeyAuthSourceDefaultDingTalkConcurrency,
+ subscriptions: SettingKeyAuthSourceDefaultDingTalkSubscriptions,
+ grantOnSignup: SettingKeyAuthSourceDefaultDingTalkGrantOnSignup,
+ grantOnFirstBind: SettingKeyAuthSourceDefaultDingTalkGrantOnFirstBind,
+ }
)
const (
@@ -606,6 +633,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyCustomMenuItems,
SettingKeyCustomEndpoints,
SettingKeyLinuxDoConnectEnabled,
+ SettingKeyDingTalkConnectEnabled,
SettingKeyWeChatConnectEnabled,
SettingKeyWeChatConnectAppID,
SettingKeyWeChatConnectAppSecret,
@@ -654,6 +682,12 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
} else {
linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled
}
+ dingTalkEnabled := false
+ if raw, ok := settings[SettingKeyDingTalkConnectEnabled]; ok {
+ dingTalkEnabled = raw == "true"
+ } else {
+ dingTalkEnabled = s.cfg != nil && s.cfg.DingTalk.Enabled
+ }
oidcEnabled := false
if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok {
oidcEnabled = raw == "true"
@@ -723,6 +757,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
CustomMenuItems: settings[SettingKeyCustomMenuItems],
CustomEndpoints: settings[SettingKeyCustomEndpoints],
LinuxDoOAuthEnabled: linuxDoEnabled,
+ DingTalkOAuthEnabled: dingTalkEnabled,
WeChatOAuthEnabled: weChatEnabled,
WeChatOAuthOpenEnabled: weChatOpenEnabled,
WeChatOAuthMPEnabled: weChatMPEnabled,
@@ -926,6 +961,7 @@ type PublicSettingsInjectionPayload struct {
CustomMenuItems json.RawMessage `json:"custom_menu_items"`
CustomEndpoints json.RawMessage `json:"custom_endpoints"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
+ DingTalkOAuthEnabled bool `json:"dingtalk_oauth_enabled"`
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"`
@@ -990,6 +1026,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
+ DingTalkOAuthEnabled: settings.DingTalkOAuthEnabled,
WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled,
WeChatOAuthMPEnabled: settings.WeChatOAuthMPEnabled,
@@ -1476,6 +1513,26 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting
updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret
}
+ // DingTalk Connect OAuth 登录
+ updates[SettingKeyDingTalkConnectEnabled] = strconv.FormatBool(settings.DingTalkConnectEnabled)
+ updates[SettingKeyDingTalkConnectClientID] = settings.DingTalkConnectClientID
+ updates[SettingKeyDingTalkConnectRedirectURL] = settings.DingTalkConnectRedirectURL
+ if settings.DingTalkConnectClientSecret != "" {
+ updates[SettingKeyDingTalkConnectClientSecret] = settings.DingTalkConnectClientSecret
+ }
+ updates[SettingKeyDingTalkConnectCorpRestrictionPolicy] = settings.DingTalkConnectCorpRestrictionPolicy
+ updates[SettingKeyDingTalkConnectInternalCorpID] = settings.DingTalkConnectInternalCorpID
+ updates[SettingKeyDingTalkConnectBypassRegistration] = strconv.FormatBool(settings.DingTalkConnectBypassRegistration)
+ updates[SettingKeyDingTalkConnectSyncCorpEmail] = strconv.FormatBool(settings.DingTalkConnectSyncCorpEmail)
+ updates[SettingKeyDingTalkConnectSyncDisplayName] = strconv.FormatBool(settings.DingTalkConnectSyncDisplayName)
+ updates[SettingKeyDingTalkConnectSyncDept] = strconv.FormatBool(settings.DingTalkConnectSyncDept)
+ updates[SettingKeyDingTalkConnectSyncCorpEmailAttrKey] = settings.DingTalkConnectSyncCorpEmailAttrKey
+ updates[SettingKeyDingTalkConnectSyncDisplayNameAttrKey] = settings.DingTalkConnectSyncDisplayNameAttrKey
+ updates[SettingKeyDingTalkConnectSyncDeptAttrKey] = settings.DingTalkConnectSyncDeptAttrKey
+ updates[SettingKeyDingTalkConnectSyncCorpEmailAttrName] = settings.DingTalkConnectSyncCorpEmailAttrName
+ updates[SettingKeyDingTalkConnectSyncDisplayNameAttrName] = settings.DingTalkConnectSyncDisplayNameAttrName
+ updates[SettingKeyDingTalkConnectSyncDeptAttrName] = settings.DingTalkConnectSyncDeptAttrName
+
// Generic OIDC OAuth 登录
updates[SettingKeyOIDCConnectEnabled] = strconv.FormatBool(settings.OIDCConnectEnabled)
updates[SettingKeyOIDCConnectProviderName] = settings.OIDCConnectProviderName
@@ -1677,19 +1734,21 @@ func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, sett
settings.WeChat.Subscriptions,
settings.GitHub.Subscriptions,
settings.Google.Subscriptions,
+ settings.DingTalk.Subscriptions,
} {
if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil {
return nil, err
}
}
- updates := make(map[string]string, 31)
+ updates := make(map[string]string, 36)
writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email)
writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo)
writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC)
writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat)
writeProviderDefaultGrantUpdates(updates, gitHubAuthSourceDefaultKeys, settings.GitHub)
writeProviderDefaultGrantUpdates(updates, googleAuthSourceDefaultKeys, settings.Google)
+ writeProviderDefaultGrantUpdates(updates, dingTalkAuthSourceDefaultKeys, settings.DingTalk)
updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup)
return updates, nil
}
@@ -2225,6 +2284,11 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut
SettingKeyAuthSourceDefaultGoogleSubscriptions,
SettingKeyAuthSourceDefaultGoogleGrantOnSignup,
SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind,
+ SettingKeyAuthSourceDefaultDingTalkBalance,
+ SettingKeyAuthSourceDefaultDingTalkConcurrency,
+ SettingKeyAuthSourceDefaultDingTalkSubscriptions,
+ SettingKeyAuthSourceDefaultDingTalkGrantOnSignup,
+ SettingKeyAuthSourceDefaultDingTalkGrantOnFirstBind,
SettingKeyForceEmailOnThirdPartySignup,
}
@@ -2240,6 +2304,7 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut
WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys),
GitHub: parseProviderDefaultGrantSettings(settings, gitHubAuthSourceDefaultKeys),
Google: parseProviderDefaultGrantSettings(settings, googleAuthSourceDefaultKeys),
+ DingTalk: parseProviderDefaultGrantSettings(settings, dingTalkAuthSourceDefaultKeys),
ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true",
}, nil
}
@@ -2316,111 +2381,116 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 初始化默认设置
defaults := map[string]string{
- SettingKeyRegistrationEnabled: "true",
- SettingKeyEmailVerifyEnabled: "false",
- SettingKeyRegistrationEmailSuffixWhitelist: "[]",
- SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
- SettingKeyLoginAgreementEnabled: "false",
- SettingKeyLoginAgreementMode: defaultLoginAgreementMode,
- SettingKeyLoginAgreementUpdatedAt: defaultLoginAgreementDate,
- SettingKeyLoginAgreementDocuments: loginAgreementDocumentsJSON,
- SettingKeySiteName: "Sub2API",
- SettingKeySiteLogo: "",
- SettingKeyPurchaseSubscriptionEnabled: "false",
- SettingKeyPurchaseSubscriptionURL: "",
- SettingKeyTableDefaultPageSize: "20",
- SettingKeyTablePageSizeOptions: "[10,20,50,100]",
- SettingKeyCustomMenuItems: "[]",
- SettingKeyCustomEndpoints: "[]",
- SettingKeyWeChatConnectEnabled: "false",
- SettingKeyWeChatConnectAppID: "",
- SettingKeyWeChatConnectAppSecret: "",
- SettingKeyWeChatConnectOpenAppID: "",
- SettingKeyWeChatConnectOpenAppSecret: "",
- SettingKeyWeChatConnectMPAppID: "",
- SettingKeyWeChatConnectMPAppSecret: "",
- SettingKeyWeChatConnectMobileAppID: "",
- SettingKeyWeChatConnectMobileAppSecret: "",
- SettingKeyWeChatConnectOpenEnabled: "false",
- SettingKeyWeChatConnectMPEnabled: "false",
- SettingKeyWeChatConnectMobileEnabled: "false",
- SettingKeyWeChatConnectMode: "open",
- SettingKeyWeChatConnectScopes: "snsapi_login",
- SettingKeyWeChatConnectRedirectURL: "",
- SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend,
- SettingKeyGitHubOAuthEnabled: "false",
- SettingKeyGitHubOAuthClientID: "",
- SettingKeyGitHubOAuthClientSecret: "",
- SettingKeyGitHubOAuthRedirectURL: "",
- SettingKeyGitHubOAuthFrontendRedirectURL: defaultGitHubOAuthFrontend,
- SettingKeyGoogleOAuthEnabled: "false",
- SettingKeyGoogleOAuthClientID: "",
- SettingKeyGoogleOAuthClientSecret: "",
- SettingKeyGoogleOAuthRedirectURL: "",
- SettingKeyGoogleOAuthFrontendRedirectURL: defaultGoogleOAuthFrontend,
- SettingKeyOIDCConnectEnabled: "false",
- SettingKeyOIDCConnectProviderName: "OIDC",
- SettingKeyOIDCConnectClientID: "",
- SettingKeyOIDCConnectClientSecret: "",
- SettingKeyOIDCConnectIssuerURL: "",
- SettingKeyOIDCConnectDiscoveryURL: "",
- SettingKeyOIDCConnectAuthorizeURL: "",
- SettingKeyOIDCConnectTokenURL: "",
- SettingKeyOIDCConnectUserInfoURL: "",
- SettingKeyOIDCConnectJWKSURL: "",
- SettingKeyOIDCConnectScopes: "openid email profile",
- SettingKeyOIDCConnectRedirectURL: "",
- SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
- SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
- SettingKeyOIDCConnectUsePKCE: strconv.FormatBool(oidcUsePKCEDefault),
- SettingKeyOIDCConnectValidateIDToken: strconv.FormatBool(oidcValidateIDTokenDefault),
- SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
- SettingKeyOIDCConnectClockSkewSeconds: "120",
- SettingKeyOIDCConnectRequireEmailVerified: "false",
- SettingKeyOIDCConnectUserInfoEmailPath: "",
- SettingKeyOIDCConnectUserInfoIDPath: "",
- SettingKeyOIDCConnectUserInfoUsernamePath: "",
- SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
- SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
- SettingKeyAffiliateRebateRate: strconv.FormatFloat(AffiliateRebateRateDefault, 'f', 8, 64),
- SettingKeyAffiliateRebateFreezeHours: strconv.Itoa(AffiliateRebateFreezeHoursDefault),
- SettingKeyAffiliateRebateDurationDays: strconv.Itoa(AffiliateRebateDurationDaysDefault),
- SettingKeyAffiliateRebatePerInviteeCap: strconv.FormatFloat(AffiliateRebatePerInviteeCapDefault, 'f', 2, 64),
- SettingKeyDefaultUserRPMLimit: "0",
- SettingKeyDefaultSubscriptions: "[]",
- SettingKeyAuthSourceDefaultEmailBalance: "0",
- SettingKeyAuthSourceDefaultEmailConcurrency: "5",
- SettingKeyAuthSourceDefaultEmailSubscriptions: "[]",
- SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
- SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
- SettingKeyAuthSourceDefaultLinuxDoBalance: "0",
- SettingKeyAuthSourceDefaultLinuxDoConcurrency: "5",
- SettingKeyAuthSourceDefaultLinuxDoSubscriptions: "[]",
- SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "false",
- SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "false",
- SettingKeyAuthSourceDefaultOIDCBalance: "0",
- SettingKeyAuthSourceDefaultOIDCConcurrency: "5",
- SettingKeyAuthSourceDefaultOIDCSubscriptions: "[]",
- SettingKeyAuthSourceDefaultOIDCGrantOnSignup: "false",
- SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "false",
- SettingKeyAuthSourceDefaultWeChatBalance: "0",
- SettingKeyAuthSourceDefaultWeChatConcurrency: "5",
- SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]",
- SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "false",
- SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false",
- SettingKeyAuthSourceDefaultGitHubBalance: "0",
- SettingKeyAuthSourceDefaultGitHubConcurrency: "5",
- SettingKeyAuthSourceDefaultGitHubSubscriptions: "[]",
- SettingKeyAuthSourceDefaultGitHubGrantOnSignup: "false",
- SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind: "false",
- SettingKeyAuthSourceDefaultGoogleBalance: "0",
- SettingKeyAuthSourceDefaultGoogleConcurrency: "5",
- SettingKeyAuthSourceDefaultGoogleSubscriptions: "[]",
- SettingKeyAuthSourceDefaultGoogleGrantOnSignup: "false",
- SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind: "false",
- SettingKeyForceEmailOnThirdPartySignup: "false",
- SettingKeySMTPPort: "587",
- SettingKeySMTPUseTLS: "false",
+ SettingKeyRegistrationEnabled: "true",
+ SettingKeyEmailVerifyEnabled: "false",
+ SettingKeyRegistrationEmailSuffixWhitelist: "[]",
+ SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能
+ SettingKeyLoginAgreementEnabled: "false",
+ SettingKeyLoginAgreementMode: defaultLoginAgreementMode,
+ SettingKeyLoginAgreementUpdatedAt: defaultLoginAgreementDate,
+ SettingKeyLoginAgreementDocuments: loginAgreementDocumentsJSON,
+ SettingKeySiteName: "Sub2API",
+ SettingKeySiteLogo: "",
+ SettingKeyPurchaseSubscriptionEnabled: "false",
+ SettingKeyPurchaseSubscriptionURL: "",
+ SettingKeyTableDefaultPageSize: "20",
+ SettingKeyTablePageSizeOptions: "[10,20,50,100]",
+ SettingKeyCustomMenuItems: "[]",
+ SettingKeyCustomEndpoints: "[]",
+ SettingKeyWeChatConnectEnabled: "false",
+ SettingKeyWeChatConnectAppID: "",
+ SettingKeyWeChatConnectAppSecret: "",
+ SettingKeyWeChatConnectOpenAppID: "",
+ SettingKeyWeChatConnectOpenAppSecret: "",
+ SettingKeyWeChatConnectMPAppID: "",
+ SettingKeyWeChatConnectMPAppSecret: "",
+ SettingKeyWeChatConnectMobileAppID: "",
+ SettingKeyWeChatConnectMobileAppSecret: "",
+ SettingKeyWeChatConnectOpenEnabled: "false",
+ SettingKeyWeChatConnectMPEnabled: "false",
+ SettingKeyWeChatConnectMobileEnabled: "false",
+ SettingKeyWeChatConnectMode: "open",
+ SettingKeyWeChatConnectScopes: "snsapi_login",
+ SettingKeyWeChatConnectRedirectURL: "",
+ SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend,
+ SettingKeyGitHubOAuthEnabled: "false",
+ SettingKeyGitHubOAuthClientID: "",
+ SettingKeyGitHubOAuthClientSecret: "",
+ SettingKeyGitHubOAuthRedirectURL: "",
+ SettingKeyGitHubOAuthFrontendRedirectURL: defaultGitHubOAuthFrontend,
+ SettingKeyGoogleOAuthEnabled: "false",
+ SettingKeyGoogleOAuthClientID: "",
+ SettingKeyGoogleOAuthClientSecret: "",
+ SettingKeyGoogleOAuthRedirectURL: "",
+ SettingKeyGoogleOAuthFrontendRedirectURL: defaultGoogleOAuthFrontend,
+ SettingKeyOIDCConnectEnabled: "false",
+ SettingKeyOIDCConnectProviderName: "OIDC",
+ SettingKeyOIDCConnectClientID: "",
+ SettingKeyOIDCConnectClientSecret: "",
+ SettingKeyOIDCConnectIssuerURL: "",
+ SettingKeyOIDCConnectDiscoveryURL: "",
+ SettingKeyOIDCConnectAuthorizeURL: "",
+ SettingKeyOIDCConnectTokenURL: "",
+ SettingKeyOIDCConnectUserInfoURL: "",
+ SettingKeyOIDCConnectJWKSURL: "",
+ SettingKeyOIDCConnectScopes: "openid email profile",
+ SettingKeyOIDCConnectRedirectURL: "",
+ SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback",
+ SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post",
+ SettingKeyOIDCConnectUsePKCE: strconv.FormatBool(oidcUsePKCEDefault),
+ SettingKeyOIDCConnectValidateIDToken: strconv.FormatBool(oidcValidateIDTokenDefault),
+ SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256",
+ SettingKeyOIDCConnectClockSkewSeconds: "120",
+ SettingKeyOIDCConnectRequireEmailVerified: "false",
+ SettingKeyOIDCConnectUserInfoEmailPath: "",
+ SettingKeyOIDCConnectUserInfoIDPath: "",
+ SettingKeyOIDCConnectUserInfoUsernamePath: "",
+ SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
+ SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
+ SettingKeyAffiliateRebateRate: strconv.FormatFloat(AffiliateRebateRateDefault, 'f', 8, 64),
+ SettingKeyAffiliateRebateFreezeHours: strconv.Itoa(AffiliateRebateFreezeHoursDefault),
+ SettingKeyAffiliateRebateDurationDays: strconv.Itoa(AffiliateRebateDurationDaysDefault),
+ SettingKeyAffiliateRebatePerInviteeCap: strconv.FormatFloat(AffiliateRebatePerInviteeCapDefault, 'f', 2, 64),
+ SettingKeyDefaultUserRPMLimit: "0",
+ SettingKeyDefaultSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultEmailBalance: "0",
+ SettingKeyAuthSourceDefaultEmailConcurrency: "5",
+ SettingKeyAuthSourceDefaultEmailSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultLinuxDoBalance: "0",
+ SettingKeyAuthSourceDefaultLinuxDoConcurrency: "5",
+ SettingKeyAuthSourceDefaultLinuxDoSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultOIDCBalance: "0",
+ SettingKeyAuthSourceDefaultOIDCConcurrency: "5",
+ SettingKeyAuthSourceDefaultOIDCSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultOIDCGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultWeChatBalance: "0",
+ SettingKeyAuthSourceDefaultWeChatConcurrency: "5",
+ SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultGitHubBalance: "0",
+ SettingKeyAuthSourceDefaultGitHubConcurrency: "5",
+ SettingKeyAuthSourceDefaultGitHubSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultGitHubGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultGoogleBalance: "0",
+ SettingKeyAuthSourceDefaultGoogleConcurrency: "5",
+ SettingKeyAuthSourceDefaultGoogleSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultGoogleGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind: "false",
+ SettingKeyAuthSourceDefaultDingTalkBalance: "0",
+ SettingKeyAuthSourceDefaultDingTalkConcurrency: "5",
+ SettingKeyAuthSourceDefaultDingTalkSubscriptions: "[]",
+ SettingKeyAuthSourceDefaultDingTalkGrantOnSignup: "false",
+ SettingKeyAuthSourceDefaultDingTalkGrantOnFirstBind: "false",
+ SettingKeyForceEmailOnThirdPartySignup: "false",
+ SettingKeySMTPPort: "587",
+ SettingKeySMTPUseTLS: "false",
// Model fallback defaults
SettingKeyEnableModelFallback: "false",
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
@@ -2599,6 +2669,136 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != ""
+ // DingTalk Connect 设置:
+ // - 兼容 config.yaml/env
+ // - 支持后台系统设置覆盖并持久化(存储于 DB)
+ dingTalkBase := config.DingTalkConnectConfig{}
+ if s.cfg != nil {
+ dingTalkBase = s.cfg.DingTalk
+ }
+
+ if raw, ok := settings[SettingKeyDingTalkConnectEnabled]; ok {
+ result.DingTalkConnectEnabled = raw == "true"
+ } else {
+ result.DingTalkConnectEnabled = dingTalkBase.Enabled
+ }
+
+ if v, ok := settings[SettingKeyDingTalkConnectClientID]; ok && strings.TrimSpace(v) != "" {
+ result.DingTalkConnectClientID = strings.TrimSpace(v)
+ } else {
+ result.DingTalkConnectClientID = dingTalkBase.ClientID
+ }
+
+ if v, ok := settings[SettingKeyDingTalkConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
+ result.DingTalkConnectRedirectURL = strings.TrimSpace(v)
+ } else {
+ result.DingTalkConnectRedirectURL = dingTalkBase.RedirectURL
+ }
+
+ result.DingTalkConnectClientSecret = strings.TrimSpace(settings[SettingKeyDingTalkConnectClientSecret])
+ if result.DingTalkConnectClientSecret == "" {
+ result.DingTalkConnectClientSecret = strings.TrimSpace(dingTalkBase.ClientSecret)
+ }
+ result.DingTalkConnectClientSecretConfigured = result.DingTalkConnectClientSecret != ""
+
+ if v, ok := settings[SettingKeyDingTalkConnectCorpRestrictionPolicy]; ok && strings.TrimSpace(v) != "" {
+ result.DingTalkConnectCorpRestrictionPolicy = strings.TrimSpace(v)
+ } else {
+ result.DingTalkConnectCorpRestrictionPolicy = dingTalkBase.CorpRestrictionPolicy
+ }
+ result.DingTalkConnectCorpRestrictionPolicy = coerceDeprecatedDingTalkCorpPolicy(result.DingTalkConnectCorpRestrictionPolicy)
+
+ if v, ok := settings[SettingKeyDingTalkConnectInternalCorpID]; ok && strings.TrimSpace(v) != "" {
+ result.DingTalkConnectInternalCorpID = strings.TrimSpace(v)
+ } else {
+ result.DingTalkConnectInternalCorpID = dingTalkBase.InternalCorpID
+ }
+
+ if v, ok := settings[SettingKeyDingTalkConnectBypassRegistration]; ok && strings.TrimSpace(v) != "" {
+ result.DingTalkConnectBypassRegistration = strings.EqualFold(strings.TrimSpace(v), "true")
+ } else {
+ result.DingTalkConnectBypassRegistration = dingTalkBase.BypassRegistration
+ }
+ // bypass_registration 仅在 internal_only 模式下有意义;其它策略下强制 false,
+ // 以保证加载出的 effective config 永远是一致状态。
+ if result.DingTalkConnectCorpRestrictionPolicy != "internal_only" {
+ result.DingTalkConnectBypassRegistration = false
+ }
+
+ if v, ok := settings[SettingKeyDingTalkConnectSyncCorpEmail]; ok && strings.TrimSpace(v) != "" {
+ result.DingTalkConnectSyncCorpEmail = strings.EqualFold(strings.TrimSpace(v), "true")
+ } else {
+ result.DingTalkConnectSyncCorpEmail = dingTalkBase.SyncCorpEmail
+ }
+ if v, ok := settings[SettingKeyDingTalkConnectSyncDisplayName]; ok && strings.TrimSpace(v) != "" {
+ result.DingTalkConnectSyncDisplayName = strings.EqualFold(strings.TrimSpace(v), "true")
+ } else {
+ result.DingTalkConnectSyncDisplayName = dingTalkBase.SyncDisplayName
+ }
+ if v, ok := settings[SettingKeyDingTalkConnectSyncDept]; ok && strings.TrimSpace(v) != "" {
+ result.DingTalkConnectSyncDept = strings.EqualFold(strings.TrimSpace(v), "true")
+ } else {
+ result.DingTalkConnectSyncDept = dingTalkBase.SyncDept
+ }
+ // 身份同步三开关仅在 internal_only 模式下有意义;其它策略强制 false。
+ if result.DingTalkConnectCorpRestrictionPolicy != "internal_only" {
+ result.DingTalkConnectSyncCorpEmail = false
+ result.DingTalkConnectSyncDisplayName = false
+ result.DingTalkConnectSyncDept = false
+ }
+
+ // 身份同步目标 attr key(DB 空 → fallback 默认值)
+ result.DingTalkConnectSyncCorpEmailAttrKey = strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncCorpEmailAttrKey])
+ if result.DingTalkConnectSyncCorpEmailAttrKey == "" {
+ if v := strings.TrimSpace(dingTalkBase.SyncCorpEmailAttrKey); v != "" {
+ result.DingTalkConnectSyncCorpEmailAttrKey = v
+ } else {
+ result.DingTalkConnectSyncCorpEmailAttrKey = "dingtalk_email"
+ }
+ }
+ result.DingTalkConnectSyncDisplayNameAttrKey = strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncDisplayNameAttrKey])
+ if result.DingTalkConnectSyncDisplayNameAttrKey == "" {
+ if v := strings.TrimSpace(dingTalkBase.SyncDisplayNameAttrKey); v != "" {
+ result.DingTalkConnectSyncDisplayNameAttrKey = v
+ } else {
+ result.DingTalkConnectSyncDisplayNameAttrKey = "dingtalk_name"
+ }
+ }
+ result.DingTalkConnectSyncDeptAttrKey = strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncDeptAttrKey])
+ if result.DingTalkConnectSyncDeptAttrKey == "" {
+ if v := strings.TrimSpace(dingTalkBase.SyncDeptAttrKey); v != "" {
+ result.DingTalkConnectSyncDeptAttrKey = v
+ } else {
+ result.DingTalkConnectSyncDeptAttrKey = "dingtalk_department"
+ }
+ }
+
+ // 身份同步目标 attr 显示名称(DB 空 → fallback 默认中文)
+ result.DingTalkConnectSyncCorpEmailAttrName = strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncCorpEmailAttrName])
+ if result.DingTalkConnectSyncCorpEmailAttrName == "" {
+ if v := strings.TrimSpace(dingTalkBase.SyncCorpEmailAttrName); v != "" {
+ result.DingTalkConnectSyncCorpEmailAttrName = v
+ } else {
+ result.DingTalkConnectSyncCorpEmailAttrName = "钉钉企业邮箱"
+ }
+ }
+ result.DingTalkConnectSyncDisplayNameAttrName = strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncDisplayNameAttrName])
+ if result.DingTalkConnectSyncDisplayNameAttrName == "" {
+ if v := strings.TrimSpace(dingTalkBase.SyncDisplayNameAttrName); v != "" {
+ result.DingTalkConnectSyncDisplayNameAttrName = v
+ } else {
+ result.DingTalkConnectSyncDisplayNameAttrName = "钉钉姓名"
+ }
+ }
+ result.DingTalkConnectSyncDeptAttrName = strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncDeptAttrName])
+ if result.DingTalkConnectSyncDeptAttrName == "" {
+ if v := strings.TrimSpace(dingTalkBase.SyncDeptAttrName); v != "" {
+ result.DingTalkConnectSyncDeptAttrName = v
+ } else {
+ result.DingTalkConnectSyncDeptAttrName = "钉钉部门"
+ }
+ }
+
// Generic OIDC 设置:
// - 兼容 config.yaml/env
// - 支持后台系统设置覆盖并持久化(存储于 DB)
@@ -2992,10 +3192,14 @@ func mergeProviderDefaultGrantSettings(globalDefaults ProviderDefaultGrantSettin
GrantOnFirstBind: providerDefaults.GrantOnFirstBind,
}
- if providerDefaults.Balance != defaultAuthSourceBalance {
+ // 注意:不能把 parse 默认值 (defaultAuthSourceBalance / defaultAuthSourceConcurrency)
+ // 当作"未配置"哨兵——admin 完全有权显式设成相同的值,那时仍应覆盖 globalDefaults。
+ // 旧实现的 `!= defaultAuthSourceConcurrency` 会把 admin 设的 5 与 fallback 5 混淆,
+ // 导致渠道发放退回到全局默认(如 1),表现为"管理员设 5、新用户实际拿 1"。
+ if providerDefaults.Balance >= 0 {
result.Balance = providerDefaults.Balance
}
- if providerDefaults.Concurrency > 0 && providerDefaults.Concurrency != defaultAuthSourceConcurrency {
+ if providerDefaults.Concurrency > 0 {
result.Concurrency = providerDefaults.Concurrency
}
if len(providerDefaults.Subscriptions) > 0 {
@@ -3281,6 +3485,157 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
return effective, nil
}
+// GetDingTalkConnectOAuthConfig 返回用于登录的"最终生效" DingTalk Connect 配置。
+//
+// 优先级:
+// - 若对应系统设置键存在,则覆盖 config.yaml/env 的值
+// - 否则回退到 config.yaml/env 的值
+func (s *SettingService) GetDingTalkConnectOAuthConfig(ctx context.Context) (config.DingTalkConnectConfig, error) {
+ if s == nil || s.cfg == nil {
+ return config.DingTalkConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded")
+ }
+
+ effective := s.cfg.DingTalk
+
+ keys := []string{
+ SettingKeyDingTalkConnectEnabled,
+ SettingKeyDingTalkConnectClientID,
+ SettingKeyDingTalkConnectClientSecret,
+ SettingKeyDingTalkConnectRedirectURL,
+ SettingKeyDingTalkConnectCorpRestrictionPolicy,
+ SettingKeyDingTalkConnectInternalCorpID,
+ SettingKeyDingTalkConnectBypassRegistration,
+ SettingKeyDingTalkConnectSyncCorpEmail,
+ SettingKeyDingTalkConnectSyncDisplayName,
+ SettingKeyDingTalkConnectSyncDept,
+ SettingKeyDingTalkConnectSyncCorpEmailAttrKey,
+ SettingKeyDingTalkConnectSyncDisplayNameAttrKey,
+ SettingKeyDingTalkConnectSyncDeptAttrKey,
+ }
+ settings, err := s.settingRepo.GetMultiple(ctx, keys)
+ if err != nil {
+ return config.DingTalkConnectConfig{}, fmt.Errorf("get dingtalk connect settings: %w", err)
+ }
+
+ if raw, ok := settings[SettingKeyDingTalkConnectEnabled]; ok {
+ effective.Enabled = raw == "true"
+ }
+ if v, ok := settings[SettingKeyDingTalkConnectClientID]; ok && strings.TrimSpace(v) != "" {
+ effective.ClientID = strings.TrimSpace(v)
+ }
+ if v, ok := settings[SettingKeyDingTalkConnectClientSecret]; ok && strings.TrimSpace(v) != "" {
+ effective.ClientSecret = strings.TrimSpace(v)
+ }
+ if v, ok := settings[SettingKeyDingTalkConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
+ effective.RedirectURL = strings.TrimSpace(v)
+ }
+ if v, ok := settings[SettingKeyDingTalkConnectCorpRestrictionPolicy]; ok && strings.TrimSpace(v) != "" {
+ effective.CorpRestrictionPolicy = strings.TrimSpace(v)
+ }
+ effective.CorpRestrictionPolicy = coerceDeprecatedDingTalkCorpPolicy(effective.CorpRestrictionPolicy)
+ if v, ok := settings[SettingKeyDingTalkConnectInternalCorpID]; ok && strings.TrimSpace(v) != "" {
+ effective.InternalCorpID = strings.TrimSpace(v)
+ }
+ if v, ok := settings[SettingKeyDingTalkConnectBypassRegistration]; ok && strings.TrimSpace(v) != "" {
+ effective.BypassRegistration = strings.EqualFold(strings.TrimSpace(v), "true")
+ }
+ // bypass_registration 仅在 internal_only 模式下有意义;其它策略下强制 false,
+ // 以保证 OAuth callback 看到的 effective config 永远是一致状态。
+ if effective.CorpRestrictionPolicy != "internal_only" {
+ effective.BypassRegistration = false
+ }
+
+ if v, ok := settings[SettingKeyDingTalkConnectSyncCorpEmail]; ok && strings.TrimSpace(v) != "" {
+ effective.SyncCorpEmail = strings.EqualFold(strings.TrimSpace(v), "true")
+ }
+ if v, ok := settings[SettingKeyDingTalkConnectSyncDisplayName]; ok && strings.TrimSpace(v) != "" {
+ effective.SyncDisplayName = strings.EqualFold(strings.TrimSpace(v), "true")
+ }
+ if v, ok := settings[SettingKeyDingTalkConnectSyncDept]; ok && strings.TrimSpace(v) != "" {
+ effective.SyncDept = strings.EqualFold(strings.TrimSpace(v), "true")
+ }
+ // 身份同步三开关仅在 internal_only 模式下有意义;其它策略强制 false。
+ if effective.CorpRestrictionPolicy != "internal_only" {
+ effective.SyncCorpEmail = false
+ effective.SyncDisplayName = false
+ effective.SyncDept = false
+ }
+
+ // 身份同步目标 attr key(DB 空 → fallback 默认值)
+ if v := strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncCorpEmailAttrKey]); v != "" {
+ effective.SyncCorpEmailAttrKey = v
+ }
+ if effective.SyncCorpEmailAttrKey == "" {
+ effective.SyncCorpEmailAttrKey = "dingtalk_email"
+ }
+ if v := strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncDisplayNameAttrKey]); v != "" {
+ effective.SyncDisplayNameAttrKey = v
+ }
+ if effective.SyncDisplayNameAttrKey == "" {
+ effective.SyncDisplayNameAttrKey = "dingtalk_name"
+ }
+ if v := strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncDeptAttrKey]); v != "" {
+ effective.SyncDeptAttrKey = v
+ }
+ if effective.SyncDeptAttrKey == "" {
+ effective.SyncDeptAttrKey = "dingtalk_department"
+ }
+
+ if !effective.Enabled {
+ return config.DingTalkConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "dingtalk oauth login is disabled")
+ }
+
+ // 基础健壮性校验(避免把用户重定向到一个必然失败或不安全的 OAuth 流程里)。
+ if strings.TrimSpace(effective.ClientID) == "" {
+ return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth client id not configured")
+ }
+ if strings.TrimSpace(effective.AuthorizeURL) == "" {
+ return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth authorize url not configured")
+ }
+ if strings.TrimSpace(effective.TokenURL) == "" {
+ return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth token url not configured")
+ }
+ if strings.TrimSpace(effective.UserInfoURL) == "" {
+ return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth userinfo url not configured")
+ }
+ if strings.TrimSpace(effective.RedirectURL) == "" {
+ return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth redirect url not configured")
+ }
+ if strings.TrimSpace(effective.FrontendRedirectURL) == "" {
+ return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth frontend redirect url not configured")
+ }
+
+ if err := config.ValidateAbsoluteHTTPURL(effective.AuthorizeURL); err != nil {
+ return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth authorize url invalid")
+ }
+ if err := config.ValidateAbsoluteHTTPURL(effective.TokenURL); err != nil {
+ return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth token url invalid")
+ }
+ if err := config.ValidateAbsoluteHTTPURL(effective.UserInfoURL); err != nil {
+ return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth userinfo url invalid")
+ }
+ if err := config.ValidateAbsoluteHTTPURL(effective.RedirectURL); err != nil {
+ return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth redirect url invalid")
+ }
+ if err := config.ValidateFrontendRedirectURL(effective.FrontendRedirectURL); err != nil {
+ return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth frontend redirect url invalid")
+ }
+ if strings.TrimSpace(effective.ClientSecret) == "" {
+ return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth client secret not configured")
+ }
+
+ // 镜像 admin handler 行为:internal_only policy 隐式要求 AppType=internal
+ if effective.CorpRestrictionPolicy == "internal_only" {
+ effective.AppType = "internal"
+ }
+
+ if err := config.ValidateDingTalkConfig(effective); err != nil {
+ return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", err.Error())
+ }
+
+ return effective, nil
+}
+
// GetWeChatConnectOAuthConfig 返回用于登录的最终生效 WeChat Connect 配置。
//
// WeChat Connect 已回归 DB 系统设置模型,不再回退到 config/env。
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index bfe85995..ea5fa57c 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -46,6 +46,25 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool
LinuxDoConnectRedirectURL string
+ // DingTalk Connect OAuth 登录
+ DingTalkConnectEnabled bool
+ DingTalkConnectClientID string
+ DingTalkConnectClientSecret string
+ DingTalkConnectClientSecretConfigured bool
+ DingTalkConnectRedirectURL string
+ DingTalkConnectCorpRestrictionPolicy string
+ DingTalkConnectInternalCorpID string
+ DingTalkConnectBypassRegistration bool
+ DingTalkConnectSyncCorpEmail bool
+ DingTalkConnectSyncDisplayName bool
+ DingTalkConnectSyncDept bool
+ DingTalkConnectSyncCorpEmailAttrKey string
+ DingTalkConnectSyncDisplayNameAttrKey string
+ DingTalkConnectSyncDeptAttrKey string
+ DingTalkConnectSyncCorpEmailAttrName string
+ DingTalkConnectSyncDisplayNameAttrName string
+ DingTalkConnectSyncDeptAttrName string
+
// WeChat Connect OAuth 登录
WeChatConnectEnabled bool
WeChatConnectAppID string
@@ -235,6 +254,7 @@ type PublicSettings struct {
CustomEndpoints string // JSON array of custom endpoints
LinuxDoOAuthEnabled bool
+ DingTalkOAuthEnabled bool
WeChatOAuthEnabled bool
WeChatOAuthOpenEnabled bool
WeChatOAuthMPEnabled bool
@@ -491,25 +511,10 @@ type OpenAIFastPolicySettings struct {
}
// DefaultOpenAIFastPolicySettings 返回默认的 OpenAI fast 策略配置。
-// 默认对所有模型的 priority(fast)请求执行 filter,即剔除 service_tier 字段,
-// 让上游按 normal 优先级处理。
-//
-// 为什么 ModelWhitelist 为空(=对所有模型生效):
-// codex 客户端的 service_tier=fast 是用户级开关,与 model 字段正交。即使
-// 用户使用 gpt-4 + fast,priority 配额仍会被消耗。如果默认规则只锁
-// gpt-5.5*,"用 gpt-4 + fast 透传 priority 上游" 这条路径就会绕过策略。
-// 与 codex 真实语义对齐,默认对所有模型生效;管理员若需要只针对特定
-// 模型,可在 admin UI 中显式配置 model_whitelist。
+// 默认不配置任何规则,保留 OpenAI 上游 service_tier 语义;管理员如需
+// 限制 priority/flex,可以在 admin UI 中显式配置 filter 或 block 规则。
func DefaultOpenAIFastPolicySettings() *OpenAIFastPolicySettings {
return &OpenAIFastPolicySettings{
- Rules: []OpenAIFastPolicyRule{
- {
- ServiceTier: OpenAIFastTierPriority,
- Action: BetaPolicyActionFilter,
- Scope: BetaPolicyScopeAll,
- ModelWhitelist: []string{},
- FallbackAction: BetaPolicyActionPass,
- },
- },
+ Rules: []OpenAIFastPolicyRule{},
}
}
diff --git a/backend/internal/service/slice_helpers.go b/backend/internal/service/slice_helpers.go
new file mode 100644
index 00000000..4894f5aa
--- /dev/null
+++ b/backend/internal/service/slice_helpers.go
@@ -0,0 +1,10 @@
+package service
+
+func containsInt64(values []int64, target int64) bool {
+ for _, v := range values {
+ if v == target {
+ return true
+ }
+ }
+ return false
+}
diff --git a/backend/internal/service/subscription_assign_idempotency_test.go b/backend/internal/service/subscription_assign_idempotency_test.go
index 40bab206..c8ace613 100644
--- a/backend/internal/service/subscription_assign_idempotency_test.go
+++ b/backend/internal/service/subscription_assign_idempotency_test.go
@@ -199,6 +199,24 @@ func (s *subscriptionUserSubRepoStub) GetByID(_ context.Context, id int64) (*Use
return &cp, nil
}
+func (s *subscriptionUserSubRepoStub) Update(_ context.Context, sub *UserSubscription) error {
+ if sub == nil {
+ return ErrSubscriptionNilInput
+ }
+ existing := s.byID[sub.ID]
+ if existing == nil {
+ return ErrSubscriptionNotFound
+ }
+ oldKey := s.key(existing.UserID, existing.GroupID)
+ cp := *sub
+ s.byID[cp.ID] = &cp
+ if oldKey != s.key(cp.UserID, cp.GroupID) {
+ delete(s.byUserGroup, oldKey)
+ }
+ s.byUserGroup[s.key(cp.UserID, cp.GroupID)] = &cp
+ return nil
+}
+
func TestAssignSubscriptionReuseWhenSemanticsMatch(t *testing.T) {
start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC)
groupRepo := &subscriptionGroupRepoStub{
diff --git a/backend/internal/service/subscription_calculate_progress_test.go b/backend/internal/service/subscription_calculate_progress_test.go
index 53e5c568..650522d5 100644
--- a/backend/internal/service/subscription_calculate_progress_test.go
+++ b/backend/internal/service/subscription_calculate_progress_test.go
@@ -66,6 +66,30 @@ func TestCalculateProgress_DailyUsage(t *testing.T) {
assert.Equal(t, dailyStart, progress.Daily.WindowStart)
}
+func TestCalculateProgress_DailyCardUsesExpiryAsDailyResetTime(t *testing.T) {
+ svc := newTestSubscriptionService()
+ startsAt := time.Now().Add(-12 * time.Hour)
+ dailyStart := time.Date(startsAt.Year(), startsAt.Month(), startsAt.Day(), 0, 0, 0, 0, startsAt.Location())
+ expiresAt := startsAt.Add(24 * time.Hour)
+
+ sub := &UserSubscription{
+ ID: 1,
+ StartsAt: startsAt,
+ ExpiresAt: expiresAt,
+ DailyUsageUSD: 3.0,
+ DailyWindowStart: ptrTime(dailyStart),
+ }
+ group := &Group{
+ Name: "Daily",
+ DailyLimitUSD: ptrFloat64(10.0),
+ }
+
+ progress := svc.calculateProgress(sub, group)
+
+ require.NotNil(t, progress.Daily, "日卡有日限额和窗口时 Daily 不应为 nil")
+ assert.Equal(t, expiresAt, progress.Daily.ResetsAt, "日卡的一次性日额度结束时间应为订阅过期时间")
+}
+
func TestCalculateProgress_WeeklyUsage(t *testing.T) {
svc := newTestSubscriptionService()
now := time.Now()
diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go
index f0a5540e..9905e6a1 100644
--- a/backend/internal/service/subscription_service.go
+++ b/backend/internal/service/subscription_service.go
@@ -196,7 +196,8 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
now := time.Now()
var newExpiresAt time.Time
- if existingSub.ExpiresAt.After(now) {
+ isExpired := !existingSub.ExpiresAt.After(now)
+ if !isExpired {
// 未过期:从当前过期时间累加
newExpiresAt = existingSub.ExpiresAt.AddDate(0, 0, validityDays)
} else {
@@ -209,43 +210,8 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
newExpiresAt = MaxExpiresAt
}
- // 开启事务:ExtendExpiry + UpdateStatus + UpdateNotes 在同一事务中完成
- tx, err := s.entClient.Tx(ctx)
- if err != nil {
- return nil, false, fmt.Errorf("begin transaction: %w", err)
- }
- txCtx := dbent.NewTxContext(ctx, tx)
-
- // 更新过期时间
- if err := s.userSubRepo.ExtendExpiry(txCtx, existingSub.ID, newExpiresAt); err != nil {
- _ = tx.Rollback()
- return nil, false, fmt.Errorf("extend subscription: %w", err)
- }
-
- // 如果订阅已过期或被暂停,恢复为active状态
- if existingSub.Status != SubscriptionStatusActive {
- if err := s.userSubRepo.UpdateStatus(txCtx, existingSub.ID, SubscriptionStatusActive); err != nil {
- _ = tx.Rollback()
- return nil, false, fmt.Errorf("update subscription status: %w", err)
- }
- }
-
- // 追加备注
- if input.Notes != "" {
- newNotes := existingSub.Notes
- if newNotes != "" {
- newNotes += "\n"
- }
- newNotes += input.Notes
- if err := s.userSubRepo.UpdateNotes(txCtx, existingSub.ID, newNotes); err != nil {
- _ = tx.Rollback()
- return nil, false, fmt.Errorf("update subscription notes: %w", err)
- }
- }
-
- // 提交事务
- if err := tx.Commit(); err != nil {
- return nil, false, fmt.Errorf("commit transaction: %w", err)
+ if err := s.updateExistingSubscriptionTerm(ctx, existingSub, input.Notes, now, newExpiresAt, isExpired); err != nil {
+ return nil, false, err
}
// 失效订阅缓存
@@ -284,6 +250,94 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
return sub, false, nil // false 表示是新建
}
+func (s *SubscriptionService) updateExistingSubscriptionTerm(
+ ctx context.Context,
+ existingSub *UserSubscription,
+ notes string,
+ startsAt time.Time,
+ newExpiresAt time.Time,
+ isExpired bool,
+) error {
+ return s.withSubscriptionUpdateTx(ctx, func(txCtx context.Context) error {
+ if isExpired {
+ renewed := renewedSubscriptionTerm(existingSub, notes, startsAt, newExpiresAt)
+ if err := s.userSubRepo.Update(txCtx, renewed); err != nil {
+ return fmt.Errorf("renew expired subscription: %w", err)
+ }
+ return nil
+ }
+
+ // 更新过期时间
+ if err := s.userSubRepo.ExtendExpiry(txCtx, existingSub.ID, newExpiresAt); err != nil {
+ return fmt.Errorf("extend subscription: %w", err)
+ }
+
+ // 如果订阅被暂停,恢复为 active 状态
+ if existingSub.Status != SubscriptionStatusActive {
+ if err := s.userSubRepo.UpdateStatus(txCtx, existingSub.ID, SubscriptionStatusActive); err != nil {
+ return fmt.Errorf("update subscription status: %w", err)
+ }
+ }
+
+ // 追加备注
+ if notes != "" {
+ if err := s.userSubRepo.UpdateNotes(txCtx, existingSub.ID, appendSubscriptionNotes(existingSub.Notes, notes)); err != nil {
+ return fmt.Errorf("update subscription notes: %w", err)
+ }
+ }
+
+ return nil
+ })
+}
+
+func (s *SubscriptionService) withSubscriptionUpdateTx(ctx context.Context, fn func(context.Context) error) error {
+ if s.entClient == nil {
+ return fn(ctx)
+ }
+
+ tx, err := s.entClient.Tx(ctx)
+ if err != nil {
+ return fmt.Errorf("begin transaction: %w", err)
+ }
+ txCtx := dbent.NewTxContext(ctx, tx)
+
+ if err := fn(txCtx); err != nil {
+ _ = tx.Rollback()
+ return err
+ }
+
+ if err := tx.Commit(); err != nil {
+ return fmt.Errorf("commit transaction: %w", err)
+ }
+ return nil
+}
+
+func renewedSubscriptionTerm(existingSub *UserSubscription, notes string, startsAt, expiresAt time.Time) *UserSubscription {
+ renewed := *existingSub
+ windowStart := startOfDay(startsAt)
+ renewed.StartsAt = startsAt
+ renewed.ExpiresAt = expiresAt
+ renewed.Status = SubscriptionStatusActive
+ renewed.DailyWindowStart = &windowStart
+ renewed.WeeklyWindowStart = &windowStart
+ renewed.MonthlyWindowStart = &windowStart
+ renewed.DailyUsageUSD = 0
+ renewed.WeeklyUsageUSD = 0
+ renewed.MonthlyUsageUSD = 0
+ renewed.Notes = appendSubscriptionNotes(existingSub.Notes, notes)
+ return &renewed
+}
+
+func appendSubscriptionNotes(existingNotes, newNotes string) string {
+ if newNotes == "" {
+ return existingNotes
+ }
+ if existingNotes == "" {
+ return newNotes
+ }
+ return existingNotes + "\n" + newNotes
+}
+
// createSubscription 创建新订阅(内部方法)
func (s *SubscriptionService) createSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) {
validityDays := input.ValidityDays
@@ -945,6 +999,9 @@ func (s *SubscriptionService) calculateProgress(sub *UserSubscription, group *Gr
if group.HasDailyLimit() && sub.DailyWindowStart != nil {
limit := *group.DailyLimitUSD
resetsAt := sub.DailyWindowStart.Add(24 * time.Hour)
+ if dailyResetTime := sub.DailyResetTime(); dailyResetTime != nil {
+ resetsAt = *dailyResetTime
+ }
progress.Daily = &UsageWindowProgress{
LimitUSD: limit,
UsedUSD: sub.DailyUsageUSD,
diff --git a/backend/internal/service/upstream_models.go b/backend/internal/service/upstream_models.go
new file mode 100644
index 00000000..77e8d1e4
--- /dev/null
+++ b/backend/internal/service/upstream_models.go
@@ -0,0 +1,474 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "sort"
+ "strings"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/claude"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
+)
+
+const upstreamModelsBodyLimit int64 = 8 << 20
+
+// UpstreamModelSyncErrorKind classifies model sync failures for safe HTTP mapping.
+type UpstreamModelSyncErrorKind string
+
+const (
+ // UpstreamModelSyncErrorConfiguration means the account or server configuration cannot perform the sync.
+ UpstreamModelSyncErrorConfiguration UpstreamModelSyncErrorKind = "configuration"
+ // UpstreamModelSyncErrorUnsupported means the account format is intentionally unsupported for live model sync.
+ UpstreamModelSyncErrorUnsupported UpstreamModelSyncErrorKind = "unsupported"
+ // UpstreamModelSyncErrorUpstream means the configured upstream failed or returned an unusable response.
+ UpstreamModelSyncErrorUpstream UpstreamModelSyncErrorKind = "upstream"
+)
+
+// UpstreamModelSyncError keeps internal failure details wrapped while exposing a safe client message.
+type UpstreamModelSyncError struct {
+ Kind UpstreamModelSyncErrorKind
+ Message string
+ Err error
+}
+
+func (e *UpstreamModelSyncError) Error() string {
+ if e == nil {
+ return ""
+ }
+ if e.Err == nil {
+ return e.Message
+ }
+ return e.Message + ": " + e.Err.Error()
+}
+
+func (e *UpstreamModelSyncError) Unwrap() error {
+ if e == nil {
+ return nil
+ }
+ return e.Err
+}
+
+// SafeMessage returns the sanitized message that can be sent to API clients.
+func (e *UpstreamModelSyncError) SafeMessage() string {
+ if e == nil || strings.TrimSpace(e.Message) == "" {
+ return "Failed to sync upstream models"
+ }
+ return e.Message
+}
+
+func newUpstreamModelSyncConfigError(message string, err error) error {
+ return &UpstreamModelSyncError{Kind: UpstreamModelSyncErrorConfiguration, Message: message, Err: err}
+}
+
+func newUpstreamModelSyncUnsupportedError(message string, err error) error {
+ return &UpstreamModelSyncError{Kind: UpstreamModelSyncErrorUnsupported, Message: message, Err: err}
+}
+
+func newUpstreamModelSyncUpstreamError(message string, err error) error {
+ return &UpstreamModelSyncError{Kind: UpstreamModelSyncErrorUpstream, Message: message, Err: err}
+}
+
+// FetchUpstreamSupportedModels fetches the live model list from the account's upstream API format.
+func (s *AccountTestService) FetchUpstreamSupportedModels(ctx context.Context, account *Account) ([]string, error) {
+ if s == nil {
+ return nil, newUpstreamModelSyncConfigError("Account test service is not configured", nil)
+ }
+ if account == nil {
+ return nil, newUpstreamModelSyncConfigError("Account is required", nil)
+ }
+
+ if account.Platform == PlatformAntigravity && account.Type != AccountTypeAPIKey {
+ return s.fetchAntigravityOAuthUpstreamModels(ctx, account)
+ }
+
+ if s.httpUpstream == nil {
+ return nil, newUpstreamModelSyncConfigError("Upstream HTTP client is not configured", nil)
+ }
+
+ req, err := s.buildUpstreamModelsRequest(ctx, account)
+ if err != nil {
+ return nil, err
+ }
+
+ proxyURL := upstreamModelsProxyURL(account)
+ resp, err := s.doUpstreamModelsRequest(req, proxyURL, account)
+ if err != nil {
+ return nil, newUpstreamModelSyncUpstreamError("Failed to request upstream model list", err)
+ }
+ defer func() { _ = resp.Body.Close() }()
+
+ body, err := io.ReadAll(io.LimitReader(resp.Body, upstreamModelsBodyLimit+1))
+ if err != nil {
+ return nil, newUpstreamModelSyncUpstreamError("Failed to read upstream model list", err)
+ }
+ if int64(len(body)) > upstreamModelsBodyLimit {
+ return nil, newUpstreamModelSyncUpstreamError("Upstream model list response is too large", fmt.Errorf("response exceeds %d bytes", upstreamModelsBodyLimit))
+ }
+
+ if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
+ return nil, newUpstreamModelSyncUpstreamError(
+ fmt.Sprintf("Upstream model list request failed with HTTP %d", resp.StatusCode),
+ fmt.Errorf("upstream model list returned HTTP %d", resp.StatusCode),
+ )
+ }
+
+ models, err := extractUpstreamModelIDs(body)
+ if err != nil {
+ return nil, newUpstreamModelSyncUpstreamError("Upstream model list response was not valid JSON", err)
+ }
+ if len(models) == 0 {
+ return nil, newUpstreamModelSyncUpstreamError("Upstream returned no supported models", nil)
+ }
+
+ return models, nil
+}
+
+func (s *AccountTestService) buildUpstreamModelsRequest(ctx context.Context, account *Account) (*http.Request, error) {
+ switch {
+ case account.Platform == PlatformAntigravity:
+ return s.buildAntigravityAPIKeyModelsRequest(ctx, account)
+ case account.IsOpenAI():
+ return s.buildOpenAIUpstreamModelsRequest(ctx, account)
+ case account.IsGemini():
+ return s.buildGeminiUpstreamModelsRequest(ctx, account)
+ case account.IsAnthropic():
+ return s.buildAnthropicUpstreamModelsRequest(ctx, account)
+ default:
+ return nil, newUpstreamModelSyncUnsupportedError(
+ fmt.Sprintf("Unsupported platform for upstream model sync: %s", account.Platform), nil,
+ )
+ }
+}
+
+func (s *AccountTestService) buildAnthropicUpstreamModelsRequest(ctx context.Context, account *Account) (*http.Request, error) {
+ if account.IsBedrock() || account.Type == AccountTypeServiceAccount {
+ return nil, newUpstreamModelSyncUnsupportedError(
+ fmt.Sprintf("Unsupported Anthropic account type for upstream model sync: %s", account.Type), nil,
+ )
+ }
+
+ baseURL := "https://api.anthropic.com"
+ authHeaderName := ""
+ authHeaderValue := ""
+ betaHeader := ""
+
+ if account.IsOAuth() {
+ accessToken := strings.TrimSpace(account.GetCredential("access_token"))
+ if accessToken == "" && s.claudeTokenProvider != nil {
+ token, tokenErr := s.claudeTokenProvider.GetAccessToken(ctx, account)
+ if tokenErr != nil {
+ return nil, newUpstreamModelSyncUpstreamError("Failed to get Anthropic access token", tokenErr)
+ }
+ accessToken = strings.TrimSpace(token)
+ }
+ if accessToken == "" {
+ return nil, newUpstreamModelSyncConfigError("No Anthropic access token is available", nil)
+ }
+ authHeaderName = "Authorization"
+ authHeaderValue = "Bearer " + accessToken
+ betaHeader = claude.DefaultBetaHeader
+ } else if account.Type == AccountTypeAPIKey {
+ apiKey := strings.TrimSpace(account.GetCredential("api_key"))
+ if apiKey == "" {
+ return nil, newUpstreamModelSyncConfigError("No Anthropic API key is available", nil)
+ }
+ baseURL = account.GetBaseURL()
+ if strings.TrimSpace(baseURL) == "" {
+ baseURL = "https://api.anthropic.com"
+ }
+ authHeaderName = "x-api-key"
+ authHeaderValue = apiKey
+ betaHeader = claude.APIKeyBetaHeader
+ } else {
+ return nil, newUpstreamModelSyncUnsupportedError(
+ fmt.Sprintf("Unsupported Anthropic account type for upstream model sync: %s", account.Type), nil,
+ )
+ }
+
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, newUpstreamModelSyncConfigError("Invalid Anthropic base URL", err)
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, buildV1ModelsURL(normalizedBaseURL), nil)
+ if err != nil {
+ return nil, newUpstreamModelSyncConfigError("Invalid Anthropic model list URL", err)
+ }
+ for key, value := range claude.DefaultHeaders {
+ req.Header.Set(key, value)
+ }
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("anthropic-version", "2023-06-01")
+ req.Header.Set("anthropic-beta", betaHeader)
+ req.Header.Set(authHeaderName, authHeaderValue)
+ return req, nil
+}
+
+func (s *AccountTestService) buildAntigravityAPIKeyModelsRequest(ctx context.Context, account *Account) (*http.Request, error) {
+ if account.Type != AccountTypeAPIKey {
+ return nil, newUpstreamModelSyncUnsupportedError(
+ fmt.Sprintf("Unsupported Antigravity account type for upstream model sync: %s", account.Type), nil,
+ )
+ }
+ apiKey := strings.TrimSpace(account.GetCredential("api_key"))
+ if apiKey == "" {
+ return nil, newUpstreamModelSyncConfigError("No Antigravity API key is available", nil)
+ }
+
+ baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/")
+ if baseURL == "" {
+ return nil, newUpstreamModelSyncConfigError("Antigravity API-key base URL is required for upstream model sync", nil)
+ }
+ if !strings.HasSuffix(strings.ToLower(baseURL), "/antigravity") {
+ return nil, newUpstreamModelSyncUnsupportedError(
+ "Antigravity API-key upstream model sync requires a compatible gateway base URL ending in /antigravity; use Antigravity OAuth for official Cloud Code upstreams",
+ nil,
+ )
+ }
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, newUpstreamModelSyncConfigError("Invalid Antigravity base URL", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, buildV1ModelsURL(normalizedBaseURL), nil)
+ if err != nil {
+ return nil, newUpstreamModelSyncConfigError("Invalid Antigravity model list URL", err)
+ }
+ for key, value := range claude.DefaultHeaders {
+ req.Header.Set(key, value)
+ }
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("anthropic-version", "2023-06-01")
+ req.Header.Set("anthropic-beta", claude.APIKeyBetaHeader)
+ req.Header.Set("x-api-key", apiKey)
+ return req, nil
+}
+
+func (s *AccountTestService) buildOpenAIUpstreamModelsRequest(ctx context.Context, account *Account) (*http.Request, error) {
+ if account.Type != AccountTypeAPIKey {
+ return nil, newUpstreamModelSyncUnsupportedError(
+ fmt.Sprintf("Unsupported OpenAI account type for upstream model sync: %s", account.Type), nil,
+ )
+ }
+ apiKey := strings.TrimSpace(account.GetOpenAIApiKey())
+ if apiKey == "" {
+ return nil, newUpstreamModelSyncConfigError("No OpenAI API key is available", nil)
+ }
+
+ baseURL := account.GetOpenAIBaseURL()
+ if strings.TrimSpace(baseURL) == "" {
+ baseURL = "https://api.openai.com"
+ }
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, newUpstreamModelSyncConfigError("Invalid OpenAI base URL", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, buildOpenAIModelsURL(normalizedBaseURL), nil)
+ if err != nil {
+ return nil, newUpstreamModelSyncConfigError("Invalid OpenAI model list URL", err)
+ }
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("Authorization", "Bearer "+apiKey)
+ return req, nil
+}
+
+func (s *AccountTestService) buildGeminiUpstreamModelsRequest(ctx context.Context, account *Account) (*http.Request, error) {
+ baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL)
+ if strings.TrimSpace(baseURL) == "" {
+ baseURL = geminicli.AIStudioBaseURL
+ }
+ normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL)
+ if err != nil {
+ return nil, newUpstreamModelSyncConfigError("Invalid Gemini base URL", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, buildGeminiModelsURL(normalizedBaseURL), nil)
+ if err != nil {
+ return nil, newUpstreamModelSyncConfigError("Invalid Gemini model list URL", err)
+ }
+ req.Header.Set("Accept", "application/json")
+
+ switch account.Type {
+ case AccountTypeAPIKey:
+ apiKey := strings.TrimSpace(account.GetCredential("api_key"))
+ if apiKey == "" {
+ return nil, newUpstreamModelSyncConfigError("No Gemini API key is available", nil)
+ }
+ req.Header.Set("x-goog-api-key", apiKey)
+ case AccountTypeOAuth:
+ if strings.TrimSpace(account.GetCredential("project_id")) != "" {
+ return nil, newUpstreamModelSyncUnsupportedError("Gemini Code Assist model listing is not supported by this sync button", nil)
+ }
+ if s.geminiTokenProvider == nil {
+ return nil, newUpstreamModelSyncConfigError("Gemini token provider is not configured", nil)
+ }
+ accessToken, tokenErr := s.geminiTokenProvider.GetAccessToken(ctx, account)
+ if tokenErr != nil {
+ return nil, newUpstreamModelSyncUpstreamError("Failed to get Gemini access token", tokenErr)
+ }
+ accessToken = strings.TrimSpace(accessToken)
+ if accessToken == "" {
+ return nil, newUpstreamModelSyncConfigError("No Gemini access token is available", nil)
+ }
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+ default:
+ return nil, newUpstreamModelSyncUnsupportedError(
+ fmt.Sprintf("Unsupported Gemini account type for upstream model sync: %s", account.Type), nil,
+ )
+ }
+
+ return req, nil
+}
+
+func (s *AccountTestService) fetchAntigravityOAuthUpstreamModels(ctx context.Context, account *Account) ([]string, error) {
+ if s.antigravityGatewayService == nil || s.antigravityGatewayService.GetTokenProvider() == nil {
+ return nil, newUpstreamModelSyncConfigError("Antigravity token provider is not configured", nil)
+ }
+
+ accessToken, err := s.antigravityGatewayService.GetTokenProvider().GetAccessToken(ctx, account)
+ if err != nil {
+ return nil, newUpstreamModelSyncUpstreamError("Failed to get Antigravity access token", err)
+ }
+ accessToken = strings.TrimSpace(accessToken)
+ if accessToken == "" {
+ return nil, newUpstreamModelSyncConfigError("No Antigravity access token is available", nil)
+ }
+
+ client, err := antigravity.NewClient(upstreamModelsProxyURL(account))
+ if err != nil {
+ return nil, newUpstreamModelSyncConfigError("Failed to configure Antigravity client", err)
+ }
+ modelsResp, _, err := client.FetchAvailableModels(ctx, accessToken, strings.TrimSpace(account.GetCredential("project_id")))
+ if err != nil {
+ return nil, newUpstreamModelSyncUpstreamError("Failed to fetch Antigravity available models", err)
+ }
+ if modelsResp == nil || len(modelsResp.Models) == 0 {
+ return nil, newUpstreamModelSyncUpstreamError("Upstream returned no supported models", nil)
+ }
+
+ models := make([]string, 0, len(modelsResp.Models))
+ for modelID := range modelsResp.Models {
+ models = append(models, strings.TrimSpace(modelID))
+ }
+ return dedupeAndSortModelIDs(models), nil
+}
+
+func (s *AccountTestService) doUpstreamModelsRequest(req *http.Request, proxyURL string, account *Account) (*http.Response, error) {
+ if s.tlsFPProfileService == nil {
+ return s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, nil)
+ }
+ return s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account))
+}
+
+func upstreamModelsProxyURL(account *Account) string {
+ if account != nil && account.ProxyID != nil && account.Proxy != nil {
+ return account.Proxy.URL()
+ }
+ return ""
+}
+
+func buildV1ModelsURL(base string) string {
+ normalized := strings.TrimRight(strings.TrimSpace(base), "/")
+ if strings.HasSuffix(normalized, "/v1/models") {
+ return normalized
+ }
+ if strings.HasSuffix(normalized, "/v1") {
+ return normalized + "/models"
+ }
+ return normalized + "/v1/models"
+}
+
+func buildOpenAIModelsURL(base string) string {
+ normalized := strings.TrimRight(strings.TrimSpace(base), "/")
+ if strings.HasSuffix(normalized, "/v1/models") {
+ return normalized
+ }
+ if strings.HasSuffix(normalized, "/v1") {
+ return normalized + "/models"
+ }
+ return normalized + "/v1/models"
+}
+
+func buildGeminiModelsURL(base string) string {
+ normalized := strings.TrimRight(strings.TrimSpace(base), "/")
+ if strings.HasSuffix(normalized, "/v1beta/models") {
+ return normalized
+ }
+ if strings.HasSuffix(normalized, "/v1beta") {
+ return normalized + "/models"
+ }
+ return normalized + "/v1beta/models"
+}
+
+type upstreamModelEntry struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+}
+
+func extractUpstreamModelIDs(body []byte) ([]string, error) {
+ var response struct {
+ Data []upstreamModelEntry `json:"data"`
+ Models []upstreamModelEntry `json:"models"`
+ }
+ if err := json.Unmarshal(body, &response); err != nil {
+ var arrayResponse []upstreamModelEntry
+ if arrayErr := json.Unmarshal(body, &arrayResponse); arrayErr != nil {
+ return nil, fmt.Errorf("parse upstream model list: %w", err)
+ }
+
+ models := make([]string, 0, len(arrayResponse))
+ for _, entry := range arrayResponse {
+ models = append(models, upstreamModelEntryID(entry))
+ }
+ return dedupeAndSortModelIDs(models), nil
+ }
+
+ models := make([]string, 0, len(response.Data)+len(response.Models))
+ for _, entry := range response.Data {
+ models = append(models, upstreamModelEntryID(entry))
+ }
+ for _, entry := range response.Models {
+ models = append(models, upstreamModelEntryID(entry))
+ }
+
+ if len(models) == 0 {
+ var arrayResponse []upstreamModelEntry
+ if err := json.Unmarshal(body, &arrayResponse); err == nil {
+ for _, entry := range arrayResponse {
+ models = append(models, upstreamModelEntryID(entry))
+ }
+ }
+ }
+
+ return dedupeAndSortModelIDs(models), nil
+}
+
+func upstreamModelEntryID(entry upstreamModelEntry) string {
+ modelID := strings.TrimSpace(entry.ID)
+ if modelID == "" {
+ modelID = strings.TrimSpace(entry.Name)
+ }
+ return strings.TrimPrefix(modelID, "models/")
+}
+
+func dedupeAndSortModelIDs(models []string) []string {
+ seen := make(map[string]struct{}, len(models))
+ result := make([]string, 0, len(models))
+ for _, model := range models {
+ model = strings.TrimSpace(model)
+ if model == "" {
+ continue
+ }
+ if _, exists := seen[model]; exists {
+ continue
+ }
+ seen[model] = struct{}{}
+ result = append(result, model)
+ }
+ sort.Strings(result)
+ return result
+}
diff --git a/backend/internal/service/upstream_models_test.go b/backend/internal/service/upstream_models_test.go
new file mode 100644
index 00000000..6831e791
--- /dev/null
+++ b/backend/internal/service/upstream_models_test.go
@@ -0,0 +1,226 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "io"
+ "net/http"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+func upstreamModelSyncTestConfig() *config.Config {
+ return &config.Config{
+ Security: config.SecurityConfig{
+ URLAllowlist: config.URLAllowlistConfig{Enabled: false},
+ },
+ }
+}
+
+func TestBuildV1ModelsURL(t *testing.T) {
+ t.Parallel()
+
+ require.Equal(t, "https://api.anthropic.com/v1/models", buildV1ModelsURL("https://api.anthropic.com"))
+ require.Equal(t, "https://api.anthropic.com/v1/models", buildV1ModelsURL("https://api.anthropic.com/v1"))
+ require.Equal(t, "https://api.anthropic.com/v1/models", buildV1ModelsURL("https://api.anthropic.com/v1/models"))
+ require.Equal(t, "https://gateway.example.com/antigravity/v1/models", buildV1ModelsURL("https://gateway.example.com/antigravity/"))
+}
+
+func TestBuildGeminiModelsURL(t *testing.T) {
+ t.Parallel()
+
+ require.Equal(t, "https://generativelanguage.googleapis.com/v1beta/models", buildGeminiModelsURL("https://generativelanguage.googleapis.com"))
+ require.Equal(t, "https://generativelanguage.googleapis.com/v1beta/models", buildGeminiModelsURL("https://generativelanguage.googleapis.com/v1beta"))
+ require.Equal(t, "https://generativelanguage.googleapis.com/v1beta/models", buildGeminiModelsURL("https://generativelanguage.googleapis.com/v1beta/models"))
+}
+
+func TestExtractUpstreamModelIDs(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ body string
+ want []string
+ }{
+ {
+ name: "openai and anthropic data array",
+ body: `{"data":[{"id":"claude-sonnet-4-5"},{"id":"gpt-5"},{"id":"gpt-5"},{"id":""}]}`,
+ want: []string{"claude-sonnet-4-5", "gpt-5"},
+ },
+ {
+ name: "gemini models array strips prefix",
+ body: `{"models":[{"name":"models/gemini-2.5-pro"},{"name":"gemini-2.5-flash"}]}`,
+ want: []string{"gemini-2.5-flash", "gemini-2.5-pro"},
+ },
+ {
+ name: "top level array",
+ body: `[{"id":"z-model"},{"name":"models/a-model"}]`,
+ want: []string{"a-model", "z-model"},
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ got, err := extractUpstreamModelIDs([]byte(tt.body))
+ require.NoError(t, err)
+ require.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestBuildUpstreamModelsRequestsForAPIKeyAccounts(t *testing.T) {
+ t.Parallel()
+
+ svc := &AccountTestService{cfg: upstreamModelSyncTestConfig()}
+ ctx := context.Background()
+
+ anthropicReq, err := svc.buildAnthropicUpstreamModelsRequest(ctx, &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": "anthropic-key",
+ "base_url": "https://anthropic.example.com/v1",
+ },
+ })
+ require.NoError(t, err)
+ require.Equal(t, "https://anthropic.example.com/v1/models", anthropicReq.URL.String())
+ require.Equal(t, "anthropic-key", anthropicReq.Header.Get("x-api-key"))
+ require.Equal(t, "2023-06-01", anthropicReq.Header.Get("anthropic-version"))
+
+ openAIReq, err := svc.buildOpenAIUpstreamModelsRequest(ctx, &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": "openai-key",
+ "base_url": "https://openai.example.com",
+ },
+ })
+ require.NoError(t, err)
+ require.Equal(t, "https://openai.example.com/v1/models", openAIReq.URL.String())
+ require.Equal(t, "Bearer openai-key", openAIReq.Header.Get("Authorization"))
+
+ geminiReq, err := svc.buildGeminiUpstreamModelsRequest(ctx, &Account{
+ Platform: PlatformGemini,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": "gemini-key",
+ "base_url": "https://generativelanguage.googleapis.com/v1beta",
+ },
+ })
+ require.NoError(t, err)
+ require.Equal(t, "https://generativelanguage.googleapis.com/v1beta/models", geminiReq.URL.String())
+ require.Equal(t, "gemini-key", geminiReq.Header.Get("x-goog-api-key"))
+
+ antigravityReq, err := svc.buildAntigravityAPIKeyModelsRequest(ctx, &Account{
+ Platform: PlatformAntigravity,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": "antigravity-key",
+ "base_url": "https://gateway.example.com/antigravity",
+ },
+ })
+ require.NoError(t, err)
+ require.Equal(t, "https://gateway.example.com/antigravity/v1/models", antigravityReq.URL.String())
+ require.Equal(t, "antigravity-key", antigravityReq.Header.Get("x-api-key"))
+}
+
+func TestBuildAntigravityAPIKeyModelsRequestRejectsOfficialCloudCodeBase(t *testing.T) {
+ t.Parallel()
+
+ svc := &AccountTestService{cfg: upstreamModelSyncTestConfig()}
+ _, err := svc.buildAntigravityAPIKeyModelsRequest(context.Background(), &Account{
+ Platform: PlatformAntigravity,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": "antigravity-key",
+ "base_url": "https://cloudcode-pa.googleapis.com",
+ },
+ })
+ require.Error(t, err)
+
+ var syncErr *UpstreamModelSyncError
+ require.True(t, errors.As(err, &syncErr))
+ require.Equal(t, UpstreamModelSyncErrorUnsupported, syncErr.Kind)
+ require.Contains(t, syncErr.SafeMessage(), "compatible gateway")
+}
+
+func TestBuildAnthropicUpstreamModelsRequestRejectsBedrock(t *testing.T) {
+ t.Parallel()
+
+ svc := &AccountTestService{cfg: upstreamModelSyncTestConfig()}
+ _, err := svc.buildAnthropicUpstreamModelsRequest(context.Background(), &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeBedrock,
+ })
+ require.Error(t, err)
+
+ var syncErr *UpstreamModelSyncError
+ require.True(t, errors.As(err, &syncErr))
+ require.Equal(t, UpstreamModelSyncErrorUnsupported, syncErr.Kind)
+}
+
+func TestFetchUpstreamSupportedModelsParsesOpenAIResponse(t *testing.T) {
+ t.Parallel()
+
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusOK,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"data":[{"id":"gpt-5"},{"id":"gpt-5"},{"name":"o3"}]}`)),
+ }}
+ svc := &AccountTestService{
+ httpUpstream: upstream,
+ cfg: upstreamModelSyncTestConfig(),
+ }
+
+ models, err := svc.FetchUpstreamSupportedModels(context.Background(), &Account{
+ ID: 7,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": "openai-key",
+ "base_url": "https://openai.example.com/v1",
+ },
+ })
+ require.NoError(t, err)
+ require.Equal(t, []string{"gpt-5", "o3"}, models)
+ require.Equal(t, "https://openai.example.com/v1/models", upstream.lastReq.URL.String())
+ require.Equal(t, "Bearer openai-key", upstream.lastReq.Header.Get("Authorization"))
+}
+
+func TestFetchUpstreamSupportedModelsDoesNotExposeUpstreamBody(t *testing.T) {
+ t.Parallel()
+
+ upstream := &httpUpstreamRecorder{resp: &http.Response{
+ StatusCode: http.StatusBadGateway,
+ Header: http.Header{"Content-Type": []string{"application/json"}},
+ Body: io.NopCloser(strings.NewReader(`{"error":"SECRET_TOKEN should not be exposed"}`)),
+ }}
+ svc := &AccountTestService{
+ httpUpstream: upstream,
+ cfg: upstreamModelSyncTestConfig(),
+ }
+
+ _, err := svc.FetchUpstreamSupportedModels(context.Background(), &Account{
+ ID: 8,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": "openai-key",
+ "base_url": "https://openai.example.com/v1",
+ },
+ })
+ require.Error(t, err)
+ require.NotContains(t, err.Error(), "SECRET_TOKEN")
+
+ var syncErr *UpstreamModelSyncError
+ require.True(t, errors.As(err, &syncErr))
+ require.Equal(t, UpstreamModelSyncErrorUpstream, syncErr.Kind)
+ require.NotContains(t, syncErr.SafeMessage(), "SECRET_TOKEN")
+ require.Contains(t, syncErr.SafeMessage(), "HTTP 502")
+}
diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go
index e29d282e..d63f47cc 100644
--- a/backend/internal/service/usage_log.go
+++ b/backend/internal/service/usage_log.go
@@ -162,9 +162,13 @@ type UsageLog struct {
CacheTTLOverridden bool
// 图片生成字段
- ImageCount int
- ImageSize *string
- MediaType *string
+ ImageCount int
+ ImageSize *string
+ ImageInputSize *string
+ ImageOutputSize *string
+ ImageSizeSource *string
+ ImageSizeBreakdown map[string]int
+ MediaType *string
CreatedAt time.Time
diff --git a/backend/internal/service/user_attribute_service.go b/backend/internal/service/user_attribute_service.go
index 6c2f8077..ef19e078 100644
--- a/backend/internal/service/user_attribute_service.go
+++ b/backend/internal/service/user_attribute_service.go
@@ -72,6 +72,11 @@ func (s *UserAttributeService) GetDefinition(ctx context.Context, id int64) (*Us
return s.defRepo.GetByID(ctx, id)
}
+// GetDefinitionByKey retrieves a definition by its unique key
+func (s *UserAttributeService) GetDefinitionByKey(ctx context.Context, key string) (*UserAttributeDefinition, error) {
+ return s.defRepo.GetByKey(ctx, key)
+}
+
// ListDefinitions lists all definitions
func (s *UserAttributeService) ListDefinitions(ctx context.Context, enabledOnly bool) ([]UserAttributeDefinition, error) {
return s.defRepo.List(ctx, enabledOnly)
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index 208a05db..61e9f846 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -142,10 +142,11 @@ type UserIdentitySummary struct {
}
type UserIdentitySummarySet struct {
- Email UserIdentitySummary `json:"email"`
- LinuxDo UserIdentitySummary `json:"linuxdo"`
- OIDC UserIdentitySummary `json:"oidc"`
- WeChat UserIdentitySummary `json:"wechat"`
+ Email UserIdentitySummary `json:"email"`
+ LinuxDo UserIdentitySummary `json:"linuxdo"`
+ OIDC UserIdentitySummary `json:"oidc"`
+ WeChat UserIdentitySummary `json:"wechat"`
+ DingTalk UserIdentitySummary `json:"dingtalk"`
}
type StartUserIdentityBindingRequest struct {
@@ -261,10 +262,11 @@ func (s *UserService) GetProfileIdentitySummaries(ctx context.Context, userID in
}
summaries := UserIdentitySummarySet{
- Email: s.buildEmailIdentitySummary(user, records),
- LinuxDo: s.buildProviderIdentitySummary("linuxdo", user, records),
- OIDC: s.buildProviderIdentitySummary("oidc", user, records),
- WeChat: s.buildProviderIdentitySummary("wechat", user, records),
+ Email: s.buildEmailIdentitySummary(user, records),
+ LinuxDo: s.buildProviderIdentitySummary("linuxdo", user, records),
+ OIDC: s.buildProviderIdentitySummary("oidc", user, records),
+ WeChat: s.buildProviderIdentitySummary("wechat", user, records),
+ DingTalk: s.buildProviderIdentitySummary("dingtalk", user, records),
}
s.applyExplicitProviderAvailability(ctx, &summaries)
@@ -284,6 +286,7 @@ func (s *UserService) applyExplicitProviderAvailability(ctx context.Context, sum
SettingKeyWeChatConnectMPEnabled,
SettingKeyWeChatConnectMobileEnabled,
SettingKeyWeChatConnectMode,
+ SettingKeyDingTalkConnectEnabled,
})
if err != nil {
return
@@ -292,6 +295,9 @@ func (s *UserService) applyExplicitProviderAvailability(ctx context.Context, sum
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok && strings.TrimSpace(raw) != "" && raw != "true" {
disableIdentityBindAction(&summaries.LinuxDo)
}
+ if raw, ok := settings[SettingKeyDingTalkConnectEnabled]; ok && strings.TrimSpace(raw) != "" && raw != "true" {
+ disableIdentityBindAction(&summaries.DingTalk)
+ }
if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok && strings.TrimSpace(raw) != "" && raw != "true" {
disableIdentityBindAction(&summaries.OIDC)
}
@@ -697,7 +703,7 @@ func (s *UserService) canUnbindProvider(provider string, user *User, records []U
return true
}
- for _, candidate := range []string{"linuxdo", "oidc", "wechat"} {
+ for _, candidate := range []string{"linuxdo", "oidc", "wechat", "dingtalk"} {
if candidate == provider {
continue
}
@@ -773,6 +779,8 @@ func buildUserIdentityBindAuthorizeURL(provider, redirectTo string) (string, err
path = "/api/v1/auth/oauth/oidc/bind/start"
case "wechat":
path = "/api/v1/auth/oauth/wechat/bind/start"
+ case "dingtalk":
+ path = "/api/v1/auth/oauth/dingtalk/bind/start"
default:
return "", ErrIdentityProviderInvalid
}
@@ -791,6 +799,8 @@ func normalizeUserIdentityProvider(provider string) string {
return "oidc"
case "wechat":
return "wechat"
+ case "dingtalk":
+ return "dingtalk"
case "email":
return "email"
default:
diff --git a/backend/internal/service/user_subscription.go b/backend/internal/service/user_subscription.go
index ec547d81..6303e6e3 100644
--- a/backend/internal/service/user_subscription.go
+++ b/backend/internal/service/user_subscription.go
@@ -50,11 +50,25 @@ func (s *UserSubscription) IsWindowActivated() bool {
return s.DailyWindowStart != nil || s.WeeklyWindowStart != nil || s.MonthlyWindowStart != nil
}
+func (s *UserSubscription) HasOneTimeDailyQuota() bool {
+ if s == nil || s.StartsAt.IsZero() || s.ExpiresAt.IsZero() {
+ return false
+ }
+ return !s.ExpiresAt.After(s.StartsAt.AddDate(0, 0, 1))
+}
+
func (s *UserSubscription) NeedsDailyReset() bool {
+ return s.NeedsDailyResetAt(time.Now())
+}
+
+func (s *UserSubscription) NeedsDailyResetAt(now time.Time) bool {
if s.DailyWindowStart == nil {
return false
}
- return time.Since(*s.DailyWindowStart) >= 24*time.Hour
+ if s.HasOneTimeDailyQuota() {
+ return false
+ }
+ return !now.Before(s.DailyWindowStart.Add(24 * time.Hour))
}
func (s *UserSubscription) NeedsWeeklyReset() bool {
@@ -75,6 +89,10 @@ func (s *UserSubscription) DailyResetTime() *time.Time {
if s.DailyWindowStart == nil {
return nil
}
+ if s.HasOneTimeDailyQuota() {
+ t := s.ExpiresAt
+ return &t
+ }
t := s.DailyWindowStart.Add(24 * time.Hour)
return &t
}
diff --git a/backend/internal/service/user_subscription_daily_quota_test.go b/backend/internal/service/user_subscription_daily_quota_test.go
new file mode 100644
index 00000000..3738bdd6
--- /dev/null
+++ b/backend/internal/service/user_subscription_daily_quota_test.go
@@ -0,0 +1,178 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+type dailyResetTrackingUserSubRepo struct {
+ userSubRepoNoop
+
+ resetDailyCalled bool
+}
+
+func (r *dailyResetTrackingUserSubRepo) ResetDailyUsage(context.Context, int64, time.Time) error {
+ r.resetDailyCalled = true
+ return nil
+}
+
+func TestAssignOrExtendSubscription_ExpiredDailyCardStartsNewOneTimeQuota(t *testing.T) {
+ groupRepo := &subscriptionGroupRepoStub{
+ group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription},
+ }
+ subRepo := newSubscriptionUserSubRepoStub()
+ oldStart := time.Now().AddDate(0, 0, -3)
+ oldWindowStart := startOfDay(oldStart)
+ subRepo.seed(&UserSubscription{
+ ID: 100,
+ UserID: 200,
+ GroupID: 1,
+ StartsAt: oldStart,
+ ExpiresAt: oldStart.AddDate(0, 0, 1),
+ Status: SubscriptionStatusExpired,
+ DailyWindowStart: &oldWindowStart,
+ WeeklyWindowStart: &oldWindowStart,
+ MonthlyWindowStart: &oldWindowStart,
+ DailyUsageUSD: 10,
+ WeeklyUsageUSD: 20,
+ MonthlyUsageUSD: 30,
+ Notes: "old",
+ })
+ svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil)
+
+ renewed, reused, err := svc.AssignOrExtendSubscription(context.Background(), &AssignSubscriptionInput{
+ UserID: 200,
+ GroupID: 1,
+ ValidityDays: 1,
+ Notes: "new",
+ })
+
+ require.NoError(t, err)
+ require.True(t, reused)
+ require.True(t, renewed.HasOneTimeDailyQuota(), "过期后重新购买 1 日卡仍应被识别为一次性日额度")
+ require.Equal(t, SubscriptionStatusActive, renewed.Status)
+ require.True(t, renewed.StartsAt.After(oldStart), "重新购买过期订阅时应重置当前周期 StartsAt")
+ require.False(t, renewed.ExpiresAt.After(renewed.StartsAt.AddDate(0, 0, 1)))
+ require.NotNil(t, renewed.DailyWindowStart)
+ require.Equal(t, startOfDay(renewed.StartsAt), *renewed.DailyWindowStart)
+ require.Equal(t, 0.0, renewed.DailyUsageUSD)
+ require.Equal(t, 0.0, renewed.WeeklyUsageUSD)
+ require.Equal(t, 0.0, renewed.MonthlyUsageUSD)
+ require.Equal(t, "old\nnew", renewed.Notes)
+}
+
+func TestUserSubscriptionNeedsDailyReset_DailyCardKeepsOneTimeQuota(t *testing.T) {
+ start := time.Date(2026, 5, 18, 12, 0, 0, 0, time.UTC)
+ dailyWindowStart := time.Date(2026, 5, 18, 0, 0, 0, 0, time.UTC)
+ sub := &UserSubscription{
+ StartsAt: start,
+ ExpiresAt: start.Add(24 * time.Hour),
+ DailyWindowStart: &dailyWindowStart,
+ DailyUsageUSD: 10,
+ }
+
+ require.True(t, sub.HasOneTimeDailyQuota())
+ require.False(t, sub.NeedsDailyResetAt(dailyWindowStart.Add(25*time.Hour)), "日卡应作为一次性配额,跨 0 点后不再刷新日额度")
+}
+
+func TestUserSubscriptionNeedsDailyReset_MultiDaySubscriptionStillRefreshes(t *testing.T) {
+ start := time.Date(2026, 5, 18, 12, 0, 0, 0, time.UTC)
+ dailyWindowStart := time.Date(2026, 5, 18, 0, 0, 0, 0, time.UTC)
+ sub := &UserSubscription{
+ StartsAt: start,
+ ExpiresAt: start.AddDate(0, 0, 2),
+ DailyWindowStart: &dailyWindowStart,
+ }
+
+ require.False(t, sub.HasOneTimeDailyQuota())
+ require.True(t, sub.NeedsDailyResetAt(dailyWindowStart.Add(24*time.Hour)), "多日订阅仍应按 24 小时日窗口刷新")
+}
+
+func TestUserSubscriptionDailyResetTime_DailyCardReturnsExpiry(t *testing.T) {
+ start := time.Date(2026, 5, 18, 12, 0, 0, 0, time.UTC)
+ dailyWindowStart := time.Date(2026, 5, 18, 0, 0, 0, 0, time.UTC)
+ expiresAt := start.Add(24 * time.Hour)
+ sub := &UserSubscription{
+ StartsAt: start,
+ ExpiresAt: expiresAt,
+ DailyWindowStart: &dailyWindowStart,
+ }
+
+ resetAt := sub.DailyResetTime()
+ require.NotNil(t, resetAt)
+ require.Equal(t, expiresAt, *resetAt, "日卡展示的日额度结束时间应为订阅过期时间")
+}
+
+func TestCheckAndResetWindows_DailyCardDoesNotResetDailyUsage(t *testing.T) {
+ now := time.Now()
+ startsAt := now.Add(-23 * time.Hour)
+ dailyWindowStart := now.Add(-25 * time.Hour)
+ repo := &dailyResetTrackingUserSubRepo{}
+ svc := NewSubscriptionService(groupRepoNoop{}, repo, nil, nil, nil)
+ sub := &UserSubscription{
+ ID: 1,
+ UserID: 10,
+ GroupID: 20,
+ StartsAt: startsAt,
+ ExpiresAt: startsAt.Add(24 * time.Hour),
+ DailyUsageUSD: 10,
+ DailyWindowStart: &dailyWindowStart,
+ }
+
+ err := svc.CheckAndResetWindows(context.Background(), sub)
+
+ require.NoError(t, err)
+ require.False(t, repo.resetDailyCalled, "日卡作为一次性配额,过了 24 小时日窗口也不应重置 daily usage")
+ require.Equal(t, 10.0, sub.DailyUsageUSD)
+}
+
+func TestCheckAndResetWindows_MultiDaySubscriptionStillResetsDailyUsage(t *testing.T) {
+ now := time.Now()
+ startsAt := now.Add(-48 * time.Hour)
+ dailyWindowStart := now.Add(-25 * time.Hour)
+ repo := &dailyResetTrackingUserSubRepo{}
+ svc := NewSubscriptionService(groupRepoNoop{}, repo, nil, nil, nil)
+ sub := &UserSubscription{
+ ID: 1,
+ UserID: 10,
+ GroupID: 20,
+ StartsAt: startsAt,
+ ExpiresAt: startsAt.AddDate(0, 0, 2),
+ DailyUsageUSD: 10,
+ DailyWindowStart: &dailyWindowStart,
+ }
+
+ err := svc.CheckAndResetWindows(context.Background(), sub)
+
+ require.NoError(t, err)
+ require.True(t, repo.resetDailyCalled, "多日订阅仍应重置过期 daily window")
+ require.Equal(t, 0.0, sub.DailyUsageUSD)
+}
+
+func TestValidateAndCheckLimits_DailyCardDoesNotAllowSecondQuotaAfterMidnight(t *testing.T) {
+ start := time.Now().Add(-23 * time.Hour)
+ dailyWindowStart := time.Now().Add(-25 * time.Hour)
+ dailyLimit := 10.0
+ sub := &UserSubscription{
+ Status: SubscriptionStatusActive,
+ StartsAt: start,
+ ExpiresAt: start.Add(24 * time.Hour),
+ DailyWindowStart: &dailyWindowStart,
+ DailyUsageUSD: dailyLimit + 0.01,
+ }
+ group := &Group{
+ SubscriptionType: SubscriptionTypeSubscription,
+ DailyLimitUSD: &dailyLimit,
+ }
+ svc := NewSubscriptionService(groupRepoNoop{}, userSubRepoNoop{}, nil, nil, nil)
+
+ needsMaintenance, err := svc.ValidateAndCheckLimits(sub, group)
+
+ require.False(t, needsMaintenance, "日卡跨过日窗口后不应触发 daily reset 维护")
+ require.True(t, errors.Is(err, ErrDailyLimitExceeded))
+ require.Equal(t, dailyLimit+0.01, sub.DailyUsageUSD, "热路径不应清零日卡已用额度")
+}
diff --git a/backend/migrations/136_add_dingtalk_provider_type.sql b/backend/migrations/136_add_dingtalk_provider_type.sql
new file mode 100644
index 00000000..79c7ba05
--- /dev/null
+++ b/backend/migrations/136_add_dingtalk_provider_type.sql
@@ -0,0 +1,27 @@
+ALTER TABLE users
+ DROP CONSTRAINT IF EXISTS users_signup_source_check;
+
+ALTER TABLE users
+ ADD CONSTRAINT users_signup_source_check
+ CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google', 'dingtalk'));
+
+ALTER TABLE auth_identities
+ DROP CONSTRAINT IF EXISTS auth_identities_provider_type_check;
+
+ALTER TABLE auth_identities
+ ADD CONSTRAINT auth_identities_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google', 'dingtalk'));
+
+ALTER TABLE auth_identity_channels
+ DROP CONSTRAINT IF EXISTS auth_identity_channels_provider_type_check;
+
+ALTER TABLE auth_identity_channels
+ ADD CONSTRAINT auth_identity_channels_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google', 'dingtalk'));
+
+ALTER TABLE pending_auth_sessions
+ DROP CONSTRAINT IF EXISTS pending_auth_sessions_provider_type_check;
+
+ALTER TABLE pending_auth_sessions
+ ADD CONSTRAINT pending_auth_sessions_provider_type_check
+ CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google', 'dingtalk'));
diff --git a/backend/migrations/136_remove_ops_retry_replay.sql b/backend/migrations/136_remove_ops_retry_replay.sql
new file mode 100644
index 00000000..12ecc77d
--- /dev/null
+++ b/backend/migrations/136_remove_ops_retry_replay.sql
@@ -0,0 +1,16 @@
+-- Remove unused Ops retry/replay storage.
+-- The retry endpoints are no longer exposed, so keeping request bodies and
+-- retry audit rows only increases write width, memory retention, and DB size.
+
+DROP TABLE IF EXISTS ops_retry_attempts CASCADE;
+
+ALTER TABLE ops_error_logs
+ DROP COLUMN IF EXISTS request_body,
+ DROP COLUMN IF EXISTS request_headers,
+ DROP COLUMN IF EXISTS request_body_truncated,
+ DROP COLUMN IF EXISTS request_body_bytes,
+ DROP COLUMN IF EXISTS is_retryable,
+ DROP COLUMN IF EXISTS retry_count,
+ DROP COLUMN IF EXISTS resolved_retry_id;
+
+COMMENT ON TABLE ops_error_logs IS 'Ops error logs (vNext). Stores sanitized error details; request replay storage removed.';
diff --git a/backend/migrations/136_usage_log_image_size_metadata.sql b/backend/migrations/136_usage_log_image_size_metadata.sql
new file mode 100644
index 00000000..76bcb956
--- /dev/null
+++ b/backend/migrations/136_usage_log_image_size_metadata.sql
@@ -0,0 +1,51 @@
+-- Add generated-image billing size audit metadata.
+-- `image_size` remains the canonical billing tier used for cost calculation.
+
+ALTER TABLE usage_logs
+ ADD COLUMN IF NOT EXISTS image_input_size VARCHAR(32);
+
+ALTER TABLE usage_logs
+ ADD COLUMN IF NOT EXISTS image_output_size VARCHAR(32);
+
+ALTER TABLE usage_logs
+ ADD COLUMN IF NOT EXISTS image_size_source VARCHAR(16);
+
+ALTER TABLE usage_logs
+ ADD COLUMN IF NOT EXISTS image_size_breakdown JSONB;
+
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'usage_logs_image_size_source_check'
+ AND conrelid = 'usage_logs'::regclass
+ ) THEN
+ ALTER TABLE usage_logs
+ ADD CONSTRAINT usage_logs_image_size_source_check
+ CHECK (
+ image_size_source IS NULL
+ OR image_size_source IN ('output', 'input', 'default', 'legacy')
+ ) NOT VALID;
+ END IF;
+END $$;
+
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_constraint
+ WHERE conname = 'usage_logs_image_billing_size_check'
+ AND conrelid = 'usage_logs'::regclass
+ ) THEN
+ ALTER TABLE usage_logs
+ ADD CONSTRAINT usage_logs_image_billing_size_check
+ CHECK (
+ image_count <= 0
+ OR (
+ image_size IS NOT NULL
+ AND image_size IN ('1K', '2K', '4K', 'mixed')
+ )
+ ) NOT VALID;
+ END IF;
+END $$;
diff --git a/backend/migrations/137_redeem_code_expires_at.sql b/backend/migrations/137_redeem_code_expires_at.sql
new file mode 100644
index 00000000..4fa27927
--- /dev/null
+++ b/backend/migrations/137_redeem_code_expires_at.sql
@@ -0,0 +1,8 @@
+-- Add optional expiry time for redeem codes themselves.
+-- `validity_days` remains the subscription duration granted after redeeming.
+
+ALTER TABLE redeem_codes
+ ADD COLUMN IF NOT EXISTS expires_at TIMESTAMPTZ;
+
+CREATE INDEX IF NOT EXISTS idx_redeem_codes_expires_at
+ ON redeem_codes (expires_at);
diff --git a/backend/migrations/138_channel_monitor_openai_api_mode.sql b/backend/migrations/138_channel_monitor_openai_api_mode.sql
new file mode 100644
index 00000000..5b16f39c
--- /dev/null
+++ b/backend/migrations/138_channel_monitor_openai_api_mode.sql
@@ -0,0 +1,40 @@
+-- Migration: 137_channel_monitor_openai_api_mode
+-- 为渠道监控和请求模板增加 OpenAI 协议模式:
+-- chat_completions -> /v1/chat/completions + messages
+-- responses -> /v1/responses + instructions/input
+-- 历史数据默认保持 chat_completions,避免改变现有监控行为。
+
+ALTER TABLE channel_monitors
+ ADD COLUMN IF NOT EXISTS api_mode VARCHAR(32) NOT NULL DEFAULT 'chat_completions';
+
+ALTER TABLE channel_monitor_request_templates
+ ADD COLUMN IF NOT EXISTS api_mode VARCHAR(32) NOT NULL DEFAULT 'chat_completions';
+
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1 FROM information_schema.table_constraints
+ WHERE constraint_name = 'channel_monitors_api_mode_check'
+ AND table_name = 'channel_monitors'
+ ) THEN
+ ALTER TABLE channel_monitors
+ ADD CONSTRAINT channel_monitors_api_mode_check
+ CHECK (api_mode IN ('chat_completions', 'responses'));
+ END IF;
+
+ IF NOT EXISTS (
+ SELECT 1 FROM information_schema.table_constraints
+ WHERE constraint_name = 'channel_monitor_request_templates_api_mode_check'
+ AND table_name = 'channel_monitor_request_templates'
+ ) THEN
+ ALTER TABLE channel_monitor_request_templates
+ ADD CONSTRAINT channel_monitor_request_templates_api_mode_check
+ CHECK (api_mode IN ('chat_completions', 'responses'));
+ END IF;
+END $$;
+
+CREATE INDEX IF NOT EXISTS idx_channel_monitors_provider_api_mode
+ ON channel_monitors (provider, api_mode);
+
+CREATE INDEX IF NOT EXISTS idx_channel_monitor_templates_provider_api_mode
+ ON channel_monitor_request_templates (provider, api_mode);
diff --git a/backend/migrations/139_seed_openai_monitor_templates.sql b/backend/migrations/139_seed_openai_monitor_templates.sql
new file mode 100644
index 00000000..6326cd0a
--- /dev/null
+++ b/backend/migrations/139_seed_openai_monitor_templates.sql
@@ -0,0 +1,47 @@
+-- Migration: 138_seed_openai_monitor_templates
+-- 内置 OpenAI 渠道监控模板。重点是把协议模式显式化:
+-- 1) OpenAI-compatible 使用 Chat Completions payload
+-- 2) Responses / 本站自检 使用 Responses payload,默认 body 由后端 adapter 填入 instructions + input
+-- 所有模板都可直接选择;ON CONFLICT 保证重复部署不覆盖用户编辑。
+
+INSERT INTO channel_monitor_request_templates (
+ name, provider, api_mode, description, extra_headers, body_override_mode, body_override
+)
+VALUES
+(
+ 'OpenAI Compatible 默认检测',
+ 'openai',
+ 'chat_completions',
+ '适用于大多数 OpenAI-compatible 上游:POST /v1/chat/completions,后端自动生成 messages 数学 challenge。',
+ '{}'::jsonb,
+ 'off',
+ NULL
+),
+(
+ 'OpenAI Compatible 低 token 检测',
+ 'openai',
+ 'chat_completions',
+ '仍走 /v1/chat/completions,仅把 max_tokens 调低;model/messages/stream 由后端保护,避免误伤 challenge。',
+ '{}'::jsonb,
+ 'merge',
+ '{"max_tokens": 20}'::jsonb
+),
+(
+ 'OpenAI Responses / 本站自检',
+ 'openai',
+ 'responses',
+ '适用于本站或原生 Responses API:POST /v1/responses,默认 payload 自动带 instructions 与 input,避免 Instructions are required。',
+ '{}'::jsonb,
+ 'off',
+ NULL
+),
+(
+ 'OpenAI Responses 低 token 检测',
+ 'openai',
+ 'responses',
+ '仍走 /v1/responses,仅把 max_output_tokens 调低;instructions/input/model/stream 由后端保护。',
+ '{}'::jsonb,
+ 'merge',
+ '{"max_output_tokens": 20}'::jsonb
+)
+ON CONFLICT (provider, name) DO NOTHING;
diff --git a/backend/resources/model-pricing/model_prices_and_context_window.json b/backend/resources/model-pricing/model_prices_and_context_window.json
index 413c3065..76f8f90f 100644
--- a/backend/resources/model-pricing/model_prices_and_context_window.json
+++ b/backend/resources/model-pricing/model_prices_and_context_window.json
@@ -5224,6 +5224,39 @@
"supports_tool_choice": true,
"supports_vision": true
},
+ "codex-auto-review": {
+ "cache_read_input_token_cost": 2.5e-07,
+ "input_cost_per_token": 2.5e-06,
+ "litellm_provider": "openai",
+ "max_input_tokens": 1050000,
+ "max_output_tokens": 128000,
+ "max_tokens": 128000,
+ "mode": "chat",
+ "output_cost_per_token": 1.5e-05,
+ "supported_endpoints": [
+ "/v1/chat/completions",
+ "/v1/responses"
+ ],
+ "supported_modalities": [
+ "text",
+ "image"
+ ],
+ "supported_output_modalities": [
+ "text",
+ "image"
+ ],
+ "supports_function_calling": true,
+ "supports_native_streaming": true,
+ "supports_parallel_function_calling": true,
+ "supports_pdf_input": true,
+ "supports_prompt_caching": true,
+ "supports_reasoning": true,
+ "supports_response_schema": true,
+ "supports_service_tier": true,
+ "supports_system_messages": true,
+ "supports_tool_choice": true,
+ "supports_vision": true
+ },
"gpt-5.4-mini": {
"cache_read_input_token_cost": 7.5e-08,
"input_cost_per_token": 7.5e-07,
diff --git a/deploy/install.sh b/deploy/install.sh
index 6dcf4123..1846dede 100644
--- a/deploy/install.sh
+++ b/deploy/install.sh
@@ -7,6 +7,21 @@
set -e
+# Bash 4+ is required for associative arrays used by the localized message table.
+# Keep this guard before any Bash 4-only syntax so older shells fail with a clear hint.
+if [ -z "${BASH_VERSION:-}" ]; then
+ echo "Error: This installer must be run with Bash 4.0 or later." >&2
+ echo "Please install Bash 4+ and run it with that interpreter." >&2
+ exit 1
+fi
+
+BASH_MAJOR_VERSION="${BASH_VERSION%%.*}"
+if [ "$BASH_MAJOR_VERSION" -lt 4 ]; then
+ echo "Error: Bash 4.0 or later is required. Current version: $BASH_VERSION" >&2
+ echo "Please install Bash 4+ and retry with that interpreter." >&2
+ exit 1
+fi
+
# Colors
RED='\033[0;31m'
GREEN='\033[0;32m'
diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts
index 00ed4087..92b0abca 100644
--- a/frontend/src/api/admin/accounts.ts
+++ b/frontend/src/api/admin/accounts.ts
@@ -232,9 +232,12 @@ export async function clearError(id: number): Promise {
* @param id - Account ID
* @returns Account usage info
*/
-export async function getUsage(id: number, source?: 'passive' | 'active'): Promise {
+export async function getUsage(id: number, source?: 'passive' | 'active', force?: boolean): Promise {
+ const params: Record = {}
+ if (source) params.source = source
+ if (force) params.force = 'true'
const { data } = await apiClient.get(`/admin/accounts/${id}/usage`, {
- params: source ? { source } : undefined
+ params: Object.keys(params).length > 0 ? params : undefined
})
return data
}
@@ -446,6 +449,20 @@ export async function getAvailableModels(id: number): Promise {
return data
}
+export interface SyncUpstreamModelsResult {
+ models: string[]
+}
+
+/**
+ * Sync live supported models from the account's upstream model-list endpoint
+ * @param id - Account ID
+ * @returns List of model IDs returned by the upstream
+ */
+export async function syncUpstreamModels(id: number): Promise {
+ const { data } = await apiClient.post(`/admin/accounts/${id}/models/sync-upstream`)
+ return data
+}
+
export interface CRSPreviewAccount {
crs_account_id: string
kind: string
@@ -660,6 +677,7 @@ export const accountsAPI = {
resetTempUnschedulable,
setSchedulable,
getAvailableModels,
+ syncUpstreamModels,
generateAuthUrl,
exchangeCode,
refreshOpenAIToken,
diff --git a/frontend/src/api/admin/channelMonitor.ts b/frontend/src/api/admin/channelMonitor.ts
index 949c4bc8..bdef7d33 100644
--- a/frontend/src/api/admin/channelMonitor.ts
+++ b/frontend/src/api/admin/channelMonitor.ts
@@ -8,11 +8,13 @@ import { apiClient } from '../client'
export type Provider = 'openai' | 'anthropic' | 'gemini'
export type MonitorStatus = 'operational' | 'degraded' | 'failed' | 'error'
export type BodyOverrideMode = 'off' | 'merge' | 'replace'
+export type APIMode = 'chat_completions' | 'responses'
export interface ChannelMonitor {
id: number
name: string
provider: Provider
+ api_mode: APIMode
endpoint: string
api_key_masked: string
/**
@@ -70,6 +72,7 @@ export interface ListResponse {
export interface CreateParams {
name: string
provider: Provider
+ api_mode?: APIMode
endpoint: string
api_key: string
primary_model: string
diff --git a/frontend/src/api/admin/channelMonitorTemplate.ts b/frontend/src/api/admin/channelMonitorTemplate.ts
index 01b3c2d0..7b048104 100644
--- a/frontend/src/api/admin/channelMonitorTemplate.ts
+++ b/frontend/src/api/admin/channelMonitorTemplate.ts
@@ -6,12 +6,13 @@
*/
import { apiClient } from '../client'
-import type { BodyOverrideMode, Provider } from './channelMonitor'
+import type { APIMode, BodyOverrideMode, Provider } from './channelMonitor'
export interface ChannelMonitorTemplate {
id: number
name: string
provider: Provider
+ api_mode: APIMode
description: string
extra_headers: Record
body_override_mode: BodyOverrideMode
@@ -24,6 +25,7 @@ export interface ChannelMonitorTemplate {
export interface ListParams {
provider?: Provider
+ api_mode?: APIMode
}
export interface ListResponse {
@@ -33,6 +35,7 @@ export interface ListResponse {
export interface CreateParams {
name: string
provider: Provider
+ api_mode?: APIMode
description?: string
extra_headers?: Record
body_override_mode?: BodyOverrideMode
@@ -41,6 +44,7 @@ export interface CreateParams {
export interface UpdateParams {
name?: string
+ api_mode?: APIMode
description?: string
extra_headers?: Record
body_override_mode?: BodyOverrideMode
@@ -55,6 +59,7 @@ export interface AssociatedMonitorBrief {
id: number
name: string
provider: Provider
+ api_mode: APIMode
enabled: boolean
}
diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts
index 9d430134..afa43a2d 100644
--- a/frontend/src/api/admin/channels.ts
+++ b/frontend/src/api/admin/channels.ts
@@ -164,5 +164,19 @@ export async function getModelDefaultPricing(model: string): Promise {
+ const { data } = await apiClient.get('/admin/channels/pricing/sync-models', {
+ params: { platform }
+ })
+ return data
+}
+
+const channelsAPI = { list, getById, create, update, remove, getModelDefaultPricing, syncPricingModels }
export default channelsAPI
diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts
index 49e487b7..dda7d892 100644
--- a/frontend/src/api/admin/dashboard.ts
+++ b/frontend/src/api/admin/dashboard.ts
@@ -266,10 +266,17 @@ export async function getUserSpendingRanking(
return data
}
+export interface PlatformUsage {
+ platform: string
+ today_actual_cost: number
+ total_actual_cost: number
+}
+
export interface BatchUserUsageStats {
user_id: number
today_actual_cost: number
total_actual_cost: number
+ by_platform?: PlatformUsage[]
}
export interface BatchUsersUsageResponse {
diff --git a/frontend/src/api/admin/ops.ts b/frontend/src/api/admin/ops.ts
index 0890fc4d..fff5014b 100644
--- a/frontend/src/api/admin/ops.ts
+++ b/frontend/src/api/admin/ops.ts
@@ -1,52 +1,18 @@
/**
* Admin Ops API endpoints (vNext)
- * - Error logs list/detail + retry (client/upstream)
+ * - Error logs list/detail
* - Dashboard overview (raw path)
*/
import { apiClient } from '../client'
import type { PaginatedResponse } from '@/types'
-export type OpsRetryMode = 'client' | 'upstream'
export type OpsQueryMode = 'auto' | 'raw' | 'preagg'
export interface OpsRequestOptions {
signal?: AbortSignal
}
-export interface OpsRetryRequest {
- mode: OpsRetryMode
- pinned_account_id?: number
- force?: boolean
-}
-
-export interface OpsRetryAttempt {
- id: number
- created_at: string
- requested_by_user_id: number
- source_error_id: number
- mode: string
- pinned_account_id?: number | null
- pinned_account_name?: string
-
- status: string
- started_at?: string | null
- finished_at?: string | null
- duration_ms?: number | null
-
- success?: boolean | null
- http_status_code?: number | null
- upstream_request_id?: string | null
- used_account_id?: number | null
- used_account_name?: string
- response_preview?: string | null
- response_truncated?: boolean | null
-
- result_request_id?: string | null
- result_error_id?: number | null
- error_message?: string | null
-}
-
export type OpsUpstreamErrorEvent = {
at_unix_ms?: number
platform?: string
@@ -54,33 +20,11 @@ export type OpsUpstreamErrorEvent = {
account_name?: string
upstream_status_code?: number
upstream_request_id?: string
- upstream_request_body?: string
kind?: string
message?: string
detail?: string
}
-export interface OpsRetryResult {
- attempt_id: number
- mode: OpsRetryMode
- status: 'running' | 'succeeded' | 'failed' | string
-
- pinned_account_id?: number | null
- used_account_id?: number | null
-
- http_status_code: number
- upstream_request_id: string
-
- response_preview: string
- response_truncated: boolean
-
- error_message: string
-
- started_at: string
- finished_at: string
- duration_ms: number
-}
-
export interface OpsDashboardOverview {
start_time: string
end_time: string
@@ -946,13 +890,9 @@ export interface OpsErrorLog {
platform: string
model: string
- is_retryable: boolean
- retry_count: number
-
resolved: boolean
resolved_at?: string | null
resolved_by_user_id?: number | null
- resolved_retry_id?: number | null
client_request_id: string
request_id: string
@@ -994,10 +934,6 @@ export interface OpsErrorDetail extends OpsErrorLog {
response_latency_ms?: number | null
time_to_first_token_ms?: number | null
- request_body: string
- request_body_truncated: boolean
- request_body_bytes?: number | null
-
is_business_limited: boolean
}
@@ -1156,16 +1092,6 @@ export async function getErrorLogDetail(id: number): Promise {
return data
}
-export async function retryErrorRequest(id: number, req: OpsRetryRequest): Promise {
- const { data } = await apiClient.post(`/admin/ops/errors/${id}/retry`, req)
- return data
-}
-
-export async function listRetryAttempts(errorId: number, limit = 50): Promise {
- const { data } = await apiClient.get(`/admin/ops/errors/${errorId}/retries`, { params: { limit } })
- return data
-}
-
export async function updateErrorResolved(errorId: number, resolved: boolean): Promise {
await apiClient.put(`/admin/ops/errors/${errorId}/resolve`, { resolved })
}
@@ -1191,21 +1117,6 @@ export async function getUpstreamErrorDetail(id: number): Promise {
- const { data } = await apiClient.post(`/admin/ops/request-errors/${id}/retry-client`, {})
- return data
-}
-
-export async function retryRequestErrorUpstreamEvent(id: number, idx: number): Promise {
- const { data } = await apiClient.post(`/admin/ops/request-errors/${id}/upstream-errors/${idx}/retry`, {})
- return data
-}
-
-export async function retryUpstreamError(id: number): Promise {
- const { data } = await apiClient.post(`/admin/ops/upstream-errors/${id}/retry`, {})
- return data
-}
-
export async function updateRequestErrorResolved(errorId: number, resolved: boolean): Promise {
await apiClient.put(`/admin/ops/request-errors/${errorId}/resolve`, { resolved })
}
@@ -1380,8 +1291,6 @@ export const opsAPI = {
// Legacy unified endpoints
listErrorLogs,
getErrorLogDetail,
- retryErrorRequest,
- listRetryAttempts,
updateErrorResolved,
// New split endpoints
@@ -1389,9 +1298,6 @@ export const opsAPI = {
listUpstreamErrors,
getRequestErrorDetail,
getUpstreamErrorDetail,
- retryRequestErrorClient,
- retryRequestErrorUpstreamEvent,
- retryUpstreamError,
updateRequestErrorResolved,
updateUpstreamErrorResolved,
listRequestErrorUpstreamErrors,
diff --git a/frontend/src/api/admin/redeem.ts b/frontend/src/api/admin/redeem.ts
index 57626b1e..398d68a4 100644
--- a/frontend/src/api/admin/redeem.ts
+++ b/frontend/src/api/admin/redeem.ts
@@ -60,6 +60,7 @@ export async function getById(id: number): Promise {
* @param value - Value of the code
* @param groupId - Group ID (required for subscription type)
* @param validityDays - Validity days (for subscription type)
+ * @param expiresInDays - Days before the code itself expires
* @returns Array of generated redeem codes
*/
export async function generate(
@@ -67,7 +68,8 @@ export async function generate(
type: RedeemCodeType,
value: number,
groupId?: number | null,
- validityDays?: number
+ validityDays?: number,
+ expiresInDays?: number | null
): Promise {
const payload: GenerateRedeemCodesRequest = {
count,
@@ -82,6 +84,9 @@ export async function generate(
payload.validity_days = validityDays
}
}
+ if (expiresInDays && expiresInDays > 0) {
+ payload.expires_in_days = expiresInDays
+ }
const { data } = await apiClient.post('/admin/redeem-codes/generate', payload)
return data
diff --git a/frontend/src/api/admin/riskControl.ts b/frontend/src/api/admin/riskControl.ts
index e63a53a2..4dad1f58 100644
--- a/frontend/src/api/admin/riskControl.ts
+++ b/frontend/src/api/admin/riskControl.ts
@@ -1,6 +1,7 @@
import { apiClient } from '../client'
export type ModerationMode = 'off' | 'observe' | 'pre_block'
+export type KeywordBlockingMode = 'keyword_only' | 'keyword_and_api' | 'api_only'
export interface ContentModerationConfig {
enabled: boolean
@@ -29,6 +30,8 @@ export interface ContentModerationConfig {
hit_retention_days: number
non_hit_retention_days: number
pre_hash_check_enabled: boolean
+ blocked_keywords: string[]
+ keyword_blocking_mode: KeywordBlockingMode
}
export type ContentModerationAPIKeyStatusValue = 'unknown' | 'ok' | 'error' | 'frozen'
@@ -100,6 +103,8 @@ export interface UpdateContentModerationConfig {
hit_retention_days?: number
non_hit_retention_days?: number
pre_hash_check_enabled?: boolean
+ blocked_keywords?: string[]
+ keyword_blocking_mode?: KeywordBlockingMode
}
export interface ContentModerationRuntimeStatus {
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index badd06e5..8550ea3f 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -22,7 +22,8 @@ export type AuthSourceType =
| "oidc"
| "wechat"
| "github"
- | "google";
+ | "google"
+ | "dingtalk";
export interface AuthSourceDefaultsValue {
balance: number;
@@ -64,6 +65,7 @@ const AUTH_SOURCE_TYPES: AuthSourceType[] = [
"wechat",
"github",
"google",
+ "dingtalk",
];
const AUTH_SOURCE_DEFAULT_BALANCE = 0;
const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5;
@@ -352,6 +354,11 @@ export interface SystemSettings {
auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[];
auth_source_default_wechat_grant_on_signup?: boolean;
auth_source_default_wechat_grant_on_first_bind?: boolean;
+ auth_source_default_dingtalk_balance?: number;
+ auth_source_default_dingtalk_concurrency?: number;
+ auth_source_default_dingtalk_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_dingtalk_grant_on_signup?: boolean;
+ auth_source_default_dingtalk_grant_on_first_bind?: boolean;
auth_source_default_github_balance?: number;
auth_source_default_github_concurrency?: number;
auth_source_default_github_subscriptions?: DefaultSubscriptionSetting[];
@@ -396,6 +403,24 @@ export interface SystemSettings {
linuxdo_connect_client_secret_configured: boolean;
linuxdo_connect_redirect_url: string;
+ // DingTalk Connect OAuth settings
+ dingtalk_connect_enabled: boolean;
+ dingtalk_connect_client_id: string;
+ dingtalk_connect_client_secret_configured: boolean;
+ dingtalk_connect_redirect_url: string;
+ dingtalk_connect_corp_restriction_policy: string;
+ dingtalk_connect_internal_corp_id: string;
+ dingtalk_connect_bypass_registration: boolean;
+ dingtalk_connect_sync_corp_email: boolean;
+ dingtalk_connect_sync_display_name: boolean;
+ dingtalk_connect_sync_dept: boolean;
+ dingtalk_connect_sync_corp_email_attr_key: string;
+ dingtalk_connect_sync_display_name_attr_key: string;
+ dingtalk_connect_sync_dept_attr_key: string;
+ dingtalk_connect_sync_corp_email_attr_name: string;
+ dingtalk_connect_sync_display_name_attr_name: string;
+ dingtalk_connect_sync_dept_attr_name: string;
+
// WeChat Connect OAuth settings
wechat_connect_enabled: boolean;
wechat_connect_app_id: string;
@@ -504,6 +529,7 @@ export interface SystemSettings {
payment_cancel_rate_limit_window: number;
payment_cancel_rate_limit_unit: string;
payment_cancel_rate_limit_window_mode: string;
+ payment_alipay_force_qrcode?: boolean;
payment_visible_method_alipay_source?: string;
payment_visible_method_wxpay_source?: string;
payment_visible_method_alipay_enabled?: boolean;
@@ -572,6 +598,11 @@ export interface UpdateSettingsRequest {
auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[];
auth_source_default_wechat_grant_on_signup?: boolean;
auth_source_default_wechat_grant_on_first_bind?: boolean;
+ auth_source_default_dingtalk_balance?: number;
+ auth_source_default_dingtalk_concurrency?: number;
+ auth_source_default_dingtalk_subscriptions?: DefaultSubscriptionSetting[];
+ auth_source_default_dingtalk_grant_on_signup?: boolean;
+ auth_source_default_dingtalk_grant_on_first_bind?: boolean;
auth_source_default_github_balance?: number;
auth_source_default_github_concurrency?: number;
auth_source_default_github_subscriptions?: DefaultSubscriptionSetting[];
@@ -610,6 +641,22 @@ export interface UpdateSettingsRequest {
linuxdo_connect_client_id?: string;
linuxdo_connect_client_secret?: string;
linuxdo_connect_redirect_url?: string;
+ dingtalk_connect_enabled?: boolean;
+ dingtalk_connect_client_id?: string;
+ dingtalk_connect_client_secret?: string;
+ dingtalk_connect_redirect_url?: string;
+ dingtalk_connect_corp_restriction_policy?: string;
+ dingtalk_connect_internal_corp_id?: string;
+ dingtalk_connect_bypass_registration?: boolean;
+ dingtalk_connect_sync_corp_email?: boolean;
+ dingtalk_connect_sync_display_name?: boolean;
+ dingtalk_connect_sync_dept?: boolean;
+ dingtalk_connect_sync_corp_email_attr_key?: string;
+ dingtalk_connect_sync_display_name_attr_key?: string;
+ dingtalk_connect_sync_dept_attr_key?: string;
+ dingtalk_connect_sync_corp_email_attr_name?: string;
+ dingtalk_connect_sync_display_name_attr_name?: string;
+ dingtalk_connect_sync_dept_attr_name?: string;
wechat_connect_enabled?: boolean;
wechat_connect_app_id?: string;
wechat_connect_app_secret?: string;
@@ -701,6 +748,7 @@ export interface UpdateSettingsRequest {
payment_cancel_rate_limit_window?: number;
payment_cancel_rate_limit_unit?: string;
payment_cancel_rate_limit_window_mode?: string;
+ payment_alipay_force_qrcode?: boolean;
payment_visible_method_alipay_source?: string;
payment_visible_method_wxpay_source?: string;
payment_visible_method_alipay_enabled?: boolean;
diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts
index bb990fc4..fd259230 100644
--- a/frontend/src/api/auth.ts
+++ b/frontend/src/api/auth.ts
@@ -592,7 +592,7 @@ export async function completeWeChatOAuthRegistration(
}
async function createPendingOAuthAccount(
- provider: 'linuxdo' | 'oidc' | 'wechat',
+ provider: 'linuxdo' | 'oidc' | 'wechat' | 'dingtalk',
invitationCode: string,
decision?: OAuthAdoptionDecision,
affiliateCode?: string
@@ -633,6 +633,14 @@ export async function createPendingWeChatOAuthAccount(
return createPendingOAuthAccount('wechat', invitationCode, decision, affiliateCode)
}
+export async function createPendingDingTalkOAuthAccount(
+ invitationCode: string,
+ decision?: OAuthAdoptionDecision,
+ affiliateCode?: string
+): Promise {
+ return createPendingOAuthAccount('dingtalk', invitationCode, decision, affiliateCode)
+}
+
export async function completePendingOAuthBindLogin(
decision?: OAuthAdoptionDecision
): Promise {
@@ -683,7 +691,8 @@ export const authAPI = {
exchangePendingOAuthCompletion,
completeLinuxDoOAuthRegistration,
completeOIDCOAuthRegistration,
- completeWeChatOAuthRegistration
+ completeWeChatOAuthRegistration,
+ createPendingDingTalkOAuthAccount
}
export default authAPI
diff --git a/frontend/src/api/usage.ts b/frontend/src/api/usage.ts
index 802c428f..7169b698 100644
--- a/frontend/src/api/usage.ts
+++ b/frontend/src/api/usage.ts
@@ -15,6 +15,16 @@ import type {
// ==================== Dashboard Types ====================
+export interface PlatformDashboardStats {
+ platform: string
+ total_requests: number
+ total_tokens: number
+ total_actual_cost: number
+ today_requests: number
+ today_tokens: number
+ today_actual_cost: number
+}
+
export interface UserDashboardStats {
total_api_keys: number
active_api_keys: number
@@ -37,6 +47,7 @@ export interface UserDashboardStats {
average_duration_ms: number
rpm: number // 近5分钟平均每分钟请求数
tpm: number // 近5分钟平均每分钟Token数
+ by_platform?: PlatformDashboardStats[]
}
export interface TrendParams {
diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue
index dd38a49f..8438c584 100644
--- a/frontend/src/components/account/AccountStatusIndicator.vue
+++ b/frontend/src/components/account/AccountStatusIndicator.vue
@@ -231,6 +231,8 @@ const formatScopeName = (scope: string): string => {
'gemini-2.5-flash-thinking': 'G25FT',
'gemini-2.5-pro': 'G25P',
'gemini-2.5-flash-image': 'G25I',
+ // Gemini 3.5 系列
+ 'gemini-3.5-flash': 'G35F',
// Gemini 3 系列
'gemini-3-flash': 'G3F',
'gemini-3.1-pro-high': 'G3PH',
diff --git a/frontend/src/components/account/AccountTestModal.vue b/frontend/src/components/account/AccountTestModal.vue
index be9cba48..8f81789d 100644
--- a/frontend/src/components/account/AccountTestModal.vue
+++ b/frontend/src/components/account/AccountTestModal.vue
@@ -292,7 +292,7 @@ const openAITestModeOptions = computed(() => [
{ value: 'compact', label: t('admin.accounts.openai.testModeCompact') }
])
const previewImageUrl = ref('')
-const prioritizedGeminiModels = ['gemini-3.1-flash-image', 'gemini-2.5-flash-image', 'gemini-2.5-flash', 'gemini-2.5-pro', 'gemini-3-flash-preview', 'gemini-3-pro-preview', 'gemini-2.0-flash']
+const prioritizedGeminiModels = ['gemini-3.1-flash-image', 'gemini-2.5-flash-image', 'gemini-3.5-flash', 'gemini-2.5-flash', 'gemini-2.5-pro', 'gemini-3-flash-preview', 'gemini-3-pro-preview', 'gemini-2.0-flash']
const supportsGeminiImageTest = computed(() => {
const modelID = selectedModelId.value.toLowerCase()
if (!modelID.startsWith('gemini-') || !modelID.includes('-image')) return false
diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue
index 90a67922..2e160b62 100644
--- a/frontend/src/components/account/AccountUsageCell.vue
+++ b/frontend/src/components/account/AccountUsageCell.vue
@@ -126,6 +126,30 @@
:show-now-when-idle="true"
color="emerald"
/>
+
+
+
+
+
+ {{ t('admin.accounts.usageWindow.activeQuery') }}
+
+
@@ -1135,7 +1159,7 @@ const attachVisibilityObserver = () => {
const loadActiveUsage = async () => {
activeQueryLoading.value = true
try {
- usageInfo.value = await adminAPI.accounts.getUsage(props.account.id, 'active')
+ usageInfo.value = await adminAPI.accounts.getUsage(props.account.id, 'active', true)
} catch (e: any) {
console.error('Failed to load active usage:', e)
} finally {
diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue
index b55fe98a..b078a557 100644
--- a/frontend/src/components/account/CreateAccountModal.vue
+++ b/frontend/src/components/account/CreateAccountModal.vue
@@ -2760,6 +2760,28 @@
+
+
+
+
+
{{ t('admin.accounts.openai.responsesMode') }}
+
+ {{ t('admin.accounts.openai.responsesModeDesc') }}
+
+
+
+
+
+
+
+
@@ -3237,7 +3259,8 @@ import type {
CheckMixedChannelResponse,
CreateAccountRequest,
CodexSessionImportMessage,
- OpenAICompactMode
+ OpenAICompactMode,
+ OpenAIResponsesMode
} from '@/types'
import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
@@ -3395,6 +3418,7 @@ const interceptWarmupRequests = ref(false)
const autoPauseOnExpired = ref(true)
const openaiPassthroughEnabled = ref(false)
const openAICompactMode = ref
('auto')
+const openAIResponsesMode = ref('auto')
const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
@@ -3458,6 +3482,11 @@ const openAICompactModeOptions = computed(() => [
{ value: 'force_on', label: t('admin.accounts.openai.compactModeForceOn') },
{ value: 'force_off', label: t('admin.accounts.openai.compactModeForceOff') }
])
+const openAIResponsesModeOptions = computed(() => [
+ { value: 'auto', label: t('admin.accounts.openai.responsesModeAuto') },
+ { value: 'force_responses', label: t('admin.accounts.openai.responsesModeForceResponses') },
+ { value: 'force_chat_completions', label: t('admin.accounts.openai.responsesModeForceChatCompletions') }
+])
function buildAntigravityExtra(): Record | undefined {
const extra: Record = {}
@@ -4190,6 +4219,7 @@ const resetForm = () => {
autoPauseOnExpired.value = true
openaiPassthroughEnabled.value = false
openAICompactMode.value = 'auto'
+ openAIResponsesMode.value = 'auto'
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
codexCLIOnlyEnabled.value = false
@@ -4282,6 +4312,12 @@ const buildOpenAIExtra = (base?: Record): Record 0 ? extra : undefined
}
diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue
index 80f0b890..9e664764 100644
--- a/frontend/src/components/account/EditAccountModal.vue
+++ b/frontend/src/components/account/EditAccountModal.vue
@@ -139,10 +139,10 @@
-
+
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
- {{
+ {{
t('admin.accounts.supportsAllModels')
}}
@@ -454,10 +454,10 @@
-
+
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
- {{
+ {{
t('admin.accounts.supportsAllModels')
}}
@@ -666,10 +666,10 @@
-
+
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
- {{
+ {{
t('admin.accounts.supportsAllModels')
}}
@@ -891,7 +891,7 @@
{{ t('admin.accounts.selectedModels', { count: allowedModels.length }) }}
- {{ t('admin.accounts.supportsAllModels') }}
+ {{ t('admin.accounts.supportsAllModels') }}
@@ -987,6 +987,17 @@
{{ t('admin.accounts.mapRequestModels') }}
+
+
+ {{ isSyncingAntigravityUpstream ? t('admin.accounts.syncUpstreamModelsLoading') : t('admin.accounts.syncUpstreamModels') }}
+
+
+
+
+
+
+
+
{{ t('admin.accounts.openai.responsesMode') }}
+
+ {{ t('admin.accounts.openai.responsesModeDesc') }}
+
+
+
+
+
+
+
+ {{ t(openAIResponsesStatusKey) }}
+
+
+
('whitelist')
const antigravityWhitelistModels = ref
([])
const antigravityModelMappings = ref([])
+const isSyncingAntigravityUpstream = ref(false)
const tempUnschedEnabled = ref(false)
const tempUnschedRules = ref([])
const getModelMappingKey = createStableObjectKeyResolver('edit-model-mapping')
@@ -2332,6 +2370,7 @@ const customBaseUrl = ref('')
// OpenAI 自动透传开关(OAuth/API Key)
const openaiPassthroughEnabled = ref(false)
const openAICompactMode = ref('auto')
+const openAIResponsesMode = ref('auto')
const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF)
const codexCLIOnlyEnabled = ref(false)
@@ -2433,9 +2472,36 @@ const openAICompactModeOptions = computed(() => [
{ value: 'force_on', label: t('admin.accounts.openai.compactModeForceOn') },
{ value: 'force_off', label: t('admin.accounts.openai.compactModeForceOff') }
])
+const openAIResponsesModeOptions = computed(() => [
+ { value: 'auto', label: t('admin.accounts.openai.responsesModeAuto') },
+ { value: 'force_responses', label: t('admin.accounts.openai.responsesModeForceResponses') },
+ { value: 'force_chat_completions', label: t('admin.accounts.openai.responsesModeForceChatCompletions') }
+])
+const normalizeOpenAIResponsesMode = (mode: unknown): OpenAIResponsesMode => {
+ if (mode === 'force_responses' || mode === 'force_chat_completions') {
+ return mode
+ }
+ return 'auto'
+}
const isOpenAIModelRestrictionDisabled = computed(() =>
props.account?.platform === 'openai' && openaiPassthroughEnabled.value
)
+const openAIResponsesStatusKey = computed(() => {
+ if (openAIResponsesMode.value === 'force_responses') {
+ return 'admin.accounts.openai.responsesStatusForcedResponses'
+ }
+ if (openAIResponsesMode.value === 'force_chat_completions') {
+ return 'admin.accounts.openai.responsesStatusForcedChatCompletions'
+ }
+ const extra = props.account?.extra as Record | undefined
+ if (extra?.openai_responses_supported === true) {
+ return 'admin.accounts.openai.responsesStatusAutoSupported'
+ }
+ if (extra?.openai_responses_supported === false) {
+ return 'admin.accounts.openai.responsesStatusAutoUnsupported'
+ }
+ return 'admin.accounts.openai.responsesStatusAutoUnknown'
+})
const openAICompactStatusKey = computed(() => {
const extra = props.account?.extra as Record | undefined
if (!props.account || props.account.platform !== 'openai') return ''
@@ -2542,6 +2608,19 @@ const normalizePoolModeRetryCount = (value: number) => {
return normalized
}
+const loadModelRestrictionFromMapping = (rawMapping?: Record) => {
+ const parsed = splitModelMappingObject(rawMapping)
+ allowedModels.value = parsed.allowedModels
+ modelMappings.value = parsed.modelMappings
+ modelRestrictionMode.value =
+ parsed.modelMappings.length > 0 && parsed.allowedModels.length === 0
+ ? 'mapping'
+ : 'whitelist'
+}
+
+const buildModelRestrictionMapping = () =>
+ buildModelMappingObject('combined', allowedModels.value, modelMappings.value)
+
const syncFormFromAccount = (newAccount: Account | null) => {
if (!newAccount) {
return
@@ -2582,6 +2661,7 @@ const syncFormFromAccount = (newAccount: Account | null) => {
// Load OpenAI passthrough toggle (OpenAI OAuth/API Key)
openaiPassthroughEnabled.value = false
openAICompactMode.value = 'auto'
+ openAIResponsesMode.value = 'auto'
openAICompactModelMappings.value = []
openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF
@@ -2592,6 +2672,9 @@ const syncFormFromAccount = (newAccount: Account | null) => {
if (newAccount.platform === 'openai' && (newAccount.type === 'oauth' || newAccount.type === 'apikey')) {
openaiPassthroughEnabled.value = extra?.openai_passthrough === true || extra?.openai_oauth_passthrough === true
openAICompactMode.value = (extra?.openai_compact_mode as OpenAICompactMode) || 'auto'
+ if (newAccount.type === 'apikey') {
+ openAIResponsesMode.value = normalizeOpenAIResponsesMode(extra?.openai_responses_mode)
+ }
const codexImageGenerationBridgeValue = typeof extra?.codex_image_generation_bridge === 'boolean'
? extra.codex_image_generation_bridge
: extra?.codex_image_generation_bridge_enabled
@@ -2713,30 +2796,7 @@ const syncFormFromAccount = (newAccount: Account | null) => {
editBaseUrl.value = (credentials.base_url as string) || platformDefaultUrl
// Load model mappings and detect mode
- const existingMappings = credentials.model_mapping as Record | undefined
- if (existingMappings && typeof existingMappings === 'object') {
- const entries = Object.entries(existingMappings)
-
- // Detect if this is whitelist mode (all from === to) or mapping mode
- const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
-
- if (isWhitelistMode) {
- // Whitelist mode: populate allowedModels
- modelRestrictionMode.value = 'whitelist'
- allowedModels.value = entries.map(([from]) => from)
- modelMappings.value = []
- } else {
- // Mapping mode: populate modelMappings
- modelRestrictionMode.value = 'mapping'
- modelMappings.value = entries.map(([from, to]) => ({ from, to }))
- allowedModels.value = []
- }
- } else {
- // No mappings: default to whitelist mode with empty selection (allow all)
- modelRestrictionMode.value = 'whitelist'
- modelMappings.value = []
- allowedModels.value = []
- }
+ loadModelRestrictionFromMapping(credentials.model_mapping as Record | undefined)
// Load pool mode
poolModeEnabled.value = credentials.pool_mode === true
@@ -2780,24 +2840,7 @@ const syncFormFromAccount = (newAccount: Account | null) => {
loadQuotaNotifyFromExtra(bedrockExtra)
// Load model mappings for bedrock
- const existingMappings = bedrockCreds.model_mapping as Record | undefined
- if (existingMappings && typeof existingMappings === 'object') {
- const entries = Object.entries(existingMappings)
- const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
- if (isWhitelistMode) {
- modelRestrictionMode.value = 'whitelist'
- allowedModels.value = entries.map(([from]) => from)
- modelMappings.value = []
- } else {
- modelRestrictionMode.value = 'mapping'
- modelMappings.value = entries.map(([from, to]) => ({ from, to }))
- allowedModels.value = []
- }
- } else {
- modelRestrictionMode.value = 'whitelist'
- modelMappings.value = []
- allowedModels.value = []
- }
+ loadModelRestrictionFromMapping(bedrockCreds.model_mapping as Record | undefined)
} else if (newAccount.type === 'upstream' && newAccount.credentials) {
const credentials = newAccount.credentials as Record
editBaseUrl.value = (credentials.base_url as string) || ''
@@ -2808,24 +2851,7 @@ const syncFormFromAccount = (newAccount: Account | null) => {
editVertexLocation.value = (credentials.location as string) || (credentials.vertex_location as string) || 'us-central1'
// Load model mappings for service_account
- const existingMappings = credentials.model_mapping as Record | undefined
- if (existingMappings && typeof existingMappings === 'object') {
- const entries = Object.entries(existingMappings)
- const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
- if (isWhitelistMode) {
- modelRestrictionMode.value = 'whitelist'
- allowedModels.value = entries.map(([from]) => from)
- modelMappings.value = []
- } else {
- modelRestrictionMode.value = 'mapping'
- modelMappings.value = entries.map(([from, to]) => ({ from, to }))
- allowedModels.value = []
- }
- } else {
- modelRestrictionMode.value = 'whitelist'
- modelMappings.value = []
- allowedModels.value = []
- }
+ loadModelRestrictionFromMapping(credentials.model_mapping as Record | undefined)
} else {
const platformDefaultUrl =
newAccount.platform === 'openai'
@@ -2838,24 +2864,7 @@ const syncFormFromAccount = (newAccount: Account | null) => {
// Load model mappings for OpenAI OAuth accounts
if (newAccount.platform === 'openai' && newAccount.credentials) {
const oauthCredentials = newAccount.credentials as Record
- const existingMappings = oauthCredentials.model_mapping as Record | undefined
- if (existingMappings && typeof existingMappings === 'object') {
- const entries = Object.entries(existingMappings)
- const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to)
- if (isWhitelistMode) {
- modelRestrictionMode.value = 'whitelist'
- allowedModels.value = entries.map(([from]) => from)
- modelMappings.value = []
- } else {
- modelRestrictionMode.value = 'mapping'
- modelMappings.value = entries.map(([from, to]) => ({ from, to }))
- allowedModels.value = []
- }
- } else {
- modelRestrictionMode.value = 'whitelist'
- modelMappings.value = []
- allowedModels.value = []
- }
+ loadModelRestrictionFromMapping(oauthCredentials.model_mapping as Record | undefined)
} else {
modelRestrictionMode.value = 'whitelist'
modelMappings.value = []
@@ -2935,6 +2944,40 @@ const addAntigravityPresetMapping = (from: string, to: string) => {
antigravityModelMappings.value.push({ from, to })
}
+const syncAntigravityUpstreamModels = async () => {
+ if (!props.account?.id || isSyncingAntigravityUpstream.value) return
+
+ isSyncingAntigravityUpstream.value = true
+ try {
+ const result = await adminAPI.accounts.syncUpstreamModels(props.account.id)
+ const upstreamModels = result.models.map((model) => model.trim()).filter(Boolean)
+ if (upstreamModels.length === 0) {
+ appStore.showInfo(t('admin.accounts.syncUpstreamModelsEmpty'))
+ return
+ }
+
+ let addedCount = 0
+ for (const model of upstreamModels) {
+ const exists = antigravityModelMappings.value.some((mapping) => mapping.from === model)
+ if (!exists) {
+ antigravityModelMappings.value.push({ from: model, to: model })
+ addedCount += 1
+ }
+ }
+
+ if (addedCount > 0) {
+ appStore.showSuccess(t('admin.accounts.syncUpstreamModelsSuccess', { count: addedCount, total: upstreamModels.length }))
+ } else {
+ appStore.showInfo(t('admin.accounts.syncUpstreamModelsNoChanges', { count: upstreamModels.length }))
+ }
+ } catch (error) {
+ const message = error instanceof Error ? error.message : t('admin.accounts.syncUpstreamModelsFailed')
+ appStore.showError(t('admin.accounts.syncUpstreamModelsError', { message }))
+ } finally {
+ isSyncingAntigravityUpstream.value = false
+ }
+}
+
// Error code toggle helper
const toggleErrorCode = (code: number) => {
const index = selectedErrorCodes.value.indexOf(code)
@@ -3343,20 +3386,22 @@ const handleSubmit = async () => {
}
// Handle API key
+ // 后端响应已脱敏:currentCredentials 不会再包含 api_key 原文。
+ // 用户填入新值则覆盖;留空时优先看 credentials_status.has_api_key;
+ // 若后端尚未升级(无 credentials_status),回退读旧结构 currentCredentials.api_key。
+ // 两者都无才报错。
+ const hasExistingApiKey =
+ props.account.credentials_status?.has_api_key ?? Boolean(currentCredentials.api_key)
if (editApiKey.value.trim()) {
- // User provided a new API key
newCredentials.api_key = editApiKey.value.trim()
- } else if (currentCredentials.api_key) {
- // Preserve existing api_key
- newCredentials.api_key = currentCredentials.api_key
- } else {
+ } else if (!hasExistingApiKey) {
appStore.showError(t('admin.accounts.apiKeyIsRequired'))
return
}
// Add model mapping if configured(OpenAI 开启自动透传时保留现有映射,不再编辑)
if (shouldApplyModelMapping) {
- const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
+ const modelMapping = buildModelRestrictionMapping()
if (modelMapping) {
newCredentials.model_mapping = modelMapping
} else {
@@ -3434,7 +3479,15 @@ const handleSubmit = async () => {
return
}
- if (!currentCredentials.service_account_json && !currentCredentials.service_account) {
+ // SA JSON 已脱敏不再随 credentials 返回,存在性优先读 credentials_status。
+ // 若后端尚未升级(无 credentials_status),回退读旧结构 service_account_json / service_account。
+ const credentialsStatus = props.account.credentials_status
+ const hasExistingServiceAccountJson = credentialsStatus
+ ? Boolean(
+ credentialsStatus.has_service_account_json || credentialsStatus.has_service_account
+ )
+ : Boolean(currentCredentials.service_account_json || currentCredentials.service_account)
+ if (!hasExistingServiceAccountJson) {
appStore.showError(t('admin.accounts.vertexSaJsonRequired'))
return
}
@@ -3444,7 +3497,7 @@ const handleSubmit = async () => {
newCredentials.tier_id = 'vertex'
// Add model mapping if configured
- const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
+ const modelMapping = buildModelRestrictionMapping()
if (modelMapping) {
newCredentials.model_mapping = modelMapping
} else {
@@ -3494,7 +3547,7 @@ const handleSubmit = async () => {
}
// Model mapping
- const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
+ const modelMapping = buildModelRestrictionMapping()
if (modelMapping) {
newCredentials.model_mapping = modelMapping
} else {
@@ -3528,7 +3581,7 @@ const handleSubmit = async () => {
const shouldApplyModelMapping = !openaiPassthroughEnabled.value
if (shouldApplyModelMapping) {
- const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value)
+ const modelMapping = buildModelRestrictionMapping()
if (modelMapping) {
newCredentials.model_mapping = modelMapping
} else {
@@ -3721,6 +3774,13 @@ const handleSubmit = async () => {
} else {
newExtra.openai_compact_mode = openAICompactMode.value
}
+ if (props.account.type === 'apikey') {
+ if (openAIResponsesMode.value === 'auto') {
+ delete newExtra.openai_responses_mode
+ } else {
+ newExtra.openai_responses_mode = openAIResponsesMode.value
+ }
+ }
delete newExtra.codex_image_generation_bridge_enabled
if (codexImageGenerationBridgeMode.value === 'inherit') {
diff --git a/frontend/src/components/account/ModelWhitelistSelector.vue b/frontend/src/components/account/ModelWhitelistSelector.vue
index ebce3740..9a0d6af8 100644
--- a/frontend/src/components/account/ModelWhitelistSelector.vue
+++ b/frontend/src/components/account/ModelWhitelistSelector.vue
@@ -85,6 +85,15 @@
>
{{ t('admin.accounts.fillRelatedModels') }}
+
+ {{ isSyncingUpstream ? t('admin.accounts.syncUpstreamModelsLoading') : t('admin.accounts.syncUpstreamModels') }}
+
()
const emit = defineEmits<{
@@ -145,6 +156,7 @@ const showDropdown = ref(false)
const searchQuery = ref('')
const customModel = ref('')
const isComposing = ref(false)
+const isSyncingUpstream = ref(false)
const normalizedPlatforms = computed(() => {
const rawPlatforms =
props.platforms && props.platforms.length > 0
@@ -162,6 +174,13 @@ const normalizedPlatforms = computed(() => {
)
})
+const upstreamSyncPlatforms = new Set(['anthropic', 'openai', 'gemini', 'antigravity'])
+const canSyncUpstream = computed(() => {
+ if (!props.accountId) return false
+ if (normalizedPlatforms.value.length === 0) return true
+ return normalizedPlatforms.value.some(platform => upstreamSyncPlatforms.has(platform.toLowerCase()))
+})
+
const availableOptions = computed(() => {
if (normalizedPlatforms.value.length === 0) {
return allModels
@@ -229,6 +248,41 @@ const fillRelated = () => {
emit('update:modelValue', newModels)
}
+const syncUpstreamModels = async () => {
+ if (!props.accountId || isSyncingUpstream.value) return
+
+ isSyncingUpstream.value = true
+ try {
+ const result = await accountsAPI.syncUpstreamModels(props.accountId)
+ const upstreamModels = result.models.map(model => model.trim()).filter(Boolean)
+ if (upstreamModels.length === 0) {
+ appStore.showInfo(t('admin.accounts.syncUpstreamModelsEmpty'))
+ return
+ }
+
+ const newModels = [...props.modelValue]
+ let addedCount = 0
+ for (const model of upstreamModels) {
+ if (!newModels.includes(model)) {
+ newModels.push(model)
+ addedCount += 1
+ }
+ }
+
+ emit('update:modelValue', newModels)
+ if (addedCount > 0) {
+ appStore.showSuccess(t('admin.accounts.syncUpstreamModelsSuccess', { count: addedCount, total: upstreamModels.length }))
+ } else {
+ appStore.showInfo(t('admin.accounts.syncUpstreamModelsNoChanges', { count: upstreamModels.length }))
+ }
+ } catch (error) {
+ const message = error instanceof Error ? error.message : t('admin.accounts.syncUpstreamModelsFailed')
+ appStore.showError(t('admin.accounts.syncUpstreamModelsError', { message }))
+ } finally {
+ isSyncingUpstream.value = false
+ }
+}
+
const clearAll = () => {
emit('update:modelValue', [])
}
diff --git a/frontend/src/components/account/__tests__/EditAccountModal.spec.ts b/frontend/src/components/account/__tests__/EditAccountModal.spec.ts
index 04486154..0b8e939c 100644
--- a/frontend/src/components/account/__tests__/EditAccountModal.spec.ts
+++ b/frontend/src/components/account/__tests__/EditAccountModal.spec.ts
@@ -141,6 +141,32 @@ function buildAccount() {
} as any
}
+function buildVertexAccount() {
+ return {
+ id: 2,
+ name: 'Vertex SA',
+ notes: '',
+ platform: 'gemini',
+ type: 'service_account',
+ credentials: {
+ service_account_json: '{"type":"service_account","client_email":"sa@example.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\\nMIIE\\n-----END PRIVATE KEY-----\\n"}',
+ project_id: 'demo-project',
+ client_email: 'sa@example.iam.gserviceaccount.com',
+ location: 'us-central1',
+ tier_id: 'vertex'
+ },
+ extra: {},
+ proxy_id: null,
+ concurrency: 1,
+ priority: 1,
+ rate_multiplier: 1,
+ status: 'active',
+ group_ids: [],
+ expires_at: null,
+ auto_pause_on_expired: false
+ } as any
+}
+
function mountModal(account = buildAccount()) {
return mount(EditAccountModal, {
props: {
@@ -190,6 +216,31 @@ describe('EditAccountModal', () => {
})
})
+ it('preserves model mappings when editing the whitelist', async () => {
+ const account = buildAccount()
+ account.credentials.model_mapping = {
+ 'gpt-5.2': 'gpt-5.2',
+ 'gpt-latest': 'gpt-5.2'
+ }
+ updateAccountMock.mockReset()
+ checkMixedChannelRiskMock.mockReset()
+ checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
+ updateAccountMock.mockResolvedValue(account)
+
+ const wrapper = mountModal(account)
+
+ expect(wrapper.get('[data-testid="model-whitelist-value"]').text()).toBe('gpt-5.2')
+
+ await wrapper.get('[data-testid="rewrite-to-snapshot"]').trigger('click')
+ await wrapper.get('form#edit-account-form').trigger('submit.prevent')
+
+ expect(updateAccountMock).toHaveBeenCalledTimes(1)
+ expect(updateAccountMock.mock.calls[0]?.[1]?.credentials?.model_mapping).toEqual({
+ 'gpt-5.2-2025-12-11': 'gpt-5.2-2025-12-11',
+ 'gpt-latest': 'gpt-5.2'
+ })
+ })
+
it('submits OpenAI compact mode and compact-only model mapping', async () => {
const account = buildAccount()
account.extra = {
@@ -217,6 +268,48 @@ describe('EditAccountModal', () => {
})
})
+ it('submits OpenAI APIKey Responses support override mode', async () => {
+ const account = buildAccount()
+ account.extra = {
+ openai_responses_mode: 'force_chat_completions',
+ openai_responses_supported: false
+ }
+ updateAccountMock.mockReset()
+ checkMixedChannelRiskMock.mockReset()
+ checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
+ updateAccountMock.mockResolvedValue(account)
+
+ const wrapper = mountModal(account)
+
+ await wrapper.get('[data-testid="openai-responses-mode-select"]').setValue('force_responses')
+ await wrapper.get('form#edit-account-form').trigger('submit.prevent')
+
+ expect(updateAccountMock).toHaveBeenCalledTimes(1)
+ expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.openai_responses_mode).toBe('force_responses')
+ expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.openai_responses_supported).toBe(false)
+ })
+
+ it('clears OpenAI APIKey Responses override when set back to auto', async () => {
+ const account = buildAccount()
+ account.extra = {
+ openai_responses_mode: 'force_chat_completions',
+ openai_responses_supported: true
+ }
+ updateAccountMock.mockReset()
+ checkMixedChannelRiskMock.mockReset()
+ checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
+ updateAccountMock.mockResolvedValue(account)
+
+ const wrapper = mountModal(account)
+
+ await wrapper.get('[data-testid="openai-responses-mode-select"]').setValue('auto')
+ await wrapper.get('form#edit-account-form').trigger('submit.prevent')
+
+ expect(updateAccountMock).toHaveBeenCalledTimes(1)
+ expect(updateAccountMock.mock.calls[0]?.[1]?.extra).not.toHaveProperty('openai_responses_mode')
+ expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.openai_responses_supported).toBe(true)
+ })
+
it('submits account-level Codex image generation bridge override', async () => {
const account = buildAccount()
account.extra = {
@@ -237,4 +330,122 @@ describe('EditAccountModal', () => {
expect(updateAccountMock.mock.calls[0]?.[1]?.extra?.codex_image_generation_bridge).toBe(true)
expect(updateAccountMock.mock.calls[0]?.[1]?.extra).not.toHaveProperty('codex_image_generation_bridge_enabled')
})
+
+ it('allows saving apikey account when backend redacted api_key but credentials_status reports it exists', async () => {
+ // 新前端 + 新后端:响应已脱敏,credentials 里没有 api_key,credentials_status.has_api_key=true
+ const account = buildAccount()
+ account.credentials = {
+ base_url: 'https://api.openai.com',
+ model_mapping: { 'gpt-5.2': 'gpt-5.2' }
+ }
+ account.credentials_status = { has_api_key: true }
+ updateAccountMock.mockReset()
+ checkMixedChannelRiskMock.mockReset()
+ checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
+ updateAccountMock.mockResolvedValue(account)
+
+ const wrapper = mountModal(account)
+
+ await wrapper.get('form#edit-account-form').trigger('submit.prevent')
+
+ expect(updateAccountMock).toHaveBeenCalledTimes(1)
+ // 用户未输入新 key 时,payload 不应带 api_key,由后端合并保留旧值
+ expect(updateAccountMock.mock.calls[0]?.[1]?.credentials).not.toHaveProperty('api_key')
+ })
+
+ it('allows saving apikey account against legacy backend without credentials_status', async () => {
+ // 新前端 + 旧后端:credentials_status 缺失,但 credentials.api_key 仍是明文,应允许保存
+ const account = buildAccount()
+ // 显式确保没有 credentials_status
+ expect(account.credentials_status).toBeUndefined()
+ updateAccountMock.mockReset()
+ checkMixedChannelRiskMock.mockReset()
+ checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
+ updateAccountMock.mockResolvedValue(account)
+
+ const wrapper = mountModal(account)
+
+ await wrapper.get('form#edit-account-form').trigger('submit.prevent')
+
+ expect(updateAccountMock).toHaveBeenCalledTimes(1)
+ // 旧后端响应未脱敏,原 api_key 会随 currentCredentials 一起传回去(旧行为,等价于无操作)
+ expect(updateAccountMock.mock.calls[0]?.[1]?.credentials?.api_key).toBe('sk-test')
+ })
+
+ it('blocks apikey save when neither credentials_status nor legacy api_key indicates existence', async () => {
+ const account = buildAccount()
+ account.credentials = {
+ base_url: 'https://api.openai.com'
+ }
+ // 既没有 credentials_status 也没有旧的 api_key
+ updateAccountMock.mockReset()
+ checkMixedChannelRiskMock.mockReset()
+ checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
+
+ const wrapper = mountModal(account)
+
+ await wrapper.get('form#edit-account-form').trigger('submit.prevent')
+
+ expect(updateAccountMock).not.toHaveBeenCalled()
+ })
+
+ it('allows saving Vertex SA account when backend redacted service_account_json but credentials_status reports it exists', async () => {
+ // 新前端 + 新后端:响应已脱敏,credentials 里没有 service_account_json,credentials_status.has_service_account_json=true
+ const account = buildVertexAccount()
+ account.credentials = {
+ project_id: 'demo-project',
+ client_email: 'sa@example.iam.gserviceaccount.com',
+ location: 'us-central1',
+ tier_id: 'vertex'
+ }
+ account.credentials_status = { has_service_account_json: true }
+ updateAccountMock.mockReset()
+ checkMixedChannelRiskMock.mockReset()
+ checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
+ updateAccountMock.mockResolvedValue(account)
+
+ const wrapper = mountModal(account)
+
+ await wrapper.get('form#edit-account-form').trigger('submit.prevent')
+
+ expect(updateAccountMock).toHaveBeenCalledTimes(1)
+ expect(updateAccountMock.mock.calls[0]?.[1]?.credentials?.project_id).toBe('demo-project')
+ })
+
+ it('allows saving Vertex SA account against legacy backend without credentials_status', async () => {
+ // 新前端 + 旧后端:credentials_status 缺失,但 credentials.service_account_json 仍是明文,应允许保存
+ const account = buildVertexAccount()
+ expect(account.credentials_status).toBeUndefined()
+ expect(account.credentials.service_account_json).toBeTruthy()
+ updateAccountMock.mockReset()
+ checkMixedChannelRiskMock.mockReset()
+ checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
+ updateAccountMock.mockResolvedValue(account)
+
+ const wrapper = mountModal(account)
+
+ await wrapper.get('form#edit-account-form').trigger('submit.prevent')
+
+ expect(updateAccountMock).toHaveBeenCalledTimes(1)
+ })
+
+ it('blocks Vertex SA save when neither credentials_status nor legacy json indicates existence', async () => {
+ const account = buildVertexAccount()
+ account.credentials = {
+ project_id: 'demo-project',
+ client_email: 'sa@example.iam.gserviceaccount.com',
+ location: 'us-central1',
+ tier_id: 'vertex'
+ }
+ // 既没有 credentials_status 也没有旧的 service_account_json
+ updateAccountMock.mockReset()
+ checkMixedChannelRiskMock.mockReset()
+ checkMixedChannelRiskMock.mockResolvedValue({ has_risk: false })
+
+ const wrapper = mountModal(account)
+
+ await wrapper.get('form#edit-account-form').trigger('submit.prevent')
+
+ expect(updateAccountMock).not.toHaveBeenCalled()
+ })
})
diff --git a/frontend/src/components/admin/account/AccountTestModal.vue b/frontend/src/components/admin/account/AccountTestModal.vue
index ae0fd9a7..f4624f3f 100644
--- a/frontend/src/components/admin/account/AccountTestModal.vue
+++ b/frontend/src/components/admin/account/AccountTestModal.vue
@@ -275,7 +275,7 @@ const loadingModels = ref(false)
let abortController: AbortController | null = null
const generatedImages = ref([])
const previewImageUrl = ref('')
-const prioritizedGeminiModels = ['gemini-3.1-flash-image', 'gemini-2.5-flash-image', 'gemini-2.5-flash', 'gemini-2.5-pro', 'gemini-3-flash-preview', 'gemini-3-pro-preview', 'gemini-2.0-flash']
+const prioritizedGeminiModels = ['gemini-3.1-flash-image', 'gemini-2.5-flash-image', 'gemini-3.5-flash', 'gemini-2.5-flash', 'gemini-2.5-pro', 'gemini-3-flash-preview', 'gemini-3-pro-preview', 'gemini-2.0-flash']
const supportsGeminiImageTest = computed(() => {
const modelID = selectedModelId.value.toLowerCase()
if (!modelID.startsWith('gemini-') || !modelID.includes('-image')) return false
diff --git a/frontend/src/components/admin/channel/__tests__/types.spec.ts b/frontend/src/components/admin/channel/__tests__/types.spec.ts
new file mode 100644
index 00000000..9fd8e066
--- /dev/null
+++ b/frontend/src/components/admin/channel/__tests__/types.spec.ts
@@ -0,0 +1,79 @@
+import { describe, expect, it } from 'vitest'
+import { validateIntervals, type IntervalFormEntry } from '../types'
+
+function makeInterval(over: Partial): IntervalFormEntry {
+ return {
+ min_tokens: 0,
+ max_tokens: null,
+ tier_label: '',
+ input_price: null,
+ output_price: null,
+ cache_write_price: null,
+ cache_read_price: null,
+ per_request_price: null,
+ sort_order: 0,
+ ...over,
+ }
+}
+
+describe('validateIntervals', () => {
+ describe('token mode', () => {
+ it('rejects unbounded interval that is not last', () => {
+ const intervals: IntervalFormEntry[] = [
+ makeInterval({ min_tokens: 0, max_tokens: null, input_price: 1, output_price: 1 }),
+ makeInterval({ min_tokens: 200000, max_tokens: 500000, input_price: 2, output_price: 2 }),
+ ]
+ expect(validateIntervals(intervals, 'token')).toMatch(/无上限/)
+ })
+
+ it('accepts unbounded interval at the end', () => {
+ const intervals: IntervalFormEntry[] = [
+ makeInterval({ min_tokens: 0, max_tokens: 200000, input_price: 1, output_price: 1 }),
+ makeInterval({ min_tokens: 200000, max_tokens: null, input_price: 2, output_price: 2 }),
+ ]
+ expect(validateIntervals(intervals, 'token')).toBeNull()
+ })
+
+ it('rejects overlapping intervals', () => {
+ const intervals: IntervalFormEntry[] = [
+ makeInterval({ min_tokens: 0, max_tokens: 250000, input_price: 1, output_price: 1 }),
+ makeInterval({ min_tokens: 200000, max_tokens: 500000, input_price: 2, output_price: 2 }),
+ ]
+ expect(validateIntervals(intervals, 'token')).toMatch(/重叠/)
+ })
+
+ it('defaults mode to token when omitted', () => {
+ const intervals: IntervalFormEntry[] = [
+ makeInterval({ min_tokens: 0, max_tokens: null, input_price: 1, output_price: 1 }),
+ makeInterval({ min_tokens: 100, max_tokens: 200, input_price: 2, output_price: 2 }),
+ ]
+ expect(validateIntervals(intervals)).toMatch(/无上限/)
+ })
+ })
+
+ describe('image / per_request mode', () => {
+ it('allows multiple unbounded tiers identified by label', () => {
+ const intervals: IntervalFormEntry[] = [
+ makeInterval({ tier_label: '1K', per_request_price: 0.04 }),
+ makeInterval({ tier_label: '2K', per_request_price: 0.06 }),
+ makeInterval({ tier_label: '4K', per_request_price: 0.08 }),
+ ]
+ expect(validateIntervals(intervals, 'image')).toBeNull()
+ expect(validateIntervals(intervals, 'per_request')).toBeNull()
+ })
+
+ it('still rejects negative prices', () => {
+ const intervals: IntervalFormEntry[] = [
+ makeInterval({ tier_label: '1K', per_request_price: -1 }),
+ ]
+ expect(validateIntervals(intervals, 'image')).toMatch(/不能为负数/)
+ })
+
+ it('still rejects max <= min on a single tier', () => {
+ const intervals: IntervalFormEntry[] = [
+ makeInterval({ tier_label: '1K', min_tokens: 100, max_tokens: 50, per_request_price: 0.04 }),
+ ]
+ expect(validateIntervals(intervals, 'image')).toMatch(/必须大于/)
+ })
+ })
+})
diff --git a/frontend/src/components/admin/channel/types.ts b/frontend/src/components/admin/channel/types.ts
index 076bd589..960fe317 100644
--- a/frontend/src/components/admin/channel/types.ts
+++ b/frontend/src/components/admin/channel/types.ts
@@ -115,8 +115,17 @@ export function findModelConflict(models: string[]): [string, string] | null {
// ── 区间校验 ──────────────────────────────────────────────
-/** 校验区间列表的合法性,返回错误消息;通过则返回 null */
-export function validateIntervals(intervals: IntervalFormEntry[]): string | null {
+/** 校验区间列表的合法性,返回错误消息;通过则返回 null
+ *
+ * mode 决定区间语义:
+ * - token:区间是上下文 token 数分段 (min, max],不能重叠,无上限段必须放最后
+ * - per_request / image:区间是按 tier_label 分层(1K/2K/4K 等),后端按 label
+ * 匹配,不依赖 min/max,因此跳过重叠 / last-unlimited 校验
+ */
+export function validateIntervals(
+ intervals: IntervalFormEntry[],
+ mode: BillingMode = 'token',
+): string | null {
if (!intervals || intervals.length === 0) return null
// 按 min_tokens 排序(不修改原数组)
@@ -126,6 +135,9 @@ export function validateIntervals(intervals: IntervalFormEntry[]): string | null
const err = validateSingleInterval(sorted[i], i)
if (err) return err
}
+
+ // per_request / image 模式按 tier_label 匹配,不做 token 区间重叠校验
+ if (mode !== 'token') return null
return checkIntervalOverlap(sorted)
}
diff --git a/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue b/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue
index 0d6b4ace..404b6916 100644
--- a/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue
+++ b/frontend/src/components/admin/monitor/MonitorAdvancedRequestConfig.vue
@@ -106,9 +106,15 @@
diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue
index 629e6aa2..65ac1548 100644
--- a/frontend/src/components/admin/usage/UsageTable.vue
+++ b/frontend/src/components/admin/usage/UsageTable.vue
@@ -86,19 +86,19 @@
-
- {{ getBillingModeLabel(row.billing_mode, t) }}
+
+ {{ getBillingModeLabel(getDisplayBillingMode(row), t) }}
-
+
{{ row.image_count }}{{ t('usage.imageUnit') }}
-
({{ row.image_size || '2K' }})
+
({{ formatImageBillingSize(row, t) }})
@@ -280,21 +280,30 @@
${{ tooltipData.output_cost.toFixed(6) }}
-
-
- {{ t('usage.inputTokenPrice') }}
- {{ formatTokenPricePerMillion(tooltipData.input_cost, tooltipData.input_tokens) }} {{ t('usage.perMillionTokens') }}
-
-
- {{ t('usage.outputTokenPrice') }}
- {{ formatTokenPricePerMillion(tooltipData.output_cost, tooltipData.output_tokens) }} {{ t('usage.perMillionTokens') }}
-
-
-
-
+
{{ t('usage.imageCount') }}
- {{ tooltipData.image_count }}{{ t('usage.imageUnit') }} ({{ tooltipData.image_size || '2K' }})
+ {{ tooltipData.image_count }}{{ t('usage.imageUnit') }}
+
+
+ {{ t('usage.imageBillingSize') }}
+ {{ formatImageBillingSize(tooltipData, t) }}
+
+
+ {{ t('usage.imageSizeSource') }}
+ {{ formatImageSizeSource(tooltipData, t) }}
+
+
+ {{ t('usage.imageInputSize') }}
+ {{ formatImageInputSize(tooltipData, t) }}
+
+
+ {{ t('usage.imageOutputSize') }}
+ {{ formatImageOutputSize(tooltipData, t) }}
+
+
+ {{ t('usage.imageSizeBreakdown') }}
+ {{ formatImageSizeBreakdown(tooltipData) }}
{{ t('usage.imageUnitPrice') }}
@@ -305,6 +314,16 @@
${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }}
+
+
+ {{ t('usage.inputTokenPrice') }}
+ {{ formatTokenPricePerMillion(tooltipData.input_cost, tooltipData.input_tokens) }} {{ t('usage.perMillionTokens') }}
+
+
+ {{ t('usage.outputTokenPrice') }}
+ {{ formatTokenPricePerMillion(tooltipData.output_cost, tooltipData.output_tokens) }} {{ t('usage.perMillionTokens') }}
+
+
{{ t('usage.unitPrice') }}
${{ tooltipData?.total_cost?.toFixed(6) || '0.000000' }}
@@ -366,6 +385,13 @@ import { formatTokenPricePerMillion } from '@/utils/usagePricing'
import { getUsageServiceTierLabel } from '@/utils/usageServiceTier'
import { resolveUsageRequestType } from '@/utils/usageRequestType'
import { getBillingModeLabel, getBillingModeBadgeClass, BILLING_MODE_TOKEN, BILLING_MODE_IMAGE } from '@/utils/billingMode'
+import {
+ formatImageBillingSize,
+ formatImageInputSize,
+ formatImageOutputSize,
+ formatImageSizeBreakdown,
+ formatImageSizeSource,
+} from '@/utils/imageUsage'
/** Compute the account-billed cost for display: (account_stats_cost ?? total_cost) * rate_multiplier */
function accountBilled(row: { total_cost?: number | null; account_stats_cost?: number | null; account_rate_multiplier?: number | null }): number {
@@ -381,6 +407,17 @@ function imageUnitPrice(row: AdminUsageLog | null): number {
return Number.isFinite(price) ? price : 0
}
+function isImageUsage(row: Pick
| null | undefined): boolean {
+ return (row?.image_count ?? 0) > 0
+}
+
+function getDisplayBillingMode(row: Pick | null | undefined): string | null | undefined {
+ if (isImageUsage(row)) {
+ return BILLING_MODE_IMAGE
+ }
+ return row?.billing_mode
+}
+
import DataTable from '@/components/common/DataTable.vue'
import EmptyState from '@/components/common/EmptyState.vue'
import Icon from '@/components/icons/Icon.vue'
diff --git a/frontend/src/components/admin/usage/__tests__/UsageTable.spec.ts b/frontend/src/components/admin/usage/__tests__/UsageTable.spec.ts
index 9309c88b..ece0dbda 100644
--- a/frontend/src/components/admin/usage/__tests__/UsageTable.spec.ts
+++ b/frontend/src/components/admin/usage/__tests__/UsageTable.spec.ts
@@ -22,6 +22,26 @@ const messages: Record = {
'usage.original': 'Original',
'usage.userBilled': 'User billed',
'usage.accountBilled': 'Account billed',
+ 'usage.imageUnit': ' images',
+ 'usage.imageCount': 'Image count',
+ 'usage.imageBillingSize': 'Billing size',
+ 'usage.imageInputSize': 'Input size',
+ 'usage.imageOutputSize': 'Output size',
+ 'usage.imageSizeSource': 'Size source',
+ 'usage.imageSizeBreakdown': 'Size breakdown',
+ 'usage.imageSizeSourceOutput': 'Upstream output',
+ 'usage.imageSizeSourceInput': 'Request input',
+ 'usage.imageSizeSourceDefault': 'Default billing tier',
+ 'usage.imageSizeSourceLegacy': 'Legacy record',
+ 'usage.imageSizeSourceMissing': 'Not recorded',
+ 'usage.imageSizeNotRecorded': 'not recorded',
+ 'usage.imageSizeLegacyUnstandardized': 'legacy unstandardized',
+ 'usage.imageSizeUnknown': 'unknown',
+ 'usage.imageUnitPrice': 'Per-image price',
+ 'usage.imageTotalPrice': 'Image total price',
+ 'admin.usage.billingModeToken': 'Token',
+ 'admin.usage.billingModePerRequest': 'Per request',
+ 'admin.usage.billingModeImage': 'Image',
}
vi.mock('vue-i18n', async () => {
@@ -40,12 +60,42 @@ const DataTableStub = {
`,
}
+const baseImageRow = {
+ request_id: 'req-admin-image',
+ model: 'gpt-image-2',
+ actual_cost: 0.4,
+ total_cost: 0.4,
+ account_rate_multiplier: 1,
+ rate_multiplier: 1,
+ service_tier: null,
+ input_cost: 0,
+ output_cost: 0,
+ cache_creation_cost: 0,
+ cache_read_cost: 0,
+ input_tokens: 0,
+ output_tokens: 0,
+ cache_creation_tokens: 0,
+ cache_read_tokens: 0,
+ cache_creation_5m_tokens: 0,
+ cache_creation_1h_tokens: 0,
+ cache_ttl_overridden: false,
+ billing_mode: 'image',
+ image_count: 2,
+ image_size: '2K',
+ image_input_size: null,
+ image_output_size: null,
+ image_size_source: null,
+ image_size_breakdown: null,
+}
+
describe('admin UsageTable tooltip', () => {
beforeEach(() => {
vi.spyOn(HTMLElement.prototype, 'getBoundingClientRect').mockReturnValue({
@@ -93,7 +143,8 @@ describe('admin UsageTable tooltip', () => {
},
})
- await wrapper.find('.group.relative').trigger('mouseenter')
+ const tooltipTriggers = wrapper.findAll('.group.relative')
+ await tooltipTriggers[tooltipTriggers.length - 1].trigger('mouseenter')
await nextTick()
const text = wrapper.text()
@@ -147,4 +198,126 @@ describe('admin UsageTable tooltip', () => {
expect(text).toContain('claude-sonnet-4')
expect(text).toContain('claude-sonnet-4-20250514')
})
+
+ it.each([
+ {
+ name: 'defaulted row',
+ row: {
+ ...baseImageRow,
+ request_id: 'req-admin-default-image',
+ image_size: '2K',
+ image_input_size: 'auto',
+ image_output_size: null,
+ image_size_source: 'default',
+ },
+ expected: ['2K', 'Default billing tier', 'auto', 'unknown'],
+ },
+ {
+ name: 'output-sourced row',
+ row: {
+ ...baseImageRow,
+ request_id: 'req-admin-output-image',
+ image_size: '4K',
+ image_input_size: '1024x1024',
+ image_output_size: '3840x2160',
+ image_size_source: 'output',
+ image_size_breakdown: { '4K': 1 },
+ },
+ expected: ['4K', 'Upstream output', '1024x1024', '3840x2160', '4K x 1'],
+ },
+ {
+ name: 'input-sourced row',
+ row: {
+ ...baseImageRow,
+ request_id: 'req-admin-input-image',
+ image_size: '1K',
+ image_input_size: '1024x1024',
+ image_output_size: null,
+ image_size_source: 'input',
+ },
+ expected: ['1K', 'Request input', '1024x1024', 'unknown'],
+ },
+ {
+ name: 'legacy unstandardized row',
+ row: {
+ ...baseImageRow,
+ request_id: 'req-admin-legacy-unstandardized-image',
+ image_size: '512x512',
+ image_input_size: null,
+ image_output_size: null,
+ image_size_source: null,
+ },
+ expected: ['legacy unstandardized: 512x512', 'Legacy record', 'unknown'],
+ },
+ ])('shows image usage metadata for $name', async ({ row, expected }) => {
+ const wrapper = mount(UsageTable, {
+ props: {
+ data: [row],
+ loading: false,
+ columns: [],
+ },
+ global: {
+ stubs: {
+ DataTable: DataTableStub,
+ EmptyState: true,
+ Icon: true,
+ Teleport: true,
+ },
+ },
+ })
+
+ await wrapper.find('.group.relative').trigger('mouseenter')
+ await nextTick()
+
+ const text = wrapper.text()
+ expect(text).toContain('Image count')
+ expect(text).toContain('Billing size')
+ expect(text).toContain('Size source')
+ expect(text).toContain('Input size')
+ expect(text).toContain('Output size')
+ expect(text).toContain('Per-image price')
+ expect(text).toContain('Image total price')
+ for (const value of expected) {
+ expect(text).toContain(value)
+ }
+ })
+
+ it('displays historical image rows with missing billing_mode as image usage without a 2K fallback', async () => {
+ const wrapper = mount(UsageTable, {
+ props: {
+ data: [
+ {
+ ...baseImageRow,
+ request_id: 'req-admin-legacy-missing-image',
+ billing_mode: null,
+ image_size: null,
+ image_input_size: null,
+ image_output_size: null,
+ image_size_source: null,
+ image_size_breakdown: null,
+ },
+ ],
+ loading: false,
+ columns: [],
+ },
+ global: {
+ stubs: {
+ DataTable: DataTableStub,
+ EmptyState: true,
+ Icon: true,
+ Teleport: true,
+ },
+ },
+ })
+
+ await wrapper.find('.group.relative').trigger('mouseenter')
+ await nextTick()
+
+ const text = wrapper.text()
+ expect(text).toContain('Image')
+ expect(text).toContain('Image count')
+ expect(text).toContain('Per-image price')
+ expect(text).toContain('not recorded')
+ expect(text).not.toContain('(2K)')
+ })
})
diff --git a/frontend/src/components/auth/DingTalkOAuthSection.vue b/frontend/src/components/auth/DingTalkOAuthSection.vue
new file mode 100644
index 00000000..9003225d
--- /dev/null
+++ b/frontend/src/components/auth/DingTalkOAuthSection.vue
@@ -0,0 +1,61 @@
+
+
+
+
+
+ 钉
+
+ {{ t('auth.dingtalk.signIn') }}
+
+
+
+
+
+ {{ t('auth.oauthOrContinue') }}
+
+
+
+
+
+
+
diff --git a/frontend/src/components/auth/TotpLoginModal.vue b/frontend/src/components/auth/TotpLoginModal.vue
index 0ae2f482..5f68b9a7 100644
--- a/frontend/src/components/auth/TotpLoginModal.vue
+++ b/frontend/src/components/auth/TotpLoginModal.vue
@@ -24,6 +24,18 @@
+
+
(['', '', '', '', '', ''])
const inputRefs = ref<(HTMLInputElement | null)[]>([])
+const hiddenOtpInputRef = ref
(null)
// Watch for code changes and auto-submit when 6 digits are entered
watch(
@@ -104,6 +118,10 @@ defineExpose({
inputRefs.value.forEach(input => {
if (input) input.value = ''
})
+ // Clear hidden autofill input
+ if (hiddenOtpInputRef.value) {
+ hiddenOtpInputRef.value.value = ''
+ }
nextTick(() => {
inputRefs.value[0]?.focus()
})
@@ -126,6 +144,26 @@ const handleCodeInput = (event: Event, index: number) => {
}
}
+// Handle autofill from password managers via the hidden autocomplete="one-time-code" input
+const handleHiddenOtpInput = (event: Event) => {
+ const input = event.target as HTMLInputElement
+ const digits = input.value.replace(/[^0-9]/g, '').slice(0, 6).split('')
+
+ digits.forEach((digit, i) => {
+ code.value[i] = digit
+ if (inputRefs.value[i]) {
+ inputRefs.value[i]!.value = digit
+ }
+ })
+
+ for (let i = digits.length; i < 6; i++) {
+ code.value[i] = ''
+ if (inputRefs.value[i]) {
+ inputRefs.value[i]!.value = ''
+ }
+ }
+}
+
const handleKeydown = (event: KeyboardEvent, index: number) => {
if (event.key === 'Backspace') {
const input = event.target as HTMLInputElement
diff --git a/frontend/src/components/keys/UseKeyModal.vue b/frontend/src/components/keys/UseKeyModal.vue
index 99478562..e3729cc8 100644
--- a/frontend/src/components/keys/UseKeyModal.vue
+++ b/frontend/src/components/keys/UseKeyModal.vue
@@ -785,6 +785,17 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
}
}
},
+ 'gemini-3.5-flash': {
+ name: 'Gemini 3.5 Flash',
+ limit: {
+ context: 1048576,
+ output: 65536
+ },
+ modalities: {
+ input: ['text', 'image', 'pdf'],
+ output: ['text']
+ }
+ },
'gemini-3-flash-preview': {
name: 'Gemini 3 Flash Preview',
limit: {
diff --git a/frontend/src/components/payment/PaymentProviderDialog.vue b/frontend/src/components/payment/PaymentProviderDialog.vue
index 86304cf6..b6085ed0 100644
--- a/frontend/src/components/payment/PaymentProviderDialog.vue
+++ b/frontend/src/components/payment/PaymentProviderDialog.vue
@@ -34,7 +34,7 @@
-
+
{{ t('admin.settings.payment.paymentMode') }}
const callbackPaths = computed(() => PROVIDER_CALLBACK_PATHS[form.provider_key] || null)
+const supportsPaymentMode = computed(() => providerSupportsPaymentMode(form.provider_key))
+
const paymentModeOptions = computed(() => {
+ if (form.provider_key === 'alipay') {
+ // For Alipay official: "" = default (precreate → page.pay fallback);
+ // "redirect" = always open the Alipay checkout page in a new tab.
+ return [
+ { value: '', label: t('admin.settings.payment.modeQRCode') },
+ { value: PAYMENT_MODE_REDIRECT, label: t('admin.settings.payment.modeRedirect') },
+ ]
+ }
return [
{ value: PAYMENT_MODE_QRCODE, label: t('admin.settings.payment.modeQRCode') },
{ value: PAYMENT_MODE_POPUP, label: t('admin.settings.payment.modePopup') },
@@ -476,6 +512,7 @@ function toggleType(type: string) {
function onKeyChange() {
form.supported_types = [...(PROVIDER_SUPPORTED_TYPES[form.provider_key] || [])]
+ form.payment_mode = defaultPaymentMode(form.provider_key)
clearConfig()
applyDefaults()
}
@@ -591,7 +628,7 @@ function handleSave() {
name: form.name,
supported_types: form.supported_types,
enabled: form.enabled,
- payment_mode: form.provider_key === 'easypay' ? form.payment_mode : '',
+ payment_mode: supportsPaymentMode.value ? form.payment_mode : '',
refund_enabled: form.refund_enabled,
allow_user_refund: form.refund_enabled ? form.allow_user_refund : false,
config: filteredConfig,
@@ -611,7 +648,7 @@ function reset(defaultKey: string) {
form.provider_key = defaultKey
form.supported_types = [...(PROVIDER_SUPPORTED_TYPES[defaultKey] || [])]
form.enabled = true
- form.payment_mode = defaultKey === 'easypay' ? PAYMENT_MODE_QRCODE : ''
+ form.payment_mode = defaultPaymentMode(defaultKey)
form.refund_enabled = false
form.allow_user_refund = false
clearConfig()
@@ -623,7 +660,12 @@ function loadProvider(provider: ProviderInstance) {
form.provider_key = provider.provider_key
form.supported_types = provider.supported_types
form.enabled = provider.enabled
- form.payment_mode = provider.payment_mode || (provider.provider_key === 'easypay' ? PAYMENT_MODE_QRCODE : '')
+ // Coerce to a valid value for this provider. Guards against stale data
+ // (e.g. "popup" written by an older client) showing up as an unselected
+ // button in the dialog.
+ form.payment_mode = isValidPaymentMode(provider.provider_key, provider.payment_mode || '')
+ ? (provider.payment_mode || '')
+ : defaultPaymentMode(provider.provider_key)
form.refund_enabled = provider.refund_enabled
form.allow_user_refund = provider.allow_user_refund
clearConfig()
diff --git a/frontend/src/components/payment/PaymentQRDialog.vue b/frontend/src/components/payment/PaymentQRDialog.vue
index 09d273cc..34cc5203 100644
--- a/frontend/src/components/payment/PaymentQRDialog.vue
+++ b/frontend/src/components/payment/PaymentQRDialog.vue
@@ -114,6 +114,11 @@ const paidOrder = ref(null)
let pollTimer: ReturnType | null = null
let countdownTimer: ReturnType | null = null
+let verifyAttempts = 0
+let lastVerifyAt = 0
+
+const VERIFY_RETRY_INTERVAL_MS = 15000
+const VERIFY_RETRY_MAX_ATTEMPTS = 6
const isAlipay = computed(() => props.paymentType.includes('alipay'))
const isWxpay = computed(() => props.paymentType.includes('wxpay'))
@@ -186,8 +191,9 @@ async function renderQR() {
async function pollStatus() {
if (!props.orderId) return
- const order = await paymentStore.pollOrderStatus(props.orderId)
+ let order = await paymentStore.pollOrderStatus(props.orderId)
if (!order) return
+ order = await tryRecoverPendingOrder(order)
if (order.status === 'COMPLETED' || order.status === 'PAID') {
cleanup()
paidOrder.value = order
@@ -199,6 +205,27 @@ async function pollStatus() {
}
}
+async function tryRecoverPendingOrder(order: PaymentOrder): Promise {
+ if (!isWxpay.value) return order
+ const outTradeNo = String(order.out_trade_no || '').trim()
+ if (!outTradeNo) return order
+ const normalizedStatus = String(order.status || '').trim().toUpperCase()
+ if (normalizedStatus !== 'PENDING') return order
+ const now = Date.now()
+ if (verifyAttempts >= VERIFY_RETRY_MAX_ATTEMPTS || now - lastVerifyAt < VERIFY_RETRY_INTERVAL_MS) {
+ return order
+ }
+
+ lastVerifyAt = now
+ verifyAttempts += 1
+ try {
+ const result = await paymentAPI.verifyOrder(outTradeNo)
+ return result.data ?? order
+ } catch {
+ return order
+ }
+}
+
function startCountdown(seconds: number) {
remainingSeconds.value = Math.max(0, seconds)
if (remainingSeconds.value <= 0) {
@@ -250,6 +277,8 @@ function init() {
expired.value = false
cancelling.value = false
qrUrl.value = props.qrCode
+ verifyAttempts = 0
+ lastVerifyAt = 0
let seconds = 30 * 60
if (props.expiresAt) {
diff --git a/frontend/src/components/payment/PaymentStatusPanel.vue b/frontend/src/components/payment/PaymentStatusPanel.vue
index 2c8b0a93..2a1349af 100644
--- a/frontend/src/components/payment/PaymentStatusPanel.vue
+++ b/frontend/src/components/payment/PaymentStatusPanel.vue
@@ -175,6 +175,11 @@ const outcome = ref(null)
let pollTimer: ReturnType | null = null
let countdownTimer: ReturnType | null = null
+let verifyAttempts = 0
+let lastVerifyAt = 0
+
+const VERIFY_RETRY_INTERVAL_MS = 15000
+const VERIFY_RETRY_MAX_ATTEMPTS = 6
const isAlipay = computed(() => props.paymentType.includes('alipay'))
const isWxpay = computed(() => props.paymentType.includes('wxpay'))
@@ -241,10 +246,32 @@ async function renderQR() {
})
}
+async function tryRecoverPendingOrder(order: PaymentOrder): Promise {
+ if (!isWxpay.value) return order
+ const outTradeNo = String(order.out_trade_no || '').trim()
+ if (!outTradeNo) return order
+ const normalizedStatus = String(order.status || '').trim().toUpperCase()
+ if (normalizedStatus !== 'PENDING') return order
+ const now = Date.now()
+ if (verifyAttempts >= VERIFY_RETRY_MAX_ATTEMPTS || now - lastVerifyAt < VERIFY_RETRY_INTERVAL_MS) {
+ return order
+ }
+
+ lastVerifyAt = now
+ verifyAttempts += 1
+ try {
+ const result = await paymentAPI.verifyOrder(outTradeNo)
+ return result.data ?? order
+ } catch {
+ return order
+ }
+}
+
async function pollStatus() {
if (!props.orderId || outcome.value) return
- const order = await paymentStore.pollOrderStatus(props.orderId)
+ let order = await paymentStore.pollOrderStatus(props.orderId)
if (!order) return
+ order = await tryRecoverPendingOrder(order)
if (isSuccessStatus(order.status)) {
cleanup()
paidOrder.value = order
@@ -291,6 +318,8 @@ function cleanup() {
// Initialize on mount
qrUrl.value = props.qrCode
+verifyAttempts = 0
+lastVerifyAt = 0
let seconds = 30 * 60
if (props.expiresAt) {
seconds = Math.floor((new Date(props.expiresAt).getTime() - Date.now()) / 1000)
diff --git a/frontend/src/components/payment/ProviderCard.vue b/frontend/src/components/payment/ProviderCard.vue
index 9a73c027..e64d5d5e 100644
--- a/frontend/src/components/payment/ProviderCard.vue
+++ b/frontend/src/components/payment/ProviderCard.vue
@@ -69,7 +69,7 @@ import Icon from '@/components/icons/Icon.vue'
import ToggleSwitch from './ToggleSwitch.vue'
import type { ProviderInstance } from '@/types/payment'
import type { TypeOption } from './providerConfig'
-import { PAYMENT_MODE_QRCODE, PAYMENT_MODE_POPUP } from './providerConfig'
+import { PAYMENT_MODE_QRCODE, PAYMENT_MODE_POPUP, PAYMENT_MODE_REDIRECT } from './providerConfig'
const PROVIDER_KEY_LABELS: Record = {
easypay: 'admin.settings.payment.providerEasypay',
@@ -99,6 +99,7 @@ const keyLabel = computed(() => t(PROVIDER_KEY_LABELS[props.provider.provider_ke
const modeLabel = computed(() => {
if (props.provider.payment_mode === PAYMENT_MODE_QRCODE) return t('admin.settings.payment.modeQRCode')
if (props.provider.payment_mode === PAYMENT_MODE_POPUP) return t('admin.settings.payment.modePopup')
+ if (props.provider.payment_mode === PAYMENT_MODE_REDIRECT) return t('admin.settings.payment.modeRedirect')
return ''
})
diff --git a/frontend/src/components/payment/SubscriptionPlanCard.vue b/frontend/src/components/payment/SubscriptionPlanCard.vue
index fbaa2744..687f119e 100644
--- a/frontend/src/components/payment/SubscriptionPlanCard.vue
+++ b/frontend/src/components/payment/SubscriptionPlanCard.vue
@@ -147,6 +147,7 @@ const MODEL_SCOPE_LABELS: Record = {
}
const modelScopeLabels = computed(() => {
+ if (platform.value !== 'antigravity') return []
const scopes = props.plan.supported_model_scopes
if (!scopes || scopes.length === 0) return []
return scopes.map(s => MODEL_SCOPE_LABELS[s] || s)
diff --git a/frontend/src/components/payment/__tests__/PaymentStatusPanel.spec.ts b/frontend/src/components/payment/__tests__/PaymentStatusPanel.spec.ts
index ea2b6377..7e392478 100644
--- a/frontend/src/components/payment/__tests__/PaymentStatusPanel.spec.ts
+++ b/frontend/src/components/payment/__tests__/PaymentStatusPanel.spec.ts
@@ -3,6 +3,7 @@ import { flushPromises, mount } from '@vue/test-utils'
const pollOrderStatus = vi.hoisted(() => vi.fn())
const cancelOrder = vi.hoisted(() => vi.fn())
+const verifyOrder = vi.hoisted(() => vi.fn())
const showError = vi.hoisted(() => vi.fn())
const toCanvas = vi.hoisted(() => vi.fn())
@@ -31,6 +32,7 @@ vi.mock('@/stores', () => ({
vi.mock('@/api/payment', () => ({
paymentAPI: {
cancelOrder,
+ verifyOrder,
},
}))
@@ -62,6 +64,7 @@ describe('PaymentStatusPanel', () => {
vi.useFakeTimers()
pollOrderStatus.mockReset()
cancelOrder.mockReset()
+ verifyOrder.mockReset()
showError.mockReset()
toCanvas.mockReset().mockResolvedValue(undefined)
})
@@ -128,4 +131,35 @@ describe('PaymentStatusPanel', () => {
openSpy.mockRestore()
})
+
+ it('actively verifies a stuck pending order and settles it when upstream confirms payment', async () => {
+ pollOrderStatus.mockResolvedValue(orderFactory('PENDING'))
+ verifyOrder.mockResolvedValue({
+ data: orderFactory('COMPLETED'),
+ })
+
+ const wrapper = mount(PaymentStatusPanel, {
+ props: {
+ orderId: 42,
+ qrCode: 'https://pay.example.com/qr/42',
+ expiresAt: '2099-01-01T12:30:00Z',
+ paymentType: 'wxpay',
+ orderType: 'balance',
+ },
+ global: {
+ stubs: {
+ Icon: true,
+ },
+ },
+ })
+
+ await flushPromises()
+ await vi.advanceTimersByTimeAsync(3000)
+ await flushPromises()
+
+ expect(pollOrderStatus).toHaveBeenCalledWith(42)
+ expect(verifyOrder).toHaveBeenCalledWith('sub2_20260420abcd1234')
+ expect(wrapper.text()).toContain('payment.result.success')
+ expect(wrapper.emitted('success')).toHaveLength(1)
+ })
})
diff --git a/frontend/src/components/payment/__tests__/SubscriptionPlanCard.spec.ts b/frontend/src/components/payment/__tests__/SubscriptionPlanCard.spec.ts
new file mode 100644
index 00000000..ebe695ee
--- /dev/null
+++ b/frontend/src/components/payment/__tests__/SubscriptionPlanCard.spec.ts
@@ -0,0 +1,64 @@
+import { mount } from "@vue/test-utils";
+import { describe, expect, it } from "vitest";
+import { createI18n } from "vue-i18n";
+import SubscriptionPlanCard from "../SubscriptionPlanCard.vue";
+
+const i18n = createI18n({
+ legacy: false,
+ locale: "en",
+ fallbackWarn: false,
+ missingWarn: false,
+ messages: {
+ en: {
+ payment: {
+ days: "days",
+ models: "Models",
+ planCard: {
+ quota: "Quota",
+ rate: "Rate",
+ unlimited: "Unlimited",
+ },
+ subscribeNow: "Subscribe now",
+ },
+ },
+ },
+});
+
+const mountPlanCard = (groupPlatform: string) =>
+ mount(SubscriptionPlanCard, {
+ props: {
+ plan: {
+ id: 1,
+ group_id: 10,
+ group_platform: groupPlatform,
+ name: "Pro",
+ price: 10,
+ amount: 1000,
+ features: [],
+ rate_multiplier: 1,
+ validity_days: 30,
+ validity_unit: "day",
+ supported_model_scopes: ["claude", "gemini_text", "gemini_image"],
+ is_active: true,
+ },
+ },
+ global: { plugins: [i18n] },
+ });
+
+describe("SubscriptionPlanCard", () => {
+ it("does not show Antigravity model scopes for OpenAI plans", () => {
+ const text = mountPlanCard("openai").text();
+
+ expect(text).not.toContain("Claude");
+ expect(text).not.toContain("Gemini");
+ expect(text).not.toContain("Imagen");
+ });
+
+ it("shows model scopes for Antigravity plans", () => {
+ const text = mountPlanCard("antigravity").text();
+
+ expect(text).toContain("Claude");
+ expect(text).toContain("Gemini");
+ expect(text).toContain("Imagen");
+ });
+});
diff --git a/frontend/src/components/payment/__tests__/paymentFlow.spec.ts b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts
index e9530ff2..7eda7a0d 100644
--- a/frontend/src/components/payment/__tests__/paymentFlow.spec.ts
+++ b/frontend/src/components/payment/__tests__/paymentFlow.spec.ts
@@ -220,6 +220,36 @@ describe('decidePaymentLaunch', () => {
expect(decision.jsapi?.appId).toBe('wx123')
expect(decision.paymentState.orderType).toBe('subscription')
})
+
+ it('forces qr_waiting for mobile alipay when forceQRCode is enabled', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ pay_url: 'https://pay.example.com/mobile/session',
+ qr_code: 'https://pay.example.com/qr/session',
+ }), {
+ visibleMethod: 'alipay',
+ orderType: 'balance',
+ isMobile: true,
+ forceQRCode: true,
+ })
+
+ expect(decision.kind).toBe('qr_waiting')
+ expect(decision.paymentState.qrCode).toBe('https://pay.example.com/qr/session')
+ })
+
+ it('does not affect non-alipay methods when forceQRCode is enabled', () => {
+ const decision = decidePaymentLaunch(createOrderResult({
+ pay_url: 'https://pay.example.com/mobile/session',
+ qr_code: 'https://pay.example.com/qr/session',
+ }), {
+ visibleMethod: 'wxpay',
+ orderType: 'balance',
+ isMobile: true,
+ forceQRCode: true,
+ })
+
+ // wxpay mobile with pay_url still redirects
+ expect(decision.kind).toBe('redirect_waiting')
+ })
})
describe('buildCreateOrderPayload', () => {
@@ -260,6 +290,34 @@ describe('buildCreateOrderPayload', () => {
payment_source: 'wechat_in_app_resume',
})
})
+
+ it('passes is_mobile: false when forceQRCode is enabled for alipay', () => {
+ expect(buildCreateOrderPayload({
+ amount: 50,
+ paymentType: 'alipay',
+ orderType: 'balance',
+ origin: 'https://app.example.com',
+ isMobile: true,
+ isWechatBrowser: false,
+ forceQRCode: true,
+ })).toMatchObject({
+ is_mobile: false,
+ })
+ })
+
+ it('still passes is_mobile: true when forceQRCode is enabled for non-alipay methods', () => {
+ expect(buildCreateOrderPayload({
+ amount: 50,
+ paymentType: 'wxpay',
+ orderType: 'balance',
+ origin: 'https://app.example.com',
+ isMobile: true,
+ isWechatBrowser: false,
+ forceQRCode: true,
+ })).toMatchObject({
+ is_mobile: true,
+ })
+ })
})
describe('readPaymentRecoverySnapshot', () => {
diff --git a/frontend/src/components/payment/paymentFlow.ts b/frontend/src/components/payment/paymentFlow.ts
index e66ef8e3..ab5acf26 100644
--- a/frontend/src/components/payment/paymentFlow.ts
+++ b/frontend/src/components/payment/paymentFlow.ts
@@ -55,6 +55,8 @@ export interface PaymentLaunchContext {
orderType: OrderType
isMobile: boolean
isWechatBrowser?: boolean
+ /** When true, Alipay payments always use QR code regardless of device type */
+ forceQRCode?: boolean
now?: number
stripePopupUrl?: string
stripeRouteUrl?: string
@@ -78,6 +80,8 @@ export interface BuildCreateOrderPayloadInput {
origin?: string
isMobile: boolean
isWechatBrowser: boolean
+ /** When true, Alipay payments always use QR code (passes is_mobile: false to backend) */
+ forceQRCode?: boolean
}
type CreateOrderFlowResult = CreateOrderResult & {
@@ -111,11 +115,16 @@ export function getVisibleMethods(methods: Record): Record<
export function buildCreateOrderPayload(input: BuildCreateOrderPayloadInput): CreateOrderRequest {
const visibleMethod = normalizeVisibleMethod(input.paymentType) || input.paymentType.trim()
const normalizedOrigin = (input.origin || '').trim().replace(/\/+$/, '')
+ // When forceQRCode is enabled for alipay, always tell the backend this is not a mobile
+ // request so it generates a QR code instead of a mobile-redirect URL.
+ const effectiveMobile = (input.forceQRCode && visibleMethod === 'alipay')
+ ? false
+ : input.isMobile
const payload: CreateOrderRequest = {
amount: input.amount,
payment_type: visibleMethod,
order_type: input.orderType,
- is_mobile: input.isMobile,
+ is_mobile: effectiveMobile,
payment_source: visibleMethod === 'wxpay' && input.isWechatBrowser
? 'wechat_in_app_resume'
: 'hosted_redirect',
@@ -190,9 +199,14 @@ export function decidePaymentLaunch(
}
const normalizedPaymentMode = baseState.paymentMode.trim().toLowerCase()
+ // When forceQRCode is on for alipay, treat the device as desktop so the mobile-redirect
+ // branch is bypassed and we fall through to qr_waiting.
+ const effectiveMobile = (context.forceQRCode && visibleMethod === 'alipay')
+ ? false
+ : context.isMobile
const prefersRedirect = normalizedPaymentMode === 'redirect'
|| normalizedPaymentMode === 'popup'
- || (context.isMobile && !!baseState.payUrl)
+ || (effectiveMobile && !!baseState.payUrl)
const prefersQr = normalizedPaymentMode === 'qrcode'
|| normalizedPaymentMode === 'native'
|| (!prefersRedirect && !!baseState.qrCode)
diff --git a/frontend/src/components/payment/providerConfig.ts b/frontend/src/components/payment/providerConfig.ts
index 9e69bf58..2b612b43 100644
--- a/frontend/src/components/payment/providerConfig.ts
+++ b/frontend/src/components/payment/providerConfig.ts
@@ -47,6 +47,11 @@ export const METHOD_ORDER = ['alipay', 'alipay_direct', 'wxpay', 'wxpay_direct',
/** Payment mode constants */
export const PAYMENT_MODE_QRCODE = 'qrcode'
export const PAYMENT_MODE_POPUP = 'popup'
+/** Alipay-only: skip FACE_TO_FACE_PAYMENT precreate and open the Alipay
+ * checkout page in a new tab instead. Backend `alipay.go` matches on this
+ * literal (case-insensitive); other values fall back to the default
+ * precreate→pagepay flow. */
+export const PAYMENT_MODE_REDIRECT = 'redirect'
export const PAYMENT_CURRENCY_OPTIONS: TypeOption[] = [
{ value: 'CNY', label: 'CNY' },
diff --git a/frontend/src/components/user/PlatformCostCell.vue b/frontend/src/components/user/PlatformCostCell.vue
new file mode 100644
index 00000000..dd5de111
--- /dev/null
+++ b/frontend/src/components/user/PlatformCostCell.vue
@@ -0,0 +1,24 @@
+
+
+
+ {{ t('admin.users.today') }}:
+ ${{ usage.today_actual_cost.toFixed(4) }}
+
+
+ {{ t('admin.users.total') }}:
+ ${{ usage.total_actual_cost.toFixed(4) }}
+
+
+ —
+
+
+
diff --git a/frontend/src/components/user/PlatformUsageBreakdown.vue b/frontend/src/components/user/PlatformUsageBreakdown.vue
new file mode 100644
index 00000000..e995bc01
--- /dev/null
+++ b/frontend/src/components/user/PlatformUsageBreakdown.vue
@@ -0,0 +1,103 @@
+
+
+
+ {{ t('admin.users.today') }}:
+ ${{ today.toFixed(4) }}
+
+
+
+ {{ t('admin.users.total') }}:
+ ${{ total.toFixed(4) }}
+
+
+
+
+ {{ t('admin.users.platformBreakdown') }}
+ {{ t('admin.users.today') }} / {{ t('admin.users.total') }}
+
+
+
+ {{ item.isOther ? t('admin.users.platformOther') : platformLabel(item.platform) }}
+
+
+ ${{ item.today_actual_cost.toFixed(4) }}
+ /
+ ${{ item.total_actual_cost.toFixed(4) }}
+
+
+
+
+
+
+
diff --git a/frontend/src/components/user/dashboard/UserDashboardStats.vue b/frontend/src/components/user/dashboard/UserDashboardStats.vue
index dfba3a51..97d2da3d 100644
--- a/frontend/src/components/user/dashboard/UserDashboardStats.vue
+++ b/frontend/src/components/user/dashboard/UserDashboardStats.vue
@@ -131,20 +131,118 @@
+
+
+
+
+
{{ t('dashboard.platformBreakdown') }}
+
+ {{ t('dashboard.platformCount', { count: sortedPlatforms.length }) }}
+
+
+
+
+
+
+ {{ item.isOther ? t('dashboard.platformOther') : platformLabel(item.platform) }}
+
+
+ ${{ formatCost(item.total_actual_cost) }}
+
+
+
+
+ {{ t('dashboard.todayCost') }}
+ ${{ formatCost(item.today_actual_cost) }}
+
+
+ {{ t('dashboard.requests') }}
+
+ {{ item.total_requests > 0 ? formatNumber(item.total_requests) : '-' }}
+
+
+
+ {{ t('dashboard.tokens') }}
+
+ {{ item.total_tokens > 0 ? formatTokens(item.total_tokens) : '-' }}
+
+
+
+
+
+
+
+
diff --git a/frontend/src/views/admin/SubscriptionsView.vue b/frontend/src/views/admin/SubscriptionsView.vue
index cf39a4a7..1eae3a0b 100644
--- a/frontend/src/views/admin/SubscriptionsView.vue
+++ b/frontend/src/views/admin/SubscriptionsView.vue
@@ -246,7 +246,7 @@
d="M12 8v4l3 3m6-3a9 9 0 11-18 0 9 9 0 0118 0z"
/>
- {{ formatResetTime(row.daily_window_start, 'daily') }}
+ {{ formatDailyUsageWindow(row) }}
@@ -758,6 +758,7 @@ import Select from '@/components/common/Select.vue'
import GroupBadge from '@/components/common/GroupBadge.vue'
import GroupOptionItem from '@/components/common/GroupOptionItem.vue'
import Icon from '@/components/icons/Icon.vue'
+import { getRemainingDurationParts, isOneTimeDailyQuota, type RemainingDurationParts } from '@/utils/subscriptionQuota'
const { t } = useI18n()
const appStore = useAppStore()
@@ -1314,8 +1315,41 @@ const getProgressClass = (used: number | null | undefined, limit: number | null)
return 'bg-green-500'
}
+const formatResetDuration = (parts: RemainingDurationParts): string => {
+ if (parts.days > 0) {
+ return t('admin.subscriptions.resetInDaysHours', { days: parts.days, hours: parts.hours })
+ }
+
+ if (parts.hours > 0) {
+ return t('admin.subscriptions.resetInHoursMinutes', { hours: parts.hours, minutes: parts.minutes })
+ }
+
+ return t('admin.subscriptions.resetInMinutes', { minutes: parts.minutes })
+}
+
+const formatQuotaEndDuration = (parts: RemainingDurationParts): string => {
+ if (parts.days > 0) {
+ return t('admin.subscriptions.quotaEndsInDaysHours', { days: parts.days, hours: parts.hours })
+ }
+
+ if (parts.hours > 0) {
+ return t('admin.subscriptions.quotaEndsInHoursMinutes', { hours: parts.hours, minutes: parts.minutes })
+ }
+
+ return t('admin.subscriptions.quotaEndsInMinutes', { minutes: parts.minutes })
+}
+
+const formatDailyUsageWindow = (subscription: UserSubscription): string => {
+ if (isOneTimeDailyQuota(subscription) && subscription.expires_at) {
+ const parts = getRemainingDurationParts(subscription.expires_at)
+ return parts ? formatQuotaEndDuration(parts) : t('admin.subscriptions.windowNotActive')
+ }
+
+ return formatResetTime(subscription.daily_window_start, 'daily')
+}
+
// Format reset time based on window start and period type
-const formatResetTime = (windowStart: string, period: 'daily' | 'weekly' | 'monthly'): string => {
+const formatResetTime = (windowStart: string | null, period: 'daily' | 'weekly' | 'monthly'): string => {
if (!windowStart) return t('admin.subscriptions.windowNotActive')
const start = new Date(windowStart)
@@ -1335,21 +1369,9 @@ const formatResetTime = (windowStart: string, period: 'daily' | 'weekly' | 'mont
break
}
- const diffMs = resetTime.getTime() - now.getTime()
- if (diffMs <= 0) return t('admin.subscriptions.windowNotActive')
+ const parts = getRemainingDurationParts(resetTime, now)
- const diffSeconds = Math.floor(diffMs / 1000)
- const days = Math.floor(diffSeconds / 86400)
- const hours = Math.floor((diffSeconds % 86400) / 3600)
- const minutes = Math.floor((diffSeconds % 3600) / 60)
-
- if (days > 0) {
- return t('admin.subscriptions.resetInDaysHours', { days, hours })
- } else if (hours > 0) {
- return t('admin.subscriptions.resetInHoursMinutes', { hours, minutes })
- } else {
- return t('admin.subscriptions.resetInMinutes', { minutes })
- }
+ return parts ? formatResetDuration(parts) : t('admin.subscriptions.windowNotActive')
}
// Handle click outside to close dropdowns
diff --git a/frontend/src/views/admin/UsersView.vue b/frontend/src/views/admin/UsersView.vue
index ea67f695..512bae67 100644
--- a/frontend/src/views/admin/UsersView.vue
+++ b/frontend/src/views/admin/UsersView.vue
@@ -199,15 +199,22 @@
{{ col.label }}
@@ -237,7 +244,7 @@
-
-
-
- {{ t('admin.users.today') }}:
-
- ${{ (usageStats[row.id]?.today_actual_cost ?? 0).toFixed(4) }}
-
-
-
-
{{ t('admin.users.total') }}:
-
- ${{ (usageStats[row.id]?.total_actual_cost ?? 0).toFixed(4) }}
-
+
+
+
+
{{ column.label }}
+
+
+ {{ usageSort.metric === 'today' ? t('admin.users.today') : t('admin.users.total') }}
+
+
+
+
+
+
+
+
+
+
+ {{ metric === 'today' ? t('admin.users.today') : t('admin.users.total') }}
+
+
+
+
+
+ {{ t('admin.users.sortCurrentPageOnly') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
(() => [
{ key: 'subscriptions', label: t('admin.users.columns.subscriptions'), sortable: false },
{ key: 'balance', label: t('admin.users.columns.balance'), sortable: true },
{ key: 'usage', label: t('admin.users.columns.usage'), sortable: false },
+ { key: 'usage_anthropic', label: t('admin.users.columns.usageAnthropic'), sortable: false },
+ { key: 'usage_openai', label: t('admin.users.columns.usageOpenAI'), sortable: false },
+ { key: 'usage_gemini', label: t('admin.users.columns.usageGemini'), sortable: false },
+ { key: 'usage_antigravity', label: t('admin.users.columns.usageAntigravity'), sortable: false },
{ key: 'concurrency', label: t('admin.users.columns.concurrency'), sortable: true },
{ key: 'status', label: t('admin.users.columns.status'), sortable: true },
{ key: 'last_active_at', label: t('admin.users.columns.lastActive'), sortable: true },
@@ -728,12 +827,25 @@ const toggleableColumns = computed(() =>
const hiddenColumns = reactive>(new Set())
// Default hidden columns (columns hidden by default on first load)
-const DEFAULT_HIDDEN_COLUMNS = ['notes', 'groups', 'subscriptions', 'usage', 'concurrency']
+const DEFAULT_HIDDEN_COLUMNS = [
+ 'notes', 'groups', 'subscriptions', 'usage', 'concurrency',
+ 'usage_anthropic', 'usage_openai', 'usage_gemini', 'usage_antigravity'
+]
const REMOVED_COLUMNS = new Set(['last_login_at'])
-const FORCED_VISIBLE_COLUMNS = new Set(['last_active_at'])
+// 强制可见列:加载时会被强制移出 hiddenColumns,并在列设置 UI 上 disabled。
+// 当前没有列需要强制可见 —— last_active_at 已改为可被用户隐藏。
+const FORCED_VISIBLE_COLUMNS = new Set()
-// localStorage key for column settings
+// localStorage keys for column settings
const HIDDEN_COLUMNS_KEY = 'user-hidden-columns'
+// 列设置 schema 版本号。每次给 DEFAULT_HIDDEN_COLUMNS 新增列时 bump 一次,
+// 并在 VERSION_NEW_HIDDEN_COLUMNS 中登记该版本新增的 key。
+// 这样老用户升级后这些新列会被自动隐藏一次,而不会影响他们对其它老列的偏好。
+const COLUMN_SETTINGS_VERSION_KEY = 'user-column-settings-version'
+const COLUMN_SETTINGS_VERSION = 2
+const VERSION_NEW_HIDDEN_COLUMNS: Record = {
+ 2: ['usage_anthropic', 'usage_openai', 'usage_gemini', 'usage_antigravity']
+}
// Load saved column settings
const loadSavedColumns = () => {
@@ -744,9 +856,27 @@ const loadSavedColumns = () => {
parsed
.filter(key => !REMOVED_COLUMNS.has(key) && !FORCED_VISIBLE_COLUMNS.has(key))
.forEach(key => hiddenColumns.add(key))
+
+ // 老用户升级:把每个未应用过的版本里新增的默认隐藏列自动追加到 hiddenColumns。
+ const storedVersion = Number(localStorage.getItem(COLUMN_SETTINGS_VERSION_KEY) ?? '1')
+ if (storedVersion < COLUMN_SETTINGS_VERSION) {
+ let mutated = false
+ for (let v = storedVersion + 1; v <= COLUMN_SETTINGS_VERSION; v++) {
+ for (const key of VERSION_NEW_HIDDEN_COLUMNS[v] ?? []) {
+ if (REMOVED_COLUMNS.has(key) || FORCED_VISIBLE_COLUMNS.has(key)) continue
+ if (!hiddenColumns.has(key)) {
+ hiddenColumns.add(key)
+ mutated = true
+ }
+ }
+ }
+ if (mutated) saveColumnsToStorage()
+ else localStorage.setItem(COLUMN_SETTINGS_VERSION_KEY, String(COLUMN_SETTINGS_VERSION))
+ }
} else {
// Use default hidden columns on first load
DEFAULT_HIDDEN_COLUMNS.forEach(key => hiddenColumns.add(key))
+ localStorage.setItem(COLUMN_SETTINGS_VERSION_KEY, String(COLUMN_SETTINGS_VERSION))
}
} catch (e) {
console.error('Failed to load saved columns:', e)
@@ -758,13 +888,18 @@ const loadSavedColumns = () => {
const saveColumnsToStorage = () => {
try {
localStorage.setItem(HIDDEN_COLUMNS_KEY, JSON.stringify([...hiddenColumns]))
+ localStorage.setItem(COLUMN_SETTINGS_VERSION_KEY, String(COLUMN_SETTINGS_VERSION))
} catch (e) {
console.error('Failed to save columns:', e)
}
}
// Toggle column visibility
+const isForcedVisibleColumn = (key: string) => FORCED_VISIBLE_COLUMNS.has(key)
const toggleColumn = (key: string) => {
+ // 强制可见列(如 last_active_at)在加载时会被恢复成可见,
+ // 这里阻止用户在当前会话隐藏它,避免"取消勾选 → 刷新又恢复"的反直觉行为。
+ if (FORCED_VISIBLE_COLUMNS.has(key)) return
const wasHidden = hiddenColumns.has(key)
if (hiddenColumns.has(key)) {
hiddenColumns.delete(key)
@@ -772,7 +907,7 @@ const toggleColumn = (key: string) => {
hiddenColumns.add(key)
}
saveColumnsToStorage()
- if (wasHidden && (key === 'usage' || key.startsWith('attr_'))) {
+ if (wasHidden && (key === 'usage' || key.startsWith('usage_') || key.startsWith('attr_'))) {
refreshCurrentPageSecondaryData()
}
if (key === 'subscriptions') {
@@ -785,7 +920,22 @@ const toggleColumn = (key: string) => {
// Check if column is visible (not in hidden set)
const isColumnVisible = (key: string) => !hiddenColumns.has(key)
-const hasVisibleUsageColumn = computed(() => !hiddenColumns.has('usage'))
+// usage 主列或任意 usage_ 子列可见时都需要批量拉取用量数据
+// 列 key → 平台名('usage' 主列汇总所有平台时为 null)
+// 显式数组取代 Object.keys():保证迭代顺序(决定列头排序按钮渲染顺序)
+// 不会因 JS 引擎差异或 USAGE_COLUMN_PLATFORMS 属性顺序调整而静默变化。
+const USAGE_COLUMN_KEYS: readonly string[] = ['usage', 'usage_anthropic', 'usage_openai', 'usage_gemini', 'usage_antigravity']
+const USAGE_COLUMN_PLATFORMS: Record = {
+ usage: null,
+ usage_anthropic: 'anthropic',
+ usage_openai: 'openai',
+ usage_gemini: 'gemini',
+ usage_antigravity: 'antigravity'
+}
+const PLATFORM_USAGE_COLUMNS = USAGE_COLUMN_KEYS.filter((k) => k !== 'usage')
+const hasVisibleUsageColumn = computed(
+ () => !hiddenColumns.has('usage') || PLATFORM_USAGE_COLUMNS.some((k) => !hiddenColumns.has(k))
+)
const hasVisibleSubscriptionsColumn = computed(() => !hiddenColumns.has('subscriptions'))
const hasVisibleGroupsColumn = computed(() => !hiddenColumns.has('groups'))
const hasVisibleAttributeColumns = computed(() =>
@@ -945,6 +1095,97 @@ const getAttributeDefinition = (attrId: number): UserAttributeDefinition | undef
return attributeDefinitions.value.find(d => d.id === attrId)
}
const usageStats = ref>({})
+
+const getPlatformUsage = (userId: number, platform: string) =>
+ usageStats.value[userId]?.by_platform?.find((p) => p.platform === platform)
+
+// 用量列前端排序:DataTable 工作在 server-side-sort 模式,所有 sortable
+// 字段都会触发后端查询,而用量列数据是异步批量拉取后再合并到当前页,
+// 因此采用独立的前端排序状态对当前页 users 做本地排序。
+// 排序状态独立于后端 sortState 持久化;缺失数据按 0 处理(desc 沉底、asc 置顶)。
+type UsageMetric = 'today' | 'total'
+type UsageSortState = { key: string; metric: UsageMetric; order: 'asc' | 'desc' } | null
+const USAGE_SORT_STORAGE_KEY = 'admin-users-usage-sort'
+
+const loadInitialUsageSort = (): UsageSortState => {
+ try {
+ const raw = localStorage.getItem(USAGE_SORT_STORAGE_KEY)
+ if (!raw) return null
+ const parsed = JSON.parse(raw) as Partial<{ key: string; metric: string; order: string }>
+ if (!parsed.key || !USAGE_COLUMN_KEYS.includes(parsed.key)) return null
+ const metric: UsageMetric = parsed.metric === 'total' ? 'total' : 'today'
+ const order: 'asc' | 'desc' = parsed.order === 'asc' ? 'asc' : 'desc'
+ return { key: parsed.key, metric, order }
+ } catch {
+ return null
+ }
+}
+const usageSort = ref(loadInitialUsageSort())
+const persistUsageSort = () => {
+ try {
+ if (usageSort.value) {
+ localStorage.setItem(USAGE_SORT_STORAGE_KEY, JSON.stringify(usageSort.value))
+ } else {
+ localStorage.removeItem(USAGE_SORT_STORAGE_KEY)
+ }
+ } catch (e) {
+ console.error('Failed to persist usage sort:', e)
+ }
+}
+
+const isUsageSortActive = (key: string, metric: UsageMetric) =>
+ !!usageSort.value && usageSort.value.key === key && usageSort.value.metric === metric
+const getUsageSortOrder = (key: string, metric: UsageMetric): 'asc' | 'desc' | null =>
+ isUsageSortActive(key, metric) ? usageSort.value!.order : null
+
+// 三态循环:desc → asc → off。选完即关闭菜单(用户大多希望"选中即应用",
+// 想再切换 order 时重新打开菜单点同一项即可)。
+const toggleUsageSort = (key: string, metric: UsageMetric) => {
+ const cur = usageSort.value
+ if (cur && cur.key === key && cur.metric === metric) {
+ usageSort.value = cur.order === 'desc' ? { key, metric, order: 'asc' } : null
+ } else {
+ usageSort.value = { key, metric, order: 'desc' }
+ }
+ persistUsageSort()
+ openUsageSortMenu.value = null
+}
+
+// 列头排序按钮点击后弹出的"今日/近30天"选择菜单,同时只允许一个列展开。
+// 点击图标本身不触发排序,仅开关菜单;首次排序由用户在菜单内选择 metric 触发(默认 desc,详见 toggleUsageSort)。
+const openUsageSortMenu = ref(null)
+const toggleUsageSortMenu = (key: string) => {
+ openUsageSortMenu.value = openUsageSortMenu.value === key ? null : key
+}
+
+const getUsageValue = (userId: number, key: string, metric: UsageMetric): number => {
+ const stats = usageStats.value[userId]
+ if (!stats) return 0
+ const platform = USAGE_COLUMN_PLATFORMS[key]
+ if (platform === null) {
+ return metric === 'today' ? stats.today_actual_cost ?? 0 : stats.total_actual_cost ?? 0
+ }
+ const p = stats.by_platform?.find((x) => x.platform === platform)
+ if (!p) return 0
+ return metric === 'today' ? p.today_actual_cost ?? 0 : p.total_actual_cost ?? 0
+}
+
+// 在 server-side 排序结果之上叠加用量列的本地排序;无 usageSort 时直接透传原数组。
+// 稳定排序:等值按原 index 保序,避免拉取新用量数据时表行抖动。
+const sortedUsers = computed(() => {
+ const s = usageSort.value
+ if (!s) return users.value
+ return [...users.value]
+ .map((row, index) => ({ row, index }))
+ .sort((a, b) => {
+ const av = getUsageValue(a.row.id, s.key, s.metric)
+ const bv = getUsageValue(b.row.id, s.key, s.metric)
+ if (av !== bv) return s.order === 'asc' ? av - bv : bv - av
+ return a.index - b.index
+ })
+ .map((x) => x.row)
+})
+
// User attribute definitions and values
const attributeDefinitions = ref([])
const userAttributeValues = ref>>({})
@@ -1095,6 +1336,10 @@ const handleClickOutside = (event: MouseEvent) => {
if (columnDropdownRef.value && !columnDropdownRef.value.contains(target)) {
showColumnDropdown.value = false
}
+ // Close usage sort dropdown when clicking outside any usage-sort-trigger
+ if (openUsageSortMenu.value !== null && !target.closest('.usage-sort-trigger')) {
+ openUsageSortMenu.value = null
+ }
// Close expanded group dropdown when clicking outside
if (expandedGroupUserId.value !== null) {
expandedGroupUserId.value = null
diff --git a/frontend/src/views/admin/__tests__/groupsSupportedModelScopes.spec.ts b/frontend/src/views/admin/__tests__/groupsSupportedModelScopes.spec.ts
new file mode 100644
index 00000000..182fafbd
--- /dev/null
+++ b/frontend/src/views/admin/__tests__/groupsSupportedModelScopes.spec.ts
@@ -0,0 +1,31 @@
+import { describe, expect, it } from "vitest";
+import { normalizeSupportedModelScopesForPlatform } from "../groupsSupportedModelScopes";
+
+describe("normalizeSupportedModelScopesForPlatform", () => {
+ it("preserves model scopes for Antigravity groups", () => {
+ expect(
+ normalizeSupportedModelScopesForPlatform("antigravity", [
+ "claude",
+ "gemini_text",
+ ]),
+ ).toEqual(["claude", "gemini_text"]);
+ });
+
+ it("returns an empty array for Antigravity groups without scopes", () => {
+ expect(normalizeSupportedModelScopesForPlatform("antigravity", undefined)).toEqual([]);
+ });
+
+ it("drops hidden model scopes for OpenAI groups", () => {
+ expect(
+ normalizeSupportedModelScopesForPlatform("openai", [
+ "claude",
+ "gemini_text",
+ "gemini_image",
+ ]),
+ ).toEqual([]);
+ });
+
+ it("drops hidden model scopes for other non-Antigravity groups", () => {
+ expect(normalizeSupportedModelScopesForPlatform("claude", ["claude"])).toEqual([]);
+ });
+});
diff --git a/frontend/src/views/admin/groupsSupportedModelScopes.ts b/frontend/src/views/admin/groupsSupportedModelScopes.ts
new file mode 100644
index 00000000..6f20e0d3
--- /dev/null
+++ b/frontend/src/views/admin/groupsSupportedModelScopes.ts
@@ -0,0 +1,7 @@
+export const normalizeSupportedModelScopesForPlatform = (
+ platform: string,
+ scopes: string[] | undefined,
+): string[] => {
+ if (platform !== "antigravity") return [];
+ return scopes ?? [];
+};
diff --git a/frontend/src/views/admin/ops/OpsDashboard.vue b/frontend/src/views/admin/ops/OpsDashboard.vue
index 50bc5249..9c512db6 100644
--- a/frontend/src/views/admin/ops/OpsDashboard.vue
+++ b/frontend/src/views/admin/ops/OpsDashboard.vue
@@ -310,8 +310,6 @@ const applyRouteQueryToState = () => {
}
}
-applyRouteQueryToState()
-
const buildQueryFromState = () => {
const next: Record = { ...route.query }
@@ -380,6 +378,8 @@ const requestDetailsPreset = ref({
const showSettingsDialog = ref(false)
const showAlertRulesCard = ref(false)
+applyRouteQueryToState()
+
// Auto refresh settings
const showAlertEvents = ref(true)
const showOpenAITokenStats = ref(false)
diff --git a/frontend/src/views/admin/ops/components/OpsErrorDistributionChart.vue b/frontend/src/views/admin/ops/components/OpsErrorDistributionChart.vue
index a52b5442..ad7ce074 100644
--- a/frontend/src/views/admin/ops/components/OpsErrorDistributionChart.vue
+++ b/frontend/src/views/admin/ops/components/OpsErrorDistributionChart.vue
@@ -30,7 +30,11 @@ const colors = computed(() => ({
text: isDarkMode.value ? '#9ca3af' : '#6b7280'
}))
-const hasData = computed(() => (props.data?.total ?? 0) > 0)
+const totalSlaErrors = computed(() =>
+ (props.data?.items ?? []).reduce((total, item) => total + Number(item.sla || 0), 0)
+)
+
+const hasData = computed(() => totalSlaErrors.value > 0)
const state = computed(() => {
if (hasData.value) return 'ready'
@@ -54,7 +58,7 @@ const categories = computed(() => {
for (const item of props.data.items || []) {
const code = Number(item.status_code || 0)
- const count = Number(item.total || 0)
+ const count = Number(item.sla || 0)
if (!Number.isFinite(code) || !Number.isFinite(count)) continue
if ([502, 503, 504].includes(code)) upstream += count
diff --git a/frontend/src/views/admin/ops/components/OpsErrorTrendChart.vue b/frontend/src/views/admin/ops/components/OpsErrorTrendChart.vue
index 088dc317..6e07926f 100644
--- a/frontend/src/views/admin/ops/components/OpsErrorTrendChart.vue
+++ b/frontend/src/views/admin/ops/components/OpsErrorTrendChart.vue
@@ -45,9 +45,7 @@ const colors = computed(() => ({
text: isDarkMode.value ? '#9ca3af' : '#6b7280'
}))
-const totalRequestErrors = computed(() =>
- sumNumbers(props.points.map((p) => (p.error_count_sla ?? 0) + (p.business_limited_count ?? 0)))
-)
+const totalRequestErrors = computed(() => sumNumbers(props.points.map((p) => p.error_count_sla ?? 0)))
const totalUpstreamErrors = computed(() =>
sumNumbers(
diff --git a/frontend/src/views/admin/ops/components/__tests__/OpsErrorScopeCharts.spec.ts b/frontend/src/views/admin/ops/components/__tests__/OpsErrorScopeCharts.spec.ts
new file mode 100644
index 00000000..b7a6590f
--- /dev/null
+++ b/frontend/src/views/admin/ops/components/__tests__/OpsErrorScopeCharts.spec.ts
@@ -0,0 +1,147 @@
+import { mount } from '@vue/test-utils'
+import { describe, expect, it, vi } from 'vitest'
+import { defineComponent } from 'vue'
+import OpsErrorDistributionChart from '../OpsErrorDistributionChart.vue'
+import OpsErrorTrendChart from '../OpsErrorTrendChart.vue'
+
+vi.mock('chart.js', () => ({
+ Chart: { register: vi.fn() },
+ ArcElement: {},
+ CategoryScale: {},
+ Filler: {},
+ Legend: {},
+ LineElement: {},
+ LinearScale: {},
+ PointElement: {},
+ Title: {},
+ Tooltip: {},
+}))
+
+vi.mock('vue-chartjs', async () => {
+ const { defineComponent } = await import('vue')
+
+ return {
+ Doughnut: defineComponent({
+ name: 'Doughnut',
+ props: {
+ data: { type: Object, required: true },
+ options: { type: Object, default: () => ({}) },
+ },
+ template: '
',
+ }),
+ Line: defineComponent({
+ name: 'LineChartStub',
+ props: {
+ data: { type: Object, required: true },
+ options: { type: Object, default: () => ({}) },
+ },
+ template: '
',
+ }),
+ }
+})
+
+vi.mock('../../utils/opsFormatters', () => ({
+ formatHistoryLabel: (date: string | undefined) => date ?? '',
+ sumNumbers: (values: Array) =>
+ values.reduce((total, value) => total + (typeof value === 'number' && Number.isFinite(value) ? value : 0), 0),
+}))
+
+vi.mock('vue-i18n', async (importOriginal) => {
+ const actual = await importOriginal()
+
+ return {
+ ...actual,
+ useI18n: () => ({
+ t: (key: string) => key,
+ }),
+ }
+})
+
+const HelpTooltipStub = defineComponent({
+ name: 'HelpTooltip',
+ props: {
+ content: { type: String, default: '' },
+ },
+ template: ' ',
+})
+
+const EmptyStateStub = defineComponent({
+ name: 'EmptyState',
+ props: {
+ title: { type: String, default: '' },
+ description: { type: String, default: '' },
+ },
+ template: '
',
+})
+
+const globalStubs = {
+ stubs: {
+ HelpTooltip: HelpTooltipStub,
+ EmptyState: EmptyStateStub,
+ },
+}
+
+describe('Ops SLA-scoped error charts', () => {
+ it('错误分布图按 SLA 错误数统计,不把业务限制错误算进请求错误分布', () => {
+ const wrapper = mount(OpsErrorDistributionChart, {
+ props: {
+ loading: false,
+ data: {
+ total: 10,
+ items: [
+ { status_code: 400, total: 7, sla: 2, business_limited: 5 },
+ { status_code: 503, total: 3, sla: 0, business_limited: 3 },
+ ],
+ },
+ },
+ global: globalStubs,
+ })
+
+ const doughnut = wrapper.findComponent({ name: 'Doughnut' })
+ expect(doughnut.exists()).toBe(true)
+ expect(doughnut.props('data')).toMatchObject({
+ labels: ['admin.ops.client'],
+ datasets: [{ data: [2] }],
+ })
+ })
+
+ it('错误分布图在只有业务限制错误时显示为空态', () => {
+ const wrapper = mount(OpsErrorDistributionChart, {
+ props: {
+ loading: false,
+ data: {
+ total: 4,
+ items: [{ status_code: 500, total: 4, sla: 0, business_limited: 4 }],
+ },
+ },
+ global: globalStubs,
+ })
+
+ expect(wrapper.findComponent({ name: 'Doughnut' }).exists()).toBe(false)
+ expect(wrapper.find('.empty-state-stub').exists()).toBe(true)
+ })
+
+ it('错误趋势图的请求错误详情按钮只按 SLA 错误启用', () => {
+ const wrapper = mount(OpsErrorTrendChart, {
+ props: {
+ loading: false,
+ timeRange: '1h',
+ points: [
+ {
+ bucket_start: '2026-05-18T00:00:00Z',
+ error_count_total: 5,
+ business_limited_count: 5,
+ error_count_sla: 0,
+ upstream_error_count_excl_429_529: 0,
+ upstream_429_count: 0,
+ upstream_529_count: 0,
+ },
+ ],
+ },
+ global: globalStubs,
+ })
+
+ const requestErrorsButton = wrapper.findAll('button')[0]
+ expect(requestErrorsButton.attributes('disabled')).toBeDefined()
+ })
+})
diff --git a/frontend/src/views/admin/ops/utils/__tests__/errorDetailResponse.spec.ts b/frontend/src/views/admin/ops/utils/__tests__/errorDetailResponse.spec.ts
index 7d294e0c..3b809d0f 100644
--- a/frontend/src/views/admin/ops/utils/__tests__/errorDetailResponse.spec.ts
+++ b/frontend/src/views/admin/ops/utils/__tests__/errorDetailResponse.spec.ts
@@ -14,8 +14,6 @@ function makeDetail(overrides: Partial): OpsErrorDetail {
status_code: 502,
platform: 'openai',
model: 'gpt-4o-mini',
- is_retryable: true,
- retry_count: 0,
resolved: false,
client_request_id: 'crid-1',
request_id: 'rid-1',
@@ -25,8 +23,6 @@ function makeDetail(overrides: Partial): OpsErrorDetail {
group_name: 'group',
error_body: '',
user_agent: '',
- request_body: '',
- request_body_truncated: false,
is_business_limited: false,
...overrides
}
diff --git a/frontend/src/views/auth/DingTalkCallbackView.vue b/frontend/src/views/auth/DingTalkCallbackView.vue
new file mode 100644
index 00000000..ffe07c4f
--- /dev/null
+++ b/frontend/src/views/auth/DingTalkCallbackView.vue
@@ -0,0 +1,852 @@
+
+
+
+
+
+ {{ t('auth.dingtalk.callbackTitle') }}
+
+
+ {{ isProcessing ? t('auth.dingtalk.callbackProcessing') : t('auth.dingtalk.callbackHint') }}
+
+
+
+
+
+
+
+
+
+ {{ t('auth.oauthFlow.profileDetailsTitle', { providerName }) }}
+
+
+ {{ t('auth.oauthFlow.profileDetailsDescription', { providerName }) }}
+
+
+
+
+
+
+
+ {{ t('auth.oauthFlow.useDisplayName') }}
+
+
+ {{ suggestedDisplayName }}
+
+
+
+
+
+
+
+
+
+ {{ t('auth.oauthFlow.useAvatar') }}
+
+
+ {{ suggestedAvatarUrl }}
+
+
+
+
+
+
+
+
+ {{ t('auth.dingtalk.invitationRequired') }}
+
+
+
+
+
+ {{ isSubmitting ? t('auth.dingtalk.completing') : t('auth.dingtalk.completeRegistration') }}
+
+
+
+
+
+ {{ t('auth.oauthFlow.reviewProfileBeforeContinue', { providerName }) }}
+
+
+ {{ isSubmitting ? t('common.processing') : t('auth.continue') }}
+
+
+
+
+
+
+
+
+ {{ t('auth.oauthFlow.chooseHowToContinue') }}
+
+
+ {{
+ pendingAccountEmail
+ ? t('auth.oauthFlow.suggestedEmail', { email: pendingAccountEmail })
+ : t('auth.oauthFlow.chooseAccountActionHint')
+ }}
+
+
+
+
+
+ {{ t('auth.oauthFlow.bindExistingAccount') }}
+
+
+ {{ t('auth.oauthFlow.createNewAccount') }}
+
+
+
+
+
+
+
+
+ {{ t('auth.oauthFlow.createAccountHint') }}
+
+
+
+
+
+
+ {{ t('auth.oauthFlow.bindLoginHint', { providerName }) }}
+
+
+
+
+
+ {{ isSubmitting ? t('common.processing') : t('auth.oauthFlow.logInAndBind') }}
+
+
+ {{ t('auth.oauthFlow.useDifferentEmail') }}
+
+
+
+
+
+
+ {{
+ t('auth.oauthFlow.totpHint', {
+ providerName,
+ account: totpUserEmailMasked || t('auth.oauthFlow.yourAccount')
+ })
+ }}
+
+
+
+
+ {{ isSubmitting ? t('common.processing') : t('auth.oauthFlow.verifyAndContinue') }}
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/views/auth/DingTalkEmailCompletionView.vue b/frontend/src/views/auth/DingTalkEmailCompletionView.vue
new file mode 100644
index 00000000..11e631cd
--- /dev/null
+++ b/frontend/src/views/auth/DingTalkEmailCompletionView.vue
@@ -0,0 +1,132 @@
+
+
+
+
+
+ {{ t('auth.dingtalk.createAccountTitle') }}
+
+
+ {{ t('auth.oauthFlow.createAccountHint') }}
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/views/auth/LoginView.vue b/frontend/src/views/auth/LoginView.vue
index 3e89b079..52f3cfef 100644
--- a/frontend/src/views/auth/LoginView.vue
+++ b/frontend/src/views/auth/LoginView.vue
@@ -152,6 +152,11 @@
:disabled="authActionDisabled"
:show-divider="false"
/>
+
(false)
const turnstileEnabled = ref(false)
const turnstileSiteKey = ref('')
const linuxdoOAuthEnabled = ref(false)
+const dingtalkOAuthEnabled = ref(false)
const wechatOAuthEnabled = ref(false)
const backendModeEnabled = ref(false)
const oidcOAuthEnabled = ref(false)
@@ -283,6 +290,7 @@ const showOAuthLogin = computed(
() =>
!backendModeEnabled.value &&
(linuxdoOAuthEnabled.value ||
+ dingtalkOAuthEnabled.value ||
wechatOAuthEnabled.value ||
oidcOAuthEnabled.value ||
githubOAuthEnabled.value ||
@@ -311,6 +319,7 @@ onMounted(async () => {
turnstileEnabled.value = settings.turnstile_enabled
turnstileSiteKey.value = settings.turnstile_site_key || ''
linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled
+ dingtalkOAuthEnabled.value = settings.dingtalk_oauth_enabled ?? false
wechatOAuthEnabled.value = isWeChatWebOAuthEnabled(settings)
backendModeEnabled.value = settings.backend_mode_enabled
oidcOAuthEnabled.value = settings.oidc_oauth_enabled
diff --git a/frontend/src/views/auth/OidcCallbackView.vue b/frontend/src/views/auth/OidcCallbackView.vue
index 51b17dbf..e17f05e9 100644
--- a/frontend/src/views/auth/OidcCallbackView.vue
+++ b/frontend/src/views/auth/OidcCallbackView.vue
@@ -338,6 +338,7 @@ type PendingOidcCompletion = PendingOAuthExchangeResponse & {
pending_email?: string
resolved_email?: string
existing_account_email?: string
+ compat_email?: string
email?: string
suggested_email?: string
provider_fallback?: string
@@ -461,6 +462,7 @@ function extractPendingAccountEmail(completion: PendingOidcCompletion): string {
return (
completion.pending_email ||
completion.existing_account_email ||
+ completion.compat_email ||
completion.resolved_email ||
completion.email ||
completion.suggested_email ||
diff --git a/frontend/src/views/user/PaymentResultView.vue b/frontend/src/views/user/PaymentResultView.vue
index 8d4b2d3e..51122d25 100644
--- a/frontend/src/views/user/PaymentResultView.vue
+++ b/frontend/src/views/user/PaymentResultView.vue
@@ -267,10 +267,15 @@ async function resolveOrderFromResumeToken(resumeToken: string): Promise {
try {
- const result = await paymentAPI.verifyOrderPublic(outTradeNo)
+ const result = await paymentAPI.verifyOrder(outTradeNo)
return result.data
} catch (_err: unknown) {
- return null
+ try {
+ const result = await paymentAPI.verifyOrderPublic(outTradeNo)
+ return result.data
+ } catch (_innerErr: unknown) {
+ return null
+ }
}
}
diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue
index 3c7a85fc..b7037b57 100644
--- a/frontend/src/views/user/PaymentView.vue
+++ b/frontend/src/views/user/PaymentView.vue
@@ -698,6 +698,7 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
origin: typeof window !== 'undefined' ? window.location.origin : '',
isMobile: isMobileDevice(),
isWechatBrowser: typeof window !== 'undefined' && /MicroMessenger/i.test(window.navigator.userAgent),
+ forceQRCode: !!(checkout.value.alipay_force_qrcode && normalizeVisibleMethod(requestType) === 'alipay'),
})
if (options.openid) {
payload.openid = options.openid
@@ -745,6 +746,7 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n
orderType,
isMobile: isMobileDevice(),
isWechatBrowser: typeof window !== 'undefined' && /MicroMessenger/i.test(window.navigator.userAgent),
+ forceQRCode: !!(checkout.value.alipay_force_qrcode && visibleMethod === 'alipay'),
stripePopupUrl: stripeRouteUrl,
stripeRouteUrl,
airwallexRouteUrl,
diff --git a/frontend/src/views/user/ProfileView.vue b/frontend/src/views/user/ProfileView.vue
index 84055119..a6481cee 100644
--- a/frontend/src/views/user/ProfileView.vue
+++ b/frontend/src/views/user/ProfileView.vue
@@ -7,6 +7,7 @@
(undefined)
const wechatOAuthMPEnabled = ref(undefined)
@@ -89,6 +91,7 @@ onMounted(async () => {
balanceLowNotifyEnabled.value = settings.balance_low_notify_enabled ?? false
systemDefaultThreshold.value = settings.balance_low_notify_threshold ?? 0
linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled ?? false
+ dingtalkOAuthEnabled.value = settings.dingtalk_oauth_enabled ?? false
wechatOAuthEnabled.value = isWeChatWebOAuthEnabled(settings)
wechatOAuthOpenEnabled.value = typeof settings.wechat_oauth_open_enabled === 'boolean'
? settings.wechat_oauth_open_enabled
diff --git a/frontend/src/views/user/SubscriptionsView.vue b/frontend/src/views/user/SubscriptionsView.vue
index 4453c12c..51616ce8 100644
--- a/frontend/src/views/user/SubscriptionsView.vue
+++ b/frontend/src/views/user/SubscriptionsView.vue
@@ -127,11 +127,7 @@
v-if="subscription.daily_window_start"
class="text-xs text-gray-500 dark:text-dark-400"
>
- {{
- t('userSubscriptions.resetIn', {
- time: formatResetTime(subscription.daily_window_start, 24)
- })
- }}
+ {{ formatDailyUsageWindow(subscription) }}
@@ -256,6 +252,7 @@ import AppLayout from '@/components/layout/AppLayout.vue'
import Icon from '@/components/icons/Icon.vue'
import { formatDateOnly } from '@/utils/format'
import { platformBorderClass, platformBadgeClass, platformButtonClass, platformLabel } from '@/utils/platformColors'
+import { getRemainingDurationParts, isOneTimeDailyQuota, type RemainingDurationParts } from '@/utils/subscriptionQuota'
function platformAccentDotClass(p: string): string {
switch (p) {
@@ -335,30 +332,38 @@ function getExpirationClass(expiresAt: string): string {
return 'text-gray-700 dark:text-gray-300'
}
+function formatDurationParts(parts: RemainingDurationParts): string {
+ if (parts.days > 0) {
+ return `${parts.days}d ${parts.hours}h`
+ }
+
+ if (parts.hours > 0) {
+ return `${parts.hours}h ${parts.minutes}m`
+ }
+
+ return `${parts.minutes}m`
+}
+
+function formatDailyUsageWindow(subscription: UserSubscription): string {
+ if (isOneTimeDailyQuota(subscription) && subscription.expires_at) {
+ const parts = getRemainingDurationParts(subscription.expires_at)
+ if (!parts) return t('userSubscriptions.windowNotActive')
+ return t('userSubscriptions.quotaEndsIn', { time: formatDurationParts(parts) })
+ }
+
+ return t('userSubscriptions.resetIn', {
+ time: formatResetTime(subscription.daily_window_start, 24)
+ })
+}
+
function formatResetTime(windowStart: string | null, windowHours: number): string {
if (!windowStart) return t('userSubscriptions.windowNotActive')
const start = new Date(windowStart)
const end = new Date(start.getTime() + windowHours * 60 * 60 * 1000)
- const now = new Date()
- const diff = end.getTime() - now.getTime()
+ const parts = getRemainingDurationParts(end)
- if (diff <= 0) return t('userSubscriptions.windowNotActive')
-
- const hours = Math.floor(diff / (1000 * 60 * 60))
- const minutes = Math.floor((diff % (1000 * 60 * 60)) / (1000 * 60))
-
- if (hours > 24) {
- const days = Math.floor(hours / 24)
- const remainingHours = hours % 24
- return `${days}d ${remainingHours}h`
- }
-
- if (hours > 0) {
- return `${hours}h ${minutes}m`
- }
-
- return `${minutes}m`
+ return parts ? formatDurationParts(parts) : t('userSubscriptions.windowNotActive')
}
onMounted(() => {
diff --git a/frontend/src/views/user/UsageView.vue b/frontend/src/views/user/UsageView.vue
index 0eb8d455..d6807aaf 100644
--- a/frontend/src/views/user/UsageView.vue
+++ b/frontend/src/views/user/UsageView.vue
@@ -191,14 +191,14 @@
- {{ getBillingModeLabel(row.billing_mode, t) }}
+ :class="getBillingModeBadgeClass(getDisplayBillingMode(row))">
+ {{ getBillingModeLabel(getDisplayBillingMode(row), t) }}
-
-
+
+
- {{ row.image_count }}{{ $t('usage.imageUnit') }}
- ({{ row.image_size || '2K' }})
+ {{ row.image_count }}{{ t('usage.imageUnit') }}
+ ({{ formatImageBillingSize(row, t) }})
@@ -447,22 +447,31 @@
{{ t('admin.usage.outputCost') }}
${{ tooltipData.output_cost.toFixed(6) }}
-
-
-
- {{ t('usage.inputTokenPrice') }}
- {{ formatTokenPricePerMillion(tooltipData.input_cost, tooltipData.input_tokens) }} {{ t('usage.perMillionTokens') }}
-
-
- {{ t('usage.outputTokenPrice') }}
- {{ formatTokenPricePerMillion(tooltipData.output_cost, tooltipData.output_tokens) }} {{ t('usage.perMillionTokens') }}
-
-
-
-
+
+
{{ t('usage.imageCount') }}
- {{ tooltipData.image_count }}{{ t('usage.imageUnit') }} ({{ tooltipData.image_size || '2K' }})
+ {{ tooltipData.image_count }}{{ t('usage.imageUnit') }}
+
+
+ {{ t('usage.imageBillingSize') }}
+ {{ formatImageBillingSize(tooltipData, t) }}
+
+
+ {{ t('usage.imageSizeSource') }}
+ {{ formatImageSizeSource(tooltipData, t) }}
+
+
+ {{ t('usage.imageInputSize') }}
+ {{ formatImageInputSize(tooltipData, t) }}
+
+
+ {{ t('usage.imageOutputSize') }}
+ {{ formatImageOutputSize(tooltipData, t) }}
+
+
+ {{ t('usage.imageSizeBreakdown') }}
+ {{ formatImageSizeBreakdown(tooltipData) }}
{{ t('usage.imageUnitPrice') }}
@@ -473,6 +482,17 @@
${{ tooltipData.total_cost?.toFixed(6) || '0.000000' }}
+
+
+
+ {{ t('usage.inputTokenPrice') }}
+ {{ formatTokenPricePerMillion(tooltipData.input_cost, tooltipData.input_tokens) }} {{ t('usage.perMillionTokens') }}
+
+
+ {{ t('usage.outputTokenPrice') }}
+ {{ formatTokenPricePerMillion(tooltipData.output_cost, tooltipData.output_tokens) }} {{ t('usage.perMillionTokens') }}
+
+
{{ t('usage.unitPrice') }}
${{ tooltipData?.total_cost?.toFixed(6) || '0.000000' }}
@@ -538,7 +558,19 @@ import { formatCacheTokens, formatMultiplier } from '@/utils/formatters'
import { formatTokenPricePerMillion } from '@/utils/usagePricing'
import { getUsageServiceTierLabel } from '@/utils/usageServiceTier'
import { resolveUsageRequestType } from '@/utils/usageRequestType'
-import { getBillingModeLabel, getBillingModeBadgeClass } from '@/utils/billingMode'
+import {
+ BILLING_MODE_IMAGE,
+ BILLING_MODE_TOKEN,
+ getBillingModeBadgeClass,
+ getBillingModeLabel,
+} from '@/utils/billingMode'
+import {
+ formatImageBillingSize,
+ formatImageInputSize,
+ formatImageOutputSize,
+ formatImageSizeBreakdown,
+ formatImageSizeSource,
+} from '@/utils/imageUsage'
const { t } = useI18n()
const appStore = useAppStore()
@@ -646,6 +678,17 @@ const imageUnitPrice = (row: UsageLog | null): number => {
return Number.isFinite(price) ? price : 0
}
+const isImageUsage = (row: Pick
| null | undefined): boolean => {
+ return (row?.image_count ?? 0) > 0
+}
+
+const getDisplayBillingMode = (row: Pick | null | undefined): string | null | undefined => {
+ if (isImageUsage(row)) {
+ return BILLING_MODE_IMAGE
+ }
+ return row?.billing_mode
+}
+
const formatUserAgent = (ua: string): string => {
return ua
}
@@ -877,7 +920,7 @@ const exportToCSV = async () => {
formatReasoningEffort(log.reasoning_effort),
log.inbound_endpoint || '',
getRequestTypeExportText(log),
- getBillingModeLabel(log.billing_mode, t),
+ getBillingModeLabel(getDisplayBillingMode(log), t),
log.input_tokens,
log.output_tokens,
log.cache_read_tokens,
diff --git a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
index 49015ef4..09f2f0b6 100644
--- a/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
+++ b/frontend/src/views/user/__tests__/PaymentResultView.spec.ts
@@ -7,6 +7,7 @@ const routeState = vi.hoisted(() => ({
const routerPush = vi.hoisted(() => vi.fn())
const pollOrderStatus = vi.hoisted(() => vi.fn())
+const verifyOrder = vi.hoisted(() => vi.fn())
const verifyOrderPublic = vi.hoisted(() => vi.fn())
const resolveOrderPublicByResumeToken = vi.hoisted(() => vi.fn())
@@ -37,6 +38,7 @@ vi.mock('@/stores/payment', () => ({
vi.mock('@/api/payment', () => ({
paymentAPI: {
+ verifyOrder,
verifyOrderPublic,
resolveOrderPublicByResumeToken,
},
@@ -86,6 +88,7 @@ describe('PaymentResultView', () => {
routeState.query = {}
routerPush.mockReset()
pollOrderStatus.mockReset()
+ verifyOrder.mockReset()
verifyOrderPublic.mockReset()
resolveOrderPublicByResumeToken.mockReset()
window.localStorage.clear()
@@ -329,6 +332,7 @@ describe('PaymentResultView', () => {
out_trade_no: 'legacy-123',
trade_status: 'TRADE_SUCCESS',
}
+ verifyOrder.mockRejectedValue(new Error('auth required'))
verifyOrderPublic.mockResolvedValue({
data: orderFactory('PAID'),
})
@@ -343,11 +347,36 @@ describe('PaymentResultView', () => {
await flushPromises()
+ expect(verifyOrder).toHaveBeenCalledWith('legacy-123')
expect(verifyOrderPublic).toHaveBeenCalledWith('legacy-123')
expect(pollOrderStatus).not.toHaveBeenCalled()
expect(wrapper.text()).toContain('payment.result.success')
})
+ it('prefers authenticated order verification before falling back to public lookup', async () => {
+ routeState.query = {
+ out_trade_no: 'auth-verify-123',
+ trade_status: 'TRADE_SUCCESS',
+ }
+ verifyOrder.mockResolvedValue({
+ data: orderFactory('COMPLETED'),
+ })
+
+ const wrapper = mount(PaymentResultView, {
+ global: {
+ stubs: {
+ OrderStatusBadge: true,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ expect(verifyOrder).toHaveBeenCalledWith('auth-verify-123')
+ expect(verifyOrderPublic).not.toHaveBeenCalled()
+ expect(wrapper.text()).toContain('payment.result.success')
+ })
+
it('does not use public out_trade_no verification for bare order numbers without legacy return markers', async () => {
routeState.query = {
out_trade_no: 'legacy-bare',
diff --git a/frontend/src/views/user/__tests__/UsageView.spec.ts b/frontend/src/views/user/__tests__/UsageView.spec.ts
index 5f87619c..011b96c8 100644
--- a/frontend/src/views/user/__tests__/UsageView.spec.ts
+++ b/frontend/src/views/user/__tests__/UsageView.spec.ts
@@ -41,6 +41,26 @@ const messages: Record = {
'usage.duration': 'Duration',
'usage.time': 'Time',
'usage.userAgent': 'User Agent',
+ 'usage.imageUnit': ' images',
+ 'usage.imageCount': 'Image count',
+ 'usage.imageBillingSize': 'Billing size',
+ 'usage.imageInputSize': 'Input size',
+ 'usage.imageOutputSize': 'Output size',
+ 'usage.imageSizeSource': 'Size source',
+ 'usage.imageSizeBreakdown': 'Size breakdown',
+ 'usage.imageSizeSourceOutput': 'Upstream output',
+ 'usage.imageSizeSourceInput': 'Request input',
+ 'usage.imageSizeSourceDefault': 'Default billing tier',
+ 'usage.imageSizeSourceLegacy': 'Legacy record',
+ 'usage.imageSizeSourceMissing': 'Not recorded',
+ 'usage.imageSizeNotRecorded': 'not recorded',
+ 'usage.imageSizeLegacyUnstandardized': 'legacy unstandardized',
+ 'usage.imageSizeUnknown': 'unknown',
+ 'usage.imageUnitPrice': 'Per-image price',
+ 'usage.imageTotalPrice': 'Image total price',
+ 'admin.usage.billingModeToken': 'Token',
+ 'admin.usage.billingModePerRequest': 'Per request',
+ 'admin.usage.billingModeImage': 'Image',
}
vi.mock('@/api', () => ({
@@ -69,7 +89,19 @@ vi.mock('vue-i18n', async () => {
const AppLayoutStub = { template: '
' }
const TablePageLayoutStub = {
- template: '
',
+ template: '
',
+}
+const DataTableStub = {
+ props: ['data'],
+ template: `
+
+ `,
}
describe('user UsageView tooltip', () => {
@@ -146,6 +178,7 @@ describe('user UsageView tooltip', () => {
EmptyState: true,
Select: true,
DateRangePicker: true,
+ DataTable: DataTableStub,
Icon: true,
Teleport: true,
},
@@ -244,6 +277,7 @@ describe('user UsageView tooltip', () => {
EmptyState: true,
Select: true,
DateRangePicker: true,
+ DataTable: DataTableStub,
Icon: true,
Teleport: true,
},
@@ -274,4 +308,233 @@ describe('user UsageView tooltip', () => {
window.URL.revokeObjectURL = originalRevokeObjectURL
clickSpy.mockRestore()
})
+
+ it('exports historical image rows with image billing mode derived from image_count', async () => {
+ const exportedLogs = [
+ {
+ request_id: 'req-user-export-legacy-image',
+ actual_cost: 0.2,
+ total_cost: 0.2,
+ rate_multiplier: 1,
+ service_tier: null,
+ input_cost: 0,
+ output_cost: 0,
+ cache_creation_cost: 0,
+ cache_read_cost: 0,
+ input_tokens: 0,
+ output_tokens: 0,
+ cache_creation_tokens: 0,
+ cache_read_tokens: 0,
+ cache_creation_5m_tokens: 0,
+ cache_creation_1h_tokens: 0,
+ image_count: 1,
+ image_size: null,
+ billing_mode: null,
+ first_token_ms: null,
+ duration_ms: 345,
+ created_at: '2026-03-08T00:00:00Z',
+ model: 'gpt-image-2',
+ reasoning_effort: null,
+ api_key: { name: 'demo-key' },
+ },
+ ]
+
+ query.mockResolvedValue({
+ items: exportedLogs,
+ total: 1,
+ pages: 1,
+ })
+ getStatsByDateRange.mockResolvedValue({
+ total_requests: 1,
+ total_tokens: 0,
+ total_cost: 0.2,
+ avg_duration_ms: 1,
+ })
+ list.mockResolvedValue({ items: [] })
+
+ let exportedBlob: Blob | null = null
+ const originalCreateObjectURL = window.URL.createObjectURL
+ const originalRevokeObjectURL = window.URL.revokeObjectURL
+ window.URL.createObjectURL = vi.fn((blob: Blob | MediaSource) => {
+ exportedBlob = blob as Blob
+ return 'blob:usage-export'
+ }) as typeof window.URL.createObjectURL
+ window.URL.revokeObjectURL = vi.fn(() => {}) as typeof window.URL.revokeObjectURL
+ const clickSpy = vi.spyOn(HTMLAnchorElement.prototype, 'click').mockImplementation(() => {})
+
+ const wrapper = mount(UsageView, {
+ global: {
+ stubs: {
+ AppLayout: AppLayoutStub,
+ TablePageLayout: TablePageLayoutStub,
+ Pagination: true,
+ EmptyState: true,
+ Select: true,
+ DateRangePicker: true,
+ DataTable: DataTableStub,
+ Icon: true,
+ Teleport: true,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ const setupState = (wrapper.vm as any).$?.setupState
+ await setupState.exportToCSV()
+
+ expect(exportedBlob).not.toBeNull()
+ const csv = await new Promise((resolve, reject) => {
+ const reader = new FileReader()
+ reader.onload = () => resolve(String(reader.result))
+ reader.onerror = () => reject(reader.error)
+ reader.readAsText(exportedBlob as Blob)
+ })
+ expect(csv).toContain('Billing Mode')
+ expect(csv).toContain('Image')
+ expect(csv).not.toContain(',Token,0,0,0,0,')
+
+ window.URL.createObjectURL = originalCreateObjectURL
+ window.URL.revokeObjectURL = originalRevokeObjectURL
+ clickSpy.mockRestore()
+ })
+
+ it('does not display a 2K fallback for historical image rows with missing size', async () => {
+ query.mockResolvedValue({
+ items: [
+ {
+ request_id: 'req-user-legacy-missing-image',
+ actual_cost: 0.2,
+ total_cost: 0.2,
+ rate_multiplier: 1,
+ service_tier: null,
+ input_cost: 0,
+ output_cost: 0,
+ cache_creation_cost: 0,
+ cache_read_cost: 0,
+ input_tokens: 0,
+ output_tokens: 0,
+ cache_creation_tokens: 0,
+ cache_read_tokens: 0,
+ cache_creation_5m_tokens: 0,
+ cache_creation_1h_tokens: 0,
+ image_count: 1,
+ image_size: null,
+ image_input_size: null,
+ image_output_size: null,
+ image_size_source: null,
+ image_size_breakdown: null,
+ billing_mode: null,
+ first_token_ms: null,
+ duration_ms: 1,
+ created_at: '2026-03-08T00:00:00Z',
+ model: 'gpt-image-2',
+ },
+ ],
+ total: 1,
+ pages: 1,
+ })
+ getStatsByDateRange.mockResolvedValue({
+ total_requests: 1,
+ total_tokens: 0,
+ total_cost: 0.2,
+ avg_duration_ms: 1,
+ })
+ list.mockResolvedValue({ items: [] })
+
+ const wrapper = mount(UsageView, {
+ global: {
+ stubs: {
+ AppLayout: AppLayoutStub,
+ TablePageLayout: TablePageLayoutStub,
+ Pagination: true,
+ EmptyState: true,
+ Select: true,
+ DateRangePicker: true,
+ DataTable: DataTableStub,
+ Icon: true,
+ Teleport: true,
+ },
+ },
+ })
+
+ await flushPromises()
+ await nextTick()
+
+ const text = wrapper.text()
+ expect(text).toContain('Image')
+ expect(text).toContain('not recorded')
+ expect(text).not.toContain('(2K)')
+ })
+
+ it('shows image billing metadata in the user cost tooltip', async () => {
+ query.mockResolvedValue({
+ items: [],
+ total: 0,
+ pages: 0,
+ })
+ getStatsByDateRange.mockResolvedValue({
+ total_requests: 0,
+ total_tokens: 0,
+ total_cost: 0,
+ avg_duration_ms: 0,
+ })
+ list.mockResolvedValue({ items: [] })
+
+ const wrapper = mount(UsageView, {
+ global: {
+ stubs: {
+ AppLayout: AppLayoutStub,
+ TablePageLayout: TablePageLayoutStub,
+ Pagination: true,
+ EmptyState: true,
+ Select: true,
+ DateRangePicker: true,
+ DataTable: DataTableStub,
+ Icon: true,
+ Teleport: true,
+ },
+ },
+ })
+
+ await flushPromises()
+
+ const setupState = (wrapper.vm as any).$?.setupState
+ setupState.tooltipData = {
+ request_id: 'req-user-output-image',
+ actual_cost: 0.8,
+ total_cost: 0.8,
+ rate_multiplier: 1,
+ service_tier: null,
+ input_cost: 0,
+ output_cost: 0,
+ cache_creation_cost: 0,
+ cache_read_cost: 0,
+ input_tokens: 0,
+ output_tokens: 0,
+ cache_creation_tokens: 0,
+ cache_read_tokens: 0,
+ billing_mode: null,
+ image_count: 2,
+ image_size: '4K',
+ image_input_size: '1024x1024',
+ image_output_size: '3840x2160',
+ image_size_source: 'output',
+ image_size_breakdown: { '4K': 2 },
+ }
+ setupState.tooltipVisible = true
+ await nextTick()
+
+ const text = wrapper.text()
+ expect(text).toContain('Image count')
+ expect(text).toContain('Billing size')
+ expect(text).toContain('4K')
+ expect(text).toContain('Size source')
+ expect(text).toContain('Upstream output')
+ expect(text).toContain('Input size')
+ expect(text).toContain('1024x1024')
+ expect(text).toContain('Output size')
+ expect(text).toContain('3840x2160')
+ expect(text).toContain('4K x 2')
+ })
})