chore: merge upstream v0.1.127 — keep omniroute customizations

Upstream highlights:
- v0.1.127 release (150 commits): channel-monitor 协议管理、OpenAI
  Responses 路由配置、模型定价 LiteLLM 默认、payment 强制扫码、
  钉钉 OAuth、用户用量按平台拆分、Ops 错误分类 SLA 调整、
  Anthropic passthrough keepalive、Gemini chat completions 路由 ...
- 91da8159 feat(risk-control): 内容审计新增关键词拦截
- 3d22dd34 feat: gemini-3.5-flash 模型支持

Conflicts resolved:
- Dockerfile: keep pnpm pin to 9.15.9 (upstream pinned generic v9 floating).
- wire_gen.go: combine upstream NewSettingHandler(+userAttributeService)
  with local NewOpsHandler(opsService, requestEventBus, opsLogBroadcaster).
  Verified by re-running wire generate.
- scheduler_cache.go: keep both upstream openai_responses_{mode,supported}
  keys and local model_rate_limits key in filterSchedulerExtra().
- gateway_service.go: keep local context-compression block; drop now-dead
  setOpsUpstreamRequestBody call (upstream removed ops retry replay).
- docker-compose.yml: keep local windsurf-ls service profile and named
  volumes; keep local healthcheck start_period values.

Test mock signatures bumped to match current constructors:
- gateway_models_test.go: add nil for RPMTokenBucketService.
- account_handler_available_models_test.go: add nil for windsurfChatService.
This commit is contained in:
win 2026-05-20 12:39:08 +08:00
commit 158785bfc9
323 changed files with 20637 additions and 3653 deletions

View File

@ -22,8 +22,8 @@ RUN sed -i 's#https://dl-cdn.alpinelinux.org/alpine#https://mirrors.aliyun.com/a
WORKDIR /app/frontend
# Install pnpm. Keep this on v9 so Docker builds do not fail on pnpm v10's
# interactive build-script approval flow.
# Install pnpm. Pin to a specific v9 patch to dodge pnpm v10's interactive
# build-script approval flow and keep Docker builds reproducible.
RUN corepack enable && corepack prepare pnpm@9.15.9 --activate
# Install dependencies first (better caching)

View File

@ -62,13 +62,18 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
</tr>
<tr>
<td width="180"><a href="https://poixe.com/i/sub2api"><img src="assets/partners/logos/poixe.png" alt="PoixeAi" width="150"></a></td>
<td>Thanks to Poixe Ai for sponsoring this project! Poixe AI provides reliable LLM API services. You can leverage the platform's API endpoints to seamlessly build AI-powered products. Additionally, you can become a vendor by providing AI API resources to the platform and earn revenue. Register through the exclusive <a href="https://poixe.com/i/sub2api">sub2api</a> referral link and receive a bonus of $5 USD on your first top-up.</td>
<td width="180"><a href="https://ctok.ai"><img src="assets/partners/logos/ctok.png" alt="CTok" width="150"></a></td>
<td>Thanks to CTok.ai for sponsoring this project! CTok.ai is dedicated to building a one-stop AI programming tool service platform. We offer professional Claude Code packages and technical community services, with support for Google Gemini and OpenAI Codex. Through carefully designed plans and a professional tech community, we provide developers with reliable service guarantees and continuous technical support, making AI-assisted programming a true productivity tool. Click <a href="https://ctok.ai">here</a> to register!</td>
</tr>
<tr>
<td width="180"><a href="https://ctok.ai"><img src="assets/partners/logos/ctok.png" alt="CTok" width="150"></a></td>
<td>Thanks to CTok.ai for sponsoring this project! CTok.ai is dedicated to building a one-stop AI programming tool service platform. We offer professional Claude Code packages and technical community services, with support for Google Gemini and OpenAI Codex. Through carefully designed plans and a professional tech community, we provide developers with reliable service guarantees and continuous technical support, making AI-assisted programming a true productivity tool. Click <a href="https://ctok.ai">here</a> to register!</td>
<td width="180"><a href="https://aigocode.com/invite/SUB2API"><img src="assets/partners/logos/aigocode.png" alt="AIGoCode" width="150"></a></td>
<td>Thanks to AIGoCode for sponsoring this project! AIGoCode is an all-in-one platform that integrates Claude Code, Codex, and the latest Gemini models, providing you with stable, efficient, and highly cost-effective AI coding services. The platform offers flexible subscription plans, zero risk of account suspension, direct access with no VPN required, and lightning-fast responses. AIGoCode has prepared a special benefit for sub2api users: if you register via <a href="https://aigocode.com/invite/SUB2API">this link</a>, you'll receive an extra 10% bonus credit on your first top-up!</td>
</tr>
<tr>
<td width="180"><a href="https://apikey.fun/register?aff=SUB2API"><img src="assets/partners/logos/apikey-fun.png" alt="APIKEY.FUN" width="150"></a></td>
<td>Thanks to APIKEY.FUN for sponsoring this project! <a href="https://apikey.fun/register?aff=SUB2API">APIKEY.FUN</a> is one of the core contributors to the sub2api open-source project, dedicated to providing open, stable, and cost-effective AI API access. The platform supports API relay services for Claude, OpenAI, Gemini, and other popular models, with pricing starting from as low as 7% of the original rate. Register via the exclusive link: <a href="https://apikey.fun/register?aff=SUB2API">APIKEY</a> to enjoy a permanent 5% discount on all recharges.</td>
</tr>
<tr>
@ -86,11 +91,6 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
<td>Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for sub2api users: register via <a href="https://www.aicodemirror.com/register?invitecode=KMVZQM">this link</a> to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off!</td>
</tr>
<tr>
<td width="180"><a href="https://aigocode.com/invite/SUB2API"><img src="assets/partners/logos/aigocode.png" alt="AIGoCode" width="150"></a></td>
<td>Thanks to AIGoCode for sponsoring this project! AIGoCode is an all-in-one platform that integrates Claude Code, Codex, and the latest Gemini models, providing you with stable, efficient, and highly cost-effective AI coding services. The platform offers flexible subscription plans, zero risk of account suspension, direct access with no VPN required, and lightning-fast responses. AIGoCode has prepared a special benefit for sub2api users: if you register via <a href="https://aigocode.com/invite/SUB2API">this link</a>, you'll receive an extra 10% bonus credit on your first top-up!</td>
</tr>
<tr>
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="assets/partners/logos/bmoplus.jpg" alt="bmoplus" width="150"></a></td>
<td>Huge thanks to BmoPlus for sponsoring this project! BmoPlus is a highly reliable AI account provider built strictly for heavy AI users and developers. They offer rock-solid, ready-to-use accounts and official top-up services for ChatGPT Plus / ChatGPT Pro (Full Warranty) / Claude Pro / Super Grok / Gemini Pro. By registering and ordering through <a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus - Premium AI Accounts & Top-ups</a>, users can unlock the mind-blowing rate of 10% of the official GPT subscription price (90% OFF)</td>
@ -108,6 +108,12 @@ Enterprise-grade high concurrency is also supported, with a dedicated management
Register now via <a href="https://pateway.ai/?ch=1tsfr51">this link</a> to receive $3 in trial credits. User top-ups start as low as 60% off, and referring friends earns both parties rewards — referral bonuses up to $150.</td>
</tr>
<tr>
<td width="180"><a href="https://api.pptoken.org/register?promo=SUB2API"><img src="assets/partners/logos/pptoken.png" alt="pptoken" width="150"></a></td>
<td>Thanks to PPToken.org for sponsoring this project! <a href="https://api.pptoken.org/register?promo=SUB2API">PPToken.org</a> specializes in GPT model API relay services, supporting Codex, Claude Code, OpenAI-compatible clients, and Gemini CLI integration. Top-ups are 1:1 (¥1 = $1 credit); GPT models start at 0.16x rate multiplier, with overall cost at roughly 2.2% of official pricing and first-token latency around 1 second — ideal for developers seeking low-cost, high-speed access to GPT model capabilities. Technical support: 24/7 real human responses (no bots), @tech in the group chat and get a reply within 10 minutes. Sponsor benefit: the first 200 users who register via the <a href="https://api.pptoken.org/register?promo=SUB2API">exclusive registration link</a> and enter promo code `SUB2API` can claim free Codex / Claude Code trial credits — no minimum spend, no card required.
</td>
</tr>
</table>
## Ecosystem

View File

@ -61,13 +61,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
</tr>
<tr>
<td width="180"><a href="https://poixe.com/i/sub2api"><img src="assets/partners/logos/poixe.png" alt="PoixeAI" width="150"></a></td>
<td>感谢 Poixe AI 赞助了本项目Poixe AI 提供可靠的 AI 模型接口服务,您可以使用平台提供的 LLM API 接口轻松构建 AI 产品,同时也可以成为供应商,为平台提供大模型资源以赚取收益。通过 <a href="https://poixe.com/i/sub2api">此链接</a> 专属链接注册,充值额外赠送 $5 美金</td>
<td width="180"><a href="https://ctok.ai"><img src="assets/partners/logos/ctok.png" alt="CTok" width="150"></a></td>
<td>感谢 CTok.ai 赞助了本项目CTok.ai 致力于打造一站式 AI 编程工具服务平台。我们提供 Claude Code 专业套餐及技术社群服务,同时支持 Google Gemini 和 OpenAI Codex。通过精心设计的套餐方案和专业的技术社群为开发者提供稳定的服务保障和持续的技术支持让 AI 辅助编程真正成为开发者的生产力工具。点击<a href="https://ctok.ai">这里</a>注册!</td>
</tr>
<tr>
<td width="180"><a href="https://ctok.ai"><img src="assets/partners/logos/ctok.png" alt="CTok" width="150"></a></td>
<td>感谢 CTok.ai 赞助了本项目CTok.ai 致力于打造一站式 AI 编程工具服务平台。我们提供 Claude Code 专业套餐及技术社群服务,同时支持 Google Gemini 和 OpenAI Codex。通过精心设计的套餐方案和专业的技术社群为开发者提供稳定的服务保障和持续的技术支持让 AI 辅助编程真正成为开发者的生产力工具。点击<a href="https://ctok.ai">这里</a>注册!</td>
<td width="180"><a href="https://aigocode.com/invite/SUB2API"><img src="assets/partners/logos/aigocode.png" alt="AIGoCode" width="150"></a></td>
<td>感谢 AIGoCode 赞助了本项目AIGoCode 是一站式集成 Claude Code、Codex 以及最新 Gemini 模型的综合平台,为您提供稳定、高效、高性价比的 AI 编程服务。平台提供灵活的订阅方案,零封号风险,免 VPN 直连响应极速。AIGoCode 为 sub2api 用户准备了专属福利:通过<a href="https://aigocode.com/invite/SUB2API">此链接</a>注册,首次充值可额外获得 10% 赠送额度!</td>
</tr>
<tr>
<td width="180"><a href="https://apikey.fun/register?aff=SUB2API"><img src="assets/partners/logos/apikey-fun.png" alt="APIKEY.FUN" width="150"></a></td>
<td>感谢 APIKEY.FUN 赞助了本项目!<a href="https://apikey.fun/register?aff=SUB2API">APIKEY.FUN</a> 是 sub2api 开源项目的核心贡献者之一,致力于提供开放、稳定、高性价比的 AI API 接入服务。平台支持 Claude、OpenAI、Gemini 等热门模型的 API 中转服务,价格低至官方原价的 7%。通过专属链接 <a href="https://apikey.fun/register?aff=SUB2API">APIKEY</a> 注册,可享受所有充值永久 95 折优惠。</td>
</tr>
<tr>
@ -85,11 +90,6 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
<td>感谢 AICodeMirror 赞助了本项目AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定性中转服务企业级并发、快速开票、7×24 小时专属技术支持。Claude Code / Codex / Gemini 官方通道低至原价 38% / 2% / 9%充值更享额外折扣AICodeMirror 为 sub2api 用户提供专属福利:通过<a href="https://www.aicodemirror.com/register?invitecode=KMVZQM">此链接</a>注册,首次充值立享 8 折优惠,企业客户最高可享 75 折!</td>
</tr>
<tr>
<td width="180"><a href="https://aigocode.com/invite/SUB2API"><img src="assets/partners/logos/aigocode.png" alt="AIGoCode" width="150"></a></td>
<td>感谢 AIGoCode 赞助了本项目AIGoCode 是一站式集成 Claude Code、Codex 以及最新 Gemini 模型的综合平台,为您提供稳定、高效、高性价比的 AI 编程服务。平台提供灵活的订阅方案,零封号风险,免 VPN 直连响应极速。AIGoCode 为 sub2api 用户准备了专属福利:通过<a href="https://aigocode.com/invite/SUB2API">此链接</a>注册,首次充值可额外获得 10% 赠送额度!</td>
</tr>
<tr>
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="assets/partners/logos/bmoplus.jpg" alt="bmoplus" width="150"></a></td>
<td>感谢 BmoPlus 赞助了本项目BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AI成品号专卖/代充</a>注册下单的用户可享GPT 官网订阅一折 的震撼价格!</td>
@ -107,6 +107,11 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
现在通过 <a href="https://pateway.ai/?ch=1tsfr51">此链接</a> 注册即送 $3 试用额度,用户充值低至 6 折,邀请好友双向赠送,邀请奖励可达 $150。</td>
</tr>
<tr>
<td width="180"><a href="https://api.pptoken.org/register?promo=SUB2API"><img src="assets/partners/logos/pptoken.png" alt="pptoken" width="150"></a></td>
<td>感谢 PPToken.org 赞助本项目! <a href="https://api.pptoken.org/register?promo=SUB2API">PPToken.org</a> 主打 GPT 系列模型 API 中转服务,支持 Codex、Claude Code、OpenAI 兼容客户端及 Gemini CLI 等工具接入。充值 1:11 元=1 美元额度GPT 模型最低 0.16 倍倍率,综合成本约为官方价格的 0.22 折,最快首字 Token 约 1 秒,适合开发者低成本、高响应速度接入 GPT 模型能力。技术支持: 7×24 小时真人响应(不是机器人),群内@技术10 分钟内有回复 。赞助商福利:前 200 名用户通过 <a href="https://api.pptoken.org/register?promo=SUB2API">[专属注册链接]</a> 注册,输入优惠码 `SUB2API`,可领取 Codex / Claude Code 免费试用额度,无门槛、不绑卡。
</td>
</tr>
</table>
## 生态项目

View File

@ -61,13 +61,18 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
</tr>
<tr>
<td width="180"><a href="https://poixe.com/i/sub2api"><img src="assets/partners/logos/poixe.png" alt="PoixeAi" width="150"></a></td>
<td>Poixe AI のご支援に感謝しますPoixe AI は信頼性の高い LLM API サービスを提供しています。プラットフォームの API エンドポイントを活用して、AI 搭載プロダクトをシームレスに構築できます。また、ベンダーとして AI API リソースをプラットフォームに提供し、収益を得ることも可能です。専用の <a href="https://poixe.com/i/sub2api">sub2api</a> 紹介リンクから登録すると、初回チャージ時に $5 USD のボーナスがもらえます。</td>
<td width="180"><a href="https://ctok.ai"><img src="assets/partners/logos/ctok.png" alt="CTok" width="150"></a></td>
<td>CTok.ai のご支援に感謝しますCTok.ai はワンストップ AI プログラミングツールサービスプラットフォームの構築に取り組んでいます。Claude Code の専用プランと技術コミュニティサービスを提供し、Google Gemini や OpenAI Codex もサポートしています。丁寧に設計されたプランと専門的な技術コミュニティを通じて、開発者に安定したサービス保証と継続的な技術サポートを提供し、AI アシスト プログラミングを真の生産性向上ツールにします。<a href="https://ctok.ai">こちら</a>から登録!</td>
</tr>
<tr>
<td width="180"><a href="https://ctok.ai"><img src="assets/partners/logos/ctok.png" alt="CTok" width="150"></a></td>
<td>CTok.ai のご支援に感謝しますCTok.ai はワンストップ AI プログラミングツールサービスプラットフォームの構築に取り組んでいます。Claude Code の専用プランと技術コミュニティサービスを提供し、Google Gemini や OpenAI Codex もサポートしています。丁寧に設計されたプランと専門的な技術コミュニティを通じて、開発者に安定したサービス保証と継続的な技術サポートを提供し、AI アシスト プログラミングを真の生産性向上ツールにします。<a href="https://ctok.ai">こちら</a>から登録!</td>
<td width="180"><a href="https://aigocode.com/invite/SUB2API"><img src="assets/partners/logos/aigocode.png" alt="AIGoCode" width="150"></a></td>
<td>AIGoCode のご支援に感謝しますAIGoCode は Claude Code、Codex、最新の Gemini モデルを統合したオールインワンプラットフォームで、安定的かつ効率的でコストパフォーマンスに優れた AI コーディングサービスを提供します。柔軟なサブスクリプションプラン、アカウント停止リスクゼロ、VPN 不要の直接アクセス、超高速レスポンスが特長です。AIGoCode は sub2api ユーザー向けに特別特典を用意しています:<a href="https://aigocode.com/invite/SUB2API">こちらのリンク</a>から登録すると、初回チャージ時に 10% のボーナスクレジットを追加プレゼント!</td>
</tr>
<tr>
<td width="180"><a href="https://apikey.fun/register?aff=SUB2API"><img src="assets/partners/logos/apikey-fun.png" alt="APIKEY.FUN" width="150"></a></td>
<td>APIKEY.FUN のご支援に感謝します!<a href="https://apikey.fun/register?aff=SUB2API">APIKEY.FUN</a> は sub2api オープンソースプロジェクトのコアコントリビューターの一つであり、オープンで安定した、コストパフォーマンスに優れた AI API アクセスサービスの提供に取り組んでいます。プラットフォームは Claude、OpenAI、Gemini など人気モデルの API 中継サービスをサポートし、価格は公式料金のわずか 7% から。専用リンク <a href="https://apikey.fun/register?aff=SUB2API">APIKEY</a> から登録すると、すべてのチャージで永久 5% 割引をご利用いただけます。</td>
</tr>
<tr>
@ -85,11 +90,6 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
<td>AICodeMirror のご支援に感謝しますAICodeMirror は Claude Code / Codex / Gemini CLI の公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時実行、迅速な請求書発行、24時間年中無休の専属テクニカルサポートを備えています。Claude Code / Codex / Gemini の公式チャネルを定価の 38% / 2% / 9% で利用可能、チャージ時にはさらに追加割引AICodeMirror は sub2api ユーザー向けに特別特典を提供中:<a href="https://www.aicodemirror.com/register?invitecode=KMVZQM">こちらのリンク</a>から登録すると、初回チャージが 20% オフ、法人のお客様は最大 25% オフ!</td>
</tr>
<tr>
<td width="180"><a href="https://aigocode.com/invite/SUB2API"><img src="assets/partners/logos/aigocode.png" alt="AIGoCode" width="150"></a></td>
<td>AIGoCode のご支援に感謝しますAIGoCode は Claude Code、Codex、最新の Gemini モデルを統合したオールインワンプラットフォームで、安定的かつ効率的でコストパフォーマンスに優れた AI コーディングサービスを提供します。柔軟なサブスクリプションプラン、アカウント停止リスクゼロ、VPN 不要の直接アクセス、超高速レスポンスが特長です。AIGoCode は sub2api ユーザー向けに特別特典を用意しています:<a href="https://aigocode.com/invite/SUB2API">こちらのリンク</a>から登録すると、初回チャージ時に 10% のボーナスクレジットを追加プレゼント!</td>
</tr>
<tr>
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="assets/partners/logos/bmoplus.jpg" alt="bmoplus" width="150"></a></td>
<td>本プロジェクトにご支援いただいた BmoPlus に感謝いたしますBmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらの<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AIアカウント専門店/代行チャージ</a>経由でご登録・ご注文いただいたユーザー様は、GPTを 公式サイト価格の約1割90% OFF という驚異的な価格でご利用いただけます!</td>
@ -107,6 +107,12 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
<a href="https://pateway.ai/?ch=1tsfr51">こちらのリンク</a>から登録すると、$3 のトライアルクレジットがもらえます。チャージは最大40%オフ、友達紹介で双方にボーナス付与 — 紹介報酬は最大 $150。</td>
</tr>
<tr>
<td width="180"><a href="https://api.pptoken.org/register?promo=SUB2API"><img src="assets/partners/logos/pptoken.png" alt="pptoken" width="150"></a></td>
<td>PPToken.org のご支援に感謝します!<a href="https://api.pptoken.org/register?promo=SUB2API">PPToken.org</a> は GPT シリーズモデルの API 中継サービスを専門としており、Codex、Claude Code、OpenAI 互換クライアント、Gemini CLI などのツール接続をサポートしています。チャージは 1:11元1ドル分のクレジット、GPT モデルは最低 0.16 倍のレート倍率で、総合コストは公式価格の約 2.2% 、最速ファーストトークンは約1秒 — 開発者が低コスト・高速レスポンスで GPT モデル機能にアクセスするのに最適です。テクニカルサポート24時間365日リアルな人間が対応ボットではありません、グループ内で @技術 すれば 10 分以内に返信。スポンサー特典:先着 200 名のユーザーが<a href="https://api.pptoken.org/register?promo=SUB2API">専用登録リンク</a>から登録し、プロモコード `SUB2API` を入力すると、Codex / Claude Code の無料トライアルクレジットを獲得できます — 最低利用額なし、カード登録不要。
</td>
</tr>
</table>
## エコシステム

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

View File

@ -1 +1 @@
0.1.126
0.1.127

View File

@ -81,7 +81,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
}
totpCache := repository.NewTotpCache(redisClient)
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService, userAttributeService)
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
@ -202,7 +205,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
registry := payment.ProvideRegistry()
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService)
requestEventBus := service.NewRequestEventBus()
opsLogBroadcaster := service.ProvideOpsLogBroadcaster()
opsHandler := admin.NewOpsHandler(opsService, requestEventBus, opsLogBroadcaster)
@ -217,9 +220,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
usageCleanupRepository := repository.NewUsageCleanupRepository(client, db)
usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService, usageCleanupService)
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
errorPassthroughRepository := repository.NewErrorPassthroughRepository(client)
errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient)
@ -231,7 +231,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
channelHandler := admin.NewChannelHandler(channelService, billingService)
channelHandler := admin.NewChannelHandler(channelService, billingService, pricingService)
channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db)
channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)

View File

@ -27,6 +27,8 @@ type ChannelMonitor struct {
Name string `json:"name,omitempty"`
// Provider holds the value of the "provider" field.
Provider channelmonitor.Provider `json:"provider,omitempty"`
// OpenAI request protocol: chat_completions or responses; non-OpenAI uses chat_completions
APIMode string `json:"api_mode,omitempty"`
// Provider base origin, e.g. https://api.openai.com
Endpoint string `json:"endpoint,omitempty"`
// AES-256-GCM encrypted API key
@ -112,7 +114,7 @@ func (*ChannelMonitor) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullBool)
case channelmonitor.FieldID, channelmonitor.FieldIntervalSeconds, channelmonitor.FieldCreatedBy, channelmonitor.FieldTemplateID:
values[i] = new(sql.NullInt64)
case channelmonitor.FieldName, channelmonitor.FieldProvider, channelmonitor.FieldEndpoint, channelmonitor.FieldAPIKeyEncrypted, channelmonitor.FieldPrimaryModel, channelmonitor.FieldGroupName, channelmonitor.FieldBodyOverrideMode:
case channelmonitor.FieldName, channelmonitor.FieldProvider, channelmonitor.FieldAPIMode, channelmonitor.FieldEndpoint, channelmonitor.FieldAPIKeyEncrypted, channelmonitor.FieldPrimaryModel, channelmonitor.FieldGroupName, channelmonitor.FieldBodyOverrideMode:
values[i] = new(sql.NullString)
case channelmonitor.FieldCreatedAt, channelmonitor.FieldUpdatedAt, channelmonitor.FieldLastCheckedAt:
values[i] = new(sql.NullTime)
@ -161,6 +163,12 @@ func (_m *ChannelMonitor) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Provider = channelmonitor.Provider(value.String)
}
case channelmonitor.FieldAPIMode:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field api_mode", values[i])
} else if value.Valid {
_m.APIMode = value.String
}
case channelmonitor.FieldEndpoint:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field endpoint", values[i])
@ -310,6 +318,9 @@ func (_m *ChannelMonitor) String() string {
builder.WriteString("provider=")
builder.WriteString(fmt.Sprintf("%v", _m.Provider))
builder.WriteString(", ")
builder.WriteString("api_mode=")
builder.WriteString(_m.APIMode)
builder.WriteString(", ")
builder.WriteString("endpoint=")
builder.WriteString(_m.Endpoint)
builder.WriteString(", ")

View File

@ -23,6 +23,8 @@ const (
FieldName = "name"
// FieldProvider holds the string denoting the provider field in the database.
FieldProvider = "provider"
// FieldAPIMode holds the string denoting the api_mode field in the database.
FieldAPIMode = "api_mode"
// FieldEndpoint holds the string denoting the endpoint field in the database.
FieldEndpoint = "endpoint"
// FieldAPIKeyEncrypted holds the string denoting the api_key_encrypted field in the database.
@ -87,6 +89,7 @@ var Columns = []string{
FieldUpdatedAt,
FieldName,
FieldProvider,
FieldAPIMode,
FieldEndpoint,
FieldAPIKeyEncrypted,
FieldPrimaryModel,
@ -121,6 +124,10 @@ var (
UpdateDefaultUpdatedAt func() time.Time
// NameValidator is a validator for the "name" field. It is called by the builders before save.
NameValidator func(string) error
// DefaultAPIMode holds the default value on creation for the "api_mode" field.
DefaultAPIMode string
// APIModeValidator is a validator for the "api_mode" field. It is called by the builders before save.
APIModeValidator func(string) error
// EndpointValidator is a validator for the "endpoint" field. It is called by the builders before save.
EndpointValidator func(string) error
// APIKeyEncryptedValidator is a validator for the "api_key_encrypted" field. It is called by the builders before save.
@ -197,6 +204,11 @@ func ByProvider(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldProvider, opts...).ToFunc()
}
// ByAPIMode orders the results by the api_mode field.
func ByAPIMode(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAPIMode, opts...).ToFunc()
}
// ByEndpoint orders the results by the endpoint field.
func ByEndpoint(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldEndpoint, opts...).ToFunc()

View File

@ -70,6 +70,11 @@ func Name(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldEQ(FieldName, v))
}
// APIMode applies equality check predicate on the "api_mode" field. It's identical to APIModeEQ.
func APIMode(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldEQ(FieldAPIMode, v))
}
// Endpoint applies equality check predicate on the "endpoint" field. It's identical to EndpointEQ.
func Endpoint(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldEQ(FieldEndpoint, v))
@ -285,6 +290,71 @@ func ProviderNotIn(vs ...Provider) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldNotIn(FieldProvider, vs...))
}
// APIModeEQ applies the EQ predicate on the "api_mode" field.
func APIModeEQ(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldEQ(FieldAPIMode, v))
}
// APIModeNEQ applies the NEQ predicate on the "api_mode" field.
func APIModeNEQ(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldNEQ(FieldAPIMode, v))
}
// APIModeIn applies the In predicate on the "api_mode" field.
func APIModeIn(vs ...string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldIn(FieldAPIMode, vs...))
}
// APIModeNotIn applies the NotIn predicate on the "api_mode" field.
func APIModeNotIn(vs ...string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldNotIn(FieldAPIMode, vs...))
}
// APIModeGT applies the GT predicate on the "api_mode" field.
func APIModeGT(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldGT(FieldAPIMode, v))
}
// APIModeGTE applies the GTE predicate on the "api_mode" field.
func APIModeGTE(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldGTE(FieldAPIMode, v))
}
// APIModeLT applies the LT predicate on the "api_mode" field.
func APIModeLT(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldLT(FieldAPIMode, v))
}
// APIModeLTE applies the LTE predicate on the "api_mode" field.
func APIModeLTE(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldLTE(FieldAPIMode, v))
}
// APIModeContains applies the Contains predicate on the "api_mode" field.
func APIModeContains(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldContains(FieldAPIMode, v))
}
// APIModeHasPrefix applies the HasPrefix predicate on the "api_mode" field.
func APIModeHasPrefix(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldAPIMode, v))
}
// APIModeHasSuffix applies the HasSuffix predicate on the "api_mode" field.
func APIModeHasSuffix(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldAPIMode, v))
}
// APIModeEqualFold applies the EqualFold predicate on the "api_mode" field.
func APIModeEqualFold(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldEqualFold(FieldAPIMode, v))
}
// APIModeContainsFold applies the ContainsFold predicate on the "api_mode" field.
func APIModeContainsFold(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldContainsFold(FieldAPIMode, v))
}
// EndpointEQ applies the EQ predicate on the "endpoint" field.
func EndpointEQ(v string) predicate.ChannelMonitor {
return predicate.ChannelMonitor(sql.FieldEQ(FieldEndpoint, v))

View File

@ -65,6 +65,20 @@ func (_c *ChannelMonitorCreate) SetProvider(v channelmonitor.Provider) *ChannelM
return _c
}
// SetAPIMode sets the "api_mode" field.
func (_c *ChannelMonitorCreate) SetAPIMode(v string) *ChannelMonitorCreate {
_c.mutation.SetAPIMode(v)
return _c
}
// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
func (_c *ChannelMonitorCreate) SetNillableAPIMode(v *string) *ChannelMonitorCreate {
if v != nil {
_c.SetAPIMode(*v)
}
return _c
}
// SetEndpoint sets the "endpoint" field.
func (_c *ChannelMonitorCreate) SetEndpoint(v string) *ChannelMonitorCreate {
_c.mutation.SetEndpoint(v)
@ -275,6 +289,10 @@ func (_c *ChannelMonitorCreate) defaults() {
v := channelmonitor.DefaultUpdatedAt()
_c.mutation.SetUpdatedAt(v)
}
if _, ok := _c.mutation.APIMode(); !ok {
v := channelmonitor.DefaultAPIMode
_c.mutation.SetAPIMode(v)
}
if _, ok := _c.mutation.ExtraModels(); !ok {
v := channelmonitor.DefaultExtraModels
_c.mutation.SetExtraModels(v)
@ -321,6 +339,14 @@ func (_c *ChannelMonitorCreate) check() error {
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)}
}
}
if _, ok := _c.mutation.APIMode(); !ok {
return &ValidationError{Name: "api_mode", err: errors.New(`ent: missing required field "ChannelMonitor.api_mode"`)}
}
if v, ok := _c.mutation.APIMode(); ok {
if err := channelmonitor.APIModeValidator(v); err != nil {
return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_mode": %w`, err)}
}
}
if _, ok := _c.mutation.Endpoint(); !ok {
return &ValidationError{Name: "endpoint", err: errors.New(`ent: missing required field "ChannelMonitor.endpoint"`)}
}
@ -421,6 +447,10 @@ func (_c *ChannelMonitorCreate) createSpec() (*ChannelMonitor, *sqlgraph.CreateS
_spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value)
_node.Provider = value
}
if value, ok := _c.mutation.APIMode(); ok {
_spec.SetField(channelmonitor.FieldAPIMode, field.TypeString, value)
_node.APIMode = value
}
if value, ok := _c.mutation.Endpoint(); ok {
_spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value)
_node.Endpoint = value
@ -606,6 +636,18 @@ func (u *ChannelMonitorUpsert) UpdateProvider() *ChannelMonitorUpsert {
return u
}
// SetAPIMode sets the "api_mode" field.
func (u *ChannelMonitorUpsert) SetAPIMode(v string) *ChannelMonitorUpsert {
u.Set(channelmonitor.FieldAPIMode, v)
return u
}
// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
func (u *ChannelMonitorUpsert) UpdateAPIMode() *ChannelMonitorUpsert {
u.SetExcluded(channelmonitor.FieldAPIMode)
return u
}
// SetEndpoint sets the "endpoint" field.
func (u *ChannelMonitorUpsert) SetEndpoint(v string) *ChannelMonitorUpsert {
u.Set(channelmonitor.FieldEndpoint, v)
@ -885,6 +927,20 @@ func (u *ChannelMonitorUpsertOne) UpdateProvider() *ChannelMonitorUpsertOne {
})
}
// SetAPIMode sets the "api_mode" field.
func (u *ChannelMonitorUpsertOne) SetAPIMode(v string) *ChannelMonitorUpsertOne {
return u.Update(func(s *ChannelMonitorUpsert) {
s.SetAPIMode(v)
})
}
// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
func (u *ChannelMonitorUpsertOne) UpdateAPIMode() *ChannelMonitorUpsertOne {
return u.Update(func(s *ChannelMonitorUpsert) {
s.UpdateAPIMode()
})
}
// SetEndpoint sets the "endpoint" field.
func (u *ChannelMonitorUpsertOne) SetEndpoint(v string) *ChannelMonitorUpsertOne {
return u.Update(func(s *ChannelMonitorUpsert) {
@ -1362,6 +1418,20 @@ func (u *ChannelMonitorUpsertBulk) UpdateProvider() *ChannelMonitorUpsertBulk {
})
}
// SetAPIMode sets the "api_mode" field.
func (u *ChannelMonitorUpsertBulk) SetAPIMode(v string) *ChannelMonitorUpsertBulk {
return u.Update(func(s *ChannelMonitorUpsert) {
s.SetAPIMode(v)
})
}
// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
func (u *ChannelMonitorUpsertBulk) UpdateAPIMode() *ChannelMonitorUpsertBulk {
return u.Update(func(s *ChannelMonitorUpsert) {
s.UpdateAPIMode()
})
}
// SetEndpoint sets the "endpoint" field.
func (u *ChannelMonitorUpsertBulk) SetEndpoint(v string) *ChannelMonitorUpsertBulk {
return u.Update(func(s *ChannelMonitorUpsert) {

View File

@ -66,6 +66,20 @@ func (_u *ChannelMonitorUpdate) SetNillableProvider(v *channelmonitor.Provider)
return _u
}
// SetAPIMode sets the "api_mode" field.
func (_u *ChannelMonitorUpdate) SetAPIMode(v string) *ChannelMonitorUpdate {
_u.mutation.SetAPIMode(v)
return _u
}
// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
func (_u *ChannelMonitorUpdate) SetNillableAPIMode(v *string) *ChannelMonitorUpdate {
if v != nil {
_u.SetAPIMode(*v)
}
return _u
}
// SetEndpoint sets the "endpoint" field.
func (_u *ChannelMonitorUpdate) SetEndpoint(v string) *ChannelMonitorUpdate {
_u.mutation.SetEndpoint(v)
@ -418,6 +432,11 @@ func (_u *ChannelMonitorUpdate) check() error {
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)}
}
}
if v, ok := _u.mutation.APIMode(); ok {
if err := channelmonitor.APIModeValidator(v); err != nil {
return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_mode": %w`, err)}
}
}
if v, ok := _u.mutation.Endpoint(); ok {
if err := channelmonitor.EndpointValidator(v); err != nil {
return &ValidationError{Name: "endpoint", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.endpoint": %w`, err)}
@ -472,6 +491,9 @@ func (_u *ChannelMonitorUpdate) sqlSave(ctx context.Context) (_node int, err err
if value, ok := _u.mutation.Provider(); ok {
_spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value)
}
if value, ok := _u.mutation.APIMode(); ok {
_spec.SetField(channelmonitor.FieldAPIMode, field.TypeString, value)
}
if value, ok := _u.mutation.Endpoint(); ok {
_spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value)
}
@ -701,6 +723,20 @@ func (_u *ChannelMonitorUpdateOne) SetNillableProvider(v *channelmonitor.Provide
return _u
}
// SetAPIMode sets the "api_mode" field.
func (_u *ChannelMonitorUpdateOne) SetAPIMode(v string) *ChannelMonitorUpdateOne {
_u.mutation.SetAPIMode(v)
return _u
}
// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
func (_u *ChannelMonitorUpdateOne) SetNillableAPIMode(v *string) *ChannelMonitorUpdateOne {
if v != nil {
_u.SetAPIMode(*v)
}
return _u
}
// SetEndpoint sets the "endpoint" field.
func (_u *ChannelMonitorUpdateOne) SetEndpoint(v string) *ChannelMonitorUpdateOne {
_u.mutation.SetEndpoint(v)
@ -1066,6 +1102,11 @@ func (_u *ChannelMonitorUpdateOne) check() error {
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)}
}
}
if v, ok := _u.mutation.APIMode(); ok {
if err := channelmonitor.APIModeValidator(v); err != nil {
return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_mode": %w`, err)}
}
}
if v, ok := _u.mutation.Endpoint(); ok {
if err := channelmonitor.EndpointValidator(v); err != nil {
return &ValidationError{Name: "endpoint", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.endpoint": %w`, err)}
@ -1137,6 +1178,9 @@ func (_u *ChannelMonitorUpdateOne) sqlSave(ctx context.Context) (_node *ChannelM
if value, ok := _u.mutation.Provider(); ok {
_spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value)
}
if value, ok := _u.mutation.APIMode(); ok {
_spec.SetField(channelmonitor.FieldAPIMode, field.TypeString, value)
}
if value, ok := _u.mutation.Endpoint(); ok {
_spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value)
}

View File

@ -26,6 +26,8 @@ type ChannelMonitorRequestTemplate struct {
Name string `json:"name,omitempty"`
// Provider holds the value of the "provider" field.
Provider channelmonitorrequesttemplate.Provider `json:"provider,omitempty"`
// OpenAI request protocol: chat_completions or responses; non-OpenAI uses chat_completions
APIMode string `json:"api_mode,omitempty"`
// Description holds the value of the "description" field.
Description string `json:"description,omitempty"`
// ExtraHeaders holds the value of the "extra_headers" field.
@ -67,7 +69,7 @@ func (*ChannelMonitorRequestTemplate) scanValues(columns []string) ([]any, error
values[i] = new([]byte)
case channelmonitorrequesttemplate.FieldID:
values[i] = new(sql.NullInt64)
case channelmonitorrequesttemplate.FieldName, channelmonitorrequesttemplate.FieldProvider, channelmonitorrequesttemplate.FieldDescription, channelmonitorrequesttemplate.FieldBodyOverrideMode:
case channelmonitorrequesttemplate.FieldName, channelmonitorrequesttemplate.FieldProvider, channelmonitorrequesttemplate.FieldAPIMode, channelmonitorrequesttemplate.FieldDescription, channelmonitorrequesttemplate.FieldBodyOverrideMode:
values[i] = new(sql.NullString)
case channelmonitorrequesttemplate.FieldCreatedAt, channelmonitorrequesttemplate.FieldUpdatedAt:
values[i] = new(sql.NullTime)
@ -116,6 +118,12 @@ func (_m *ChannelMonitorRequestTemplate) assignValues(columns []string, values [
} else if value.Valid {
_m.Provider = channelmonitorrequesttemplate.Provider(value.String)
}
case channelmonitorrequesttemplate.FieldAPIMode:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field api_mode", values[i])
} else if value.Valid {
_m.APIMode = value.String
}
case channelmonitorrequesttemplate.FieldDescription:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field description", values[i])
@ -197,6 +205,9 @@ func (_m *ChannelMonitorRequestTemplate) String() string {
builder.WriteString("provider=")
builder.WriteString(fmt.Sprintf("%v", _m.Provider))
builder.WriteString(", ")
builder.WriteString("api_mode=")
builder.WriteString(_m.APIMode)
builder.WriteString(", ")
builder.WriteString("description=")
builder.WriteString(_m.Description)
builder.WriteString(", ")

View File

@ -23,6 +23,8 @@ const (
FieldName = "name"
// FieldProvider holds the string denoting the provider field in the database.
FieldProvider = "provider"
// FieldAPIMode holds the string denoting the api_mode field in the database.
FieldAPIMode = "api_mode"
// FieldDescription holds the string denoting the description field in the database.
FieldDescription = "description"
// FieldExtraHeaders holds the string denoting the extra_headers field in the database.
@ -51,6 +53,7 @@ var Columns = []string{
FieldUpdatedAt,
FieldName,
FieldProvider,
FieldAPIMode,
FieldDescription,
FieldExtraHeaders,
FieldBodyOverrideMode,
@ -76,6 +79,10 @@ var (
UpdateDefaultUpdatedAt func() time.Time
// NameValidator is a validator for the "name" field. It is called by the builders before save.
NameValidator func(string) error
// DefaultAPIMode holds the default value on creation for the "api_mode" field.
DefaultAPIMode string
// APIModeValidator is a validator for the "api_mode" field. It is called by the builders before save.
APIModeValidator func(string) error
// DefaultDescription holds the default value on creation for the "description" field.
DefaultDescription string
// DescriptionValidator is a validator for the "description" field. It is called by the builders before save.
@ -140,6 +147,11 @@ func ByProvider(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldProvider, opts...).ToFunc()
}
// ByAPIMode orders the results by the api_mode field.
func ByAPIMode(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldAPIMode, opts...).ToFunc()
}
// ByDescription orders the results by the description field.
func ByDescription(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldDescription, opts...).ToFunc()

View File

@ -70,6 +70,11 @@ func Name(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldName, v))
}
// APIMode applies equality check predicate on the "api_mode" field. It's identical to APIModeEQ.
func APIMode(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldAPIMode, v))
}
// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
func Description(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldDescription, v))
@ -245,6 +250,71 @@ func ProviderNotIn(vs ...Provider) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldProvider, vs...))
}
// APIModeEQ applies the EQ predicate on the "api_mode" field.
func APIModeEQ(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldAPIMode, v))
}
// APIModeNEQ applies the NEQ predicate on the "api_mode" field.
func APIModeNEQ(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldAPIMode, v))
}
// APIModeIn applies the In predicate on the "api_mode" field.
func APIModeIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldAPIMode, vs...))
}
// APIModeNotIn applies the NotIn predicate on the "api_mode" field.
func APIModeNotIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldAPIMode, vs...))
}
// APIModeGT applies the GT predicate on the "api_mode" field.
func APIModeGT(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldAPIMode, v))
}
// APIModeGTE applies the GTE predicate on the "api_mode" field.
func APIModeGTE(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldAPIMode, v))
}
// APIModeLT applies the LT predicate on the "api_mode" field.
func APIModeLT(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldAPIMode, v))
}
// APIModeLTE applies the LTE predicate on the "api_mode" field.
func APIModeLTE(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldAPIMode, v))
}
// APIModeContains applies the Contains predicate on the "api_mode" field.
func APIModeContains(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldContains(FieldAPIMode, v))
}
// APIModeHasPrefix applies the HasPrefix predicate on the "api_mode" field.
func APIModeHasPrefix(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldHasPrefix(FieldAPIMode, v))
}
// APIModeHasSuffix applies the HasSuffix predicate on the "api_mode" field.
func APIModeHasSuffix(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldHasSuffix(FieldAPIMode, v))
}
// APIModeEqualFold applies the EqualFold predicate on the "api_mode" field.
func APIModeEqualFold(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldEqualFold(FieldAPIMode, v))
}
// APIModeContainsFold applies the ContainsFold predicate on the "api_mode" field.
func APIModeContainsFold(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldContainsFold(FieldAPIMode, v))
}
// DescriptionEQ applies the EQ predicate on the "description" field.
func DescriptionEQ(v string) predicate.ChannelMonitorRequestTemplate {
return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldDescription, v))

View File

@ -63,6 +63,20 @@ func (_c *ChannelMonitorRequestTemplateCreate) SetProvider(v channelmonitorreque
return _c
}
// SetAPIMode sets the "api_mode" field.
func (_c *ChannelMonitorRequestTemplateCreate) SetAPIMode(v string) *ChannelMonitorRequestTemplateCreate {
_c.mutation.SetAPIMode(v)
return _c
}
// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
func (_c *ChannelMonitorRequestTemplateCreate) SetNillableAPIMode(v *string) *ChannelMonitorRequestTemplateCreate {
if v != nil {
_c.SetAPIMode(*v)
}
return _c
}
// SetDescription sets the "description" field.
func (_c *ChannelMonitorRequestTemplateCreate) SetDescription(v string) *ChannelMonitorRequestTemplateCreate {
_c.mutation.SetDescription(v)
@ -161,6 +175,10 @@ func (_c *ChannelMonitorRequestTemplateCreate) defaults() {
v := channelmonitorrequesttemplate.DefaultUpdatedAt()
_c.mutation.SetUpdatedAt(v)
}
if _, ok := _c.mutation.APIMode(); !ok {
v := channelmonitorrequesttemplate.DefaultAPIMode
_c.mutation.SetAPIMode(v)
}
if _, ok := _c.mutation.Description(); !ok {
v := channelmonitorrequesttemplate.DefaultDescription
_c.mutation.SetDescription(v)
@ -199,6 +217,14 @@ func (_c *ChannelMonitorRequestTemplateCreate) check() error {
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
}
}
if _, ok := _c.mutation.APIMode(); !ok {
return &ValidationError{Name: "api_mode", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.api_mode"`)}
}
if v, ok := _c.mutation.APIMode(); ok {
if err := channelmonitorrequesttemplate.APIModeValidator(v); err != nil {
return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.api_mode": %w`, err)}
}
}
if v, ok := _c.mutation.Description(); ok {
if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
@ -258,6 +284,10 @@ func (_c *ChannelMonitorRequestTemplateCreate) createSpec() (*ChannelMonitorRequ
_spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
_node.Provider = value
}
if value, ok := _c.mutation.APIMode(); ok {
_spec.SetField(channelmonitorrequesttemplate.FieldAPIMode, field.TypeString, value)
_node.APIMode = value
}
if value, ok := _c.mutation.Description(); ok {
_spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
_node.Description = value
@ -378,6 +408,18 @@ func (u *ChannelMonitorRequestTemplateUpsert) UpdateProvider() *ChannelMonitorRe
return u
}
// SetAPIMode sets the "api_mode" field.
func (u *ChannelMonitorRequestTemplateUpsert) SetAPIMode(v string) *ChannelMonitorRequestTemplateUpsert {
u.Set(channelmonitorrequesttemplate.FieldAPIMode, v)
return u
}
// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
func (u *ChannelMonitorRequestTemplateUpsert) UpdateAPIMode() *ChannelMonitorRequestTemplateUpsert {
u.SetExcluded(channelmonitorrequesttemplate.FieldAPIMode)
return u
}
// SetDescription sets the "description" field.
func (u *ChannelMonitorRequestTemplateUpsert) SetDescription(v string) *ChannelMonitorRequestTemplateUpsert {
u.Set(channelmonitorrequesttemplate.FieldDescription, v)
@ -525,6 +567,20 @@ func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateProvider() *ChannelMonito
})
}
// SetAPIMode sets the "api_mode" field.
func (u *ChannelMonitorRequestTemplateUpsertOne) SetAPIMode(v string) *ChannelMonitorRequestTemplateUpsertOne {
return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
s.SetAPIMode(v)
})
}
// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateAPIMode() *ChannelMonitorRequestTemplateUpsertOne {
return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
s.UpdateAPIMode()
})
}
// SetDescription sets the "description" field.
func (u *ChannelMonitorRequestTemplateUpsertOne) SetDescription(v string) *ChannelMonitorRequestTemplateUpsertOne {
return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
@ -848,6 +904,20 @@ func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateProvider() *ChannelMonit
})
}
// SetAPIMode sets the "api_mode" field.
func (u *ChannelMonitorRequestTemplateUpsertBulk) SetAPIMode(v string) *ChannelMonitorRequestTemplateUpsertBulk {
return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
s.SetAPIMode(v)
})
}
// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateAPIMode() *ChannelMonitorRequestTemplateUpsertBulk {
return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
s.UpdateAPIMode()
})
}
// SetDescription sets the "description" field.
func (u *ChannelMonitorRequestTemplateUpsertBulk) SetDescription(v string) *ChannelMonitorRequestTemplateUpsertBulk {
return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {

View File

@ -63,6 +63,20 @@ func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableProvider(v *channelmon
return _u
}
// SetAPIMode sets the "api_mode" field.
func (_u *ChannelMonitorRequestTemplateUpdate) SetAPIMode(v string) *ChannelMonitorRequestTemplateUpdate {
_u.mutation.SetAPIMode(v)
return _u
}
// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableAPIMode(v *string) *ChannelMonitorRequestTemplateUpdate {
if v != nil {
_u.SetAPIMode(*v)
}
return _u
}
// SetDescription sets the "description" field.
func (_u *ChannelMonitorRequestTemplateUpdate) SetDescription(v string) *ChannelMonitorRequestTemplateUpdate {
_u.mutation.SetDescription(v)
@ -204,6 +218,11 @@ func (_u *ChannelMonitorRequestTemplateUpdate) check() error {
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
}
}
if v, ok := _u.mutation.APIMode(); ok {
if err := channelmonitorrequesttemplate.APIModeValidator(v); err != nil {
return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.api_mode": %w`, err)}
}
}
if v, ok := _u.mutation.Description(); ok {
if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
@ -238,6 +257,9 @@ func (_u *ChannelMonitorRequestTemplateUpdate) sqlSave(ctx context.Context) (_no
if value, ok := _u.mutation.Provider(); ok {
_spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
}
if value, ok := _u.mutation.APIMode(); ok {
_spec.SetField(channelmonitorrequesttemplate.FieldAPIMode, field.TypeString, value)
}
if value, ok := _u.mutation.Description(); ok {
_spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
}
@ -355,6 +377,20 @@ func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableProvider(v *channel
return _u
}
// SetAPIMode sets the "api_mode" field.
func (_u *ChannelMonitorRequestTemplateUpdateOne) SetAPIMode(v string) *ChannelMonitorRequestTemplateUpdateOne {
_u.mutation.SetAPIMode(v)
return _u
}
// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableAPIMode(v *string) *ChannelMonitorRequestTemplateUpdateOne {
if v != nil {
_u.SetAPIMode(*v)
}
return _u
}
// SetDescription sets the "description" field.
func (_u *ChannelMonitorRequestTemplateUpdateOne) SetDescription(v string) *ChannelMonitorRequestTemplateUpdateOne {
_u.mutation.SetDescription(v)
@ -509,6 +545,11 @@ func (_u *ChannelMonitorRequestTemplateUpdateOne) check() error {
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
}
}
if v, ok := _u.mutation.APIMode(); ok {
if err := channelmonitorrequesttemplate.APIModeValidator(v); err != nil {
return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.api_mode": %w`, err)}
}
}
if v, ok := _u.mutation.Description(); ok {
if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
@ -560,6 +601,9 @@ func (_u *ChannelMonitorRequestTemplateUpdateOne) sqlSave(ctx context.Context) (
if value, ok := _u.mutation.Provider(); ok {
_spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
}
if value, ok := _u.mutation.APIMode(); ok {
_spec.SetField(channelmonitorrequesttemplate.FieldAPIMode, field.TypeString, value)
}
if value, ok := _u.mutation.Description(); ok {
_spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
}

View File

@ -428,6 +428,7 @@ var (
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "name", Type: field.TypeString, Size: 100},
{Name: "provider", Type: field.TypeEnum, Enums: []string{"openai", "anthropic", "gemini"}},
{Name: "api_mode", Type: field.TypeString, Size: 32, Default: "chat_completions"},
{Name: "endpoint", Type: field.TypeString, Size: 500},
{Name: "api_key_encrypted", Type: field.TypeString},
{Name: "primary_model", Type: field.TypeString, Size: 200},
@ -450,7 +451,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "channel_monitors_channel_monitor_request_templates_request_template",
Columns: []*schema.Column{ChannelMonitorsColumns[17]},
Columns: []*schema.Column{ChannelMonitorsColumns[18]},
RefColumns: []*schema.Column{ChannelMonitorRequestTemplatesColumns[0]},
OnDelete: schema.SetNull,
},
@ -459,22 +460,27 @@ var (
{
Name: "channelmonitor_enabled_last_checked_at",
Unique: false,
Columns: []*schema.Column{ChannelMonitorsColumns[10], ChannelMonitorsColumns[12]},
Columns: []*schema.Column{ChannelMonitorsColumns[11], ChannelMonitorsColumns[13]},
},
{
Name: "channelmonitor_provider",
Unique: false,
Columns: []*schema.Column{ChannelMonitorsColumns[4]},
},
{
Name: "channelmonitor_provider_api_mode",
Unique: false,
Columns: []*schema.Column{ChannelMonitorsColumns[4], ChannelMonitorsColumns[5]},
},
{
Name: "channelmonitor_group_name",
Unique: false,
Columns: []*schema.Column{ChannelMonitorsColumns[9]},
Columns: []*schema.Column{ChannelMonitorsColumns[10]},
},
{
Name: "channelmonitor_template_id",
Unique: false,
Columns: []*schema.Column{ChannelMonitorsColumns[17]},
Columns: []*schema.Column{ChannelMonitorsColumns[18]},
},
},
}
@ -566,6 +572,7 @@ var (
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "name", Type: field.TypeString, Size: 100},
{Name: "provider", Type: field.TypeEnum, Enums: []string{"openai", "anthropic", "gemini"}},
{Name: "api_mode", Type: field.TypeString, Size: 32, Default: "chat_completions"},
{Name: "description", Type: field.TypeString, Nullable: true, Size: 500, Default: ""},
{Name: "extra_headers", Type: field.TypeJSON},
{Name: "body_override_mode", Type: field.TypeString, Size: 10, Default: "off"},
@ -582,6 +589,11 @@ var (
Unique: true,
Columns: []*schema.Column{ChannelMonitorRequestTemplatesColumns[4], ChannelMonitorRequestTemplatesColumns[3]},
},
{
Name: "channelmonitorrequesttemplate_provider_api_mode",
Unique: false,
Columns: []*schema.Column{ChannelMonitorRequestTemplatesColumns[4], ChannelMonitorRequestTemplatesColumns[5]},
},
},
}
// ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table.
@ -1120,6 +1132,7 @@ var (
{Name: "used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "notes", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "validity_days", Type: field.TypeInt, Default: 30},
{Name: "group_id", Type: field.TypeInt64, Nullable: true},
{Name: "used_by", Type: field.TypeInt64, Nullable: true},
@ -1132,13 +1145,13 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "redeem_codes_groups_redeem_codes",
Columns: []*schema.Column{RedeemCodesColumns[9]},
Columns: []*schema.Column{RedeemCodesColumns[10]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "redeem_codes_users_redeem_codes",
Columns: []*schema.Column{RedeemCodesColumns[10]},
Columns: []*schema.Column{RedeemCodesColumns[11]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.SetNull,
},
@ -1152,12 +1165,17 @@ var (
{
Name: "redeemcode_used_by",
Unique: false,
Columns: []*schema.Column{RedeemCodesColumns[10]},
Columns: []*schema.Column{RedeemCodesColumns[11]},
},
{
Name: "redeemcode_group_id",
Unique: false,
Columns: []*schema.Column{RedeemCodesColumns[9]},
Columns: []*schema.Column{RedeemCodesColumns[10]},
},
{
Name: "redeemcode_expires_at",
Unique: false,
Columns: []*schema.Column{RedeemCodesColumns[8]},
},
},
}
@ -1318,6 +1336,10 @@ var (
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
{Name: "image_count", Type: field.TypeInt, Default: 0},
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
{Name: "image_input_size", Type: field.TypeString, Nullable: true, Size: 32},
{Name: "image_output_size", Type: field.TypeString, Nullable: true, Size: 32},
{Name: "image_size_source", Type: field.TypeString, Nullable: true, Size: 16},
{Name: "image_size_breakdown", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
{Name: "api_key_id", Type: field.TypeInt64},
@ -1334,31 +1356,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[33]},
Columns: []*schema.Column{UsageLogsColumns[37]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[34]},
Columns: []*schema.Column{UsageLogsColumns[38]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[35]},
Columns: []*schema.Column{UsageLogsColumns[39]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[36]},
Columns: []*schema.Column{UsageLogsColumns[40]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
Columns: []*schema.Column{UsageLogsColumns[37]},
Columns: []*schema.Column{UsageLogsColumns[41]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@ -1367,32 +1389,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[36]},
Columns: []*schema.Column{UsageLogsColumns[40]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[33]},
Columns: []*schema.Column{UsageLogsColumns[37]},
},
{
Name: "usagelog_account_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[34]},
Columns: []*schema.Column{UsageLogsColumns[38]},
},
{
Name: "usagelog_group_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[35]},
Columns: []*schema.Column{UsageLogsColumns[39]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[37]},
Columns: []*schema.Column{UsageLogsColumns[41]},
},
{
Name: "usagelog_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[32]},
Columns: []*schema.Column{UsageLogsColumns[36]},
},
{
Name: "usagelog_model",
@ -1412,17 +1434,17 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[32]},
Columns: []*schema.Column{UsageLogsColumns[40], UsageLogsColumns[36]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[32]},
Columns: []*schema.Column{UsageLogsColumns[37], UsageLogsColumns[36]},
},
{
Name: "usagelog_group_id_created_at",
Unique: false,
Columns: []*schema.Column{UsageLogsColumns[35], UsageLogsColumns[32]},
Columns: []*schema.Column{UsageLogsColumns[39], UsageLogsColumns[36]},
},
},
}

View File

@ -8752,6 +8752,7 @@ type ChannelMonitorMutation struct {
updated_at *time.Time
name *string
provider *channelmonitor.Provider
api_mode *string
endpoint *string
api_key_encrypted *string
primary_model *string
@ -9023,6 +9024,42 @@ func (m *ChannelMonitorMutation) ResetProvider() {
m.provider = nil
}
// SetAPIMode sets the "api_mode" field.
func (m *ChannelMonitorMutation) SetAPIMode(s string) {
m.api_mode = &s
}
// APIMode returns the value of the "api_mode" field in the mutation.
func (m *ChannelMonitorMutation) APIMode() (r string, exists bool) {
v := m.api_mode
if v == nil {
return
}
return *v, true
}
// OldAPIMode returns the old "api_mode" field's value of the ChannelMonitor entity.
// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *ChannelMonitorMutation) OldAPIMode(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldAPIMode is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldAPIMode requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldAPIMode: %w", err)
}
return oldValue.APIMode, nil
}
// ResetAPIMode resets all changes to the "api_mode" field.
func (m *ChannelMonitorMutation) ResetAPIMode() {
m.api_mode = nil
}
// SetEndpoint sets the "endpoint" field.
func (m *ChannelMonitorMutation) SetEndpoint(s string) {
m.endpoint = &s
@ -9780,7 +9817,7 @@ func (m *ChannelMonitorMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *ChannelMonitorMutation) Fields() []string {
fields := make([]string, 0, 17)
fields := make([]string, 0, 18)
if m.created_at != nil {
fields = append(fields, channelmonitor.FieldCreatedAt)
}
@ -9793,6 +9830,9 @@ func (m *ChannelMonitorMutation) Fields() []string {
if m.provider != nil {
fields = append(fields, channelmonitor.FieldProvider)
}
if m.api_mode != nil {
fields = append(fields, channelmonitor.FieldAPIMode)
}
if m.endpoint != nil {
fields = append(fields, channelmonitor.FieldEndpoint)
}
@ -9848,6 +9888,8 @@ func (m *ChannelMonitorMutation) Field(name string) (ent.Value, bool) {
return m.Name()
case channelmonitor.FieldProvider:
return m.Provider()
case channelmonitor.FieldAPIMode:
return m.APIMode()
case channelmonitor.FieldEndpoint:
return m.Endpoint()
case channelmonitor.FieldAPIKeyEncrypted:
@ -9891,6 +9933,8 @@ func (m *ChannelMonitorMutation) OldField(ctx context.Context, name string) (ent
return m.OldName(ctx)
case channelmonitor.FieldProvider:
return m.OldProvider(ctx)
case channelmonitor.FieldAPIMode:
return m.OldAPIMode(ctx)
case channelmonitor.FieldEndpoint:
return m.OldEndpoint(ctx)
case channelmonitor.FieldAPIKeyEncrypted:
@ -9954,6 +9998,13 @@ func (m *ChannelMonitorMutation) SetField(name string, value ent.Value) error {
}
m.SetProvider(v)
return nil
case channelmonitor.FieldAPIMode:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetAPIMode(v)
return nil
case channelmonitor.FieldEndpoint:
v, ok := value.(string)
if !ok {
@ -10160,6 +10211,9 @@ func (m *ChannelMonitorMutation) ResetField(name string) error {
case channelmonitor.FieldProvider:
m.ResetProvider()
return nil
case channelmonitor.FieldAPIMode:
m.ResetAPIMode()
return nil
case channelmonitor.FieldEndpoint:
m.ResetEndpoint()
return nil
@ -12591,6 +12645,7 @@ type ChannelMonitorRequestTemplateMutation struct {
updated_at *time.Time
name *string
provider *channelmonitorrequesttemplate.Provider
api_mode *string
description *string
extra_headers *map[string]string
body_override_mode *string
@ -12846,6 +12901,42 @@ func (m *ChannelMonitorRequestTemplateMutation) ResetProvider() {
m.provider = nil
}
// SetAPIMode sets the "api_mode" field.
func (m *ChannelMonitorRequestTemplateMutation) SetAPIMode(s string) {
m.api_mode = &s
}
// APIMode returns the value of the "api_mode" field in the mutation.
func (m *ChannelMonitorRequestTemplateMutation) APIMode() (r string, exists bool) {
v := m.api_mode
if v == nil {
return
}
return *v, true
}
// OldAPIMode returns the old "api_mode" field's value of the ChannelMonitorRequestTemplate entity.
// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *ChannelMonitorRequestTemplateMutation) OldAPIMode(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldAPIMode is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldAPIMode requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldAPIMode: %w", err)
}
return oldValue.APIMode, nil
}
// ResetAPIMode resets all changes to the "api_mode" field.
func (m *ChannelMonitorRequestTemplateMutation) ResetAPIMode() {
m.api_mode = nil
}
// SetDescription sets the "description" field.
func (m *ChannelMonitorRequestTemplateMutation) SetDescription(s string) {
m.description = &s
@ -13104,7 +13195,7 @@ func (m *ChannelMonitorRequestTemplateMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *ChannelMonitorRequestTemplateMutation) Fields() []string {
fields := make([]string, 0, 8)
fields := make([]string, 0, 9)
if m.created_at != nil {
fields = append(fields, channelmonitorrequesttemplate.FieldCreatedAt)
}
@ -13117,6 +13208,9 @@ func (m *ChannelMonitorRequestTemplateMutation) Fields() []string {
if m.provider != nil {
fields = append(fields, channelmonitorrequesttemplate.FieldProvider)
}
if m.api_mode != nil {
fields = append(fields, channelmonitorrequesttemplate.FieldAPIMode)
}
if m.description != nil {
fields = append(fields, channelmonitorrequesttemplate.FieldDescription)
}
@ -13145,6 +13239,8 @@ func (m *ChannelMonitorRequestTemplateMutation) Field(name string) (ent.Value, b
return m.Name()
case channelmonitorrequesttemplate.FieldProvider:
return m.Provider()
case channelmonitorrequesttemplate.FieldAPIMode:
return m.APIMode()
case channelmonitorrequesttemplate.FieldDescription:
return m.Description()
case channelmonitorrequesttemplate.FieldExtraHeaders:
@ -13170,6 +13266,8 @@ func (m *ChannelMonitorRequestTemplateMutation) OldField(ctx context.Context, na
return m.OldName(ctx)
case channelmonitorrequesttemplate.FieldProvider:
return m.OldProvider(ctx)
case channelmonitorrequesttemplate.FieldAPIMode:
return m.OldAPIMode(ctx)
case channelmonitorrequesttemplate.FieldDescription:
return m.OldDescription(ctx)
case channelmonitorrequesttemplate.FieldExtraHeaders:
@ -13215,6 +13313,13 @@ func (m *ChannelMonitorRequestTemplateMutation) SetField(name string, value ent.
}
m.SetProvider(v)
return nil
case channelmonitorrequesttemplate.FieldAPIMode:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetAPIMode(v)
return nil
case channelmonitorrequesttemplate.FieldDescription:
v, ok := value.(string)
if !ok {
@ -13319,6 +13424,9 @@ func (m *ChannelMonitorRequestTemplateMutation) ResetField(name string) error {
case channelmonitorrequesttemplate.FieldProvider:
m.ResetProvider()
return nil
case channelmonitorrequesttemplate.FieldAPIMode:
m.ResetAPIMode()
return nil
case channelmonitorrequesttemplate.FieldDescription:
m.ResetDescription()
return nil
@ -28602,6 +28710,7 @@ type RedeemCodeMutation struct {
used_at *time.Time
notes *string
created_at *time.Time
expires_at *time.Time
validity_days *int
addvalidity_days *int
clearedFields map[string]struct{}
@ -29059,6 +29168,55 @@ func (m *RedeemCodeMutation) ResetCreatedAt() {
m.created_at = nil
}
// SetExpiresAt sets the "expires_at" field.
func (m *RedeemCodeMutation) SetExpiresAt(t time.Time) {
m.expires_at = &t
}
// ExpiresAt returns the value of the "expires_at" field in the mutation.
func (m *RedeemCodeMutation) ExpiresAt() (r time.Time, exists bool) {
v := m.expires_at
if v == nil {
return
}
return *v, true
}
// OldExpiresAt returns the old "expires_at" field's value of the RedeemCode entity.
// If the RedeemCode object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *RedeemCodeMutation) OldExpiresAt(ctx context.Context) (v *time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldExpiresAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
}
return oldValue.ExpiresAt, nil
}
// ClearExpiresAt clears the value of the "expires_at" field.
func (m *RedeemCodeMutation) ClearExpiresAt() {
m.expires_at = nil
m.clearedFields[redeemcode.FieldExpiresAt] = struct{}{}
}
// ExpiresAtCleared returns if the "expires_at" field was cleared in this mutation.
func (m *RedeemCodeMutation) ExpiresAtCleared() bool {
_, ok := m.clearedFields[redeemcode.FieldExpiresAt]
return ok
}
// ResetExpiresAt resets all changes to the "expires_at" field.
func (m *RedeemCodeMutation) ResetExpiresAt() {
m.expires_at = nil
delete(m.clearedFields, redeemcode.FieldExpiresAt)
}
// SetGroupID sets the "group_id" field.
func (m *RedeemCodeMutation) SetGroupID(i int64) {
m.group = &i
@ -29265,7 +29423,7 @@ func (m *RedeemCodeMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *RedeemCodeMutation) Fields() []string {
fields := make([]string, 0, 10)
fields := make([]string, 0, 11)
if m.code != nil {
fields = append(fields, redeemcode.FieldCode)
}
@ -29290,6 +29448,9 @@ func (m *RedeemCodeMutation) Fields() []string {
if m.created_at != nil {
fields = append(fields, redeemcode.FieldCreatedAt)
}
if m.expires_at != nil {
fields = append(fields, redeemcode.FieldExpiresAt)
}
if m.group != nil {
fields = append(fields, redeemcode.FieldGroupID)
}
@ -29320,6 +29481,8 @@ func (m *RedeemCodeMutation) Field(name string) (ent.Value, bool) {
return m.Notes()
case redeemcode.FieldCreatedAt:
return m.CreatedAt()
case redeemcode.FieldExpiresAt:
return m.ExpiresAt()
case redeemcode.FieldGroupID:
return m.GroupID()
case redeemcode.FieldValidityDays:
@ -29349,6 +29512,8 @@ func (m *RedeemCodeMutation) OldField(ctx context.Context, name string) (ent.Val
return m.OldNotes(ctx)
case redeemcode.FieldCreatedAt:
return m.OldCreatedAt(ctx)
case redeemcode.FieldExpiresAt:
return m.OldExpiresAt(ctx)
case redeemcode.FieldGroupID:
return m.OldGroupID(ctx)
case redeemcode.FieldValidityDays:
@ -29418,6 +29583,13 @@ func (m *RedeemCodeMutation) SetField(name string, value ent.Value) error {
}
m.SetCreatedAt(v)
return nil
case redeemcode.FieldExpiresAt:
v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetExpiresAt(v)
return nil
case redeemcode.FieldGroupID:
v, ok := value.(int64)
if !ok {
@ -29498,6 +29670,9 @@ func (m *RedeemCodeMutation) ClearedFields() []string {
if m.FieldCleared(redeemcode.FieldNotes) {
fields = append(fields, redeemcode.FieldNotes)
}
if m.FieldCleared(redeemcode.FieldExpiresAt) {
fields = append(fields, redeemcode.FieldExpiresAt)
}
if m.FieldCleared(redeemcode.FieldGroupID) {
fields = append(fields, redeemcode.FieldGroupID)
}
@ -29524,6 +29699,9 @@ func (m *RedeemCodeMutation) ClearField(name string) error {
case redeemcode.FieldNotes:
m.ClearNotes()
return nil
case redeemcode.FieldExpiresAt:
m.ClearExpiresAt()
return nil
case redeemcode.FieldGroupID:
m.ClearGroupID()
return nil
@ -29559,6 +29737,9 @@ func (m *RedeemCodeMutation) ResetField(name string) error {
case redeemcode.FieldCreatedAt:
m.ResetCreatedAt()
return nil
case redeemcode.FieldExpiresAt:
m.ResetExpiresAt()
return nil
case redeemcode.FieldGroupID:
m.ResetGroupID()
return nil
@ -34260,6 +34441,10 @@ type UsageLogMutation struct {
image_count *int
addimage_count *int
image_size *string
image_input_size *string
image_output_size *string
image_size_source *string
image_size_breakdown *map[string]int
cache_ttl_overridden *bool
created_at *time.Time
clearedFields map[string]struct{}
@ -36202,6 +36387,202 @@ func (m *UsageLogMutation) ResetImageSize() {
delete(m.clearedFields, usagelog.FieldImageSize)
}
// SetImageInputSize sets the "image_input_size" field.
func (m *UsageLogMutation) SetImageInputSize(s string) {
m.image_input_size = &s
}
// ImageInputSize returns the value of the "image_input_size" field in the mutation.
func (m *UsageLogMutation) ImageInputSize() (r string, exists bool) {
v := m.image_input_size
if v == nil {
return
}
return *v, true
}
// OldImageInputSize returns the old "image_input_size" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldImageInputSize(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldImageInputSize is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldImageInputSize requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldImageInputSize: %w", err)
}
return oldValue.ImageInputSize, nil
}
// ClearImageInputSize clears the value of the "image_input_size" field.
func (m *UsageLogMutation) ClearImageInputSize() {
m.image_input_size = nil
m.clearedFields[usagelog.FieldImageInputSize] = struct{}{}
}
// ImageInputSizeCleared returns if the "image_input_size" field was cleared in this mutation.
func (m *UsageLogMutation) ImageInputSizeCleared() bool {
_, ok := m.clearedFields[usagelog.FieldImageInputSize]
return ok
}
// ResetImageInputSize resets all changes to the "image_input_size" field.
func (m *UsageLogMutation) ResetImageInputSize() {
m.image_input_size = nil
delete(m.clearedFields, usagelog.FieldImageInputSize)
}
// SetImageOutputSize sets the "image_output_size" field.
func (m *UsageLogMutation) SetImageOutputSize(s string) {
m.image_output_size = &s
}
// ImageOutputSize returns the value of the "image_output_size" field in the mutation.
func (m *UsageLogMutation) ImageOutputSize() (r string, exists bool) {
v := m.image_output_size
if v == nil {
return
}
return *v, true
}
// OldImageOutputSize returns the old "image_output_size" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldImageOutputSize(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldImageOutputSize is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldImageOutputSize requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldImageOutputSize: %w", err)
}
return oldValue.ImageOutputSize, nil
}
// ClearImageOutputSize clears the value of the "image_output_size" field.
func (m *UsageLogMutation) ClearImageOutputSize() {
m.image_output_size = nil
m.clearedFields[usagelog.FieldImageOutputSize] = struct{}{}
}
// ImageOutputSizeCleared returns if the "image_output_size" field was cleared in this mutation.
func (m *UsageLogMutation) ImageOutputSizeCleared() bool {
_, ok := m.clearedFields[usagelog.FieldImageOutputSize]
return ok
}
// ResetImageOutputSize resets all changes to the "image_output_size" field.
func (m *UsageLogMutation) ResetImageOutputSize() {
m.image_output_size = nil
delete(m.clearedFields, usagelog.FieldImageOutputSize)
}
// SetImageSizeSource sets the "image_size_source" field.
func (m *UsageLogMutation) SetImageSizeSource(s string) {
m.image_size_source = &s
}
// ImageSizeSource returns the value of the "image_size_source" field in the mutation.
func (m *UsageLogMutation) ImageSizeSource() (r string, exists bool) {
v := m.image_size_source
if v == nil {
return
}
return *v, true
}
// OldImageSizeSource returns the old "image_size_source" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldImageSizeSource(ctx context.Context) (v *string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldImageSizeSource is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldImageSizeSource requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldImageSizeSource: %w", err)
}
return oldValue.ImageSizeSource, nil
}
// ClearImageSizeSource clears the value of the "image_size_source" field.
func (m *UsageLogMutation) ClearImageSizeSource() {
m.image_size_source = nil
m.clearedFields[usagelog.FieldImageSizeSource] = struct{}{}
}
// ImageSizeSourceCleared returns if the "image_size_source" field was cleared in this mutation.
func (m *UsageLogMutation) ImageSizeSourceCleared() bool {
_, ok := m.clearedFields[usagelog.FieldImageSizeSource]
return ok
}
// ResetImageSizeSource resets all changes to the "image_size_source" field.
func (m *UsageLogMutation) ResetImageSizeSource() {
m.image_size_source = nil
delete(m.clearedFields, usagelog.FieldImageSizeSource)
}
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
func (m *UsageLogMutation) SetImageSizeBreakdown(value map[string]int) {
m.image_size_breakdown = &value
}
// ImageSizeBreakdown returns the value of the "image_size_breakdown" field in the mutation.
func (m *UsageLogMutation) ImageSizeBreakdown() (r map[string]int, exists bool) {
v := m.image_size_breakdown
if v == nil {
return
}
return *v, true
}
// OldImageSizeBreakdown returns the old "image_size_breakdown" field's value of the UsageLog entity.
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *UsageLogMutation) OldImageSizeBreakdown(ctx context.Context) (v map[string]int, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldImageSizeBreakdown is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldImageSizeBreakdown requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldImageSizeBreakdown: %w", err)
}
return oldValue.ImageSizeBreakdown, nil
}
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
func (m *UsageLogMutation) ClearImageSizeBreakdown() {
m.image_size_breakdown = nil
m.clearedFields[usagelog.FieldImageSizeBreakdown] = struct{}{}
}
// ImageSizeBreakdownCleared returns if the "image_size_breakdown" field was cleared in this mutation.
func (m *UsageLogMutation) ImageSizeBreakdownCleared() bool {
_, ok := m.clearedFields[usagelog.FieldImageSizeBreakdown]
return ok
}
// ResetImageSizeBreakdown resets all changes to the "image_size_breakdown" field.
func (m *UsageLogMutation) ResetImageSizeBreakdown() {
m.image_size_breakdown = nil
delete(m.clearedFields, usagelog.FieldImageSizeBreakdown)
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) {
m.cache_ttl_overridden = &b
@ -36443,7 +36824,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UsageLogMutation) Fields() []string {
fields := make([]string, 0, 37)
fields := make([]string, 0, 41)
if m.user != nil {
fields = append(fields, usagelog.FieldUserID)
}
@ -36549,6 +36930,18 @@ func (m *UsageLogMutation) Fields() []string {
if m.image_size != nil {
fields = append(fields, usagelog.FieldImageSize)
}
if m.image_input_size != nil {
fields = append(fields, usagelog.FieldImageInputSize)
}
if m.image_output_size != nil {
fields = append(fields, usagelog.FieldImageOutputSize)
}
if m.image_size_source != nil {
fields = append(fields, usagelog.FieldImageSizeSource)
}
if m.image_size_breakdown != nil {
fields = append(fields, usagelog.FieldImageSizeBreakdown)
}
if m.cache_ttl_overridden != nil {
fields = append(fields, usagelog.FieldCacheTTLOverridden)
}
@ -36633,6 +37026,14 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.ImageCount()
case usagelog.FieldImageSize:
return m.ImageSize()
case usagelog.FieldImageInputSize:
return m.ImageInputSize()
case usagelog.FieldImageOutputSize:
return m.ImageOutputSize()
case usagelog.FieldImageSizeSource:
return m.ImageSizeSource()
case usagelog.FieldImageSizeBreakdown:
return m.ImageSizeBreakdown()
case usagelog.FieldCacheTTLOverridden:
return m.CacheTTLOverridden()
case usagelog.FieldCreatedAt:
@ -36716,6 +37117,14 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldImageCount(ctx)
case usagelog.FieldImageSize:
return m.OldImageSize(ctx)
case usagelog.FieldImageInputSize:
return m.OldImageInputSize(ctx)
case usagelog.FieldImageOutputSize:
return m.OldImageOutputSize(ctx)
case usagelog.FieldImageSizeSource:
return m.OldImageSizeSource(ctx)
case usagelog.FieldImageSizeBreakdown:
return m.OldImageSizeBreakdown(ctx)
case usagelog.FieldCacheTTLOverridden:
return m.OldCacheTTLOverridden(ctx)
case usagelog.FieldCreatedAt:
@ -36974,6 +37383,34 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
}
m.SetImageSize(v)
return nil
case usagelog.FieldImageInputSize:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetImageInputSize(v)
return nil
case usagelog.FieldImageOutputSize:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetImageOutputSize(v)
return nil
case usagelog.FieldImageSizeSource:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetImageSizeSource(v)
return nil
case usagelog.FieldImageSizeBreakdown:
v, ok := value.(map[string]int)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetImageSizeBreakdown(v)
return nil
case usagelog.FieldCacheTTLOverridden:
v, ok := value.(bool)
if !ok {
@ -37291,6 +37728,18 @@ func (m *UsageLogMutation) ClearedFields() []string {
if m.FieldCleared(usagelog.FieldImageSize) {
fields = append(fields, usagelog.FieldImageSize)
}
if m.FieldCleared(usagelog.FieldImageInputSize) {
fields = append(fields, usagelog.FieldImageInputSize)
}
if m.FieldCleared(usagelog.FieldImageOutputSize) {
fields = append(fields, usagelog.FieldImageOutputSize)
}
if m.FieldCleared(usagelog.FieldImageSizeSource) {
fields = append(fields, usagelog.FieldImageSizeSource)
}
if m.FieldCleared(usagelog.FieldImageSizeBreakdown) {
fields = append(fields, usagelog.FieldImageSizeBreakdown)
}
return fields
}
@ -37347,6 +37796,18 @@ func (m *UsageLogMutation) ClearField(name string) error {
case usagelog.FieldImageSize:
m.ClearImageSize()
return nil
case usagelog.FieldImageInputSize:
m.ClearImageInputSize()
return nil
case usagelog.FieldImageOutputSize:
m.ClearImageOutputSize()
return nil
case usagelog.FieldImageSizeSource:
m.ClearImageSizeSource()
return nil
case usagelog.FieldImageSizeBreakdown:
m.ClearImageSizeBreakdown()
return nil
}
return fmt.Errorf("unknown UsageLog nullable field %s", name)
}
@ -37460,6 +37921,18 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldImageSize:
m.ResetImageSize()
return nil
case usagelog.FieldImageInputSize:
m.ResetImageInputSize()
return nil
case usagelog.FieldImageOutputSize:
m.ResetImageOutputSize()
return nil
case usagelog.FieldImageSizeSource:
m.ResetImageSizeSource()
return nil
case usagelog.FieldImageSizeBreakdown:
m.ResetImageSizeBreakdown()
return nil
case usagelog.FieldCacheTTLOverridden:
m.ResetCacheTTLOverridden()
return nil

View File

@ -35,6 +35,8 @@ type RedeemCode struct {
Notes *string `json:"notes,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// ExpiresAt holds the value of the "expires_at" field.
ExpiresAt *time.Time `json:"expires_at,omitempty"`
// GroupID holds the value of the "group_id" field.
GroupID *int64 `json:"group_id,omitempty"`
// ValidityDays holds the value of the "validity_days" field.
@ -89,7 +91,7 @@ func (*RedeemCode) scanValues(columns []string) ([]any, error) {
values[i] = new(sql.NullInt64)
case redeemcode.FieldCode, redeemcode.FieldType, redeemcode.FieldStatus, redeemcode.FieldNotes:
values[i] = new(sql.NullString)
case redeemcode.FieldUsedAt, redeemcode.FieldCreatedAt:
case redeemcode.FieldUsedAt, redeemcode.FieldCreatedAt, redeemcode.FieldExpiresAt:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
@ -163,6 +165,13 @@ func (_m *RedeemCode) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.CreatedAt = value.Time
}
case redeemcode.FieldExpiresAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field expires_at", values[i])
} else if value.Valid {
_m.ExpiresAt = new(time.Time)
*_m.ExpiresAt = value.Time
}
case redeemcode.FieldGroupID:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field group_id", values[i])
@ -252,6 +261,11 @@ func (_m *RedeemCode) String() string {
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
if v := _m.ExpiresAt; v != nil {
builder.WriteString("expires_at=")
builder.WriteString(v.Format(time.ANSIC))
}
builder.WriteString(", ")
if v := _m.GroupID; v != nil {
builder.WriteString("group_id=")
builder.WriteString(fmt.Sprintf("%v", *v))

View File

@ -30,6 +30,8 @@ const (
FieldNotes = "notes"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldExpiresAt holds the string denoting the expires_at field in the database.
FieldExpiresAt = "expires_at"
// FieldGroupID holds the string denoting the group_id field in the database.
FieldGroupID = "group_id"
// FieldValidityDays holds the string denoting the validity_days field in the database.
@ -67,6 +69,7 @@ var Columns = []string{
FieldUsedAt,
FieldNotes,
FieldCreatedAt,
FieldExpiresAt,
FieldGroupID,
FieldValidityDays,
}
@ -148,6 +151,11 @@ func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByExpiresAt orders the results by the expires_at field.
func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
}
// ByGroupID orders the results by the group_id field.
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldGroupID, opts...).ToFunc()

View File

@ -95,6 +95,11 @@ func CreatedAt(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldEQ(FieldCreatedAt, v))
}
// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
func ExpiresAt(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldEQ(FieldExpiresAt, v))
}
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
func GroupID(v int64) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldEQ(FieldGroupID, v))
@ -535,6 +540,56 @@ func CreatedAtLTE(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldLTE(FieldCreatedAt, v))
}
// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
func ExpiresAtEQ(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldEQ(FieldExpiresAt, v))
}
// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
func ExpiresAtNEQ(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldNEQ(FieldExpiresAt, v))
}
// ExpiresAtIn applies the In predicate on the "expires_at" field.
func ExpiresAtIn(vs ...time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldIn(FieldExpiresAt, vs...))
}
// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
func ExpiresAtNotIn(vs ...time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldNotIn(FieldExpiresAt, vs...))
}
// ExpiresAtGT applies the GT predicate on the "expires_at" field.
func ExpiresAtGT(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldGT(FieldExpiresAt, v))
}
// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
func ExpiresAtGTE(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldGTE(FieldExpiresAt, v))
}
// ExpiresAtLT applies the LT predicate on the "expires_at" field.
func ExpiresAtLT(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldLT(FieldExpiresAt, v))
}
// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
func ExpiresAtLTE(v time.Time) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldLTE(FieldExpiresAt, v))
}
// ExpiresAtIsNil applies the IsNil predicate on the "expires_at" field.
func ExpiresAtIsNil() predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldIsNull(FieldExpiresAt))
}
// ExpiresAtNotNil applies the NotNil predicate on the "expires_at" field.
func ExpiresAtNotNil() predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldNotNull(FieldExpiresAt))
}
// GroupIDEQ applies the EQ predicate on the "group_id" field.
func GroupIDEQ(v int64) predicate.RedeemCode {
return predicate.RedeemCode(sql.FieldEQ(FieldGroupID, v))

View File

@ -128,6 +128,20 @@ func (_c *RedeemCodeCreate) SetNillableCreatedAt(v *time.Time) *RedeemCodeCreate
return _c
}
// SetExpiresAt sets the "expires_at" field.
func (_c *RedeemCodeCreate) SetExpiresAt(v time.Time) *RedeemCodeCreate {
_c.mutation.SetExpiresAt(v)
return _c
}
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
func (_c *RedeemCodeCreate) SetNillableExpiresAt(v *time.Time) *RedeemCodeCreate {
if v != nil {
_c.SetExpiresAt(*v)
}
return _c
}
// SetGroupID sets the "group_id" field.
func (_c *RedeemCodeCreate) SetGroupID(v int64) *RedeemCodeCreate {
_c.mutation.SetGroupID(v)
@ -327,6 +341,10 @@ func (_c *RedeemCodeCreate) createSpec() (*RedeemCode, *sqlgraph.CreateSpec) {
_spec.SetField(redeemcode.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
}
if value, ok := _c.mutation.ExpiresAt(); ok {
_spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value)
_node.ExpiresAt = &value
}
if value, ok := _c.mutation.ValidityDays(); ok {
_spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value)
_node.ValidityDays = value
@ -525,6 +543,24 @@ func (u *RedeemCodeUpsert) ClearNotes() *RedeemCodeUpsert {
return u
}
// SetExpiresAt sets the "expires_at" field.
func (u *RedeemCodeUpsert) SetExpiresAt(v time.Time) *RedeemCodeUpsert {
u.Set(redeemcode.FieldExpiresAt, v)
return u
}
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
func (u *RedeemCodeUpsert) UpdateExpiresAt() *RedeemCodeUpsert {
u.SetExcluded(redeemcode.FieldExpiresAt)
return u
}
// ClearExpiresAt clears the value of the "expires_at" field.
func (u *RedeemCodeUpsert) ClearExpiresAt() *RedeemCodeUpsert {
u.SetNull(redeemcode.FieldExpiresAt)
return u
}
// SetGroupID sets the "group_id" field.
func (u *RedeemCodeUpsert) SetGroupID(v int64) *RedeemCodeUpsert {
u.Set(redeemcode.FieldGroupID, v)
@ -732,6 +768,27 @@ func (u *RedeemCodeUpsertOne) ClearNotes() *RedeemCodeUpsertOne {
})
}
// SetExpiresAt sets the "expires_at" field.
func (u *RedeemCodeUpsertOne) SetExpiresAt(v time.Time) *RedeemCodeUpsertOne {
return u.Update(func(s *RedeemCodeUpsert) {
s.SetExpiresAt(v)
})
}
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
func (u *RedeemCodeUpsertOne) UpdateExpiresAt() *RedeemCodeUpsertOne {
return u.Update(func(s *RedeemCodeUpsert) {
s.UpdateExpiresAt()
})
}
// ClearExpiresAt clears the value of the "expires_at" field.
func (u *RedeemCodeUpsertOne) ClearExpiresAt() *RedeemCodeUpsertOne {
return u.Update(func(s *RedeemCodeUpsert) {
s.ClearExpiresAt()
})
}
// SetGroupID sets the "group_id" field.
func (u *RedeemCodeUpsertOne) SetGroupID(v int64) *RedeemCodeUpsertOne {
return u.Update(func(s *RedeemCodeUpsert) {
@ -1111,6 +1168,27 @@ func (u *RedeemCodeUpsertBulk) ClearNotes() *RedeemCodeUpsertBulk {
})
}
// SetExpiresAt sets the "expires_at" field.
func (u *RedeemCodeUpsertBulk) SetExpiresAt(v time.Time) *RedeemCodeUpsertBulk {
return u.Update(func(s *RedeemCodeUpsert) {
s.SetExpiresAt(v)
})
}
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
func (u *RedeemCodeUpsertBulk) UpdateExpiresAt() *RedeemCodeUpsertBulk {
return u.Update(func(s *RedeemCodeUpsert) {
s.UpdateExpiresAt()
})
}
// ClearExpiresAt clears the value of the "expires_at" field.
func (u *RedeemCodeUpsertBulk) ClearExpiresAt() *RedeemCodeUpsertBulk {
return u.Update(func(s *RedeemCodeUpsert) {
s.ClearExpiresAt()
})
}
// SetGroupID sets the "group_id" field.
func (u *RedeemCodeUpsertBulk) SetGroupID(v int64) *RedeemCodeUpsertBulk {
return u.Update(func(s *RedeemCodeUpsert) {

View File

@ -153,6 +153,26 @@ func (_u *RedeemCodeUpdate) ClearNotes() *RedeemCodeUpdate {
return _u
}
// SetExpiresAt sets the "expires_at" field.
func (_u *RedeemCodeUpdate) SetExpiresAt(v time.Time) *RedeemCodeUpdate {
_u.mutation.SetExpiresAt(v)
return _u
}
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
func (_u *RedeemCodeUpdate) SetNillableExpiresAt(v *time.Time) *RedeemCodeUpdate {
if v != nil {
_u.SetExpiresAt(*v)
}
return _u
}
// ClearExpiresAt clears the value of the "expires_at" field.
func (_u *RedeemCodeUpdate) ClearExpiresAt() *RedeemCodeUpdate {
_u.mutation.ClearExpiresAt()
return _u
}
// SetGroupID sets the "group_id" field.
func (_u *RedeemCodeUpdate) SetGroupID(v int64) *RedeemCodeUpdate {
_u.mutation.SetGroupID(v)
@ -321,6 +341,12 @@ func (_u *RedeemCodeUpdate) sqlSave(ctx context.Context) (_node int, err error)
if _u.mutation.NotesCleared() {
_spec.ClearField(redeemcode.FieldNotes, field.TypeString)
}
if value, ok := _u.mutation.ExpiresAt(); ok {
_spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value)
}
if _u.mutation.ExpiresAtCleared() {
_spec.ClearField(redeemcode.FieldExpiresAt, field.TypeTime)
}
if value, ok := _u.mutation.ValidityDays(); ok {
_spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value)
}
@ -528,6 +554,26 @@ func (_u *RedeemCodeUpdateOne) ClearNotes() *RedeemCodeUpdateOne {
return _u
}
// SetExpiresAt sets the "expires_at" field.
func (_u *RedeemCodeUpdateOne) SetExpiresAt(v time.Time) *RedeemCodeUpdateOne {
_u.mutation.SetExpiresAt(v)
return _u
}
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
func (_u *RedeemCodeUpdateOne) SetNillableExpiresAt(v *time.Time) *RedeemCodeUpdateOne {
if v != nil {
_u.SetExpiresAt(*v)
}
return _u
}
// ClearExpiresAt clears the value of the "expires_at" field.
func (_u *RedeemCodeUpdateOne) ClearExpiresAt() *RedeemCodeUpdateOne {
_u.mutation.ClearExpiresAt()
return _u
}
// SetGroupID sets the "group_id" field.
func (_u *RedeemCodeUpdateOne) SetGroupID(v int64) *RedeemCodeUpdateOne {
_u.mutation.SetGroupID(v)
@ -726,6 +772,12 @@ func (_u *RedeemCodeUpdateOne) sqlSave(ctx context.Context) (_node *RedeemCode,
if _u.mutation.NotesCleared() {
_spec.ClearField(redeemcode.FieldNotes, field.TypeString)
}
if value, ok := _u.mutation.ExpiresAt(); ok {
_spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value)
}
if _u.mutation.ExpiresAtCleared() {
_spec.ClearField(redeemcode.FieldExpiresAt, field.TypeTime)
}
if value, ok := _u.mutation.ValidityDays(); ok {
_spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value)
}

View File

@ -464,8 +464,14 @@ func init() {
return nil
}
}()
// channelmonitorDescAPIMode is the schema descriptor for api_mode field.
channelmonitorDescAPIMode := channelmonitorFields[2].Descriptor()
// channelmonitor.DefaultAPIMode holds the default value on creation for the api_mode field.
channelmonitor.DefaultAPIMode = channelmonitorDescAPIMode.Default.(string)
// channelmonitor.APIModeValidator is a validator for the "api_mode" field. It is called by the builders before save.
channelmonitor.APIModeValidator = channelmonitorDescAPIMode.Validators[0].(func(string) error)
// channelmonitorDescEndpoint is the schema descriptor for endpoint field.
channelmonitorDescEndpoint := channelmonitorFields[2].Descriptor()
channelmonitorDescEndpoint := channelmonitorFields[3].Descriptor()
// channelmonitor.EndpointValidator is a validator for the "endpoint" field. It is called by the builders before save.
channelmonitor.EndpointValidator = func() func(string) error {
validators := channelmonitorDescEndpoint.Validators
@ -483,11 +489,11 @@ func init() {
}
}()
// channelmonitorDescAPIKeyEncrypted is the schema descriptor for api_key_encrypted field.
channelmonitorDescAPIKeyEncrypted := channelmonitorFields[3].Descriptor()
channelmonitorDescAPIKeyEncrypted := channelmonitorFields[4].Descriptor()
// channelmonitor.APIKeyEncryptedValidator is a validator for the "api_key_encrypted" field. It is called by the builders before save.
channelmonitor.APIKeyEncryptedValidator = channelmonitorDescAPIKeyEncrypted.Validators[0].(func(string) error)
// channelmonitorDescPrimaryModel is the schema descriptor for primary_model field.
channelmonitorDescPrimaryModel := channelmonitorFields[4].Descriptor()
channelmonitorDescPrimaryModel := channelmonitorFields[5].Descriptor()
// channelmonitor.PrimaryModelValidator is a validator for the "primary_model" field. It is called by the builders before save.
channelmonitor.PrimaryModelValidator = func() func(string) error {
validators := channelmonitorDescPrimaryModel.Validators
@ -505,29 +511,29 @@ func init() {
}
}()
// channelmonitorDescExtraModels is the schema descriptor for extra_models field.
channelmonitorDescExtraModels := channelmonitorFields[5].Descriptor()
channelmonitorDescExtraModels := channelmonitorFields[6].Descriptor()
// channelmonitor.DefaultExtraModels holds the default value on creation for the extra_models field.
channelmonitor.DefaultExtraModels = channelmonitorDescExtraModels.Default.([]string)
// channelmonitorDescGroupName is the schema descriptor for group_name field.
channelmonitorDescGroupName := channelmonitorFields[6].Descriptor()
channelmonitorDescGroupName := channelmonitorFields[7].Descriptor()
// channelmonitor.DefaultGroupName holds the default value on creation for the group_name field.
channelmonitor.DefaultGroupName = channelmonitorDescGroupName.Default.(string)
// channelmonitor.GroupNameValidator is a validator for the "group_name" field. It is called by the builders before save.
channelmonitor.GroupNameValidator = channelmonitorDescGroupName.Validators[0].(func(string) error)
// channelmonitorDescEnabled is the schema descriptor for enabled field.
channelmonitorDescEnabled := channelmonitorFields[7].Descriptor()
channelmonitorDescEnabled := channelmonitorFields[8].Descriptor()
// channelmonitor.DefaultEnabled holds the default value on creation for the enabled field.
channelmonitor.DefaultEnabled = channelmonitorDescEnabled.Default.(bool)
// channelmonitorDescIntervalSeconds is the schema descriptor for interval_seconds field.
channelmonitorDescIntervalSeconds := channelmonitorFields[8].Descriptor()
channelmonitorDescIntervalSeconds := channelmonitorFields[9].Descriptor()
// channelmonitor.IntervalSecondsValidator is a validator for the "interval_seconds" field. It is called by the builders before save.
channelmonitor.IntervalSecondsValidator = channelmonitorDescIntervalSeconds.Validators[0].(func(int) error)
// channelmonitorDescExtraHeaders is the schema descriptor for extra_headers field.
channelmonitorDescExtraHeaders := channelmonitorFields[12].Descriptor()
channelmonitorDescExtraHeaders := channelmonitorFields[13].Descriptor()
// channelmonitor.DefaultExtraHeaders holds the default value on creation for the extra_headers field.
channelmonitor.DefaultExtraHeaders = channelmonitorDescExtraHeaders.Default.(map[string]string)
// channelmonitorDescBodyOverrideMode is the schema descriptor for body_override_mode field.
channelmonitorDescBodyOverrideMode := channelmonitorFields[13].Descriptor()
channelmonitorDescBodyOverrideMode := channelmonitorFields[14].Descriptor()
// channelmonitor.DefaultBodyOverrideMode holds the default value on creation for the body_override_mode field.
channelmonitor.DefaultBodyOverrideMode = channelmonitorDescBodyOverrideMode.Default.(string)
// channelmonitor.BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
@ -661,18 +667,24 @@ func init() {
return nil
}
}()
// channelmonitorrequesttemplateDescAPIMode is the schema descriptor for api_mode field.
channelmonitorrequesttemplateDescAPIMode := channelmonitorrequesttemplateFields[2].Descriptor()
// channelmonitorrequesttemplate.DefaultAPIMode holds the default value on creation for the api_mode field.
channelmonitorrequesttemplate.DefaultAPIMode = channelmonitorrequesttemplateDescAPIMode.Default.(string)
// channelmonitorrequesttemplate.APIModeValidator is a validator for the "api_mode" field. It is called by the builders before save.
channelmonitorrequesttemplate.APIModeValidator = channelmonitorrequesttemplateDescAPIMode.Validators[0].(func(string) error)
// channelmonitorrequesttemplateDescDescription is the schema descriptor for description field.
channelmonitorrequesttemplateDescDescription := channelmonitorrequesttemplateFields[2].Descriptor()
channelmonitorrequesttemplateDescDescription := channelmonitorrequesttemplateFields[3].Descriptor()
// channelmonitorrequesttemplate.DefaultDescription holds the default value on creation for the description field.
channelmonitorrequesttemplate.DefaultDescription = channelmonitorrequesttemplateDescDescription.Default.(string)
// channelmonitorrequesttemplate.DescriptionValidator is a validator for the "description" field. It is called by the builders before save.
channelmonitorrequesttemplate.DescriptionValidator = channelmonitorrequesttemplateDescDescription.Validators[0].(func(string) error)
// channelmonitorrequesttemplateDescExtraHeaders is the schema descriptor for extra_headers field.
channelmonitorrequesttemplateDescExtraHeaders := channelmonitorrequesttemplateFields[3].Descriptor()
channelmonitorrequesttemplateDescExtraHeaders := channelmonitorrequesttemplateFields[4].Descriptor()
// channelmonitorrequesttemplate.DefaultExtraHeaders holds the default value on creation for the extra_headers field.
channelmonitorrequesttemplate.DefaultExtraHeaders = channelmonitorrequesttemplateDescExtraHeaders.Default.(map[string]string)
// channelmonitorrequesttemplateDescBodyOverrideMode is the schema descriptor for body_override_mode field.
channelmonitorrequesttemplateDescBodyOverrideMode := channelmonitorrequesttemplateFields[4].Descriptor()
channelmonitorrequesttemplateDescBodyOverrideMode := channelmonitorrequesttemplateFields[5].Descriptor()
// channelmonitorrequesttemplate.DefaultBodyOverrideMode holds the default value on creation for the body_override_mode field.
channelmonitorrequesttemplate.DefaultBodyOverrideMode = channelmonitorrequesttemplateDescBodyOverrideMode.Default.(string)
// channelmonitorrequesttemplate.BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
@ -1386,7 +1398,7 @@ func init() {
// redeemcode.DefaultCreatedAt holds the default value on creation for the created_at field.
redeemcode.DefaultCreatedAt = redeemcodeDescCreatedAt.Default.(func() time.Time)
// redeemcodeDescValidityDays is the schema descriptor for validity_days field.
redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor()
redeemcodeDescValidityDays := redeemcodeFields[10].Descriptor()
// redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field.
redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int)
securitysecretMixin := schema.SecuritySecret{}.Mixin()
@ -1722,12 +1734,24 @@ func init() {
usagelogDescImageSize := usagelogFields[34].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescImageInputSize is the schema descriptor for image_input_size field.
usagelogDescImageInputSize := usagelogFields[35].Descriptor()
// usagelog.ImageInputSizeValidator is a validator for the "image_input_size" field. It is called by the builders before save.
usagelog.ImageInputSizeValidator = usagelogDescImageInputSize.Validators[0].(func(string) error)
// usagelogDescImageOutputSize is the schema descriptor for image_output_size field.
usagelogDescImageOutputSize := usagelogFields[36].Descriptor()
// usagelog.ImageOutputSizeValidator is a validator for the "image_output_size" field. It is called by the builders before save.
usagelog.ImageOutputSizeValidator = usagelogDescImageOutputSize.Validators[0].(func(string) error)
// usagelogDescImageSizeSource is the schema descriptor for image_size_source field.
usagelogDescImageSizeSource := usagelogFields[37].Descriptor()
// usagelog.ImageSizeSourceValidator is a validator for the "image_size_source" field. It is called by the builders before save.
usagelog.ImageSizeSourceValidator = usagelogDescImageSizeSource.Validators[0].(func(string) error)
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
usagelogDescCacheTTLOverridden := usagelogFields[35].Descriptor()
usagelogDescCacheTTLOverridden := usagelogFields[39].Descriptor()
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
usagelogDescCreatedAt := usagelogFields[36].Descriptor()
usagelogDescCreatedAt := usagelogFields[40].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()

View File

@ -15,12 +15,13 @@ import (
)
var authProviderTypes = map[string]struct{}{
"email": {},
"github": {},
"google": {},
"linuxdo": {},
"oidc": {},
"wechat": {},
"email": {},
"github": {},
"google": {},
"linuxdo": {},
"oidc": {},
"wechat": {},
"dingtalk": {},
}
func validateAuthProviderType(value string) error {

View File

@ -83,7 +83,7 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) {
require.Equal(t, 1, signupSource.Validators)
validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source")
for _, value := range []string{"email", "linuxdo", "wechat", "oidc", "github", "google"} {
for _, value := range []string{"email", "linuxdo", "wechat", "oidc", "github", "google", "dingtalk"} {
require.NoError(t, validator(value))
}
require.Error(t, validator("unknown"))

View File

@ -36,6 +36,10 @@ func (ChannelMonitor) Fields() []ent.Field {
MaxLen(100),
field.Enum("provider").
Values("openai", "anthropic", "gemini"),
field.String("api_mode").
Default("chat_completions").
MaxLen(32).
Comment("OpenAI request protocol: chat_completions or responses; non-OpenAI uses chat_completions"),
field.String("endpoint").
NotEmpty().
MaxLen(500).
@ -104,6 +108,7 @@ func (ChannelMonitor) Indexes() []ent.Index {
return []ent.Index{
index.Fields("enabled", "last_checked_at"),
index.Fields("provider"),
index.Fields("provider", "api_mode"),
index.Fields("group_name"),
index.Fields("template_id"),
}

View File

@ -40,6 +40,10 @@ func (ChannelMonitorRequestTemplate) Fields() []ent.Field {
MaxLen(100),
field.Enum("provider").
Values("openai", "anthropic", "gemini"),
field.String("api_mode").
Default("chat_completions").
MaxLen(32).
Comment("OpenAI request protocol: chat_completions or responses; non-OpenAI uses chat_completions"),
field.String("description").
Optional().
Default("").
@ -76,5 +80,6 @@ func (ChannelMonitorRequestTemplate) Indexes() []ent.Index {
return []ent.Index{
// 同一 provider 内 name 唯一:允许 Anthropic + OpenAI 重名 "伪装官方客户端"。
index.Fields("provider", "name").Unique(),
index.Fields("provider", "api_mode"),
}
}

View File

@ -63,6 +63,10 @@ func (RedeemCode) Fields() []ent.Field {
Immutable().
Default(time.Now).
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Time("expires_at").
Optional().
Nillable().
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
field.Int64("group_id").
Optional().
Nillable(),
@ -90,5 +94,6 @@ func (RedeemCode) Indexes() []ent.Index {
index.Fields("status"),
index.Fields("used_by"),
index.Fields("group_id"),
index.Fields("expires_at"),
}
}

View File

@ -134,6 +134,21 @@ func (UsageLog) Fields() []ent.Field {
MaxLen(10).
Optional().
Nillable(),
field.String("image_input_size").
MaxLen(32).
Optional().
Nillable(),
field.String("image_output_size").
MaxLen(32).
Optional().
Nillable(),
field.String("image_size_source").
MaxLen(16).
Optional().
Nillable(),
field.JSON("image_size_breakdown", map[string]int{}).
Optional().
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
field.Bool("cache_ttl_overridden").
Default(false),

View File

@ -77,10 +77,10 @@ func (User) Fields() []ent.Field {
field.String("signup_source").
Validate(func(value string) error {
switch value {
case "email", "linuxdo", "wechat", "oidc", "github", "google":
case "email", "linuxdo", "wechat", "oidc", "github", "google", "dingtalk":
return nil
default:
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc, github, google")
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc, github, google, dingtalk")
}
}).
Default("email"),

View File

@ -3,6 +3,7 @@
package ent
import (
"encoding/json"
"fmt"
"strings"
"time"
@ -92,6 +93,14 @@ type UsageLog struct {
ImageCount int `json:"image_count,omitempty"`
// ImageSize holds the value of the "image_size" field.
ImageSize *string `json:"image_size,omitempty"`
// ImageInputSize holds the value of the "image_input_size" field.
ImageInputSize *string `json:"image_input_size,omitempty"`
// ImageOutputSize holds the value of the "image_output_size" field.
ImageOutputSize *string `json:"image_output_size,omitempty"`
// ImageSizeSource holds the value of the "image_size_source" field.
ImageSizeSource *string `json:"image_size_source,omitempty"`
// ImageSizeBreakdown holds the value of the "image_size_breakdown" field.
ImageSizeBreakdown map[string]int `json:"image_size_breakdown,omitempty"`
// CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
// CreatedAt holds the value of the "created_at" field.
@ -179,13 +188,15 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case usagelog.FieldImageSizeBreakdown:
values[i] = new([]byte)
case usagelog.FieldStream, usagelog.FieldCacheTTLOverridden:
values[i] = new(sql.NullBool)
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldImageInputSize, usagelog.FieldImageOutputSize, usagelog.FieldImageSizeSource:
values[i] = new(sql.NullString)
case usagelog.FieldCreatedAt:
values[i] = new(sql.NullTime)
@ -434,6 +445,35 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
_m.ImageSize = new(string)
*_m.ImageSize = value.String
}
case usagelog.FieldImageInputSize:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field image_input_size", values[i])
} else if value.Valid {
_m.ImageInputSize = new(string)
*_m.ImageInputSize = value.String
}
case usagelog.FieldImageOutputSize:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field image_output_size", values[i])
} else if value.Valid {
_m.ImageOutputSize = new(string)
*_m.ImageOutputSize = value.String
}
case usagelog.FieldImageSizeSource:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field image_size_source", values[i])
} else if value.Valid {
_m.ImageSizeSource = new(string)
*_m.ImageSizeSource = value.String
}
case usagelog.FieldImageSizeBreakdown:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field image_size_breakdown", values[i])
} else if value != nil && len(*value) > 0 {
if err := json.Unmarshal(*value, &_m.ImageSizeBreakdown); err != nil {
return fmt.Errorf("unmarshal field image_size_breakdown: %w", err)
}
}
case usagelog.FieldCacheTTLOverridden:
if value, ok := values[i].(*sql.NullBool); !ok {
return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
@ -640,6 +680,24 @@ func (_m *UsageLog) String() string {
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.ImageInputSize; v != nil {
builder.WriteString("image_input_size=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.ImageOutputSize; v != nil {
builder.WriteString("image_output_size=")
builder.WriteString(*v)
}
builder.WriteString(", ")
if v := _m.ImageSizeSource; v != nil {
builder.WriteString("image_size_source=")
builder.WriteString(*v)
}
builder.WriteString(", ")
builder.WriteString("image_size_breakdown=")
builder.WriteString(fmt.Sprintf("%v", _m.ImageSizeBreakdown))
builder.WriteString(", ")
builder.WriteString("cache_ttl_overridden=")
builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
builder.WriteString(", ")

View File

@ -84,6 +84,14 @@ const (
FieldImageCount = "image_count"
// FieldImageSize holds the string denoting the image_size field in the database.
FieldImageSize = "image_size"
// FieldImageInputSize holds the string denoting the image_input_size field in the database.
FieldImageInputSize = "image_input_size"
// FieldImageOutputSize holds the string denoting the image_output_size field in the database.
FieldImageOutputSize = "image_output_size"
// FieldImageSizeSource holds the string denoting the image_size_source field in the database.
FieldImageSizeSource = "image_size_source"
// FieldImageSizeBreakdown holds the string denoting the image_size_breakdown field in the database.
FieldImageSizeBreakdown = "image_size_breakdown"
// FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
FieldCacheTTLOverridden = "cache_ttl_overridden"
// FieldCreatedAt holds the string denoting the created_at field in the database.
@ -175,6 +183,10 @@ var Columns = []string{
FieldIPAddress,
FieldImageCount,
FieldImageSize,
FieldImageInputSize,
FieldImageOutputSize,
FieldImageSizeSource,
FieldImageSizeBreakdown,
FieldCacheTTLOverridden,
FieldCreatedAt,
}
@ -242,6 +254,12 @@ var (
DefaultImageCount int
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
ImageSizeValidator func(string) error
// ImageInputSizeValidator is a validator for the "image_input_size" field. It is called by the builders before save.
ImageInputSizeValidator func(string) error
// ImageOutputSizeValidator is a validator for the "image_output_size" field. It is called by the builders before save.
ImageOutputSizeValidator func(string) error
// ImageSizeSourceValidator is a validator for the "image_size_source" field. It is called by the builders before save.
ImageSizeSourceValidator func(string) error
// DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
DefaultCacheTTLOverridden bool
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
@ -431,6 +449,21 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageSize, opts...).ToFunc()
}
// ByImageInputSize orders the results by the image_input_size field.
func ByImageInputSize(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageInputSize, opts...).ToFunc()
}
// ByImageOutputSize orders the results by the image_output_size field.
func ByImageOutputSize(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageOutputSize, opts...).ToFunc()
}
// ByImageSizeSource orders the results by the image_size_source field.
func ByImageSizeSource(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldImageSizeSource, opts...).ToFunc()
}
// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()

View File

@ -230,6 +230,21 @@ func ImageSize(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v))
}
// ImageInputSize applies equality check predicate on the "image_input_size" field. It's identical to ImageInputSizeEQ.
func ImageInputSize(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageInputSize, v))
}
// ImageOutputSize applies equality check predicate on the "image_output_size" field. It's identical to ImageOutputSizeEQ.
func ImageOutputSize(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageOutputSize, v))
}
// ImageSizeSource applies equality check predicate on the "image_size_source" field. It's identical to ImageSizeSourceEQ.
func ImageSizeSource(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageSizeSource, v))
}
// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
func CacheTTLOverridden(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
@ -1900,6 +1915,241 @@ func ImageSizeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v))
}
// ImageInputSizeEQ applies the EQ predicate on the "image_input_size" field.
func ImageInputSizeEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageInputSize, v))
}
// ImageInputSizeNEQ applies the NEQ predicate on the "image_input_size" field.
func ImageInputSizeNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldImageInputSize, v))
}
// ImageInputSizeIn applies the In predicate on the "image_input_size" field.
func ImageInputSizeIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldImageInputSize, vs...))
}
// ImageInputSizeNotIn applies the NotIn predicate on the "image_input_size" field.
func ImageInputSizeNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldImageInputSize, vs...))
}
// ImageInputSizeGT applies the GT predicate on the "image_input_size" field.
func ImageInputSizeGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldImageInputSize, v))
}
// ImageInputSizeGTE applies the GTE predicate on the "image_input_size" field.
func ImageInputSizeGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldImageInputSize, v))
}
// ImageInputSizeLT applies the LT predicate on the "image_input_size" field.
func ImageInputSizeLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldImageInputSize, v))
}
// ImageInputSizeLTE applies the LTE predicate on the "image_input_size" field.
func ImageInputSizeLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldImageInputSize, v))
}
// ImageInputSizeContains applies the Contains predicate on the "image_input_size" field.
func ImageInputSizeContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldImageInputSize, v))
}
// ImageInputSizeHasPrefix applies the HasPrefix predicate on the "image_input_size" field.
func ImageInputSizeHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldImageInputSize, v))
}
// ImageInputSizeHasSuffix applies the HasSuffix predicate on the "image_input_size" field.
func ImageInputSizeHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldImageInputSize, v))
}
// ImageInputSizeIsNil applies the IsNil predicate on the "image_input_size" field.
func ImageInputSizeIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldImageInputSize))
}
// ImageInputSizeNotNil applies the NotNil predicate on the "image_input_size" field.
func ImageInputSizeNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldImageInputSize))
}
// ImageInputSizeEqualFold applies the EqualFold predicate on the "image_input_size" field.
func ImageInputSizeEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldImageInputSize, v))
}
// ImageInputSizeContainsFold applies the ContainsFold predicate on the "image_input_size" field.
func ImageInputSizeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldImageInputSize, v))
}
// ImageOutputSizeEQ applies the EQ predicate on the "image_output_size" field.
func ImageOutputSizeEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageOutputSize, v))
}
// ImageOutputSizeNEQ applies the NEQ predicate on the "image_output_size" field.
func ImageOutputSizeNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldImageOutputSize, v))
}
// ImageOutputSizeIn applies the In predicate on the "image_output_size" field.
func ImageOutputSizeIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldImageOutputSize, vs...))
}
// ImageOutputSizeNotIn applies the NotIn predicate on the "image_output_size" field.
func ImageOutputSizeNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldImageOutputSize, vs...))
}
// ImageOutputSizeGT applies the GT predicate on the "image_output_size" field.
func ImageOutputSizeGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldImageOutputSize, v))
}
// ImageOutputSizeGTE applies the GTE predicate on the "image_output_size" field.
func ImageOutputSizeGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldImageOutputSize, v))
}
// ImageOutputSizeLT applies the LT predicate on the "image_output_size" field.
func ImageOutputSizeLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldImageOutputSize, v))
}
// ImageOutputSizeLTE applies the LTE predicate on the "image_output_size" field.
func ImageOutputSizeLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldImageOutputSize, v))
}
// ImageOutputSizeContains applies the Contains predicate on the "image_output_size" field.
func ImageOutputSizeContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldImageOutputSize, v))
}
// ImageOutputSizeHasPrefix applies the HasPrefix predicate on the "image_output_size" field.
func ImageOutputSizeHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldImageOutputSize, v))
}
// ImageOutputSizeHasSuffix applies the HasSuffix predicate on the "image_output_size" field.
func ImageOutputSizeHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldImageOutputSize, v))
}
// ImageOutputSizeIsNil applies the IsNil predicate on the "image_output_size" field.
func ImageOutputSizeIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldImageOutputSize))
}
// ImageOutputSizeNotNil applies the NotNil predicate on the "image_output_size" field.
func ImageOutputSizeNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldImageOutputSize))
}
// ImageOutputSizeEqualFold applies the EqualFold predicate on the "image_output_size" field.
func ImageOutputSizeEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldImageOutputSize, v))
}
// ImageOutputSizeContainsFold applies the ContainsFold predicate on the "image_output_size" field.
func ImageOutputSizeContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldImageOutputSize, v))
}
// ImageSizeSourceEQ applies the EQ predicate on the "image_size_source" field.
func ImageSizeSourceEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldImageSizeSource, v))
}
// ImageSizeSourceNEQ applies the NEQ predicate on the "image_size_source" field.
func ImageSizeSourceNEQ(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNEQ(FieldImageSizeSource, v))
}
// ImageSizeSourceIn applies the In predicate on the "image_size_source" field.
func ImageSizeSourceIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldIn(FieldImageSizeSource, vs...))
}
// ImageSizeSourceNotIn applies the NotIn predicate on the "image_size_source" field.
func ImageSizeSourceNotIn(vs ...string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotIn(FieldImageSizeSource, vs...))
}
// ImageSizeSourceGT applies the GT predicate on the "image_size_source" field.
func ImageSizeSourceGT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGT(FieldImageSizeSource, v))
}
// ImageSizeSourceGTE applies the GTE predicate on the "image_size_source" field.
func ImageSizeSourceGTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldGTE(FieldImageSizeSource, v))
}
// ImageSizeSourceLT applies the LT predicate on the "image_size_source" field.
func ImageSizeSourceLT(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLT(FieldImageSizeSource, v))
}
// ImageSizeSourceLTE applies the LTE predicate on the "image_size_source" field.
func ImageSizeSourceLTE(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldImageSizeSource, v))
}
// ImageSizeSourceContains applies the Contains predicate on the "image_size_source" field.
func ImageSizeSourceContains(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContains(FieldImageSizeSource, v))
}
// ImageSizeSourceHasPrefix applies the HasPrefix predicate on the "image_size_source" field.
func ImageSizeSourceHasPrefix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasPrefix(FieldImageSizeSource, v))
}
// ImageSizeSourceHasSuffix applies the HasSuffix predicate on the "image_size_source" field.
func ImageSizeSourceHasSuffix(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldHasSuffix(FieldImageSizeSource, v))
}
// ImageSizeSourceIsNil applies the IsNil predicate on the "image_size_source" field.
func ImageSizeSourceIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldImageSizeSource))
}
// ImageSizeSourceNotNil applies the NotNil predicate on the "image_size_source" field.
func ImageSizeSourceNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldImageSizeSource))
}
// ImageSizeSourceEqualFold applies the EqualFold predicate on the "image_size_source" field.
func ImageSizeSourceEqualFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEqualFold(FieldImageSizeSource, v))
}
// ImageSizeSourceContainsFold applies the ContainsFold predicate on the "image_size_source" field.
func ImageSizeSourceContainsFold(v string) predicate.UsageLog {
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSizeSource, v))
}
// ImageSizeBreakdownIsNil applies the IsNil predicate on the "image_size_breakdown" field.
func ImageSizeBreakdownIsNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldIsNull(FieldImageSizeBreakdown))
}
// ImageSizeBreakdownNotNil applies the NotNil predicate on the "image_size_breakdown" field.
func ImageSizeBreakdownNotNil() predicate.UsageLog {
return predicate.UsageLog(sql.FieldNotNull(FieldImageSizeBreakdown))
}
// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))

View File

@ -477,6 +477,54 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate {
return _c
}
// SetImageInputSize sets the "image_input_size" field.
func (_c *UsageLogCreate) SetImageInputSize(v string) *UsageLogCreate {
_c.mutation.SetImageInputSize(v)
return _c
}
// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableImageInputSize(v *string) *UsageLogCreate {
if v != nil {
_c.SetImageInputSize(*v)
}
return _c
}
// SetImageOutputSize sets the "image_output_size" field.
func (_c *UsageLogCreate) SetImageOutputSize(v string) *UsageLogCreate {
_c.mutation.SetImageOutputSize(v)
return _c
}
// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableImageOutputSize(v *string) *UsageLogCreate {
if v != nil {
_c.SetImageOutputSize(*v)
}
return _c
}
// SetImageSizeSource sets the "image_size_source" field.
func (_c *UsageLogCreate) SetImageSizeSource(v string) *UsageLogCreate {
_c.mutation.SetImageSizeSource(v)
return _c
}
// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil.
func (_c *UsageLogCreate) SetNillableImageSizeSource(v *string) *UsageLogCreate {
if v != nil {
_c.SetImageSizeSource(*v)
}
return _c
}
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
func (_c *UsageLogCreate) SetImageSizeBreakdown(v map[string]int) *UsageLogCreate {
_c.mutation.SetImageSizeBreakdown(v)
return _c
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
_c.mutation.SetCacheTTLOverridden(v)
@ -754,6 +802,21 @@ func (_c *UsageLogCreate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
if v, ok := _c.mutation.ImageInputSize(); ok {
if err := usagelog.ImageInputSizeValidator(v); err != nil {
return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)}
}
}
if v, ok := _c.mutation.ImageOutputSize(); ok {
if err := usagelog.ImageOutputSizeValidator(v); err != nil {
return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)}
}
}
if v, ok := _c.mutation.ImageSizeSource(); ok {
if err := usagelog.ImageSizeSourceValidator(v); err != nil {
return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)}
}
}
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
}
@ -916,6 +979,22 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
_node.ImageSize = &value
}
if value, ok := _c.mutation.ImageInputSize(); ok {
_spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value)
_node.ImageInputSize = &value
}
if value, ok := _c.mutation.ImageOutputSize(); ok {
_spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value)
_node.ImageOutputSize = &value
}
if value, ok := _c.mutation.ImageSizeSource(); ok {
_spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value)
_node.ImageSizeSource = &value
}
if value, ok := _c.mutation.ImageSizeBreakdown(); ok {
_spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value)
_node.ImageSizeBreakdown = value
}
if value, ok := _c.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
_node.CacheTTLOverridden = value
@ -1679,6 +1758,78 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert {
return u
}
// SetImageInputSize sets the "image_input_size" field.
func (u *UsageLogUpsert) SetImageInputSize(v string) *UsageLogUpsert {
u.Set(usagelog.FieldImageInputSize, v)
return u
}
// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateImageInputSize() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldImageInputSize)
return u
}
// ClearImageInputSize clears the value of the "image_input_size" field.
func (u *UsageLogUpsert) ClearImageInputSize() *UsageLogUpsert {
u.SetNull(usagelog.FieldImageInputSize)
return u
}
// SetImageOutputSize sets the "image_output_size" field.
func (u *UsageLogUpsert) SetImageOutputSize(v string) *UsageLogUpsert {
u.Set(usagelog.FieldImageOutputSize, v)
return u
}
// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateImageOutputSize() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldImageOutputSize)
return u
}
// ClearImageOutputSize clears the value of the "image_output_size" field.
func (u *UsageLogUpsert) ClearImageOutputSize() *UsageLogUpsert {
u.SetNull(usagelog.FieldImageOutputSize)
return u
}
// SetImageSizeSource sets the "image_size_source" field.
func (u *UsageLogUpsert) SetImageSizeSource(v string) *UsageLogUpsert {
u.Set(usagelog.FieldImageSizeSource, v)
return u
}
// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateImageSizeSource() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldImageSizeSource)
return u
}
// ClearImageSizeSource clears the value of the "image_size_source" field.
func (u *UsageLogUpsert) ClearImageSizeSource() *UsageLogUpsert {
u.SetNull(usagelog.FieldImageSizeSource)
return u
}
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
func (u *UsageLogUpsert) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsert {
u.Set(usagelog.FieldImageSizeBreakdown, v)
return u
}
// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create.
func (u *UsageLogUpsert) UpdateImageSizeBreakdown() *UsageLogUpsert {
u.SetExcluded(usagelog.FieldImageSizeBreakdown)
return u
}
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
func (u *UsageLogUpsert) ClearImageSizeBreakdown() *UsageLogUpsert {
u.SetNull(usagelog.FieldImageSizeBreakdown)
return u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
u.Set(usagelog.FieldCacheTTLOverridden, v)
@ -2457,6 +2608,90 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne {
})
}
// SetImageInputSize sets the "image_input_size" field.
func (u *UsageLogUpsertOne) SetImageInputSize(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetImageInputSize(v)
})
}
// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateImageInputSize() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateImageInputSize()
})
}
// ClearImageInputSize clears the value of the "image_input_size" field.
func (u *UsageLogUpsertOne) ClearImageInputSize() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearImageInputSize()
})
}
// SetImageOutputSize sets the "image_output_size" field.
func (u *UsageLogUpsertOne) SetImageOutputSize(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetImageOutputSize(v)
})
}
// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateImageOutputSize() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateImageOutputSize()
})
}
// ClearImageOutputSize clears the value of the "image_output_size" field.
func (u *UsageLogUpsertOne) ClearImageOutputSize() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearImageOutputSize()
})
}
// SetImageSizeSource sets the "image_size_source" field.
func (u *UsageLogUpsertOne) SetImageSizeSource(v string) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetImageSizeSource(v)
})
}
// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateImageSizeSource() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateImageSizeSource()
})
}
// ClearImageSizeSource clears the value of the "image_size_source" field.
func (u *UsageLogUpsertOne) ClearImageSizeSource() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearImageSizeSource()
})
}
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
func (u *UsageLogUpsertOne) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.SetImageSizeBreakdown(v)
})
}
// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create.
func (u *UsageLogUpsertOne) UpdateImageSizeBreakdown() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateImageSizeBreakdown()
})
}
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
func (u *UsageLogUpsertOne) ClearImageSizeBreakdown() *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
s.ClearImageSizeBreakdown()
})
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
@ -3403,6 +3638,90 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk {
})
}
// SetImageInputSize sets the "image_input_size" field.
func (u *UsageLogUpsertBulk) SetImageInputSize(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetImageInputSize(v)
})
}
// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateImageInputSize() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateImageInputSize()
})
}
// ClearImageInputSize clears the value of the "image_input_size" field.
func (u *UsageLogUpsertBulk) ClearImageInputSize() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearImageInputSize()
})
}
// SetImageOutputSize sets the "image_output_size" field.
func (u *UsageLogUpsertBulk) SetImageOutputSize(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetImageOutputSize(v)
})
}
// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateImageOutputSize() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateImageOutputSize()
})
}
// ClearImageOutputSize clears the value of the "image_output_size" field.
func (u *UsageLogUpsertBulk) ClearImageOutputSize() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearImageOutputSize()
})
}
// SetImageSizeSource sets the "image_size_source" field.
func (u *UsageLogUpsertBulk) SetImageSizeSource(v string) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetImageSizeSource(v)
})
}
// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateImageSizeSource() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateImageSizeSource()
})
}
// ClearImageSizeSource clears the value of the "image_size_source" field.
func (u *UsageLogUpsertBulk) ClearImageSizeSource() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearImageSizeSource()
})
}
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
func (u *UsageLogUpsertBulk) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.SetImageSizeBreakdown(v)
})
}
// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create.
func (u *UsageLogUpsertBulk) UpdateImageSizeBreakdown() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.UpdateImageSizeBreakdown()
})
}
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
func (u *UsageLogUpsertBulk) ClearImageSizeBreakdown() *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
s.ClearImageSizeBreakdown()
})
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {

View File

@ -739,6 +739,78 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate {
return _u
}
// SetImageInputSize sets the "image_input_size" field.
func (_u *UsageLogUpdate) SetImageInputSize(v string) *UsageLogUpdate {
_u.mutation.SetImageInputSize(v)
return _u
}
// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableImageInputSize(v *string) *UsageLogUpdate {
if v != nil {
_u.SetImageInputSize(*v)
}
return _u
}
// ClearImageInputSize clears the value of the "image_input_size" field.
func (_u *UsageLogUpdate) ClearImageInputSize() *UsageLogUpdate {
_u.mutation.ClearImageInputSize()
return _u
}
// SetImageOutputSize sets the "image_output_size" field.
func (_u *UsageLogUpdate) SetImageOutputSize(v string) *UsageLogUpdate {
_u.mutation.SetImageOutputSize(v)
return _u
}
// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableImageOutputSize(v *string) *UsageLogUpdate {
if v != nil {
_u.SetImageOutputSize(*v)
}
return _u
}
// ClearImageOutputSize clears the value of the "image_output_size" field.
func (_u *UsageLogUpdate) ClearImageOutputSize() *UsageLogUpdate {
_u.mutation.ClearImageOutputSize()
return _u
}
// SetImageSizeSource sets the "image_size_source" field.
func (_u *UsageLogUpdate) SetImageSizeSource(v string) *UsageLogUpdate {
_u.mutation.SetImageSizeSource(v)
return _u
}
// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil.
func (_u *UsageLogUpdate) SetNillableImageSizeSource(v *string) *UsageLogUpdate {
if v != nil {
_u.SetImageSizeSource(*v)
}
return _u
}
// ClearImageSizeSource clears the value of the "image_size_source" field.
func (_u *UsageLogUpdate) ClearImageSizeSource() *UsageLogUpdate {
_u.mutation.ClearImageSizeSource()
return _u
}
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
func (_u *UsageLogUpdate) SetImageSizeBreakdown(v map[string]int) *UsageLogUpdate {
_u.mutation.SetImageSizeBreakdown(v)
return _u
}
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
func (_u *UsageLogUpdate) ClearImageSizeBreakdown() *UsageLogUpdate {
_u.mutation.ClearImageSizeBreakdown()
return _u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
_u.mutation.SetCacheTTLOverridden(v)
@ -892,6 +964,21 @@ func (_u *UsageLogUpdate) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
if v, ok := _u.mutation.ImageInputSize(); ok {
if err := usagelog.ImageInputSizeValidator(v); err != nil {
return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)}
}
}
if v, ok := _u.mutation.ImageOutputSize(); ok {
if err := usagelog.ImageOutputSizeValidator(v); err != nil {
return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)}
}
}
if v, ok := _u.mutation.ImageSizeSource(); ok {
if err := usagelog.ImageSizeSourceValidator(v); err != nil {
return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
}
@ -1099,6 +1186,30 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
}
if value, ok := _u.mutation.ImageInputSize(); ok {
_spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value)
}
if _u.mutation.ImageInputSizeCleared() {
_spec.ClearField(usagelog.FieldImageInputSize, field.TypeString)
}
if value, ok := _u.mutation.ImageOutputSize(); ok {
_spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value)
}
if _u.mutation.ImageOutputSizeCleared() {
_spec.ClearField(usagelog.FieldImageOutputSize, field.TypeString)
}
if value, ok := _u.mutation.ImageSizeSource(); ok {
_spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value)
}
if _u.mutation.ImageSizeSourceCleared() {
_spec.ClearField(usagelog.FieldImageSizeSource, field.TypeString)
}
if value, ok := _u.mutation.ImageSizeBreakdown(); ok {
_spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value)
}
if _u.mutation.ImageSizeBreakdownCleared() {
_spec.ClearField(usagelog.FieldImageSizeBreakdown, field.TypeJSON)
}
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}
@ -1974,6 +2085,78 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne {
return _u
}
// SetImageInputSize sets the "image_input_size" field.
func (_u *UsageLogUpdateOne) SetImageInputSize(v string) *UsageLogUpdateOne {
_u.mutation.SetImageInputSize(v)
return _u
}
// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableImageInputSize(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetImageInputSize(*v)
}
return _u
}
// ClearImageInputSize clears the value of the "image_input_size" field.
func (_u *UsageLogUpdateOne) ClearImageInputSize() *UsageLogUpdateOne {
_u.mutation.ClearImageInputSize()
return _u
}
// SetImageOutputSize sets the "image_output_size" field.
func (_u *UsageLogUpdateOne) SetImageOutputSize(v string) *UsageLogUpdateOne {
_u.mutation.SetImageOutputSize(v)
return _u
}
// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableImageOutputSize(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetImageOutputSize(*v)
}
return _u
}
// ClearImageOutputSize clears the value of the "image_output_size" field.
func (_u *UsageLogUpdateOne) ClearImageOutputSize() *UsageLogUpdateOne {
_u.mutation.ClearImageOutputSize()
return _u
}
// SetImageSizeSource sets the "image_size_source" field.
func (_u *UsageLogUpdateOne) SetImageSizeSource(v string) *UsageLogUpdateOne {
_u.mutation.SetImageSizeSource(v)
return _u
}
// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil.
func (_u *UsageLogUpdateOne) SetNillableImageSizeSource(v *string) *UsageLogUpdateOne {
if v != nil {
_u.SetImageSizeSource(*v)
}
return _u
}
// ClearImageSizeSource clears the value of the "image_size_source" field.
func (_u *UsageLogUpdateOne) ClearImageSizeSource() *UsageLogUpdateOne {
_u.mutation.ClearImageSizeSource()
return _u
}
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
func (_u *UsageLogUpdateOne) SetImageSizeBreakdown(v map[string]int) *UsageLogUpdateOne {
_u.mutation.SetImageSizeBreakdown(v)
return _u
}
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
func (_u *UsageLogUpdateOne) ClearImageSizeBreakdown() *UsageLogUpdateOne {
_u.mutation.ClearImageSizeBreakdown()
return _u
}
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
_u.mutation.SetCacheTTLOverridden(v)
@ -2140,6 +2323,21 @@ func (_u *UsageLogUpdateOne) check() error {
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
}
}
if v, ok := _u.mutation.ImageInputSize(); ok {
if err := usagelog.ImageInputSizeValidator(v); err != nil {
return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)}
}
}
if v, ok := _u.mutation.ImageOutputSize(); ok {
if err := usagelog.ImageOutputSizeValidator(v); err != nil {
return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)}
}
}
if v, ok := _u.mutation.ImageSizeSource(); ok {
if err := usagelog.ImageSizeSourceValidator(v); err != nil {
return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)}
}
}
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
}
@ -2364,6 +2562,30 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if _u.mutation.ImageSizeCleared() {
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
}
if value, ok := _u.mutation.ImageInputSize(); ok {
_spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value)
}
if _u.mutation.ImageInputSizeCleared() {
_spec.ClearField(usagelog.FieldImageInputSize, field.TypeString)
}
if value, ok := _u.mutation.ImageOutputSize(); ok {
_spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value)
}
if _u.mutation.ImageOutputSizeCleared() {
_spec.ClearField(usagelog.FieldImageOutputSize, field.TypeString)
}
if value, ok := _u.mutation.ImageSizeSource(); ok {
_spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value)
}
if _u.mutation.ImageSizeSourceCleared() {
_spec.ClearField(usagelog.FieldImageSizeSource, field.TypeString)
}
if value, ok := _u.mutation.ImageSizeBreakdown(); ok {
_spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value)
}
if _u.mutation.ImageSizeBreakdownCleared() {
_spec.ClearField(usagelog.FieldImageSizeBreakdown, field.TypeJSON)
}
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
}

View File

@ -222,6 +222,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@ -255,6 +257,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@ -284,6 +288,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
@ -316,6 +322,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=

View File

@ -79,6 +79,7 @@ type Config struct {
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
WeChat WeChatConnectConfig `mapstructure:"wechat_connect"`
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
DingTalk DingTalkConnectConfig `mapstructure:"dingtalk_connect"`
GitHubOAuth EmailOAuthProviderConfig `mapstructure:"github_oauth"`
GoogleOAuth EmailOAuthProviderConfig `mapstructure:"google_oauth"`
Default DefaultConfig `mapstructure:"default"`
@ -250,6 +251,47 @@ type OIDCConnectConfig struct {
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
type DingTalkConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
Scopes string `mapstructure:"scopes"`
RedirectURL string `mapstructure:"redirect_url"`
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
// 平台底座 + 业务行为
DingTalkAppKind string `mapstructure:"dingtalk_app_kind"` // 仅 "internal_app"V4 fail-closed
AppType string `mapstructure:"app_type"` // "public" (default) | "internal"
// Corp 限定none | internal_only
CorpRestrictionPolicy string `mapstructure:"corp_restriction_policy"`
InternalCorpID string `mapstructure:"internal_corp_id"`
BypassRegistration bool `mapstructure:"bypass_registration"`
SyncCorpEmail bool `mapstructure:"sync_corp_email"`
SyncDisplayName bool `mapstructure:"sync_display_name"`
SyncDept bool `mapstructure:"sync_dept"`
SyncCorpEmailAttrKey string `mapstructure:"sync_corp_email_attr_key"`
SyncDisplayNameAttrKey string `mapstructure:"sync_display_name_attr_key"`
SyncDeptAttrKey string `mapstructure:"sync_dept_attr_key"`
SyncCorpEmailAttrName string `mapstructure:"sync_corp_email_attr_name"`
SyncDisplayNameAttrName string `mapstructure:"sync_display_name_attr_name"`
SyncDeptAttrName string `mapstructure:"sync_dept_attr_name"`
// 邮箱 + Username
RequireEmail bool `mapstructure:"require_email"`
UsernameOverwritePolicy string `mapstructure:"username_overwrite_policy"`
// Attribute私有版扩展点开源版仅声明
UsernameAttributeKey string `mapstructure:"username_attribute_key"`
EnableAttributeMatching bool `mapstructure:"enable_attribute_matching"`
EnableAttributeSync bool `mapstructure:"enable_attribute_sync"`
AttributeSyncFields []string `mapstructure:"attribute_sync_fields"`
AttributeSyncOverwritePolicy string `mapstructure:"attribute_sync_overwrite_policy"`
}
type EmailOAuthProviderConfig struct {
Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"`
@ -1639,6 +1681,19 @@ func setDefaults() {
viper.SetDefault("oidc_connect.userinfo_id_path", "")
viper.SetDefault("oidc_connect.userinfo_username_path", "")
// DingTalk Connect OAuth 登录
viper.SetDefault("dingtalk_connect.enabled", false)
viper.SetDefault("dingtalk_connect.authorize_url", "https://login.dingtalk.com/oauth2/auth")
viper.SetDefault("dingtalk_connect.token_url", "https://api.dingtalk.com/v1.0/oauth2/userAccessToken")
viper.SetDefault("dingtalk_connect.userinfo_url", "https://api.dingtalk.com/v1.0/contact/users/me")
viper.SetDefault("dingtalk_connect.scopes", "openid")
viper.SetDefault("dingtalk_connect.frontend_redirect_url", "/auth/dingtalk/callback")
viper.SetDefault("dingtalk_connect.dingtalk_app_kind", "internal_app")
viper.SetDefault("dingtalk_connect.app_type", "public")
viper.SetDefault("dingtalk_connect.corp_restriction_policy", "none")
viper.SetDefault("dingtalk_connect.require_email", true)
viper.SetDefault("dingtalk_connect.username_overwrite_policy", "if_empty")
// Database
viper.SetDefault("database.host", "localhost")
viper.SetDefault("database.port", 5432)
@ -2767,6 +2822,9 @@ func (c *Config) Validate() error {
if c.Concurrency.PingInterval < 5 || c.Concurrency.PingInterval > 30 {
return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds")
}
if err := ValidateDingTalkConfig(c.DingTalk); err != nil {
return fmt.Errorf("dingtalk_connect: %w", err)
}
return nil
}

View File

@ -0,0 +1,30 @@
// Package config 包含钉钉连接配置的校验逻辑。
//
// internal_only 模式安全模型(方案 A
// 不再要求 admin 填写 InternalCorpID 做二次 corpID 比对。
// 安全边界由钉钉"企业内部应用"类型本身保证——只有应用所属企业的员工才能完成 OAuth
// 因此 ValidateDingTalkConfig 只要求 app_type=internalV1不再要求 InternalCorpID 非空(原 V3 已删除)。
// InternalCorpID 字段保留admin 可选填若填写checkDingTalkCorpAllowed 不会使用它做约束。
package config
import "errors"
var (
ErrDingTalkV1AppTypeMismatch = errors.New("dingtalk: internal_only requires app_type=internal")
ErrDingTalkV4InvalidAppKind = errors.New("dingtalk: dingtalk_app_kind must be internal_app")
)
func ValidateDingTalkConfig(cfg DingTalkConnectConfig) error {
if !cfg.Enabled {
return nil
}
if cfg.DingTalkAppKind != "internal_app" {
return ErrDingTalkV4InvalidAppKind
}
if cfg.CorpRestrictionPolicy == "internal_only" {
if cfg.AppType != "internal" {
return ErrDingTalkV1AppTypeMismatch
}
}
return nil
}

View File

@ -0,0 +1,53 @@
package config
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestValidateDingTalkConfig_Disabled_Skip(t *testing.T) {
require.NoError(t, ValidateDingTalkConfig(DingTalkConnectConfig{Enabled: false}))
}
func TestValidateDingTalkConfig_V4_DingTalkAppKind(t *testing.T) {
err := ValidateDingTalkConfig(DingTalkConnectConfig{
Enabled: true,
DingTalkAppKind: "third_party_enterprise_app",
CorpRestrictionPolicy: "none",
})
require.ErrorIs(t, err, ErrDingTalkV4InvalidAppKind)
}
func TestValidateDingTalkConfig_V1_InternalOnlyRequiresInternalAppType(t *testing.T) {
err := ValidateDingTalkConfig(DingTalkConnectConfig{
Enabled: true,
DingTalkAppKind: "internal_app",
AppType: "public",
CorpRestrictionPolicy: "internal_only",
InternalCorpID: "dingABC",
})
require.ErrorIs(t, err, ErrDingTalkV1AppTypeMismatch)
}
// TestValidateDingTalkConfig_V3_InternalOnlyAllowsEmptyCorpID 验证方案 A
// internal_only 策略下InternalCorpID="" 应通过校验(企业隔离由钉钉 AppType=internal 保证)。
func TestValidateDingTalkConfig_V3_InternalOnlyAllowsEmptyCorpID(t *testing.T) {
err := ValidateDingTalkConfig(DingTalkConnectConfig{
Enabled: true,
DingTalkAppKind: "internal_app",
AppType: "internal",
CorpRestrictionPolicy: "internal_only",
InternalCorpID: "",
})
require.NoError(t, err)
}
func TestValidateDingTalkConfig_HappyPath_None(t *testing.T) {
require.NoError(t, ValidateDingTalkConfig(DingTalkConnectConfig{
Enabled: true,
DingTalkAppKind: "internal_app",
AppType: "public",
CorpRestrictionPolicy: "none",
}))
}

View File

@ -44,6 +44,9 @@ type DataProxy struct {
Status string `json:"status"`
}
// DataAccount 是管理员显式备份导出使用的账号结构,故意不走 dto.Account 的脱敏路径,
// Credentials 原文返回。这是"管理员备份"这一显式行为的一部分;如未来需要导出脱敏版本,
// 应新增独立结构而非修改这里。
type DataAccount struct {
Name string `json:"name"`
Notes *string `json:"notes,omitempty"`

View File

@ -1656,7 +1656,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
}
// GetUsage handles getting account usage information
// GET /api/v1/admin/accounts/:id/usage?source=passive|active
// GET /api/v1/admin/accounts/:id/usage?source=passive|active&force=true
func (h *AccountHandler) GetUsage(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
@ -1665,12 +1665,13 @@ func (h *AccountHandler) GetUsage(c *gin.Context) {
}
source := c.DefaultQuery("source", "active")
force := c.Query("force") == "true"
var usage *service.UsageInfo
if source == "passive" {
usage, err = h.accountUsageService.GetPassiveUsage(c.Request.Context(), accountID)
} else {
usage, err = h.accountUsageService.GetUsage(c.Request.Context(), accountID)
usage, err = h.accountUsageService.GetUsage(c.Request.Context(), accountID, force)
}
if err != nil {
response.ErrorFrom(c, err)
@ -2022,6 +2023,48 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
response.Success(c, models)
}
// SyncUpstreamModels handles syncing live supported models from an account's upstream.
// POST /api/v1/admin/accounts/:id/models/sync-upstream
func (h *AccountHandler) SyncUpstreamModels(c *gin.Context) {
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
response.BadRequest(c, "Invalid account ID")
return
}
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
if err != nil {
response.NotFound(c, "Account not found")
return
}
if h.accountTestService == nil {
response.InternalError(c, "Account test service is not configured")
return
}
models, err := h.accountTestService.FetchUpstreamSupportedModels(c.Request.Context(), account)
if err != nil {
var syncErr *service.UpstreamModelSyncError
if errors.As(err, &syncErr) {
switch syncErr.Kind {
case service.UpstreamModelSyncErrorConfiguration, service.UpstreamModelSyncErrorUnsupported:
response.BadRequest(c, syncErr.SafeMessage())
default:
slog.Warn("sync_upstream_models_failed", "account_id", accountID, "kind", syncErr.Kind)
response.Error(c, http.StatusBadGateway, syncErr.SafeMessage())
}
return
}
slog.Warn("sync_upstream_models_failed", "account_id", accountID)
response.Error(c, http.StatusBadGateway, "Failed to sync upstream models from upstream")
return
}
response.Success(c, gin.H{"models": models})
}
// SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account
// POST /api/v1/admin/accounts/:id/set-privacy
func (h *AccountHandler) SetPrivacy(c *gin.Context) {

View File

@ -3,10 +3,14 @@ package admin
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
@ -33,6 +37,40 @@ func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
return router
}
type syncUpstreamHTTPUpstream struct {
resp *http.Response
err error
}
func (u *syncUpstreamHTTPUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
if u.err != nil {
return nil, u.err
}
return u.resp, nil
}
func (u *syncUpstreamHTTPUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
return u.Do(req, proxyURL, accountID, accountConcurrency)
}
func setupSyncUpstreamModelsRouter(adminSvc service.AdminService, upstream service.HTTPUpstream) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
accountTestSvc := service.NewAccountTestService(
nil,
nil,
nil,
nil,
nil,
upstream,
&config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
nil,
)
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, accountTestSvc, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/:id/models/sync-upstream", handler.SyncUpstreamModels)
return router
}
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
@ -103,3 +141,58 @@ func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefau
require.NotEmpty(t, resp.Data)
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
}
func TestAccountHandlerSyncUpstreamModels_ConfigErrorReturnsBadRequest(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 44,
Name: "openai-apikey-missing-key",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Credentials: map[string]any{
"base_url": "https://openai.example.com/v1",
},
},
}
router := setupSyncUpstreamModelsRouter(svc, &syncUpstreamHTTPUpstream{})
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/44/models/sync-upstream", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadRequest, rec.Code)
require.Contains(t, rec.Body.String(), "No OpenAI API key is available")
}
func TestAccountHandlerSyncUpstreamModels_UpstreamErrorDoesNotExposeBody(t *testing.T) {
svc := &availableModelsAdminService{
stubAdminService: newStubAdminService(),
account: service.Account{
ID: 45,
Name: "openai-apikey-upstream-error",
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Credentials: map[string]any{
"api_key": "openai-key",
"base_url": "https://openai.example.com/v1",
},
},
}
upstream := &syncUpstreamHTTPUpstream{resp: &http.Response{
StatusCode: http.StatusBadGateway,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(`{"error":"SECRET_TOKEN should not be exposed"}`)),
}}
router := setupSyncUpstreamModelsRouter(svc, upstream)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/45/models/sync-upstream", nil)
router.ServeHTTP(rec, req)
require.Equal(t, http.StatusBadGateway, rec.Code)
require.Contains(t, rec.Body.String(), "Upstream model list request failed with HTTP 502")
require.NotContains(t, rec.Body.String(), "SECRET_TOKEN")
}

View File

@ -17,11 +17,12 @@ import (
type ChannelHandler struct {
channelService *service.ChannelService
billingService *service.BillingService
pricingService *service.PricingService
}
// NewChannelHandler creates a new admin channel handler
func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService) *ChannelHandler {
return &ChannelHandler{channelService: channelService, billingService: billingService}
func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService, pricingService *service.PricingService) *ChannelHandler {
return &ChannelHandler{channelService: channelService, billingService: billingService, pricingService: pricingService}
}
// --- Request / Response types ---
@ -500,3 +501,34 @@ func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) {
"image_output_price": pricing.ImageOutputPricePerToken,
})
}
// platformToLiteLLMProvider maps a channel platform name to the corresponding
// LiteLLM provider string used as the key in the pricing catalog.
var platformToLiteLLMProvider = map[string]string{
service.PlatformAnthropic: "anthropic",
service.PlatformOpenAI: "openai",
service.PlatformGemini: "google",
service.PlatformAntigravity: "anthropic",
}
// SyncPricingModels 返回 LiteLLM 定价目录中指定平台的最新模型列表
// GET /api/v1/admin/channels/pricing/sync-models?platform=anthropic
func (h *ChannelHandler) SyncPricingModels(c *gin.Context) {
platform := strings.ToLower(strings.TrimSpace(c.Query("platform")))
if platform == "" {
response.ErrorFrom(c, infraerrors.BadRequest("MISSING_PARAMETER", "platform parameter is required").
WithMetadata(map[string]string{"param": "platform"}))
return
}
provider, ok := platformToLiteLLMProvider[platform]
if !ok {
response.ErrorFrom(c, infraerrors.BadRequest("UNSUPPORTED_PLATFORM",
fmt.Sprintf("unsupported platform: %s", platform)).
WithMetadata(map[string]string{"param": "platform"}))
return
}
models := h.pricingService.ListModelNamesByProvider(provider)
response.Success(c, gin.H{"models": models})
}

View File

@ -3,10 +3,14 @@
package admin
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@ -416,3 +420,58 @@ func TestPricingRequestToService_NilPriceFields(t *testing.T) {
require.Nil(t, r.ImageOutputPrice)
require.Nil(t, r.PerRequestPrice)
}
// ---------------------------------------------------------------------------
// 3. SyncPricingModels handler
// ---------------------------------------------------------------------------
func setupSyncPricingModelsRouter(pricingSvc *service.PricingService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
h := &ChannelHandler{pricingService: pricingSvc}
router.GET("/channels/pricing/sync-models", h.SyncPricingModels)
return router
}
func TestSyncPricingModels_MissingPlatform(t *testing.T) {
svc := service.NewPricingService(nil, nil)
router := setupSyncPricingModelsRouter(svc)
req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code)
}
func TestSyncPricingModels_UnsupportedPlatform(t *testing.T) {
svc := service.NewPricingService(nil, nil)
router := setupSyncPricingModelsRouter(svc)
req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models?platform=unknown", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusBadRequest, w.Code)
}
func TestSyncPricingModels_ValidPlatform_EmptyService(t *testing.T) {
svc := service.NewPricingService(nil, nil)
router := setupSyncPricingModelsRouter(svc)
for _, platform := range []string{"anthropic", "openai", "gemini", "antigravity"} {
req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models?platform="+platform, nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
require.Equal(t, http.StatusOK, w.Code, "platform=%s", platform)
var body struct {
Data struct {
Models []string `json:"models"`
} `json:"data"`
}
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
require.NotNil(t, body.Data.Models, "models must not be null for platform=%s", platform)
}
}

View File

@ -38,6 +38,7 @@ func NewChannelMonitorHandler(monitorService *service.ChannelMonitorService) *Ch
type channelMonitorCreateRequest struct {
Name string `json:"name" binding:"required,max=100"`
Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"`
APIMode string `json:"api_mode" binding:"omitempty,oneof=chat_completions responses"`
Endpoint string `json:"endpoint" binding:"required,max=500"`
APIKey string `json:"api_key" binding:"required,max=2000"`
PrimaryModel string `json:"primary_model" binding:"required,max=200"`
@ -54,6 +55,7 @@ type channelMonitorCreateRequest struct {
type channelMonitorUpdateRequest struct {
Name *string `json:"name" binding:"omitempty,max=100"`
Provider *string `json:"provider" binding:"omitempty,oneof=openai anthropic gemini"`
APIMode *string `json:"api_mode" binding:"omitempty,oneof=chat_completions responses"`
Endpoint *string `json:"endpoint" binding:"omitempty,max=500"`
APIKey *string `json:"api_key" binding:"omitempty,max=2000"`
PrimaryModel *string `json:"primary_model" binding:"omitempty,max=200"`
@ -72,6 +74,7 @@ type channelMonitorResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
APIMode string `json:"api_mode"`
Endpoint string `json:"endpoint"`
APIKeyMasked string `json:"api_key_masked"`
APIKeyDecryptFailed bool `json:"api_key_decrypt_failed"`
@ -138,6 +141,7 @@ func channelMonitorToResponse(m *service.ChannelMonitor) *channelMonitorResponse
ID: m.ID,
Name: m.Name,
Provider: m.Provider,
APIMode: m.APIMode,
Endpoint: m.Endpoint,
APIKeyMasked: maskAPIKey(m.APIKey),
APIKeyDecryptFailed: m.APIKeyDecryptFailed,
@ -303,6 +307,7 @@ func (h *ChannelMonitorHandler) Create(c *gin.Context) {
m, err := h.monitorService.Create(c.Request.Context(), service.ChannelMonitorCreateParams{
Name: req.Name,
Provider: req.Provider,
APIMode: req.APIMode,
Endpoint: req.Endpoint,
APIKey: req.APIKey,
PrimaryModel: req.PrimaryModel,
@ -338,6 +343,7 @@ func (h *ChannelMonitorHandler) Update(c *gin.Context) {
m, err := h.monitorService.Update(c.Request.Context(), id, service.ChannelMonitorUpdateParams{
Name: req.Name,
Provider: req.Provider,
APIMode: req.APIMode,
Endpoint: req.Endpoint,
APIKey: req.APIKey,
PrimaryModel: req.PrimaryModel,

View File

@ -27,6 +27,7 @@ func NewChannelMonitorRequestTemplateHandler(templateService *service.ChannelMon
type channelMonitorTemplateCreateRequest struct {
Name string `json:"name" binding:"required,max=100"`
Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"`
APIMode string `json:"api_mode" binding:"omitempty,oneof=chat_completions responses"`
Description string `json:"description" binding:"max=500"`
ExtraHeaders map[string]string `json:"extra_headers"`
BodyOverrideMode string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
@ -35,6 +36,7 @@ type channelMonitorTemplateCreateRequest struct {
type channelMonitorTemplateUpdateRequest struct {
Name *string `json:"name" binding:"omitempty,max=100"`
APIMode *string `json:"api_mode" binding:"omitempty,oneof=chat_completions responses"`
Description *string `json:"description" binding:"omitempty,max=500"`
ExtraHeaders *map[string]string `json:"extra_headers"`
BodyOverrideMode *string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
@ -45,6 +47,7 @@ type channelMonitorTemplateResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
APIMode string `json:"api_mode"`
Description string `json:"description"`
ExtraHeaders map[string]string `json:"extra_headers"`
BodyOverrideMode string `json:"body_override_mode"`
@ -67,6 +70,7 @@ func (h *ChannelMonitorRequestTemplateHandler) toResponse(c *gin.Context, t *ser
ID: t.ID,
Name: t.Name,
Provider: t.Provider,
APIMode: t.APIMode,
Description: t.Description,
ExtraHeaders: headers,
BodyOverrideMode: t.BodyOverrideMode,
@ -93,6 +97,7 @@ func parseTemplateID(c *gin.Context) (int64, bool) {
func (h *ChannelMonitorRequestTemplateHandler) List(c *gin.Context) {
items, err := h.templateService.List(c.Request.Context(), service.ChannelMonitorRequestTemplateListParams{
Provider: strings.TrimSpace(c.Query("provider")),
APIMode: strings.TrimSpace(c.Query("api_mode")),
})
if err != nil {
response.ErrorFrom(c, err)
@ -129,6 +134,7 @@ func (h *ChannelMonitorRequestTemplateHandler) Create(c *gin.Context) {
t, err := h.templateService.Create(c.Request.Context(), service.ChannelMonitorRequestTemplateCreateParams{
Name: req.Name,
Provider: req.Provider,
APIMode: req.APIMode,
Description: req.Description,
ExtraHeaders: req.ExtraHeaders,
BodyOverrideMode: req.BodyOverrideMode,
@ -154,6 +160,7 @@ func (h *ChannelMonitorRequestTemplateHandler) Update(c *gin.Context) {
}
t, err := h.templateService.Update(c.Request.Context(), id, service.ChannelMonitorRequestTemplateUpdateParams{
Name: req.Name,
APIMode: req.APIMode,
Description: req.Description,
ExtraHeaders: req.ExtraHeaders,
BodyOverrideMode: req.BodyOverrideMode,
@ -209,6 +216,7 @@ type associatedMonitorBriefResponse struct {
ID int64 `json:"id"`
Name string `json:"name"`
Provider string `json:"provider"`
APIMode string `json:"api_mode"`
Enabled bool `json:"enabled"`
}
@ -227,7 +235,7 @@ func (h *ChannelMonitorRequestTemplateHandler) AssociatedMonitors(c *gin.Context
out := make([]associatedMonitorBriefResponse, 0, len(items))
for _, m := range items {
out = append(out, associatedMonitorBriefResponse{
ID: m.ID, Name: m.Name, Provider: m.Provider, Enabled: m.Enabled,
ID: m.ID, Name: m.Name, Provider: m.Provider, APIMode: m.APIMode, Enabled: m.Enabled,
})
}
response.Success(c, gin.H{"items": out})

View File

@ -46,6 +46,8 @@ type contentModerationConfigRequest struct {
HitRetentionDays *int `json:"hit_retention_days"`
NonHitRetentionDays *int `json:"non_hit_retention_days"`
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
BlockedKeywords *[]string `json:"blocked_keywords"`
KeywordBlockingMode *string `json:"keyword_blocking_mode"`
}
type contentModerationAPIKeyTestRequest struct {
@ -103,6 +105,8 @@ func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) {
HitRetentionDays: req.HitRetentionDays,
NonHitRetentionDays: req.NonHitRetentionDays,
PreHashCheckEnabled: req.PreHashCheckEnabled,
BlockedKeywords: req.BlockedKeywords,
KeywordBlockingMode: req.KeywordBlockingMode,
})
if err != nil {
response.ErrorFrom(c, err)

View File

@ -546,9 +546,14 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
return
}
// cacheKey 必须包含当日日期,否则跨午夜后 30s 内会复用昨天的 "today_*" 结果。
keyRaw, _ := json.Marshal(struct {
V int `json:"v"`
Day string `json:"day"`
UserIDs []int64 `json:"user_ids"`
}{
V: 2, // bump 当响应结构变化(如加入 by_platform 时)
Day: timezone.Today().Format("2006-01-02"),
UserIDs: userIDs,
})
cacheKey := string(keyRaw)

View File

@ -1,9 +1,7 @@
package admin
import (
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
@ -386,79 +384,6 @@ func (h *OpsHandler) ListRequestErrorUpstreamErrors(c *gin.Context) {
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
// RetryRequestErrorClient retries the client request based on stored request body.
// POST /api/v1/admin/ops/request-errors/:id/retry-client
func (h *OpsHandler) RetryRequestErrorClient(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeClient, nil)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// RetryRequestErrorUpstreamEvent retries a specific upstream attempt using captured upstream_request_body.
// POST /api/v1/admin/ops/request-errors/:id/upstream-errors/:idx/retry
func (h *OpsHandler) RetryRequestErrorUpstreamEvent(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
idxStr := strings.TrimSpace(c.Param("idx"))
idx, err := strconv.Atoi(idxStr)
if err != nil || idx < 0 {
response.BadRequest(c, "Invalid upstream idx")
return
}
result, err := h.opsService.RetryUpstreamEvent(c.Request.Context(), subject.UserID, id, idx)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// ResolveRequestError toggles resolved status.
// PUT /api/v1/admin/ops/request-errors/:id/resolve
func (h *OpsHandler) ResolveRequestError(c *gin.Context) {
@ -566,39 +491,6 @@ func (h *OpsHandler) GetUpstreamError(c *gin.Context) {
h.GetErrorLogByID(c)
}
// RetryUpstreamError retries upstream error using the original account_id.
// POST /api/v1/admin/ops/upstream-errors/:id/retry
func (h *OpsHandler) RetryUpstreamError(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeUpstream, nil)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// ResolveUpstreamError toggles resolved status.
// PUT /api/v1/admin/ops/upstream-errors/:id/resolve
func (h *OpsHandler) ResolveUpstreamError(c *gin.Context) {
@ -708,106 +600,10 @@ func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
response.Paginated(c, out.Items, out.Total, out.Page, out.PageSize)
}
type opsRetryRequest struct {
Mode string `json:"mode"`
PinnedAccountID *int64 `json:"pinned_account_id"`
Force bool `json:"force"`
}
type opsResolveRequest struct {
Resolved bool `json:"resolved"`
}
// RetryErrorRequest retries a failed request using stored request_body.
// POST /api/v1/admin/ops/errors/:id/retry
func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
subject, ok := middleware.GetAuthSubjectFromContext(c)
if !ok || subject.UserID <= 0 {
response.Error(c, http.StatusUnauthorized, "Unauthorized")
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
req := opsRetryRequest{Mode: service.OpsRetryModeClient}
if err := c.ShouldBindJSON(&req); err != nil && !errors.Is(err, io.EOF) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
if strings.TrimSpace(req.Mode) == "" {
req.Mode = service.OpsRetryModeClient
}
// Force flag is currently a UI-level acknowledgement. Server may still enforce safety constraints.
_ = req.Force
// Legacy endpoint safety: only allow retrying the client request here.
// Upstream retries must go through the split endpoints.
if strings.EqualFold(strings.TrimSpace(req.Mode), service.OpsRetryModeUpstream) {
response.BadRequest(c, "upstream retry is not supported on this endpoint")
return
}
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, req.Mode, req.PinnedAccountID)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, result)
}
// ListRetryAttempts lists retry attempts for an error log.
// GET /api/v1/admin/ops/errors/:id/retries
func (h *OpsHandler) ListRetryAttempts(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
}
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
response.ErrorFrom(c, err)
return
}
idStr := strings.TrimSpace(c.Param("id"))
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil || id <= 0 {
response.BadRequest(c, "Invalid error id")
return
}
limit := 50
if v := strings.TrimSpace(c.Query("limit")); v != "" {
n, err := strconv.Atoi(v)
if err != nil || n <= 0 {
response.BadRequest(c, "Invalid limit")
return
}
limit = n
}
items, err := h.opsService.ListRetryAttemptsByErrorID(c.Request.Context(), id, limit)
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, items)
}
// UpdateErrorResolution allows manual resolve/unresolve.
// PUT /api/v1/admin/ops/errors/:id/resolve
func (h *OpsHandler) UpdateErrorResolution(c *gin.Context) {
@ -839,7 +635,7 @@ func (h *OpsHandler) UpdateErrorResolution(c *gin.Context) {
return
}
uid := subject.UserID
if err := h.opsService.UpdateErrorResolution(c.Request.Context(), id, req.Resolved, &uid, nil); err != nil {
if err := h.opsService.UpdateErrorResolution(c.Request.Context(), id, req.Resolved, &uid); err != nil {
response.ErrorFrom(c, err)
return
}

View File

@ -8,6 +8,7 @@ import (
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
@ -33,23 +34,51 @@ func NewRedeemHandler(adminService service.AdminService, redeemService *service.
// GenerateRedeemCodesRequest represents generate redeem codes request
type GenerateRedeemCodesRequest struct {
Count int `json:"count" binding:"required,min=1,max=100"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
Value float64 `json:"value"`
GroupID *int64 `json:"group_id"` // 订阅类型必填
ValidityDays int `json:"validity_days"` // 订阅类型使用,正数增加/负数退款扣减
Count int `json:"count" binding:"required,min=1,max=100"`
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
Value float64 `json:"value"`
GroupID *int64 `json:"group_id"` // 订阅类型必填
ValidityDays int `json:"validity_days"` // 订阅类型使用,正数增加/负数退款扣减
ExpiresAt *time.Time `json:"expires_at"`
ExpiresInDays *int `json:"expires_in_days" binding:"omitempty,min=1,max=3650"`
}
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance
type CreateAndRedeemCodeRequest struct {
Code string `json:"code" binding:"required,min=3,max=128"`
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance向后兼容
Value float64 `json:"value" binding:"required"`
UserID int64 `json:"user_id" binding:"required,gt=0"`
GroupID *int64 `json:"group_id"` // subscription 类型必填
ValidityDays int `json:"validity_days"` // subscription 类型:正数增加,负数退款扣减
Notes string `json:"notes"`
Code string `json:"code" binding:"required,min=3,max=128"`
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance向后兼容
Value float64 `json:"value" binding:"required"`
UserID int64 `json:"user_id" binding:"required,gt=0"`
GroupID *int64 `json:"group_id"` // subscription 类型必填
ValidityDays int `json:"validity_days"` // subscription 类型:正数增加,负数退款扣减
Notes string `json:"notes"`
ExpiresAt *time.Time `json:"expires_at"`
ExpiresInDays *int `json:"expires_in_days" binding:"omitempty,min=1,max=3650"`
}
func resolveRedeemCodeExpiresAt(expiresAt *time.Time, expiresInDays *int) (*time.Time, error) {
if expiresAt != nil && expiresInDays != nil {
return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRY_CONFLICT", "expires_at and expires_in_days cannot both be set")
}
now := time.Now().UTC()
if expiresInDays != nil {
if *expiresInDays <= 0 {
return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRES_IN_DAYS_INVALID", "expires_in_days must be greater than zero")
}
expires := now.AddDate(0, 0, *expiresInDays)
return &expires, nil
}
if expiresAt == nil {
return nil, nil
}
expires := expiresAt.UTC()
if !expires.After(now) {
return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRES_AT_INVALID", "expires_at must be in the future")
}
return &expires, nil
}
// List handles listing all redeem codes with pagination
@ -107,6 +136,12 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
return
}
expiresAt, err := resolveRedeemCodeExpiresAt(req.ExpiresAt, req.ExpiresInDays)
if err != nil {
response.ErrorFrom(c, err)
return
}
executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{
Count: req.Count,
@ -114,6 +149,7 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
Value: req.Value,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
ExpiresAt: expiresAt,
})
if execErr != nil {
return nil, execErr
@ -158,6 +194,12 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
}
}
expiresAt, err := resolveRedeemCodeExpiresAt(req.ExpiresAt, req.ExpiresInDays)
if err != nil {
response.ErrorFrom(c, err)
return
}
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
existing, err := h.redeemService.GetByCode(ctx, req.Code)
if err == nil {
@ -175,6 +217,7 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
Notes: req.Notes,
GroupID: req.GroupID,
ValidityDays: req.ValidityDays,
ExpiresAt: expiresAt,
})
if createErr != nil {
// Unique code race: if code now exists, use idempotent semantics by used_by.
@ -199,6 +242,9 @@ func (h *RedeemHandler) resolveCreateAndRedeemExisting(ctx context.Context, exis
}
// If previous run created the code but crashed before redeem, redeem it now.
if existing.IsExpired() {
return nil, service.ErrRedeemCodeExpired
}
if existing.CanUse() {
redeemed, err := h.redeemService.Redeem(ctx, userID, existing.Code)
if err == nil {
@ -321,7 +367,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
writer := csv.NewWriter(&buf)
// Write header
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "created_at"}); err != nil {
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "expires_at", "created_at"}); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
return
}
@ -340,6 +386,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
if code.UsedAt != nil {
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
}
expiresAt := ""
if code.ExpiresAt != nil {
expiresAt = code.ExpiresAt.Format("2006-01-02 15:04:05")
}
if err := writer.Write([]string{
fmt.Sprintf("%d", code.ID),
code.Code,
@ -349,6 +399,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
usedBy,
usedByEmail,
usedAt,
expiresAt,
code.CreatedAt.Format("2006-01-02 15:04:05"),
}); err != nil {
response.InternalError(c, "Failed to export redeem codes: "+err.Error())

View File

@ -6,6 +6,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@ -139,3 +140,33 @@ func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
assert.NotEqual(t, http.StatusBadRequest, code,
"balance type should not require group_id or validity_days")
}
func TestResolveRedeemCodeExpiresAt_FromDays(t *testing.T) {
days := 3
expiresAt, err := resolveRedeemCodeExpiresAt(nil, &days)
require.NoError(t, err)
require.NotNil(t, expiresAt)
require.WithinDuration(t, time.Now().UTC().AddDate(0, 0, days), *expiresAt, 2*time.Second)
}
func TestResolveRedeemCodeExpiresAt_RejectsPastAbsoluteTime(t *testing.T) {
past := time.Now().UTC().Add(-time.Minute)
expiresAt, err := resolveRedeemCodeExpiresAt(&past, nil)
require.Error(t, err)
require.Nil(t, expiresAt)
}
func TestResolveRedeemCodeExpiresAt_RejectsNonPositiveDays(t *testing.T) {
days := 0
expiresAt, err := resolveRedeemCodeExpiresAt(nil, &days)
require.Error(t, err)
require.Nil(t, expiresAt)
}
func TestResolveRedeemCodeExpiresAt_RejectsConflictingInputs(t *testing.T) {
future := time.Now().UTC().Add(time.Hour)
days := 3
expiresAt, err := resolveRedeemCodeExpiresAt(&future, &days)
require.Error(t, err)
require.Nil(t, expiresAt)
}

View File

@ -1,9 +1,11 @@
package admin
import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
@ -60,10 +62,11 @@ type SettingHandler struct {
opsService *service.OpsService
paymentConfigService *service.PaymentConfigService
paymentService *service.PaymentService
userAttributeService *service.UserAttributeService
}
// NewSettingHandler 创建系统设置处理器
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService) *SettingHandler {
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService, userAttributeService *service.UserAttributeService) *SettingHandler {
return &SettingHandler{
settingService: settingService,
emailService: emailService,
@ -71,6 +74,7 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser
opsService: opsService,
paymentConfigService: paymentConfigService,
paymentService: paymentService,
userAttributeService: userAttributeService,
}
}
@ -135,6 +139,22 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
DingTalkConnectEnabled: settings.DingTalkConnectEnabled,
DingTalkConnectClientID: settings.DingTalkConnectClientID,
DingTalkConnectClientSecretConfigured: settings.DingTalkConnectClientSecretConfigured,
DingTalkConnectRedirectURL: settings.DingTalkConnectRedirectURL,
DingTalkConnectCorpRestrictionPolicy: settings.DingTalkConnectCorpRestrictionPolicy,
DingTalkConnectInternalCorpID: settings.DingTalkConnectInternalCorpID,
DingTalkConnectBypassRegistration: settings.DingTalkConnectBypassRegistration,
DingTalkConnectSyncCorpEmail: settings.DingTalkConnectSyncCorpEmail,
DingTalkConnectSyncDisplayName: settings.DingTalkConnectSyncDisplayName,
DingTalkConnectSyncDept: settings.DingTalkConnectSyncDept,
DingTalkConnectSyncCorpEmailAttrKey: settings.DingTalkConnectSyncCorpEmailAttrKey,
DingTalkConnectSyncDisplayNameAttrKey: settings.DingTalkConnectSyncDisplayNameAttrKey,
DingTalkConnectSyncDeptAttrKey: settings.DingTalkConnectSyncDeptAttrKey,
DingTalkConnectSyncCorpEmailAttrName: settings.DingTalkConnectSyncCorpEmailAttrName,
DingTalkConnectSyncDisplayNameAttrName: settings.DingTalkConnectSyncDisplayNameAttrName,
DingTalkConnectSyncDeptAttrName: settings.DingTalkConnectSyncDeptAttrName,
WeChatConnectEnabled: settings.WeChatConnectEnabled,
WeChatConnectAppID: settings.WeChatConnectAppID,
WeChatConnectAppSecretConfigured: settings.WeChatConnectAppSecretConfigured,
@ -258,6 +278,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
PaymentAlipayForceQRCode: paymentCfg.AlipayForceQRCode,
ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
@ -376,6 +397,24 @@ type UpdateSettingsRequest struct {
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// DingTalk Connect OAuth 登录
DingTalkConnectEnabled bool `json:"dingtalk_connect_enabled"`
DingTalkConnectClientID string `json:"dingtalk_connect_client_id"`
DingTalkConnectClientSecret string `json:"dingtalk_connect_client_secret"`
DingTalkConnectRedirectURL string `json:"dingtalk_connect_redirect_url"`
DingTalkConnectCorpRestrictionPolicy string `json:"dingtalk_connect_corp_restriction_policy"`
DingTalkConnectInternalCorpID string `json:"dingtalk_connect_internal_corp_id"`
DingTalkConnectBypassRegistration bool `json:"dingtalk_connect_bypass_registration"`
DingTalkConnectSyncCorpEmail bool `json:"dingtalk_connect_sync_corp_email"`
DingTalkConnectSyncDisplayName bool `json:"dingtalk_connect_sync_display_name"`
DingTalkConnectSyncDept bool `json:"dingtalk_connect_sync_dept"`
DingTalkConnectSyncCorpEmailAttrKey string `json:"dingtalk_connect_sync_corp_email_attr_key"`
DingTalkConnectSyncDisplayNameAttrKey string `json:"dingtalk_connect_sync_display_name_attr_key"`
DingTalkConnectSyncDeptAttrKey string `json:"dingtalk_connect_sync_dept_attr_key"`
DingTalkConnectSyncCorpEmailAttrName string `json:"dingtalk_connect_sync_corp_email_attr_name"`
DingTalkConnectSyncDisplayNameAttrName string `json:"dingtalk_connect_sync_display_name_attr_name"`
DingTalkConnectSyncDeptAttrName string `json:"dingtalk_connect_sync_dept_attr_name"`
// WeChat Connect OAuth 登录
WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
WeChatConnectAppID string `json:"wechat_connect_app_id"`
@ -446,45 +485,50 @@ type UpdateSettingsRequest struct {
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
// 默认配置
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"`
AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"`
AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"`
AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"`
AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"`
AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"`
AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"`
AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"`
AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"`
AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"`
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
DefaultConcurrency int `json:"default_concurrency"`
DefaultBalance float64 `json:"default_balance"`
AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"`
AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"`
AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"`
AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"`
AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"`
AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"`
AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"`
AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"`
AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"`
AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"`
AuthSourceDefaultDingTalkBalance *float64 `json:"auth_source_default_dingtalk_balance"`
AuthSourceDefaultDingTalkConcurrency *int `json:"auth_source_default_dingtalk_concurrency"`
AuthSourceDefaultDingTalkSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_dingtalk_subscriptions"`
AuthSourceDefaultDingTalkGrantOnSignup *bool `json:"auth_source_default_dingtalk_grant_on_signup"`
AuthSourceDefaultDingTalkGrantOnFirstBind *bool `json:"auth_source_default_dingtalk_grant_on_first_bind"`
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@ -560,6 +604,9 @@ type UpdateSettingsRequest struct {
PaymentCancelRateLimitUnit *string `json:"payment_cancel_rate_limit_unit"`
PaymentCancelRateLimitMode *string `json:"payment_cancel_rate_limit_window_mode"`
// Force Alipay mobile clients to use QR code payment instead of mobile redirect
PaymentAlipayForceQRCode *bool `json:"payment_alipay_force_qrcode"`
// Channel Monitor feature switch
ChannelMonitorEnabled *bool `json:"channel_monitor_enabled"`
ChannelMonitorDefaultIntervalSeconds *int `json:"channel_monitor_default_interval_seconds"`
@ -661,6 +708,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions)
req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions)
req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions)
req.AuthSourceDefaultDingTalkSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultDingTalkSubscriptions)
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
@ -777,6 +825,100 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
// DingTalk Connect 参数验证
// 防御性:任何写入路径上把已废弃的 corp_restriction_policy=whitelist 入参 coerce 为 none
// 避免任何直连 admin API 的客户端把死值写回 DB前端 UI 已无此选项)。
req.DingTalkConnectCorpRestrictionPolicy = service.CoerceDingTalkCorpPolicyForWrite(req.DingTalkConnectCorpRestrictionPolicy)
if req.DingTalkConnectEnabled {
req.DingTalkConnectClientID = strings.TrimSpace(req.DingTalkConnectClientID)
req.DingTalkConnectClientSecret = strings.TrimSpace(req.DingTalkConnectClientSecret)
req.DingTalkConnectRedirectURL = strings.TrimSpace(req.DingTalkConnectRedirectURL)
req.DingTalkConnectCorpRestrictionPolicy = strings.TrimSpace(req.DingTalkConnectCorpRestrictionPolicy)
req.DingTalkConnectInternalCorpID = strings.TrimSpace(req.DingTalkConnectInternalCorpID)
if req.DingTalkConnectClientID == "" {
response.BadRequest(c, "DingTalk Client ID is required when enabled")
return
}
if req.DingTalkConnectRedirectURL == "" {
response.BadRequest(c, "DingTalk Redirect URL is required when enabled")
return
}
if err := config.ValidateAbsoluteHTTPURL(req.DingTalkConnectRedirectURL); err != nil {
response.BadRequest(c, "DingTalk Redirect URL must be an absolute http(s) URL")
return
}
// 如果未提供 client_secret则保留现有值如有
if req.DingTalkConnectClientSecret == "" {
if previousSettings.DingTalkConnectClientSecret == "" {
response.BadRequest(c, "DingTalk Client Secret is required when enabled")
return
}
req.DingTalkConnectClientSecret = previousSettings.DingTalkConnectClientSecret
}
// Corp 策略校验V1/V4 fail-closed
dingTalkCfg := config.DingTalkConnectConfig{
Enabled: true,
DingTalkAppKind: "internal_app", // 硬编码settings 层仅支持 internal_app
AppType: "internal", // 对于 internal_only 策略的默认值
CorpRestrictionPolicy: req.DingTalkConnectCorpRestrictionPolicy,
InternalCorpID: req.DingTalkConnectInternalCorpID,
}
// 若未填 corp_restriction_policy保留已有配置
if dingTalkCfg.CorpRestrictionPolicy == "" {
dingTalkCfg.CorpRestrictionPolicy = previousSettings.DingTalkConnectCorpRestrictionPolicy
}
// 对于 internal_only 策略app_type 必须为 internalV1 校验)
if dingTalkCfg.CorpRestrictionPolicy == "internal_only" {
dingTalkCfg.AppType = "internal"
} else {
dingTalkCfg.AppType = "public"
}
if err := config.ValidateDingTalkConfig(dingTalkCfg); err != nil {
response.ErrorWithDetails(c, http.StatusBadRequest, err.Error(), mapDingTalkValidateError(err), nil)
return
}
// bypass_registration 仅在 internal_only 模式下有意义;其它策略下强制为 false
// 防止 admin 在切换 policy 时把 bypass 残留在 DB 中(前端 UI 也已隐藏该开关)。
if dingTalkCfg.CorpRestrictionPolicy != "internal_only" {
req.DingTalkConnectBypassRegistration = false
// 身份同步三开关同理:仅 internal_only 模式下有意义,其它策略强制 false。
req.DingTalkConnectSyncCorpEmail = false
req.DingTalkConnectSyncDisplayName = false
req.DingTalkConnectSyncDept = false
}
// 身份同步目标 attr keytrimSpace + 空值 fallback 到默认值
req.DingTalkConnectSyncCorpEmailAttrKey = strings.TrimSpace(req.DingTalkConnectSyncCorpEmailAttrKey)
if req.DingTalkConnectSyncCorpEmailAttrKey == "" {
req.DingTalkConnectSyncCorpEmailAttrKey = "dingtalk_email"
}
req.DingTalkConnectSyncDisplayNameAttrKey = strings.TrimSpace(req.DingTalkConnectSyncDisplayNameAttrKey)
if req.DingTalkConnectSyncDisplayNameAttrKey == "" {
req.DingTalkConnectSyncDisplayNameAttrKey = "dingtalk_name"
}
req.DingTalkConnectSyncDeptAttrKey = strings.TrimSpace(req.DingTalkConnectSyncDeptAttrKey)
if req.DingTalkConnectSyncDeptAttrKey == "" {
req.DingTalkConnectSyncDeptAttrKey = "dingtalk_department"
}
// 身份同步目标 attr 显示名称trim + 空值 fallback 到默认中文名
req.DingTalkConnectSyncCorpEmailAttrName = strings.TrimSpace(req.DingTalkConnectSyncCorpEmailAttrName)
if req.DingTalkConnectSyncCorpEmailAttrName == "" {
req.DingTalkConnectSyncCorpEmailAttrName = "钉钉企业邮箱"
}
req.DingTalkConnectSyncDisplayNameAttrName = strings.TrimSpace(req.DingTalkConnectSyncDisplayNameAttrName)
if req.DingTalkConnectSyncDisplayNameAttrName == "" {
req.DingTalkConnectSyncDisplayNameAttrName = "钉钉姓名"
}
req.DingTalkConnectSyncDeptAttrName = strings.TrimSpace(req.DingTalkConnectSyncDeptAttrName)
if req.DingTalkConnectSyncDeptAttrName == "" {
req.DingTalkConnectSyncDeptAttrName = "钉钉部门"
}
}
if req.WeChatConnectEnabled {
req.WeChatConnectAppID = strings.TrimSpace(req.WeChatConnectAppID)
req.WeChatConnectAppSecret = strings.TrimSpace(req.WeChatConnectAppSecret)
@ -1272,113 +1414,129 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
FrontendURL: req.FrontendURL,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled,
LoginAgreementEnabled: req.LoginAgreementEnabled,
LoginAgreementMode: loginAgreementMode,
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
LoginAgreementDocuments: loginAgreementDocuments,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
WeChatConnectEnabled: req.WeChatConnectEnabled,
WeChatConnectAppID: req.WeChatConnectAppID,
WeChatConnectAppSecret: req.WeChatConnectAppSecret,
WeChatConnectOpenAppID: req.WeChatConnectOpenAppID,
WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret,
WeChatConnectMPAppID: req.WeChatConnectMPAppID,
WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret,
WeChatConnectMobileAppID: req.WeChatConnectMobileAppID,
WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret,
WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled,
WeChatConnectMPEnabled: req.WeChatConnectMPEnabled,
WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled,
WeChatConnectMode: req.WeChatConnectMode,
WeChatConnectScopes: req.WeChatConnectScopes,
WeChatConnectRedirectURL: req.WeChatConnectRedirectURL,
WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL,
OIDCConnectEnabled: req.OIDCConnectEnabled,
OIDCConnectProviderName: req.OIDCConnectProviderName,
OIDCConnectClientID: req.OIDCConnectClientID,
OIDCConnectClientSecret: req.OIDCConnectClientSecret,
OIDCConnectIssuerURL: req.OIDCConnectIssuerURL,
OIDCConnectDiscoveryURL: req.OIDCConnectDiscoveryURL,
OIDCConnectAuthorizeURL: req.OIDCConnectAuthorizeURL,
OIDCConnectTokenURL: req.OIDCConnectTokenURL,
OIDCConnectUserInfoURL: req.OIDCConnectUserInfoURL,
OIDCConnectJWKSURL: req.OIDCConnectJWKSURL,
OIDCConnectScopes: req.OIDCConnectScopes,
OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
OIDCConnectUsePKCE: oidcUsePKCE,
OIDCConnectValidateIDToken: oidcValidateIDToken,
OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,
OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
GitHubOAuthEnabled: req.GitHubOAuthEnabled,
GitHubOAuthClientID: req.GitHubOAuthClientID,
GitHubOAuthClientSecret: req.GitHubOAuthClientSecret,
GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL,
GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL,
GoogleOAuthEnabled: req.GoogleOAuthEnabled,
GoogleOAuthClientID: req.GoogleOAuthClientID,
GoogleOAuthClientSecret: req.GoogleOAuthClientSecret,
GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL,
GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
TableDefaultPageSize: req.TableDefaultPageSize,
TablePageSizeOptions: req.TablePageSizeOptions,
CustomMenuItems: customMenuJSON,
CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
AffiliateRebateRate: affiliateRebateRate,
AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
AffiliateRebateDurationDays: affiliateRebateDurationDays,
AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
MaxClaudeCodeVersion: req.MaxClaudeCodeVersion,
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
BackendModeEnabled: req.BackendModeEnabled,
RegistrationEnabled: req.RegistrationEnabled,
EmailVerifyEnabled: req.EmailVerifyEnabled,
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
PromoCodeEnabled: req.PromoCodeEnabled,
PasswordResetEnabled: req.PasswordResetEnabled,
FrontendURL: req.FrontendURL,
InvitationCodeEnabled: req.InvitationCodeEnabled,
TotpEnabled: req.TotpEnabled,
LoginAgreementEnabled: req.LoginAgreementEnabled,
LoginAgreementMode: loginAgreementMode,
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
LoginAgreementDocuments: loginAgreementDocuments,
SMTPHost: req.SMTPHost,
SMTPPort: req.SMTPPort,
SMTPUsername: req.SMTPUsername,
SMTPPassword: req.SMTPPassword,
SMTPFrom: req.SMTPFrom,
SMTPFromName: req.SMTPFromName,
SMTPUseTLS: req.SMTPUseTLS,
TurnstileEnabled: req.TurnstileEnabled,
TurnstileSiteKey: req.TurnstileSiteKey,
TurnstileSecretKey: req.TurnstileSecretKey,
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
DingTalkConnectEnabled: req.DingTalkConnectEnabled,
DingTalkConnectClientID: req.DingTalkConnectClientID,
DingTalkConnectClientSecret: req.DingTalkConnectClientSecret,
DingTalkConnectRedirectURL: req.DingTalkConnectRedirectURL,
DingTalkConnectCorpRestrictionPolicy: req.DingTalkConnectCorpRestrictionPolicy,
DingTalkConnectInternalCorpID: req.DingTalkConnectInternalCorpID,
DingTalkConnectBypassRegistration: req.DingTalkConnectBypassRegistration,
DingTalkConnectSyncCorpEmail: req.DingTalkConnectSyncCorpEmail,
DingTalkConnectSyncDisplayName: req.DingTalkConnectSyncDisplayName,
DingTalkConnectSyncDept: req.DingTalkConnectSyncDept,
DingTalkConnectSyncCorpEmailAttrKey: req.DingTalkConnectSyncCorpEmailAttrKey,
DingTalkConnectSyncDisplayNameAttrKey: req.DingTalkConnectSyncDisplayNameAttrKey,
DingTalkConnectSyncDeptAttrKey: req.DingTalkConnectSyncDeptAttrKey,
DingTalkConnectSyncCorpEmailAttrName: req.DingTalkConnectSyncCorpEmailAttrName,
DingTalkConnectSyncDisplayNameAttrName: req.DingTalkConnectSyncDisplayNameAttrName,
DingTalkConnectSyncDeptAttrName: req.DingTalkConnectSyncDeptAttrName,
WeChatConnectEnabled: req.WeChatConnectEnabled,
WeChatConnectAppID: req.WeChatConnectAppID,
WeChatConnectAppSecret: req.WeChatConnectAppSecret,
WeChatConnectOpenAppID: req.WeChatConnectOpenAppID,
WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret,
WeChatConnectMPAppID: req.WeChatConnectMPAppID,
WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret,
WeChatConnectMobileAppID: req.WeChatConnectMobileAppID,
WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret,
WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled,
WeChatConnectMPEnabled: req.WeChatConnectMPEnabled,
WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled,
WeChatConnectMode: req.WeChatConnectMode,
WeChatConnectScopes: req.WeChatConnectScopes,
WeChatConnectRedirectURL: req.WeChatConnectRedirectURL,
WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL,
OIDCConnectEnabled: req.OIDCConnectEnabled,
OIDCConnectProviderName: req.OIDCConnectProviderName,
OIDCConnectClientID: req.OIDCConnectClientID,
OIDCConnectClientSecret: req.OIDCConnectClientSecret,
OIDCConnectIssuerURL: req.OIDCConnectIssuerURL,
OIDCConnectDiscoveryURL: req.OIDCConnectDiscoveryURL,
OIDCConnectAuthorizeURL: req.OIDCConnectAuthorizeURL,
OIDCConnectTokenURL: req.OIDCConnectTokenURL,
OIDCConnectUserInfoURL: req.OIDCConnectUserInfoURL,
OIDCConnectJWKSURL: req.OIDCConnectJWKSURL,
OIDCConnectScopes: req.OIDCConnectScopes,
OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
OIDCConnectUsePKCE: oidcUsePKCE,
OIDCConnectValidateIDToken: oidcValidateIDToken,
OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,
OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
GitHubOAuthEnabled: req.GitHubOAuthEnabled,
GitHubOAuthClientID: req.GitHubOAuthClientID,
GitHubOAuthClientSecret: req.GitHubOAuthClientSecret,
GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL,
GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL,
GoogleOAuthEnabled: req.GoogleOAuthEnabled,
GoogleOAuthClientID: req.GoogleOAuthClientID,
GoogleOAuthClientSecret: req.GoogleOAuthClientSecret,
GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL,
GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL,
SiteName: req.SiteName,
SiteLogo: req.SiteLogo,
SiteSubtitle: req.SiteSubtitle,
APIBaseURL: req.APIBaseURL,
ContactInfo: req.ContactInfo,
DocURL: req.DocURL,
HomeContent: req.HomeContent,
HideCcsImportButton: req.HideCcsImportButton,
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
TableDefaultPageSize: req.TableDefaultPageSize,
TablePageSizeOptions: req.TablePageSizeOptions,
CustomMenuItems: customMenuJSON,
CustomEndpoints: customEndpointsJSON,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
AffiliateRebateRate: affiliateRebateRate,
AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
AffiliateRebateDurationDays: affiliateRebateDurationDays,
AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
FallbackModelGemini: req.FallbackModelGemini,
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
MaxClaudeCodeVersion: req.MaxClaudeCodeVersion,
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
BackendModeEnabled: req.BackendModeEnabled,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
@ -1574,6 +1732,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind),
},
DingTalk: service.ProviderDefaultGrantSettings{
Balance: float64ValueOrDefault(req.AuthSourceDefaultDingTalkBalance, previousAuthSourceDefaults.DingTalk.Balance),
Concurrency: intValueOrDefault(req.AuthSourceDefaultDingTalkConcurrency, previousAuthSourceDefaults.DingTalk.Concurrency),
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultDingTalkSubscriptions, previousAuthSourceDefaults.DingTalk.Subscriptions),
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnSignup, previousAuthSourceDefaults.DingTalk.GrantOnSignup),
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnFirstBind, previousAuthSourceDefaults.DingTalk.GrantOnFirstBind),
},
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
}
if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil {
@ -1613,6 +1778,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
CancelRateLimitWindow: req.PaymentCancelRateLimitWindow,
CancelRateLimitUnit: req.PaymentCancelRateLimitUnit,
CancelRateLimitMode: req.PaymentCancelRateLimitMode,
AlipayForceQRCode: req.PaymentAlipayForceQRCode,
}
if err := h.paymentConfigService.UpdatePaymentConfig(c.Request.Context(), paymentReq); err != nil {
response.ErrorFrom(c, err)
@ -1632,6 +1798,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
h.ensureDingTalkSyncAttributes(c.Request.Context(), updatedSettings)
updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
@ -1682,6 +1849,22 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
DingTalkConnectEnabled: updatedSettings.DingTalkConnectEnabled,
DingTalkConnectClientID: updatedSettings.DingTalkConnectClientID,
DingTalkConnectClientSecretConfigured: updatedSettings.DingTalkConnectClientSecretConfigured,
DingTalkConnectRedirectURL: updatedSettings.DingTalkConnectRedirectURL,
DingTalkConnectCorpRestrictionPolicy: updatedSettings.DingTalkConnectCorpRestrictionPolicy,
DingTalkConnectInternalCorpID: updatedSettings.DingTalkConnectInternalCorpID,
DingTalkConnectBypassRegistration: updatedSettings.DingTalkConnectBypassRegistration,
DingTalkConnectSyncCorpEmail: updatedSettings.DingTalkConnectSyncCorpEmail,
DingTalkConnectSyncDisplayName: updatedSettings.DingTalkConnectSyncDisplayName,
DingTalkConnectSyncDept: updatedSettings.DingTalkConnectSyncDept,
DingTalkConnectSyncCorpEmailAttrKey: updatedSettings.DingTalkConnectSyncCorpEmailAttrKey,
DingTalkConnectSyncDisplayNameAttrKey: updatedSettings.DingTalkConnectSyncDisplayNameAttrKey,
DingTalkConnectSyncDeptAttrKey: updatedSettings.DingTalkConnectSyncDeptAttrKey,
DingTalkConnectSyncCorpEmailAttrName: updatedSettings.DingTalkConnectSyncCorpEmailAttrName,
DingTalkConnectSyncDisplayNameAttrName: updatedSettings.DingTalkConnectSyncDisplayNameAttrName,
DingTalkConnectSyncDeptAttrName: updatedSettings.DingTalkConnectSyncDeptAttrName,
WeChatConnectEnabled: updatedSettings.WeChatConnectEnabled,
WeChatConnectAppID: updatedSettings.WeChatConnectAppID,
WeChatConnectAppSecretConfigured: updatedSettings.WeChatConnectAppSecretConfigured,
@ -1803,6 +1986,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
PaymentAlipayForceQRCode: updatedPaymentCfg.AlipayForceQRCode,
ChannelMonitorEnabled: updatedSettings.ChannelMonitorEnabled,
ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds,
@ -1822,6 +2006,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
// hasPaymentFields returns true if any payment-related field was explicitly provided.
// mapDingTalkValidateError maps ValidateDingTalkConfig errors to machine-readable reason codes.
func mapDingTalkValidateError(err error) string {
switch {
case errors.Is(err, config.ErrDingTalkV1AppTypeMismatch):
return "dingtalk_apptype_mismatch"
case errors.Is(err, config.ErrDingTalkV4InvalidAppKind):
return "dingtalk_app_kind_invalid"
default:
return "dingtalk_corp_config_invalid"
}
}
func hasPaymentFields(req UpdateSettingsRequest) bool {
return req.PaymentEnabled != nil || req.PaymentMinAmount != nil ||
req.PaymentMaxAmount != nil || req.PaymentDailyLimit != nil ||
@ -1832,7 +2028,8 @@ func hasPaymentFields(req UpdateSettingsRequest) bool {
req.PaymentProductNameSuffix != nil || req.PaymentHelpImageURL != nil ||
req.PaymentHelpText != nil || req.PaymentCancelRateLimitEnabled != nil ||
req.PaymentCancelRateLimitMax != nil || req.PaymentCancelRateLimitWindow != nil ||
req.PaymentCancelRateLimitUnit != nil || req.PaymentCancelRateLimitMode != nil
req.PaymentCancelRateLimitUnit != nil || req.PaymentCancelRateLimitMode != nil ||
req.PaymentAlipayForceQRCode != nil
}
func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) {
@ -1935,6 +2132,45 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
changed = append(changed, "linuxdo_connect_redirect_url")
}
if before.DingTalkConnectEnabled != after.DingTalkConnectEnabled {
changed = append(changed, "dingtalk_connect_enabled")
}
if before.DingTalkConnectClientID != after.DingTalkConnectClientID {
changed = append(changed, "dingtalk_connect_client_id")
}
if req.DingTalkConnectClientSecret != "" {
changed = append(changed, "dingtalk_connect_client_secret")
}
if before.DingTalkConnectRedirectURL != after.DingTalkConnectRedirectURL {
changed = append(changed, "dingtalk_connect_redirect_url")
}
if before.DingTalkConnectCorpRestrictionPolicy != after.DingTalkConnectCorpRestrictionPolicy {
changed = append(changed, "dingtalk_connect_corp_restriction_policy")
}
if before.DingTalkConnectInternalCorpID != after.DingTalkConnectInternalCorpID {
changed = append(changed, "dingtalk_connect_internal_corp_id")
}
if before.DingTalkConnectBypassRegistration != after.DingTalkConnectBypassRegistration {
changed = append(changed, "dingtalk_connect_bypass_registration")
}
if before.DingTalkConnectSyncCorpEmail != after.DingTalkConnectSyncCorpEmail {
changed = append(changed, "dingtalk_connect_sync_corp_email")
}
if before.DingTalkConnectSyncDisplayName != after.DingTalkConnectSyncDisplayName {
changed = append(changed, "dingtalk_connect_sync_display_name")
}
if before.DingTalkConnectSyncDept != after.DingTalkConnectSyncDept {
changed = append(changed, "dingtalk_connect_sync_dept")
}
if before.DingTalkConnectSyncCorpEmailAttrKey != after.DingTalkConnectSyncCorpEmailAttrKey {
changed = append(changed, "dingtalk_connect_sync_corp_email_attr_key")
}
if before.DingTalkConnectSyncDisplayNameAttrKey != after.DingTalkConnectSyncDisplayNameAttrKey {
changed = append(changed, "dingtalk_connect_sync_display_name_attr_key")
}
if before.DingTalkConnectSyncDeptAttrKey != after.DingTalkConnectSyncDeptAttrKey {
changed = append(changed, "dingtalk_connect_sync_dept_attr_key")
}
if before.WeChatConnectEnabled != after.WeChatConnectEnabled {
changed = append(changed, "wechat_connect_enabled")
}
@ -2246,6 +2482,7 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource
{name: "wechat", before: before.WeChat, after: after.WeChat},
{name: "github", before: before.GitHub, after: after.GitHub},
{name: "google", before: before.Google, after: after.Google},
{name: "dingtalk", before: before.DingTalk, after: after.DingTalk},
}
for _, field := range fields {
if field.before.Balance != field.after.Balance {
@ -2350,6 +2587,11 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults
data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions
data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup
data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind
data["auth_source_default_dingtalk_balance"] = authSourceDefaults.DingTalk.Balance
data["auth_source_default_dingtalk_concurrency"] = authSourceDefaults.DingTalk.Concurrency
data["auth_source_default_dingtalk_subscriptions"] = authSourceDefaults.DingTalk.Subscriptions
data["auth_source_default_dingtalk_grant_on_signup"] = authSourceDefaults.DingTalk.GrantOnSignup
data["auth_source_default_dingtalk_grant_on_first_bind"] = authSourceDefaults.DingTalk.GrantOnFirstBind
data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance
data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency
data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions
@ -3044,3 +3286,56 @@ func (h *SettingHandler) TestWebSearchEmulation(c *gin.Context) {
}
response.Success(c, result)
}
// ensureDingTalkSyncAttributes 在保存 settings 后,按 admin 配置的 (attr key, attr name)
// 兜底 upsert 对应 user attribute definition不存在则创建存在但 name 不同则更新 name
// type/options/required 不变)。仅 internal_only + 对应 sync 开关开启时执行。
// 失败仅记录日志,不阻塞 settings 保存。
func (h *SettingHandler) ensureDingTalkSyncAttributes(ctx context.Context, settings *service.SystemSettings) {
if h.userAttributeService == nil || settings == nil {
return
}
if settings.DingTalkConnectCorpRestrictionPolicy != "internal_only" {
return
}
if settings.DingTalkConnectSyncDisplayName {
h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncDisplayNameAttrKey, settings.DingTalkConnectSyncDisplayNameAttrName, "钉钉 internal_only 登录时同步的钉钉姓名", service.AttributeTypeText)
}
if settings.DingTalkConnectSyncCorpEmail {
h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncCorpEmailAttrKey, settings.DingTalkConnectSyncCorpEmailAttrName, "钉钉 internal_only 登录时同步的企业邮箱", service.AttributeTypeEmail)
}
if settings.DingTalkConnectSyncDept {
h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncDeptAttrKey, settings.DingTalkConnectSyncDeptAttrName, "钉钉 internal_only 登录时同步的完整部门路径(如:公司/研发部)", service.AttributeTypeText)
}
}
func (h *SettingHandler) ensureUserAttributeDefinition(ctx context.Context, key, name, description string, attrType service.UserAttributeType) {
key = strings.TrimSpace(key)
if key == "" {
return
}
existing, err := h.userAttributeService.GetDefinitionByKey(ctx, key)
if err == nil && existing != nil {
if strings.TrimSpace(name) != "" && existing.Name != name {
if _, err := h.userAttributeService.UpdateDefinition(ctx, existing.ID, service.UpdateAttributeDefinitionInput{
Name: &name,
}); err != nil {
slog.Warn("dingtalk: update user attribute definition name failed", "key", key, "err", err.Error())
return
}
slog.Info("dingtalk: updated user attribute definition name", "key", key, "name", name)
}
return
}
if _, err := h.userAttributeService.CreateDefinition(ctx, service.CreateAttributeDefinitionInput{
Key: key,
Name: name,
Description: description,
Type: attrType,
Enabled: true,
}); err != nil {
slog.Warn("dingtalk: ensure user attribute definition failed", "key", key, "err", err.Error())
return
}
slog.Info("dingtalk: created user attribute definition", "key", key, "name", name, "type", attrType)
}

View File

@ -137,7 +137,7 @@ func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) {
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
@ -174,7 +174,7 @@ func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *tes
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"registration_enabled": true,
@ -214,7 +214,7 @@ func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedS
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
@ -264,7 +264,7 @@ func TestSettingHandler_UpdateSettings_PreservesLegacyBlankPaymentVisibleMethodS
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": false,
@ -309,7 +309,7 @@ func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFla
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
@ -388,7 +388,7 @@ func TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaul
ClockSkewSeconds: 120,
},
})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
@ -417,7 +417,7 @@ func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(
},
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"promo_code_enabled": true,
@ -450,7 +450,7 @@ func TestSettingHandler_UpdateSettings_DoesNotPersistPartialSystemSettingsWhenAu
err: errors.New("write auth source defaults failed"),
}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
body := map[string]any{
"registration_enabled": true,

View File

@ -0,0 +1,319 @@
package admin
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
// dingtalkSettingsRepoStub 复用 settingHandlerRepoStub已在 setting_handler_auth_source_defaults_test.go 定义)
func newDingTalkSettingsHandler() (*SettingHandler, *settingHandlerRepoStub) {
repo := &settingHandlerRepoStub{values: map[string]string{}}
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
return handler, repo
}
// baseValidDingTalkBody 返回一个可以通过所有校验的最小合法 body。
func baseValidDingTalkBody() map[string]any {
return map[string]any{
"dingtalk_connect_enabled": true,
"dingtalk_connect_client_id": "test-client-id",
"dingtalk_connect_client_secret": "test-client-secret",
"dingtalk_connect_redirect_url": "https://example.com/auth/dingtalk/callback",
"dingtalk_connect_corp_restriction_policy": "none",
}
}
// TestSettingsPUT_DingTalk_V3_InternalOnlyAllowsEmptyCorpID 验证方案 A
// internal_only + internal_corp_id="" 应通过校验(→ 200不再是 400。
func TestSettingsPUT_DingTalk_V3_InternalOnlyAllowsEmptyCorpID(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, _ := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
body["dingtalk_connect_internal_corp_id"] = "" // 空值现在合法
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
}
// TestSettingsPUT_DingTalk_HappyPath_None 验证 none policy → 200
func TestSettingsPUT_DingTalk_HappyPath_None(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, _ := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "none"
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, true, data["dingtalk_connect_enabled"])
}
// TestSettingsPUT_DingTalk_HappyPath_InternalOnly_WithCorpID 验证 internal_only + corp_id → 200
func TestSettingsPUT_DingTalk_HappyPath_InternalOnly_WithCorpID(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, _ := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
body["dingtalk_connect_internal_corp_id"] = "ding-corp-123"
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
}
// TestSettingsPUT_DingTalk_BypassRegistration_RoundTrip 验证 bypass_registration 字段 save+load。
// 必须用 policy=internal_onlybypass 仅在该 policy 下生效,其它 policy 写入层会 coerce 为 false。
func TestSettingsPUT_DingTalk_BypassRegistration_RoundTrip(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, _ := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
body["dingtalk_connect_bypass_registration"] = true
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, true, data["dingtalk_connect_bypass_registration"])
}
// TestSettingsPUT_DingTalk_Disabled_SkipsValidation 验证 disabled 时跳过 corp 校验 → 200。
// 用 enabled=true 时必然触发"Client ID is required when enabled"的空 client_id 作为
// 哨兵——只要 enabled=false 仍能 200 就证明跳过了。
func TestSettingsPUT_DingTalk_Disabled_SkipsValidation(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, _ := newDingTalkSettingsHandler()
body := map[string]any{
"dingtalk_connect_enabled": false,
"dingtalk_connect_client_id": "", // 这种空值在 enabled=true 时会被 400 拒绝
"dingtalk_connect_corp_restriction_policy": "internal_only",
}
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
}
// TestSettingsPUT_DingTalk_SyncFlags_InternalOnly_RoundTrip 验证三个 sync 开关在 internal_only 下可正常 save+load。
func TestSettingsPUT_DingTalk_SyncFlags_InternalOnly_RoundTrip(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, _ := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
body["dingtalk_connect_sync_corp_email"] = true
body["dingtalk_connect_sync_display_name"] = true
body["dingtalk_connect_sync_dept"] = true
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, true, data["dingtalk_connect_sync_corp_email"], "sync_corp_email should be true for internal_only")
require.Equal(t, true, data["dingtalk_connect_sync_display_name"], "sync_display_name should be true for internal_only")
require.Equal(t, true, data["dingtalk_connect_sync_dept"], "sync_dept should be true for internal_only")
}
// TestSettingsPUT_DingTalk_SyncFlags_PolicyNone_CoercedToFalse 验证 policy=none 时三个 sync 开关被 coerce 为 false。
func TestSettingsPUT_DingTalk_SyncFlags_PolicyNone_CoercedToFalse(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, _ := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "none"
body["dingtalk_connect_sync_corp_email"] = true
body["dingtalk_connect_sync_display_name"] = true
body["dingtalk_connect_sync_dept"] = true
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, false, data["dingtalk_connect_sync_corp_email"], "sync_corp_email must be coerced to false when policy=none")
require.Equal(t, false, data["dingtalk_connect_sync_display_name"], "sync_display_name must be coerced to false when policy=none")
require.Equal(t, false, data["dingtalk_connect_sync_dept"], "sync_dept must be coerced to false when policy=none")
}
// TestSettingsPUT_DingTalk_StaleWhitelist_CoercedToNone 验证升级兼容:
// admin 直接把 corp_restriction_policy=whitelist 提交(前端 UI 已无此选项,但 API 仍可命中)
// 不应导致 400 失败,应该被静默 coerce 为 none 后通过校验。
func TestSettingsPUT_DingTalk_StaleWhitelist_CoercedToNone(t *testing.T) {
gin.SetMode(gin.TestMode)
handler, repo := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "whitelist"
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "none", repo.values[service.SettingKeyDingTalkConnectCorpRestrictionPolicy],
"stale whitelist 应在写入路径被 coerce 为 none")
}
// TestSettingsPUT_DingTalk_SyncAttrKey_RoundTrip 验证 3 个 attr key 字段 save+load + 空值 fallback 到默认值。
func TestSettingsPUT_DingTalk_SyncAttrKey_RoundTrip(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Run("custom_attr_keys_saved", func(t *testing.T) {
handler, repo := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
body["dingtalk_connect_sync_corp_email"] = true
body["dingtalk_connect_sync_display_name"] = true
body["dingtalk_connect_sync_dept"] = true
body["dingtalk_connect_sync_corp_email_attr_key"] = "my_email_attr"
body["dingtalk_connect_sync_display_name_attr_key"] = "my_name_attr"
body["dingtalk_connect_sync_dept_attr_key"] = "my_dept_attr"
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
// 验证写入 DB 的 key
require.Equal(t, "my_email_attr", repo.values[service.SettingKeyDingTalkConnectSyncCorpEmailAttrKey])
require.Equal(t, "my_name_attr", repo.values[service.SettingKeyDingTalkConnectSyncDisplayNameAttrKey])
require.Equal(t, "my_dept_attr", repo.values[service.SettingKeyDingTalkConnectSyncDeptAttrKey])
// 验证响应中的 attr key
var resp response.Response
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
data, ok := resp.Data.(map[string]any)
require.True(t, ok)
require.Equal(t, "my_email_attr", data["dingtalk_connect_sync_corp_email_attr_key"])
require.Equal(t, "my_name_attr", data["dingtalk_connect_sync_display_name_attr_key"])
require.Equal(t, "my_dept_attr", data["dingtalk_connect_sync_dept_attr_key"])
})
t.Run("empty_attr_keys_fallback_to_defaults", func(t *testing.T) {
handler, repo := newDingTalkSettingsHandler()
body := baseValidDingTalkBody()
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
// 不传 attr key → 写入层 fallback 到默认值
body["dingtalk_connect_sync_corp_email_attr_key"] = ""
body["dingtalk_connect_sync_display_name_attr_key"] = ""
body["dingtalk_connect_sync_dept_attr_key"] = ""
rawBody, err := json.Marshal(body)
require.NoError(t, err)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
c.Request.Header.Set("Content-Type", "application/json")
handler.UpdateSettings(c)
require.Equal(t, http.StatusOK, rec.Code)
// 空值应 fallback 到默认值并持久化
require.Equal(t, "dingtalk_email", repo.values[service.SettingKeyDingTalkConnectSyncCorpEmailAttrKey])
require.Equal(t, "dingtalk_name", repo.values[service.SettingKeyDingTalkConnectSyncDisplayNameAttrKey])
require.Equal(t, "dingtalk_department", repo.values[service.SettingKeyDingTalkConnectSyncDeptAttrKey])
})
}

View File

@ -0,0 +1,398 @@
package handler
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
// dingTalkClientConfig 是 DingTalkClient 需要的最小配置子集
type dingTalkClientConfig struct {
ClientID string
ClientSecret string
TokenURL string
UserInfoURL string
}
type DingTalkClient struct {
cfg dingTalkClientConfig
appToken string
appTokenExp time.Time // 钉钉 7200s留 200s 余量 → 7000s
mu sync.Mutex
httpClient *http.Client
// TODO(multi-instance): Redis 集中缓存 appToken
}
type DingTalkUserTokenResp struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ExpireIn int64 `json:"expireIn"`
CorpID string `json:"corpId"`
}
func (c *DingTalkClient) ExchangeCodeForUserToken(ctx context.Context, code string) (*DingTalkUserTokenResp, error) {
body := map[string]string{
"clientId": c.cfg.ClientID,
"clientSecret": c.cfg.ClientSecret,
"code": code,
"grantType": "authorization_code",
}
payload, _ := json.Marshal(body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.TokenURL, bytes.NewReader(payload))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, parseDingTalkErr(raw, resp.StatusCode)
}
var out DingTalkUserTokenResp
if err := json.Unmarshal(raw, &out); err != nil {
return nil, err
}
if strings.TrimSpace(out.AccessToken) == "" {
return nil, parseDingTalkErr(raw, resp.StatusCode)
}
return &out, nil
}
type DingTalkAPIError struct {
Code string
Message string
HTTP int
}
func (e *DingTalkAPIError) Error() string {
return fmt.Sprintf("dingtalk api error code=%s msg=%s http=%d", e.Code, e.Message, e.HTTP)
}
func parseDingTalkErr(raw []byte, status int) error {
var v struct {
Code string `json:"code"`
Message string `json:"message"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
_ = json.Unmarshal(raw, &v)
code := v.Code
if code == "" && v.ErrCode != 0 {
code = fmt.Sprintf("%d", v.ErrCode)
}
msg := v.Message
if msg == "" {
msg = v.ErrMsg
}
return &DingTalkAPIError{Code: code, Message: msg, HTTP: status}
}
// GetUnionIdByUserToken 调用 /v1.0/contact/users/me 返回 unionId 与用户自设昵称 nick。
// nick 来自钉钉新版 OIDC 接口(用户在 App 个人资料填的昵称),与旧版 user/get.nickname 不同源。
func (c *DingTalkClient) GetUnionIdByUserToken(ctx context.Context, userToken string) (unionID string, nick string, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.cfg.UserInfoURL, nil)
if err != nil {
return "", "", err
}
req.Header.Set("x-acs-dingtalk-access-token", userToken)
resp, err := c.httpClient.Do(req)
if err != nil {
return "", "", err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return "", "", parseDingTalkErr(raw, resp.StatusCode)
}
var v struct {
UnionID string `json:"unionId"`
Nick string `json:"nick"`
}
if err := json.Unmarshal(raw, &v); err != nil {
return "", "", err
}
if strings.TrimSpace(v.UnionID) == "" {
return "", "", parseDingTalkErr(raw, resp.StatusCode)
}
return v.UnionID, v.Nick, nil
}
type DingTalkStaffInfo struct {
UserID string
Name string // 企业内真实姓名(钉钉企业管理后台配置)
Nickname string // 钉钉个人昵称(用户自己设置)
Email string
DeptIDs []int64
// CorpID 不来自 staff 接口,来自 userToken不在此 struct
}
// dingTalkOAPIBase 推导钉钉旧版 OAPI base URLhost: api.dingtalk.com → oapi.dingtalk.com
// getbyunionid 与 topapi/v2/user/get 仅在旧版 OAPI 提供,不在 v1.0 OpenAPI。
func (c *DingTalkClient) dingTalkOAPIBase() string {
u, err := url.Parse(c.cfg.UserInfoURL)
if err != nil || u.Scheme == "" || u.Host == "" {
return "https://oapi.dingtalk.com"
}
host := u.Host
if strings.HasPrefix(host, "api.") {
host = "oapi." + strings.TrimPrefix(host, "api.")
}
return u.Scheme + "://" + host
}
func (c *DingTalkClient) GetAppToken(ctx context.Context) (string, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.appToken != "" && time.Now().Before(c.appTokenExp) {
return c.appToken, nil
}
body := map[string]string{"appKey": c.cfg.ClientID, "appSecret": c.cfg.ClientSecret}
payload, _ := json.Marshal(body)
// 钉钉新版 v1.0 企业内部应用 access_token: POST /v1.0/oauth2/accessToken
// 此 token 也可作为旧版 OAPI 的 access_token 使用(钉钉文档已说明)
appTokenURL := strings.Replace(c.cfg.TokenURL, "/oauth2/userAccessToken", "/oauth2/accessToken", 1)
if !strings.Contains(appTokenURL, "accessToken") && !strings.Contains(appTokenURL, "gettoken") {
appTokenURL = c.cfg.TokenURL // fallback for test stub
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, appTokenURL, bytes.NewReader(payload))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return "", parseDingTalkErr(raw, resp.StatusCode)
}
var v struct {
AccessToken string `json:"accessToken"`
ExpireIn int64 `json:"expireIn"`
}
if err := json.Unmarshal(raw, &v); err != nil {
return "", err
}
if v.AccessToken == "" {
return "", parseDingTalkErr(raw, resp.StatusCode)
}
c.appToken = v.AccessToken
ttl := v.ExpireIn
if ttl > 200 {
ttl -= 200
}
c.appTokenExp = time.Now().Add(time.Duration(ttl) * time.Second)
return c.appToken, nil
}
func (c *DingTalkClient) GetUserIdByUnionId(ctx context.Context, unionID string) (string, error) {
appToken, err := c.GetAppToken(ctx)
if err != nil {
return "", err
}
body := map[string]string{"unionid": unionID}
payload, _ := json.Marshal(body)
// 钉钉旧版 OAPI: POST https://oapi.dingtalk.com/topapi/user/getbyunionid?access_token=XXX
// access_token 通过 query string 传递(不是 header
var targetURL string
if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") {
targetURL = c.dingTalkOAPIBase() + "/topapi/user/getbyunionid?access_token=" + url.QueryEscape(appToken)
} else {
targetURL = c.cfg.UserInfoURL // fallback for test stub
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return "", parseDingTalkErr(raw, resp.StatusCode)
}
var v struct {
Result struct {
UserID string `json:"userid"`
} `json:"result"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
if err := json.Unmarshal(raw, &v); err != nil {
return "", err
}
if v.ErrCode != 0 {
return "", parseDingTalkErr(raw, resp.StatusCode)
}
if strings.TrimSpace(v.Result.UserID) == "" {
return "", parseDingTalkErr(raw, resp.StatusCode)
}
return v.Result.UserID, nil
}
// DingTalkDeptInfo 部门信息topapi/v2/department/get 返回子集)
type DingTalkDeptInfo struct {
DeptID int64
Name string
ParentID int64
}
// GetDeptInfo 查询单个部门信息(用于递归拼部门路径)。
// 调用钉钉旧版 OAPI: POST /topapi/v2/department/get?access_token=XXX
func (c *DingTalkClient) GetDeptInfo(ctx context.Context, deptID int64) (*DingTalkDeptInfo, error) {
appToken, err := c.GetAppToken(ctx)
if err != nil {
return nil, err
}
body := map[string]any{"dept_id": deptID, "language": "zh_CN"}
payload, _ := json.Marshal(body)
var targetURL string
if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") {
targetURL = c.dingTalkOAPIBase() + "/topapi/v2/department/get?access_token=" + url.QueryEscape(appToken)
} else {
targetURL = c.cfg.UserInfoURL // test stub fallback
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, parseDingTalkErr(raw, resp.StatusCode)
}
var v struct {
Result struct {
DeptID int64 `json:"dept_id"`
Name string `json:"name"`
ParentID int64 `json:"parent_id"`
} `json:"result"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
if err := json.Unmarshal(raw, &v); err != nil {
return nil, err
}
if v.ErrCode != 0 {
return nil, parseDingTalkErr(raw, resp.StatusCode)
}
return &DingTalkDeptInfo{
DeptID: v.Result.DeptID,
Name: v.Result.Name,
ParentID: v.Result.ParentID,
}, nil
}
func (c *DingTalkClient) GetStaffInfoByUserId(ctx context.Context, userID string) (*DingTalkStaffInfo, error) {
appToken, err := c.GetAppToken(ctx)
if err != nil {
return nil, err
}
body := map[string]string{"userid": userID}
payload, _ := json.Marshal(body)
// 钉钉旧版 OAPI: POST https://oapi.dingtalk.com/topapi/v2/user/get?access_token=XXX
var targetURL string
if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") {
targetURL = c.dingTalkOAPIBase() + "/topapi/v2/user/get?access_token=" + url.QueryEscape(appToken)
} else {
targetURL = c.cfg.UserInfoURL // fallback for test stub
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
raw, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, parseDingTalkErr(raw, resp.StatusCode)
}
var v struct {
Result struct {
UserID string `json:"userid"`
Name string `json:"name"`
Nickname string `json:"nickname"`
Email string `json:"email"`
OrgEmail string `json:"org_email"`
Extension string `json:"extension"`
DeptID []int64 `json:"dept_id_list"`
} `json:"result"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
if err := json.Unmarshal(raw, &v); err != nil {
return nil, err
}
if v.ErrCode != 0 {
return nil, parseDingTalkErr(raw, resp.StatusCode)
}
if strings.TrimSpace(v.Result.UserID) == "" {
return nil, parseDingTalkErr(raw, resp.StatusCode)
}
// 邮箱三级 fallbackorg_email > email > extension["企业邮箱"]钉钉自定义扩展字段JSON string
email := strings.TrimSpace(v.Result.OrgEmail)
emailSource := "org_email"
if email == "" {
email = strings.TrimSpace(v.Result.Email)
emailSource = "email"
}
extensionParsed := false
if email == "" && strings.TrimSpace(v.Result.Extension) != "" {
var ext map[string]string
if err := json.Unmarshal([]byte(v.Result.Extension), &ext); err == nil {
extensionParsed = true
if v, ok := ext["企业邮箱"]; ok {
email = strings.TrimSpace(v)
emailSource = "extension.企业邮箱"
}
}
}
if email == "" {
emailSource = "none"
}
slog.Info("dingtalk staff fetched",
"userid", v.Result.UserID,
"name_present", v.Result.Name != "",
"nickname_present", v.Result.Nickname != "",
"name_eq_nickname", v.Result.Name != "" && v.Result.Name == v.Result.Nickname,
"email_present", v.Result.Email != "",
"org_email_present", v.Result.OrgEmail != "",
"extension_present", v.Result.Extension != "",
"extension_parsed", extensionParsed,
"email_source", emailSource,
"dept_count", len(v.Result.DeptID),
)
return &DingTalkStaffInfo{
UserID: v.Result.UserID,
Name: v.Result.Name,
Nickname: v.Result.Nickname,
Email: email,
DeptIDs: v.Result.DeptID,
}, nil
}

View File

@ -0,0 +1,143 @@
package handler
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestDingTalkClient_ExchangeCodeForUserToken_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "POST", r.Method)
require.Equal(t, "/v1.0/oauth2/userAccessToken", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"accessToken":"USER_TOKEN_X","expireIn":7200,"refreshToken":"R","corpId":"dingABC"}`))
}))
defer server.Close()
cli := &DingTalkClient{
cfg: dingTalkClientConfig{
ClientID: "k", ClientSecret: "s",
TokenURL: server.URL + "/v1.0/oauth2/userAccessToken",
},
httpClient: server.Client(),
}
resp, err := cli.ExchangeCodeForUserToken(context.Background(), "AUTH_CODE")
require.NoError(t, err)
require.Equal(t, "USER_TOKEN_X", resp.AccessToken)
require.Equal(t, "dingABC", resp.CorpID)
}
func TestDingTalkClient_GetUnionIdByUserToken_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "USER_TOKEN_X", r.Header.Get("x-acs-dingtalk-access-token"))
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"nick":"张三","unionId":"UID_AAA","openId":"OPEN","avatarUrl":"http://x"}`))
}))
defer server.Close()
cli := &DingTalkClient{
cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/v1.0/contact/users/me"},
httpClient: server.Client(),
}
unionID, nick, err := cli.GetUnionIdByUserToken(context.Background(), "USER_TOKEN_X")
require.NoError(t, err)
require.Equal(t, "UID_AAA", unionID)
require.Equal(t, "张三", nick)
}
func TestDingTalkClient_GetAppToken_Cached(t *testing.T) {
callCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
_, _ = w.Write([]byte(`{"accessToken":"APP_TKN","expireIn":7200}`))
}))
defer server.Close()
cli := &DingTalkClient{
cfg: dingTalkClientConfig{ClientID: "k", ClientSecret: "s", TokenURL: server.URL + "/gettoken"},
httpClient: server.Client(),
}
t1, err := cli.GetAppToken(context.Background())
require.NoError(t, err)
t2, err := cli.GetAppToken(context.Background())
require.NoError(t, err)
require.Equal(t, t1, t2)
require.Equal(t, 1, callCount, "second call should hit cache")
}
func TestDingTalkClient_GetUserIdByUnionId_60011(t *testing.T) {
appTokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"accessToken":"APP_TKN","expireIn":7200}`))
}))
defer appTokenServer.Close()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"errcode":60011,"errmsg":"not in directory"}`))
}))
defer server.Close()
cli := &DingTalkClient{
cfg: dingTalkClientConfig{TokenURL: appTokenServer.URL + "/gettoken"},
httpClient: server.Client(),
}
cli.appToken = "APP_TKN"
cli.appTokenExp = time.Now().Add(time.Hour)
cli.cfg.UserInfoURL = server.URL + "/v1.0/contact/users/byUnionId"
_, err := cli.GetUserIdByUnionId(context.Background(), "UID_AAA")
require.Error(t, err)
apiErr, ok := err.(*DingTalkAPIError)
require.True(t, ok)
require.Equal(t, "60011", apiErr.Code)
}
// TestDingTalkClient_GetDeptInfo_Success 验证 GetDeptInfo 正常情况返回部门信息。
func TestDingTalkClient_GetDeptInfo_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"errcode":0,"errmsg":"ok","result":{"dept_id":42,"name":"AI数据","parent_id":1}}`))
}))
defer server.Close()
cli := &DingTalkClient{
cfg: dingTalkClientConfig{
UserInfoURL: server.URL + "/stub", // 不含 /contact/users/me走 test stub 路径
},
httpClient: server.Client(),
}
cli.appToken = "APP_TKN"
cli.appTokenExp = time.Now().Add(time.Hour)
info, err := cli.GetDeptInfo(context.Background(), 42)
require.NoError(t, err)
require.Equal(t, int64(42), info.DeptID)
require.Equal(t, "AI数据", info.Name)
require.Equal(t, int64(1), info.ParentID)
}
// TestDingTalkClient_GetDeptInfo_ErrCode60003 验证 errcode=60003部门不存在时返回错误。
func TestDingTalkClient_GetDeptInfo_ErrCode60003(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"dept not found"}`))
}))
defer server.Close()
cli := &DingTalkClient{
cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"},
httpClient: server.Client(),
}
cli.appToken = "APP_TKN"
cli.appTokenExp = time.Now().Add(time.Hour)
_, err := cli.GetDeptInfo(context.Background(), 999)
require.Error(t, err)
apiErr, ok := err.(*DingTalkAPIError)
require.True(t, ok)
require.Equal(t, "60003", apiErr.Code)
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,391 @@
package handler
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestDingTalkOAuthStart_Disabled は sentinel テスト。
// TODO(task-1.10): newTestAuthHandlerWithDingTalk helper が追加されたら t.Skip を外す。
func TestDingTalkOAuthStart_Disabled(t *testing.T) {
t.Skip("helper newTestAuthHandlerWithDingTalk added in Task 1.10; sentinel only")
}
// TestBuildDingTalkSyntheticEmail_UsesUnionID 验证合成邮箱种子使用 unionID。
func TestBuildDingTalkSyntheticEmail_UsesUnionID(t *testing.T) {
unionID := "union_AbCdEf123"
email := buildDingTalkSyntheticEmail(unionID)
want := "dingtalk-union_abcdef123@dingtalk-connect.invalid"
require.Equal(t, want, email)
// 确保结果都是小写(邮箱大小写不敏感,统一小写)
require.True(t, strings.ToLower(email) == email, "synthetic email should be all lowercase")
// 确保前缀正确
require.True(t, strings.HasPrefix(email, "dingtalk-"), "should have dingtalk- prefix")
// 确保后缀是合成邮箱域名
require.True(t, strings.HasSuffix(email, "@dingtalk-connect.invalid"), "should have reserved domain suffix")
}
// TestBuildDingTalkSyntheticEmail_TrimsSpace 验证 unionID 空白被修剪。
func TestBuildDingTalkSyntheticEmail_TrimsSpace(t *testing.T) {
email := buildDingTalkSyntheticEmail(" UID_XYZ ")
require.Equal(t, "dingtalk-uid_xyz@dingtalk-connect.invalid", email)
}
// TestBuildDingTalkUpstreamClaims_EmptyStaff 验证 staff 为空 struct跨组织降级路径
// - subject 等于 unionID与 identityKey.ProviderSubject 一致)
// - corp_user_id 为空字符串(跨组织时拿不到企业 userid
// - email/username 为空字符串
// B/C: Step 3/4 失败降级时 staff = &DingTalkStaffInfo{}claims 不应有 nil。
func TestBuildDingTalkUpstreamClaims_EmptyStaff(t *testing.T) {
staff := &DingTalkStaffInfo{}
claims := buildDingTalkUpstreamClaims(staff, "UNION_AAA", "CORP_X")
require.Equal(t, "", claims["email"])
require.Equal(t, "", claims["username"])
// 重构后 subject = unionID与 identityKey.ProviderSubject 保持一致)
require.Equal(t, "UNION_AAA", claims["subject"])
require.Equal(t, "", claims["corp_user_id"]) // 企业 userid 跨组织时为空
require.Equal(t, "UNION_AAA", claims["union_id"])
require.Equal(t, "CORP_X", claims["corp_id"])
}
// TestCheckDingTalkCorpAllowed_CrossOrgPolicy 验证 policy=none 时允许任意 corp。
// D: corp 校验提前后逻辑不变。
func TestCheckDingTalkCorpAllowed_CrossOrgPolicy(t *testing.T) {
cfg := config.DingTalkConnectConfig{CorpRestrictionPolicy: "none"}
assert.True(t, checkDingTalkCorpAllowed(cfg, "dingABC"), "policy=none should allow any corp")
assert.True(t, checkDingTalkCorpAllowed(cfg, ""), "policy=none should allow empty corp")
assert.True(t, checkDingTalkCorpAllowed(cfg, "foreign_corp"), "policy=none should allow foreign corp")
}
// TestCheckDingTalkCorpAllowed_InternalOnly 验证 policy=internal_only 时的 corp 校验语义(方案 A 修订)。
// 钉钉 userAccessToken 在部分授权场景(扫码登录、非企业工作台入口)不返回 corpId 字段,
// 因此 checkDingTalkCorpAllowed 完全不校验 corpID由 step 3 GetUserIdByUnionId 做真实判定
// (跨企业用户会被钉钉错误码 60011/60121 拒绝mapDingTalkErrorCode 映射回 corp_rejected
func TestCheckDingTalkCorpAllowed_InternalOnly(t *testing.T) {
cfgWithCorpID := config.DingTalkConnectConfig{
CorpRestrictionPolicy: "internal_only",
InternalCorpID: "dingInternal",
}
assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, "dingInternal"), "internal_only: matching corpID allowed")
assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, "foreign_corp"), "internal_only: corpID 字段不再用于决策step 3 兜底")
assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, ""), "internal_only: 空 corpID 也通过(钉钉部分授权场景不返回 corpId")
cfgNoCorpID := config.DingTalkConnectConfig{
CorpRestrictionPolicy: "internal_only",
InternalCorpID: "",
}
assert.True(t, checkDingTalkCorpAllowed(cfgNoCorpID, "dingAnyNonEmpty"), "internal_only + no InternalCorpID: 非空 corpID 通过")
assert.True(t, checkDingTalkCorpAllowed(cfgNoCorpID, ""), "internal_only + no InternalCorpID: 空 corpID 也通过")
}
// TestDecideDingTalkStep34Strategy_PolicyNone 验证 policy=none 时
// Step 3/4 失败应降级shouldFallback=true, isFatal=false
func TestDecideDingTalkStep34Strategy_PolicyNone(t *testing.T) {
step3Err := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403}
shouldFallback, isFatal := decideDingTalkStep34Strategy("none", step3Err)
require.True(t, shouldFallback, "policy=none: step3 failure should trigger fallback")
require.False(t, isFatal, "policy=none: step3 failure should NOT be fatal")
}
// TestDecideDingTalkStep34Strategy_PolicyNoneEmpty 验证 policy="" 时行为与 "none" 相同。
func TestDecideDingTalkStep34Strategy_PolicyNoneEmpty(t *testing.T) {
stepErr := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403}
shouldFallback, isFatal := decideDingTalkStep34Strategy("", stepErr)
require.True(t, shouldFallback, "policy='': step failure should trigger fallback")
require.False(t, isFatal, "policy='': step failure should NOT be fatal")
}
// TestDecideDingTalkStep34Strategy_PolicyInternalOnly 验证 policy=internal_only 时
// Step 3/4 失败应 hard failisFatal=true
func TestDecideDingTalkStep34Strategy_PolicyInternalOnly(t *testing.T) {
step3Err := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403}
shouldFallback, isFatal := decideDingTalkStep34Strategy("internal_only", step3Err)
require.False(t, shouldFallback, "policy=internal_only: should NOT fallback on step3 error")
require.True(t, isFatal, "policy=internal_only: step3 failure should be fatal")
}
// TestDecideDingTalkStep34Strategy_NoError 验证 stepErr=nil 时两个返回值均为 false。
func TestDecideDingTalkStep34Strategy_NoError(t *testing.T) {
for _, policy := range []string{"none", "internal_only", ""} {
shouldFallback, isFatal := decideDingTalkStep34Strategy(policy, nil)
require.False(t, shouldFallback, "no error should not trigger fallback (policy=%q)", policy)
require.False(t, isFatal, "no error should not be fatal (policy=%q)", policy)
}
}
// TestCompleteDingTalkRegistration_UsernameFromEmailLocalPart 验证 username 为空时
// 退到 email local part@ 之前的部分)。
// E: CompleteDingTalkOAuthRegistration username fallback。
func TestCompleteDingTalkRegistration_UsernameFromEmailLocalPart(t *testing.T) {
tests := []struct {
name string
email string
username string
wantUser string
wantValid bool
}{
{
name: "username empty, normal email → local part",
email: "dingtalk-uid123@dingtalk-connect.invalid",
username: "",
wantUser: "dingtalk-uid123",
wantValid: true,
},
{
name: "username already set → keep original",
email: "user@example.com",
username: "张三",
wantUser: "张三",
wantValid: true,
},
{
name: "username empty, no @ in email → use whole email",
email: "noemail",
username: "",
wantUser: "noemail",
wantValid: true,
},
{
name: "both empty → invalid",
email: "",
username: "",
wantUser: "",
wantValid: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
username := tc.username
email := tc.email
// 模拟 CompleteDingTalkOAuthRegistration 中的 fallback 逻辑
if username == "" {
if at := strings.Index(email, "@"); at > 0 {
username = email[:at]
} else {
username = email
}
}
isValid := email != "" && username != ""
require.Equal(t, tc.wantUser, username, fmt.Sprintf("username for email=%q", tc.email))
require.Equal(t, tc.wantValid, isValid, "validity check")
})
}
}
// TestBuildDingTalkUpstreamClaims_SubjectEqualsUnionID 验证重构后 subject = unionID
// 而非 staff.UserID与 identityKey.ProviderSubject 保持一致。
// §4.2: buildDingTalkUpstreamClaims subject 字段修正。
func TestBuildDingTalkUpstreamClaims_SubjectEqualsUnionID(t *testing.T) {
staff := &DingTalkStaffInfo{UserID: "user123", Name: "张三", Email: "zhangsan@corp.com"}
claims := buildDingTalkUpstreamClaims(staff, "union456", "dingcorp789")
// 重构后 subject = unionID全局唯一与 identityKey.ProviderSubject 一致)
require.Equal(t, "union456", claims["subject"], "subject should equal unionID after refactor")
// 企业 userid 保留为独立字段,供 audit/debug 使用
require.Equal(t, "user123", claims["corp_user_id"], "corp_user_id should be staff.UserID")
// union_id 字段与 subject 相同(冗余保留,便于读取)
require.Equal(t, "union456", claims["union_id"])
require.Equal(t, "dingcorp789", claims["corp_id"])
require.Equal(t, "张三", claims["username"])
require.Equal(t, "zhangsan@corp.com", claims["email"])
}
// TestBuildDingTalkUpstreamClaims_CrossOrgEmptyCorpUserID 验证跨组织降级时
// corp_user_id 为空字符串(跨组织拿不到企业 useridsubject 仍为 unionID。
func TestBuildDingTalkUpstreamClaims_CrossOrgEmptyCorpUserID(t *testing.T) {
// 跨组织降级路径staff = &DingTalkStaffInfo{}(所有字段为零值)
staff := &DingTalkStaffInfo{}
claims := buildDingTalkUpstreamClaims(staff, "union_cross_org", "foreign_corp")
require.Equal(t, "union_cross_org", claims["subject"], "subject should still be unionID for cross-org users")
require.Equal(t, "", claims["corp_user_id"], "corp_user_id should be empty for cross-org fallback")
require.Equal(t, "", claims["email"])
require.Equal(t, "", claims["username"])
}
// TestBuildDingTalkUpstreamClaims_PrimaryDeptIDInClaims 验证首个 dept_id 被存入 claims。
func TestBuildDingTalkUpstreamClaims_PrimaryDeptIDInClaims(t *testing.T) {
staff := &DingTalkStaffInfo{UserID: "u1", Name: "张三", Email: "a@b.com", DeptIDs: []int64{42, 99}}
claims := buildDingTalkUpstreamClaims(staff, "uid1", "corpX")
// 只取首个 dept_id
require.Equal(t, int64(42), claims["primary_dept_id"], "primary_dept_id should be the first dept_id")
}
// TestBuildDingTalkUpstreamClaims_NoDeptIDs 验证无部门时 primary_dept_id=0。
func TestBuildDingTalkUpstreamClaims_NoDeptIDs(t *testing.T) {
staff := &DingTalkStaffInfo{UserID: "u2", Name: "李四"}
claims := buildDingTalkUpstreamClaims(staff, "uid2", "corpY")
require.Equal(t, int64(0), claims["primary_dept_id"], "primary_dept_id should be 0 when no depts")
}
// TestDingTalkStaffFromClaims_RoundTrip 验证 dingTalkStaffFromClaims 能从 claims 恢复 staff 信息。
func TestDingTalkStaffFromClaims_RoundTrip(t *testing.T) {
staff := &DingTalkStaffInfo{UserID: "u3", Name: "王五", Email: "ww@corp.com", DeptIDs: []int64{55}}
claims := buildDingTalkUpstreamClaims(staff, "uid3", "corpZ")
recovered := dingTalkStaffFromClaims(claims)
require.Equal(t, "王五", recovered.Name)
require.Equal(t, "ww@corp.com", recovered.Email)
require.Equal(t, "u3", recovered.UserID)
require.Equal(t, []int64{55}, recovered.DeptIDs)
}
// TestResolveDingTalkDeptPath_SingleLevel 验证单层部门parent_id=1返回部门名。
func TestResolveDingTalkDeptPath_SingleLevel(t *testing.T) {
handler := &AuthHandler{}
callCount := 0
responses := map[string]string{
"42": `{"errcode":0,"result":{"dept_id":42,"name":"研发部","parent_id":1}}`,
"1": `{"errcode":0,"result":{"dept_id":1,"name":"公司","parent_id":0}}`,
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
var req struct {
DeptID int64 `json:"dept_id"`
}
_ = json.NewDecoder(r.Body).Decode(&req)
w.Header().Set("Content-Type", "application/json")
if resp, ok := responses[fmt.Sprintf("%d", req.DeptID)]; ok {
_, _ = w.Write([]byte(resp))
} else {
_, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"not found"}`))
}
}))
defer server.Close()
cli := &DingTalkClient{
cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"},
httpClient: server.Client(),
}
cli.appToken = "tok"
cli.appTokenExp = time.Now().Add(time.Hour)
path, err := handler.resolveDingTalkDeptPath(context.Background(), cli, 42)
require.NoError(t, err)
require.Equal(t, "研发部", path)
require.Equal(t, 2, callCount)
}
// TestSyncDingTalkIdentity_UsesCfgAttrKeys 验证 syncDingTalkIdentity 使用 cfg 中配置的 attr key
// 而不是硬编码值。通过 userAttributeService=nil 使同步路径走 warn 跳过,但在此之前先验证
// syncField 构建逻辑(即 attr key 从 cfg 读取)。
// 间接验证:通过构造定制 cfg确认不同 attr key 可以正确传入(编译时保证类型正确,运行时不 panic
func TestSyncDingTalkIdentity_UsesCfgAttrKeys_NoopWithNilService(t *testing.T) {
handler := &AuthHandler{
userAttributeService: nil, // nil → 触发 warn 跳过,但不 panic
}
cfg := config.DingTalkConnectConfig{
CorpRestrictionPolicy: "internal_only",
SyncCorpEmail: true,
SyncDisplayName: true,
SyncDept: true,
// 自定义 attr key非默认值
SyncCorpEmailAttrKey: "custom_email_key",
SyncDisplayNameAttrKey: "custom_name_key",
SyncDeptAttrKey: "custom_dept_key",
}
staff := &DingTalkStaffInfo{
Name: "张三",
Email: "zhangsan@example.com",
}
// 调用不应 panicuserAttributeService 为 nil 时走 warn 跳过路径)
require.NotPanics(t, func() {
handler.syncDingTalkIdentity(context.Background(), cfg, nil, 42, staff, false)
})
}
// TestSyncDingTalkIdentity_DefaultAttrKeys_NoopWithNilService 验证 cfg 默认 attr key 为空时
// 使用 fallback 默认值dingtalk_email / dingtalk_name / dingtalk_department
// 此测试主要验证调用路径不 panic实际 key 赋值默认值的逻辑在 GetDingTalkConnectOAuthConfig 层。
func TestSyncDingTalkIdentity_DefaultAttrKeys_NoopWithNilService(t *testing.T) {
handler := &AuthHandler{
userAttributeService: nil,
}
cfg := config.DingTalkConnectConfig{
CorpRestrictionPolicy: "internal_only",
SyncCorpEmail: true,
SyncDisplayName: true,
SyncDept: false,
// 不设置 attr key等同于 GetDingTalkConnectOAuthConfig 未设置时 fallback 后的默认值已在调用前填充)
SyncCorpEmailAttrKey: "dingtalk_email",
SyncDisplayNameAttrKey: "dingtalk_name",
SyncDeptAttrKey: "dingtalk_department",
}
staff := &DingTalkStaffInfo{
Name: "李四",
Email: "lisi@corp.com",
}
require.NotPanics(t, func() {
handler.syncDingTalkIdentity(context.Background(), cfg, nil, 99, staff, false)
})
}
// TestResolveDingTalkDeptPath_MultiLevel 验证多层部门路径拼接。
func TestResolveDingTalkDeptPath_MultiLevel(t *testing.T) {
handler := &AuthHandler{}
// 模拟42(AI研发) → parent=10(研发部) → parent=1(根)
responses := map[string]string{
"42": `{"errcode":0,"result":{"dept_id":42,"name":"AI研发","parent_id":10}}`,
"10": `{"errcode":0,"result":{"dept_id":10,"name":"研发部","parent_id":1}}`,
"1": `{"errcode":0,"result":{"dept_id":1,"name":"公司","parent_id":0}}`,
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 解析请求 body 拿到 dept_id
var req struct {
DeptID int64 `json:"dept_id"`
}
_ = json.NewDecoder(r.Body).Decode(&req)
key := fmt.Sprintf("%d", req.DeptID)
w.Header().Set("Content-Type", "application/json")
if resp, ok := responses[key]; ok {
_, _ = w.Write([]byte(resp))
} else {
_, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"not found"}`))
}
}))
defer server.Close()
cli := &DingTalkClient{
cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"},
httpClient: server.Client(),
}
cli.appToken = "tok"
cli.appTokenExp = time.Now().Add(time.Hour)
path, err := handler.resolveDingTalkDeptPath(context.Background(), cli, 42)
require.NoError(t, err)
require.Equal(t, "研发部/AI研发", path)
}

View File

@ -4,6 +4,7 @@ import (
"context"
"log/slog"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@ -18,25 +19,30 @@ import (
// AuthHandler handles authentication-related requests
type AuthHandler struct {
cfg *config.Config
authService *service.AuthService
userService *service.UserService
settingSvc *service.SettingService
promoService *service.PromoService
redeemService *service.RedeemService
totpService *service.TotpService
cfg *config.Config
authService *service.AuthService
userService *service.UserService
settingSvc *service.SettingService
promoService *service.PromoService
redeemService *service.RedeemService
totpService *service.TotpService
userAttributeService *service.UserAttributeService
dingTalkClientInstance *DingTalkClient
dingTalkClientMu sync.Mutex
}
// NewAuthHandler creates a new AuthHandler
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService) *AuthHandler {
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService, userAttributeService *service.UserAttributeService) *AuthHandler {
return &AuthHandler{
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
promoService: promoService,
redeemService: redeemService,
totpService: totpService,
cfg: cfg,
authService: authService,
userService: userService,
settingSvc: settingService,
promoService: promoService,
redeemService: redeemService,
totpService: totpService,
userAttributeService: userAttributeService,
}
}

View File

@ -350,7 +350,8 @@ func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email stri
if email == "" ||
strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.DingTalkConnectSyntheticEmailDomain) {
return nil, nil
}
@ -519,7 +520,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "linuxdo")
if err != nil {
response.ErrorFrom(c, err)
return

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strings"
@ -195,6 +196,14 @@ func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPen
},
})
if err != nil {
slog.Error("pending auth session create failed",
"intent", strings.TrimSpace(payload.Intent),
"provider_type", strings.TrimSpace(payload.Identity.ProviderType),
"provider_key", strings.TrimSpace(payload.Identity.ProviderKey),
"provider_subject_len", len(strings.TrimSpace(payload.Identity.ProviderSubject)),
"resolved_email_len", len(strings.TrimSpace(payload.ResolvedEmail)),
"has_target_user", payload.TargetUserID != nil,
"error", err.Error())
return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err)
}
@ -266,6 +275,22 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
}
// pendingSessionRequiresEmailCompletion 判断 callback 写入的 completion payload 是否处于"补邮箱"状态。
// 钉钉跨组织/staff 邮箱缺失时进入此状态前端跳到补邮箱页exchange 不应走 adoption apply。
func pendingSessionRequiresEmailCompletion(payload map[string]any) bool {
if v, ok := payload["requires_email_completion"].(bool); ok && v {
return true
}
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "email_completion")
}
// pendingSessionRequiresBindLogin 判断 callback 写入的 completion payload 是否处于"必须绑定已有账户"状态。
// 钉钉 signupBlocked=true注册关 + 钉钉企业豁免关)时进入此状态:前端渲染 bind_login 表单,
// exchange 不应消费 session否则后续 /pending/bind-login 找不到 session。
func pendingSessionRequiresBindLogin(payload map[string]any) bool {
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "bind_login_required")
}
func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool {
if session == nil {
return false
@ -1467,8 +1492,10 @@ func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]
delete(normalized, key)
}
step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step")))
// 把多种 choice 别名归一为 oauthPendingChoiceStepbind_login_required 是独立终态
// (前端渲染 needsBindLogin 而非 needsChooser故不能并入归一化列表。
switch step {
case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required":
case "choice", "choose_account_action", "choose_account", "choose", "email_required":
normalized["step"] = oauthPendingChoiceStep
}
if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(normalized, "step")), oauthPendingChoiceStep) {
@ -1594,6 +1621,8 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
// bindPendingOAuthLogin = 绑定已有账户登录,不动 users.username用户已有自己的名字
h.maybeSyncDingTalkAfterLogin(c.Request.Context(), session, user.ID)
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
if err != nil {
response.InternalError(c, "Failed to generate token pair")
@ -1792,6 +1821,8 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
}
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
// createPendingOAuthAccount = 注册新账户,需要把钉钉昵称同步到 users.username 作为初始值
h.maybeSyncDingTalkAfterRegistration(c.Request.Context(), session, user.ID)
clearCookies()
writeOAuthTokenPairResponse(c, tokenPair)
}
@ -1893,6 +1924,14 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
response.Success(c, payload)
return
}
if pendingSessionRequiresEmailCompletion(payload) {
response.Success(c, payload)
return
}
if pendingSessionRequiresBindLogin(payload) {
response.Success(c, payload)
return
}
if !adoptionDecision.hasDecision() {
adoptionRequired, _ := payload["adoption_required"].(bool)
if adoptionRequired {

View File

@ -502,7 +502,8 @@ func (h *AuthHandler) findOIDCCompatEmailUser(ctx context.Context, email string)
if email == "" ||
strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) ||
strings.HasSuffix(email, service.DingTalkConnectSyntheticEmailDomain) {
return nil, nil
}
@ -666,7 +667,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "oidc")
if err != nil {
response.ErrorFrom(c, err)
return

View File

@ -548,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
return
}
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "wechat")
if err != nil {
response.ErrorFrom(c, err)
return

View File

@ -0,0 +1,67 @@
package dto
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func TestAccountFromServiceShallow_RedactsSensitiveCredentials(t *testing.T) {
src := &service.Account{
ID: 42,
Name: "demo",
Platform: "anthropic",
Type: "oauth",
Credentials: map[string]any{
"access_token": "at-secret",
"refresh_token": "rt-secret",
"id_token": "id-secret",
"api_key": "sk-secret",
"base_url": "https://api.example.com",
"model_mapping": map[string]any{"foo": "bar"},
},
}
got := AccountFromServiceShallow(src)
require.NotNil(t, got)
// 敏感键不在 Credentials 里
require.NotContains(t, got.Credentials, "access_token")
require.NotContains(t, got.Credentials, "refresh_token")
require.NotContains(t, got.Credentials, "id_token")
require.NotContains(t, got.Credentials, "api_key")
// 非敏感键保留
require.Equal(t, "https://api.example.com", got.Credentials["base_url"])
require.Equal(t, map[string]any{"foo": "bar"}, got.Credentials["model_mapping"])
// 状态 map 标记敏感键存在
require.True(t, got.CredentialsStatus["has_access_token"])
require.True(t, got.CredentialsStatus["has_refresh_token"])
require.True(t, got.CredentialsStatus["has_id_token"])
require.True(t, got.CredentialsStatus["has_api_key"])
// JSON 序列化校验:响应体里不会出现敏感子串
raw, err := json.Marshal(got)
require.NoError(t, err)
require.NotContains(t, string(raw), "rt-secret")
require.NotContains(t, string(raw), "at-secret")
require.NotContains(t, string(raw), "sk-secret")
require.NotContains(t, string(raw), "id-secret")
// 状态标识应序列化进 JSON
require.Contains(t, string(raw), "credentials_status")
require.Contains(t, string(raw), "has_refresh_token")
// 原始 service.Account 不应被改动
require.Equal(t, "rt-secret", src.Credentials["refresh_token"])
}
func TestAccountFromServiceShallow_NilCredentialsOmitsStatus(t *testing.T) {
src := &service.Account{ID: 1, Name: "n", Platform: "anthropic", Type: "oauth"}
got := AccountFromServiceShallow(src)
require.NotNil(t, got)
require.Nil(t, got.Credentials)
require.Nil(t, got.CredentialsStatus)
}

View File

@ -0,0 +1,44 @@
// Package dto provides data transfer objects for HTTP handlers.
package dto
import "github.com/Wei-Shaw/sub2api/internal/service"
// RedactCredentials 复制一份 in剥离 service.SensitiveCredentialKeys 列出的所有敏感子键,
// 并产出一个 has_<key> 状态 map 表示哪些敏感键存在且非零值。
//
// 输入 nil 时返回 nil, nil避免响应里出现空对象
// 不修改入参;调用方拿到的 out 可安全序列化进 JSON 返回前端。
func RedactCredentials(in map[string]any) (out map[string]any, status map[string]bool) {
if in == nil {
return nil, nil
}
out = make(map[string]any, len(in))
for k, v := range in {
if service.IsSensitiveCredentialKey(k) {
if isCredentialValuePresent(v) {
if status == nil {
status = make(map[string]bool, 4)
}
status["has_"+k] = true
}
continue
}
out[k] = v
}
return out, status
}
// isCredentialValuePresent 判断值是否"存在且非零"。空字符串、nil、false 均视为未配置;
// 其余非零类型(数字、对象、字符串等)视为已配置。
func isCredentialValuePresent(v any) bool {
switch x := v.(type) {
case nil:
return false
case string:
return x != ""
case bool:
return x
default:
return true
}
}

View File

@ -0,0 +1,97 @@
package dto
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestRedactCredentials_NilInput(t *testing.T) {
out, status := RedactCredentials(nil)
require.Nil(t, out)
require.Nil(t, status)
}
func TestRedactCredentials_StripsSensitiveKeysAndReportsStatus(t *testing.T) {
in := map[string]any{
"refresh_token": "rt-secret",
"access_token": "at-secret",
"api_key": "sk-secret",
"aws_secret_access_key": "aws-secret",
"service_account_json": map[string]any{"private_key": "..."},
"private_key": "raw-key",
// 非敏感
"base_url": "https://api.example.com",
"model_mapping": map[string]any{"foo": "bar"},
"project_id": "proj-1",
"expires_at": int64(123456),
}
out, status := RedactCredentials(in)
require.NotContains(t, out, "refresh_token")
require.NotContains(t, out, "access_token")
require.NotContains(t, out, "api_key")
require.NotContains(t, out, "aws_secret_access_key")
require.NotContains(t, out, "service_account_json")
require.NotContains(t, out, "private_key")
require.Equal(t, "https://api.example.com", out["base_url"])
require.Equal(t, map[string]any{"foo": "bar"}, out["model_mapping"])
require.Equal(t, "proj-1", out["project_id"])
require.Equal(t, int64(123456), out["expires_at"])
require.True(t, status["has_refresh_token"])
require.True(t, status["has_access_token"])
require.True(t, status["has_api_key"])
require.True(t, status["has_aws_secret_access_key"])
require.True(t, status["has_service_account_json"])
require.True(t, status["has_private_key"])
// 状态 map 不应携带非敏感键的 has_*
require.NotContains(t, status, "has_base_url")
require.NotContains(t, status, "has_project_id")
}
func TestRedactCredentials_EmptyValuesNotMarkedPresent(t *testing.T) {
in := map[string]any{
"refresh_token": "",
"access_token": nil,
"api_key": false,
"id_token": "actual-id",
}
out, status := RedactCredentials(in)
require.Empty(t, out, "敏感键即使为空也不应出现在 redacted output")
require.False(t, status["has_refresh_token"])
require.False(t, status["has_access_token"])
require.False(t, status["has_api_key"])
require.True(t, status["has_id_token"])
}
func TestRedactCredentials_DoesNotMutateInput(t *testing.T) {
in := map[string]any{
"refresh_token": "secret",
"base_url": "x",
}
_, _ = RedactCredentials(in)
require.Equal(t, "secret", in["refresh_token"], "原始 map 不应被修改")
require.Equal(t, "x", in["base_url"])
}
func TestRedactCredentials_AllKnownSensitiveKeys(t *testing.T) {
keys := []string{
"access_token", "refresh_token", "id_token",
"api_key", "session_key", "cookie",
"aws_secret_access_key", "aws_session_token",
"service_account_json", "service_account", "private_key",
}
in := make(map[string]any, len(keys))
for _, k := range keys {
in[k] = "filled"
}
out, status := RedactCredentials(in)
require.Empty(t, out)
for _, k := range keys {
require.True(t, status["has_"+k], "key %s 应在 status 中标记为已配置", k)
}
}

View File

@ -198,13 +198,15 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if a == nil {
return nil
}
redactedCreds, credsStatus := RedactCredentials(a.Credentials)
out := &Account{
ID: a.ID,
Name: a.Name,
Notes: a.Notes,
Platform: a.Platform,
Type: a.Type,
Credentials: a.Credentials,
Credentials: redactedCreds,
CredentialsStatus: credsStatus,
Extra: a.Extra,
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
@ -531,11 +533,15 @@ func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
UsedBy: rc.UsedBy,
UsedAt: rc.UsedAt,
CreatedAt: rc.CreatedAt,
ExpiresAt: rc.ExpiresAt,
GroupID: rc.GroupID,
ValidityDays: rc.ValidityDays,
User: UserFromServiceShallow(rc.User),
Group: GroupFromServiceShallow(rc.Group),
}
if rc.IsExpired() {
out.Status = service.StatusExpired
}
// For admin_balance/admin_concurrency types, include notes so users can see
// why they were charged or credited by admin
@ -600,6 +606,10 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
FirstTokenMs: l.FirstTokenMs,
ImageCount: l.ImageCount,
ImageSize: l.ImageSize,
ImageInputSize: l.ImageInputSize,
ImageOutputSize: l.ImageOutputSize,
ImageSizeSource: l.ImageSizeSource,
ImageSizeBreakdown: l.ImageSizeBreakdown,
MediaType: l.MediaType,
UserAgent: l.UserAgent,
CacheTTLOverridden: l.CacheTTLOverridden,

View File

@ -148,6 +148,65 @@ func TestUsageLogFromService_FallsBackToLegacyModelWhenRequestedModelMissing(t *
require.Equal(t, "claude-3", adminDTO.Model)
}
func TestUsageLogFromService_IncludesImageBillingMetadataForUserAndAdmin(t *testing.T) {
t.Parallel()
imageSize := "4K"
inputSize := "1024x1024"
outputSize := "3840x2160"
source := "output"
log := &service.UsageLog{
RequestID: "req_image_metadata",
Model: "gpt-image-2",
ImageCount: 2,
ImageSize: &imageSize,
ImageInputSize: &inputSize,
ImageOutputSize: &outputSize,
ImageSizeSource: &source,
ImageSizeBreakdown: map[string]int{"4K": 2},
}
userDTO := UsageLogFromService(log)
adminDTO := UsageLogFromServiceAdmin(log)
for _, got := range []*UsageLog{userDTO, &adminDTO.UsageLog} {
require.Equal(t, 2, got.ImageCount)
require.NotNil(t, got.ImageSize)
require.Equal(t, imageSize, *got.ImageSize)
require.NotNil(t, got.ImageInputSize)
require.Equal(t, inputSize, *got.ImageInputSize)
require.NotNil(t, got.ImageOutputSize)
require.Equal(t, outputSize, *got.ImageOutputSize)
require.NotNil(t, got.ImageSizeSource)
require.Equal(t, source, *got.ImageSizeSource)
require.Equal(t, map[string]int{"4K": 2}, got.ImageSizeBreakdown)
}
}
func TestUsageLogFromService_PreservesHistoricalMissingImageSize(t *testing.T) {
t.Parallel()
log := &service.UsageLog{
RequestID: "req_legacy_image_missing_size",
Model: "gpt-image-2",
ImageCount: 1,
ImageSize: nil,
}
dto := UsageLogFromService(log)
require.Equal(t, 1, dto.ImageCount)
require.Nil(t, dto.ImageSize)
require.Nil(t, dto.ImageInputSize)
require.Nil(t, dto.ImageOutputSize)
require.Nil(t, dto.ImageSizeSource)
require.Nil(t, dto.ImageSizeBreakdown)
body, err := json.Marshal(dto)
require.NoError(t, err)
require.Contains(t, string(body), `"image_size":null`)
require.NotContains(t, string(body), `"image_size":"2K"`)
}
func f64Ptr(value float64) *float64 {
return &value
}

View File

@ -56,6 +56,23 @@ type SystemSettings struct {
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
DingTalkConnectEnabled bool `json:"dingtalk_connect_enabled"`
DingTalkConnectClientID string `json:"dingtalk_connect_client_id"`
DingTalkConnectClientSecretConfigured bool `json:"dingtalk_connect_client_secret_configured"`
DingTalkConnectRedirectURL string `json:"dingtalk_connect_redirect_url"`
DingTalkConnectCorpRestrictionPolicy string `json:"dingtalk_connect_corp_restriction_policy"`
DingTalkConnectInternalCorpID string `json:"dingtalk_connect_internal_corp_id"`
DingTalkConnectBypassRegistration bool `json:"dingtalk_connect_bypass_registration"`
DingTalkConnectSyncCorpEmail bool `json:"dingtalk_connect_sync_corp_email"`
DingTalkConnectSyncDisplayName bool `json:"dingtalk_connect_sync_display_name"`
DingTalkConnectSyncDept bool `json:"dingtalk_connect_sync_dept"`
DingTalkConnectSyncCorpEmailAttrKey string `json:"dingtalk_connect_sync_corp_email_attr_key"`
DingTalkConnectSyncDisplayNameAttrKey string `json:"dingtalk_connect_sync_display_name_attr_key"`
DingTalkConnectSyncDeptAttrKey string `json:"dingtalk_connect_sync_dept_attr_key"`
DingTalkConnectSyncCorpEmailAttrName string `json:"dingtalk_connect_sync_corp_email_attr_name"`
DingTalkConnectSyncDisplayNameAttrName string `json:"dingtalk_connect_sync_display_name_attr_name"`
DingTalkConnectSyncDeptAttrName string `json:"dingtalk_connect_sync_dept_attr_name"`
WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
WeChatConnectAppID string `json:"wechat_connect_app_id"`
WeChatConnectAppSecretConfigured bool `json:"wechat_connect_app_secret_configured"`
@ -201,6 +218,9 @@ type SystemSettings struct {
PaymentCancelRateLimitUnit string `json:"payment_cancel_rate_limit_unit"`
PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"`
// Force Alipay mobile clients to use QR code payment instead of mobile redirect
PaymentAlipayForceQRCode bool `json:"payment_alipay_force_qrcode"`
// Balance low notification
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
@ -260,6 +280,7 @@ type PublicSettings struct {
TablePageSizeOptions []int `json:"table_page_size_options"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
DingTalkOAuthEnabled bool `json:"dingtalk_oauth_enabled"`
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`

View File

@ -149,25 +149,28 @@ type AdminGroup struct {
}
type Account struct {
ID int64 `json:"id"`
Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
LoadFactor *int `json:"load_factor,omitempty"`
Priority int `json:"priority"`
RateMultiplier float64 `json:"rate_multiplier"`
Status string `json:"status"`
ErrorMessage string `json:"error_message"`
LastUsedAt *time.Time `json:"last_used_at"`
ExpiresAt *int64 `json:"expires_at"`
AutoPauseOnExpired bool `json:"auto_pause_on_expired"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID int64 `json:"id"`
Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"`
Type string `json:"type"`
// Credentials 经 RedactCredentials 处理后只含非敏感子键;敏感 token / api_key / 私钥
// 的存在性通过 CredentialsStatushas_<key>)暴露,原始值不返回前端。
Credentials map[string]any `json:"credentials"`
CredentialsStatus map[string]bool `json:"credentials_status,omitempty"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
LoadFactor *int `json:"load_factor,omitempty"`
Priority int `json:"priority"`
RateMultiplier float64 `json:"rate_multiplier"`
Status string `json:"status"`
ErrorMessage string `json:"error_message"`
LastUsedAt *time.Time `json:"last_used_at"`
ExpiresAt *int64 `json:"expires_at"`
AutoPauseOnExpired bool `json:"auto_pause_on_expired"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Schedulable bool `json:"schedulable"`
@ -335,6 +338,7 @@ type RedeemCode struct {
UsedBy *int64 `json:"used_by"`
UsedAt *time.Time `json:"used_at"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
GroupID *int64 `json:"group_id"`
ValidityDays int `json:"validity_days"`
@ -400,9 +404,13 @@ type UsageLog struct {
FirstTokenMs *int `json:"first_token_ms"`
// 图片生成字段
ImageCount int `json:"image_count"`
ImageSize *string `json:"image_size"`
MediaType *string `json:"media_type"`
ImageCount int `json:"image_count"`
ImageSize *string `json:"image_size"`
ImageInputSize *string `json:"image_input_size"`
ImageOutputSize *string `json:"image_output_size"`
ImageSizeSource *string `json:"image_size_source"`
ImageSizeBreakdown map[string]int `json:"image_size_breakdown"`
MediaType *string `json:"media_type"`
// User-Agent
UserAgent *string `json:"user_agent"`

View File

@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@ -157,7 +158,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
setOpsRequestContext(c, "", false, body)
setOpsRequestContext(c, "", false)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
if err != nil {
@ -209,7 +210,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsRequestContext(c, reqModel, reqStream)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 验证 model 必填
@ -351,6 +352,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
reqLog.Warn("gateway.select_account_no_available",
zap.String("model", reqModel),
zap.Int64p("group_id", apiKey.GroupID),
@ -400,6 +402,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
markOpsRoutingCapacityLimited(c)
reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
zap.Int64("account_id", account.ID),
zap.String("model", reqModel),
@ -621,6 +624,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
reqLog.Warn("gateway.select_account_no_available",
zap.String("model", reqModel),
zap.Int64p("group_id", currentAPIKey.GroupID),
@ -681,6 +685,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
markOpsRoutingCapacityLimited(c)
reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
zap.Int64("account_id", account.ID),
zap.String("model", reqModel),
@ -1041,8 +1046,8 @@ func (h *GatewayHandler) Models(c *gin.Context) {
platform = forcedPlatform
}
// Get available models from account configurations (without platform filter)
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
// Get available models from account configurations for the selected group platform.
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, platform)
if len(availableModels) > 0 {
// Build model list from whitelist
@ -1063,7 +1068,7 @@ func (h *GatewayHandler) Models(c *gin.Context) {
}
// Fallback to default models
if platform == "openai" {
if platform == service.PlatformOpenAI {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": openai.DefaultModels,
@ -1071,6 +1076,14 @@ func (h *GatewayHandler) Models(c *gin.Context) {
return
}
if platform == service.PlatformGemini {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": geminicli.DefaultModels,
})
return
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": claude.DefaultModels,
@ -1407,6 +1420,11 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
if service.IsOpenAISilentRefusalErrorBody(responseBody) {
service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage(), streamStarted)
return
}
// 先检查透传规则
if h.errorPassthroughService != nil && len(responseBody) > 0 {
@ -1611,7 +1629,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
setOpsRequestContext(c, "", false, body)
setOpsRequestContext(c, "", false)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
if err != nil {
@ -1630,7 +1648,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body)
setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsedReq.Stream, false)))
// 获取订阅信息可能为nil
@ -1659,6 +1677,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil {
reqLog.Warn("gateway.count_tokens_select_account_failed", zap.Error(err))
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
return
}

View File

@ -60,7 +60,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
return
}
setOpsRequestContext(c, "", false, body)
setOpsRequestContext(c, "", false)
// Validate JSON
if !gjson.ValidBytes(body) {
@ -78,7 +78,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsRequestContext(c, reqModel, reqStream)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
@ -161,14 +161,26 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
APIKeyID: apiKey.ID,
}
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
groupPlatform := ""
if apiKey.Group != nil {
groupPlatform = apiKey.Group.Platform
}
selectionSessionHash := sessionHash
if groupPlatform == service.PlatformGemini && selectionSessionHash != "" {
selectionSessionHash = "gemini:" + selectionSessionHash
}
// 3. Account selection + failover loop
fs := NewFailoverState(h.maxAccountSwitches, false)
if groupPlatform == service.PlatformGemini {
fs = NewFailoverState(h.maxAccountSwitchesGemini, false)
}
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, selectionSessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
@ -194,6 +206,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
markOpsRoutingCapacityLimited(c)
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
@ -213,13 +226,33 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
}
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
if groupPlatform == service.PlatformGemini && account.Platform != service.PlatformGemini {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
fs.FailedAccountIDs[account.ID] = struct{}{}
continue
}
// 5. Forward request
writerSizeBeforeForward := c.Writer.Size()
forwardBody := body
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
var result *service.ForwardResult
if account.Platform == service.PlatformGemini {
if h.geminiCompatService == nil {
h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "upstream_error", "Gemini compatibility service is not configured")
if accountReleaseFunc != nil {
accountReleaseFunc()
}
return
}
result, err = h.geminiCompatService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody)
} else {
result, err = h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
}
if accountReleaseFunc != nil {
accountReleaseFunc()
@ -302,5 +335,10 @@ func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *serv
if lastErr != nil && lastErr.StatusCode > 0 {
statusCode = lastErr.StatusCode
}
if lastErr != nil && service.IsOpenAISilentRefusalErrorBody(lastErr.ResponseBody) {
service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage())
return
}
h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
}

View File

@ -60,7 +60,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
return
}
setOpsRequestContext(c, "", false, body)
setOpsRequestContext(c, "", false)
// Validate JSON
if !gjson.ValidBytes(body) {
@ -78,7 +78,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
reqStream := gjson.GetBytes(body, "stream").Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsRequestContext(c, reqModel, reqStream)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
// 解析渠道级模型映射
@ -174,6 +174,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
return
}
@ -199,6 +200,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
markOpsRoutingCapacityLimited(c)
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
return
}
@ -308,5 +310,10 @@ func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastEr
if lastErr != nil && lastErr.StatusCode > 0 {
statusCode = lastErr.StatusCode
}
if lastErr != nil && service.IsOpenAISilentRefusalErrorBody(lastErr.ResponseBody) {
service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
h.responsesErrorResponse(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage())
return
}
h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
}

View File

@ -0,0 +1,136 @@
package handler
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type gatewayModelsAccountRepoStub struct {
service.AccountRepository
byGroup map[int64][]service.Account
}
type gatewayModelsResponseForTest struct {
Object string `json:"object"`
Data []gatewayModelItemForTest `json:"data"`
}
type gatewayModelItemForTest struct {
ID string `json:"id"`
}
func (s *gatewayModelsAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
accounts, ok := s.byGroup[groupID]
if !ok {
return nil, nil
}
out := make([]service.Account, len(accounts))
copy(out, accounts)
return out, nil
}
func newGatewayModelsHandlerForTest(repo service.AccountRepository) *GatewayHandler {
return &GatewayHandler{
gatewayService: service.NewGatewayService(
repo,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
),
}
}
func TestGatewayModels_GeminiGroupFallsBackToGeminiModels(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(20)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{ID: 1, Platform: service.PlatformGemini},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{ID: groupID, Platform: service.PlatformGemini},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, "list", got.Object)
require.Contains(t, modelIDsForTest(got.Data), "gemini-2.5-flash")
require.NotContains(t, modelIDsForTest(got.Data), "claude-sonnet-4-6")
}
func TestGatewayModels_GeminiGroupFiltersMappedModelsByPlatform(t *testing.T) {
gin.SetMode(gin.TestMode)
groupID := int64(21)
h := newGatewayModelsHandlerForTest(
&gatewayModelsAccountRepoStub{
byGroup: map[int64][]service.Account{
groupID: {
{
ID: 1,
Platform: service.PlatformAnthropic,
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-6": "claude-sonnet-4-6",
},
},
},
{
ID: 2,
Platform: service.PlatformGemini,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-2.5-flash": "gemini-2.5-flash",
},
},
},
},
},
},
)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
Group: &service.Group{ID: groupID, Platform: service.PlatformGemini},
})
h.Models(c)
require.Equal(t, http.StatusOK, rec.Code)
var got gatewayModelsResponseForTest
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
require.Equal(t, []string{"gemini-2.5-flash"}, modelIDsForTest(got.Data))
}
func modelIDsForTest(models []gatewayModelItemForTest) []string {
ids := make([]string, 0, len(models))
for _, model := range models {
ids = append(ids, model.ID)
}
return ids
}

View File

@ -61,6 +61,7 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
c.JSON(http.StatusOK, gemini.FallbackModelsList())
return
}
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
@ -113,6 +114,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
return
}
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
@ -182,7 +184,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
return
}
setOpsRequestContext(c, modelName, stream, body)
setOpsRequestContext(c, modelName, stream)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
if decision := h.checkContentModeration(c, reqLog, apiKey, authSubject, service.ContentModerationProtocolGemini, modelName, body); decision != nil && decision.Blocked {
@ -372,6 +374,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
if err != nil {
if len(fs.FailedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
return
}
@ -419,6 +422,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
accountReleaseFunc := selection.ReleaseFunc
if !selection.Acquired {
if selection.WaitPlan == nil {
markOpsRoutingCapacityLimited(c)
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
return
}

View File

@ -78,7 +78,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsRequestContext(c, reqModel, reqStream)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked {
@ -143,6 +143,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
} else {
@ -155,6 +156,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
}
}
if selection == nil || selection.Account == nil {
markOpsRoutingCapacityLimited(c)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
@ -176,6 +178,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
forwardDurationMs := time.Since(forwardStart).Milliseconds()
@ -201,6 +204,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
} else {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
if c.Writer.Size() != writerSizeBeforeForward {
h.handleFailoverExhausted(c, failoverErr, true)
return
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// Pool mode: retry on the same account
if failoverErr.RetryableOnSameAccount {
@ -292,7 +299,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
// resolveRawCCUpstreamEndpoint returns the actual upstream endpoint for
// OpenAI Chat Completions requests. For APIKey accounts whose upstream
// has been probed to not support the Responses API, the request is
// is forced or probed to not support the Responses API, the request is
// forwarded directly to /v1/chat/completions — not through the default
// CC→Responses conversion path.
func resolveRawCCUpstreamEndpoint(c *gin.Context, account *service.Account) string {

View File

@ -130,7 +130,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
setOpsRequestContext(c, "", false, body)
setOpsRequestContext(c, "", false)
sessionHashBody := body
if service.IsOpenAIResponsesCompactPathForTest(c) {
if compactSeed := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()); compactSeed != "" {
@ -189,7 +189,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsRequestContext(c, reqModel, reqStream)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked {
@ -282,6 +282,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
if errors.Is(err, service.ErrNoAvailableCompactAccounts) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "compact_not_supported", "No available OpenAI accounts support /responses/compact", streamStarted)
return
@ -297,6 +298,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
if selection == nil || selection.Account == nil {
markOpsRoutingCapacityLimited(c)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
@ -330,6 +332,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if channelMapping.Mapped {
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
}
writerSizeBeforeForward := c.Writer.Size()
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
@ -354,6 +357,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
} else {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
if c.Writer.Size() != writerSizeBeforeForward {
h.handleFailoverExhausted(c, failoverErr, true)
return
}
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
// 池模式:同账号重试
if failoverErr.RetryableOnSameAccount {
@ -604,7 +611,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
setOpsRequestContext(c, reqModel, reqStream, body)
setOpsRequestContext(c, reqModel, reqStream)
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked {
@ -677,6 +684,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
)
if len(failedAccountIDs) == 0 {
if err != nil {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
@ -690,6 +698,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
}
}
if selection == nil || selection.Account == nil {
markOpsRoutingCapacityLimited(c)
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
@ -992,6 +1001,7 @@ func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
reqLog *zap.Logger,
) (func(), bool) {
if selection == nil || selection.Account == nil {
markOpsRoutingCapacityLimited(c)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
return nil, false
}
@ -1002,6 +1012,7 @@ func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true
}
if selection.WaitPlan == nil {
markOpsRoutingCapacityLimited(c)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
return nil, false
}
@ -1163,7 +1174,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
zap.Bool("has_previous_response_id", previousResponseID != ""),
zap.String("previous_response_id_kind", previousResponseIDKind),
)
setOpsRequestContext(c, reqModel, true, firstMessage)
setOpsRequestContext(c, reqModel, true)
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, firstMessage); decision != nil && decision.Blocked {
@ -1598,6 +1609,11 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error,
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
statusCode := failoverErr.StatusCode
responseBody := failoverErr.ResponseBody
if service.IsOpenAISilentRefusalErrorBody(responseBody) {
service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage(), streamStarted)
return
}
// 先检查透传规则
if h.errorPassthroughService != nil && len(responseBody) > 0 {

View File

@ -63,9 +63,9 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
}
if isMultipartImagesContentType(c.GetHeader("Content-Type")) {
setOpsRequestContext(c, "", false, nil)
setOpsRequestContext(c, "", false)
} else {
setOpsRequestContext(c, "", false, body)
setOpsRequestContext(c, "", false)
}
parsed, err := h.gatewayService.ParseOpenAIImagesRequest(c, body)
@ -98,9 +98,9 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
}
if parsed.Multipart {
setOpsRequestContext(c, parsed.Model, parsed.Stream, nil)
setOpsRequestContext(c, parsed.Model, parsed.Stream)
} else {
setOpsRequestContext(c, parsed.Model, parsed.Stream, body)
setOpsRequestContext(c, parsed.Model, parsed.Stream)
}
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsed.Stream, false)))
@ -157,6 +157,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
zap.Int("excluded_account_count", len(failedAccountIDs)),
)
if len(failedAccountIDs) == 0 {
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
return
}
@ -168,6 +169,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
return
}
if selection == nil || selection.Account == nil {
markOpsRoutingCapacityLimited(c)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
return
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"log"
"runtime"
"runtime/debug"
@ -22,10 +23,10 @@ import (
)
const (
opsModelKey = "ops_model"
opsStreamKey = "ops_stream"
opsRequestBodyKey = "ops_request_body"
opsAccountIDKey = "ops_account_id"
opsModelKey = "ops_model"
opsStreamKey = "ops_stream"
opsAccountIDKey = "ops_account_id"
opsRoutingCapacityLimitedKey = "ops_routing_capacity_limited"
opsUpstreamModelKey = "ops_upstream_model"
opsRequestTypeKey = "ops_request_type"
@ -45,6 +46,8 @@ const (
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
opsCodeUserInactive = "USER_INACTIVE"
opsCodeInvalidAPIKey = "INVALID_API_KEY"
opsCodeAPIKeyRequired = "API_KEY_REQUIRED"
)
const (
@ -332,16 +335,13 @@ func opsErrorLogConfig() (workerCount int, queueSize int) {
return workerCount, queueSize
}
func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody []byte) {
func setOpsRequestContext(c *gin.Context, model string, stream bool) {
if c == nil {
return
}
model = strings.TrimSpace(model)
c.Set(opsModelKey, model)
c.Set(opsStreamKey, stream)
if len(requestBody) > 0 {
c.Set(opsRequestBodyKey, requestBody)
}
if c.Request != nil && model != "" {
ctx := context.WithValue(c.Request.Context(), ctxkey.Model, model)
c.Request = c.Request.WithContext(ctx)
@ -360,22 +360,6 @@ func setOpsEndpointContext(c *gin.Context, upstreamModel string, requestType int
c.Set(opsRequestTypeKey, requestType)
}
func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
if c == nil || entry == nil {
return
}
v, ok := c.Get(opsRequestBodyKey)
if !ok {
return
}
raw, ok := v.([]byte)
if !ok || len(raw) == 0 {
return
}
entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = service.PrepareOpsRequestBodyForQueue(raw)
opsErrorLogSanitized.Add(1)
}
func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) {
if c == nil || accountID <= 0 {
return
@ -393,6 +377,42 @@ func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string)
}
}
func markOpsRoutingCapacityLimited(c *gin.Context) {
if c == nil {
return
}
c.Set(opsRoutingCapacityLimitedKey, true)
}
func markOpsRoutingCapacityLimitedIfNoAvailable(c *gin.Context, err error) {
if !isOpsNoAvailableAccountError(err) {
return
}
markOpsRoutingCapacityLimited(c)
}
func isOpsRoutingCapacityLimited(c *gin.Context) bool {
if c == nil {
return false
}
v, ok := c.Get(opsRoutingCapacityLimitedKey)
if !ok {
return false
}
marked, _ := v.(bool)
return marked
}
func isOpsNoAvailableAccountError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, service.ErrNoAvailableAccounts) || errors.Is(err, service.ErrNoAvailableCompactAccounts) {
return true
}
return isOpsNoAvailableAccountMessage(err.Error())
}
type opsCaptureWriter struct {
gin.ResponseWriter
limit int
@ -671,7 +691,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
ErrorPhase: "upstream",
ErrorType: "upstream_error",
// Severity/retryability should reflect the upstream failure, not the final client status (200).
// Severity should reflect the upstream failure, not the final client status (200).
Severity: classifyOpsSeverity("upstream_error", effectiveUpstreamStatus),
StatusCode: status,
IsBusinessLimited: false,
@ -688,9 +708,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
UpstreamErrorDetail: upstreamErrorDetail,
UpstreamErrors: events,
IsRetryable: classifyOpsIsRetryable("upstream_error", effectiveUpstreamStatus),
RetryCount: 0,
CreatedAt: time.Now(),
CreatedAt: time.Now(),
}
applyOpsLatencyFieldsFromContext(c, entry)
@ -714,10 +732,6 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
entry.ClientIP = &clientIP
}
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
attachOpsRequestBodyToEntry(c, entry)
// Skip logging if a passthrough rule with skip_monitoring=true matched.
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
if skip, _ := v.(bool); skip {
@ -775,11 +789,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
normalizedType := normalizeOpsErrorType(parsed.ErrorType, parsed.Code)
phase := classifyOpsPhase(normalizedType, parsed.Message, parsed.Code)
isBusinessLimited := classifyOpsIsBusinessLimited(normalizedType, phase, parsed.Code, status, parsed.Message)
errorOwner := classifyOpsErrorOwner(phase, parsed.Message)
errorSource := classifyOpsErrorSource(phase, parsed.Message)
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, normalizedType, parsed.Message, parsed.Code, status)
entry := &service.OpsInsertErrorLogInput{
RequestID: requestID,
@ -834,9 +844,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
ErrorSource: errorSource,
ErrorOwner: errorOwner,
IsRetryable: classifyOpsIsRetryable(normalizedType, status),
RetryCount: 0,
CreatedAt: time.Now(),
CreatedAt: time.Now(),
}
applyOpsLatencyFieldsFromContext(c, entry)
@ -914,20 +922,10 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
entry.ClientIP = &clientIP
}
// Persist only a minimal, whitelisted set of request headers to improve retry fidelity.
// Do NOT store Authorization/Cookie/etc.
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
attachOpsRequestBodyToEntry(c, entry)
enqueueOpsErrorLog(ops, entry)
}
}
var opsRetryRequestHeaderAllowlist = []string{
"anthropic-beta",
"anthropic-version",
}
// isCountTokensRequest checks if the request is a count_tokens request
func isCountTokensRequest(c *gin.Context) bool {
if c == nil || c.Request == nil || c.Request.URL == nil {
@ -936,32 +934,6 @@ func isCountTokensRequest(c *gin.Context) bool {
return strings.Contains(c.Request.URL.Path, "/count_tokens")
}
func extractOpsRetryRequestHeaders(c *gin.Context) *string {
if c == nil || c.Request == nil {
return nil
}
headers := make(map[string]string, 4)
for _, key := range opsRetryRequestHeaderAllowlist {
v := strings.TrimSpace(c.GetHeader(key))
if v == "" {
continue
}
// Keep headers small even if a client sends something unexpected.
headers[key] = truncateString(v, 512)
}
if len(headers) == 0 {
return nil
}
raw, err := json.Marshal(headers)
if err != nil {
return nil
}
s := string(raw)
return &s
}
func applyOpsLatencyFieldsFromContext(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
if c == nil || entry == nil {
return
@ -1116,6 +1088,9 @@ func classifyOpsPhase(errType, message, code string) string {
msg := strings.ToLower(message)
// Standardized phases: request|auth|routing|upstream|network|internal
// Map billing/concurrency/response => request; scheduling => routing.
if isOpsClientAuthError(code, msg) {
return "auth"
}
switch strings.TrimSpace(code) {
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
return "request"
@ -1136,7 +1111,7 @@ func classifyOpsPhase(errType, message, code string) string {
case "upstream_error", "overloaded_error":
return "upstream"
case "api_error":
if strings.Contains(msg, opsErrNoAvailableAccounts) {
if isOpsNoAvailableAccountMessage(msg) {
return "routing"
}
return "internal"
@ -1162,25 +1137,31 @@ func classifyOpsSeverity(errType string, status int) string {
return "P3"
}
func classifyOpsIsRetryable(errType string, statusCode int) bool {
switch errType {
case "authentication_error", "invalid_request_error":
return false
case "timeout_error":
return true
case "rate_limit_error":
// May be transient (upstream or queue); retry can help.
return true
case "billing_error", "subscription_error":
return false
case "upstream_error", "overloaded_error":
return statusCode >= 500 || statusCode == 429 || statusCode == 529
default:
return statusCode >= 500
func classifyOpsErrorLog(c *gin.Context, errType, message, code string, status int) (phase string, isBusinessLimited bool, errorOwner string, errorSource string) {
phase = classifyOpsPhase(errType, message, code)
routingCapacityLimited := isOpsRoutingCapacityLimited(c)
clientBusinessLimited := service.HasOpsClientBusinessLimited(c)
upstreamError := hasOpsUpstreamErrorContext(c)
if upstreamError && !routingCapacityLimited {
phase = "upstream"
}
if clientBusinessLimited && !upstreamError && !routingCapacityLimited {
phase = "auth"
}
if routingCapacityLimited {
phase = "routing"
}
localClientAuthError := !upstreamError && phase == "auth" && isOpsClientAuthError(code, strings.ToLower(message))
isBusinessLimited = routingCapacityLimited || clientBusinessLimited || classifyOpsIsBusinessLimited(errType, phase, code, status, message, localClientAuthError)
errorOwner = classifyOpsErrorOwner(phase, message)
errorSource = classifyOpsErrorSource(phase, message)
return phase, isBusinessLimited, errorOwner, errorSource
}
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string, localClientAuthError ...bool) bool {
if len(localClientAuthError) > 0 && localClientAuthError[0] {
return true
}
switch strings.TrimSpace(code) {
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
return true
@ -1197,6 +1178,47 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa
return false
}
func isOpsClientAuthError(code string, msg string) bool {
switch strings.TrimSpace(code) {
case opsCodeInvalidAPIKey, opsCodeAPIKeyRequired:
return true
}
return strings.Contains(msg, "invalid api key") || strings.Contains(msg, "api key is required")
}
func hasOpsUpstreamErrorContext(c *gin.Context) bool {
if c == nil {
return false
}
if v, ok := c.Get(service.OpsUpstreamStatusCodeKey); ok {
switch code := v.(type) {
case int:
if code > 0 {
return true
}
case int64:
if code > 0 {
return true
}
}
}
if v, ok := c.Get(service.OpsUpstreamErrorsKey); ok {
if events, ok := v.([]*service.OpsUpstreamErrorEvent); ok && len(events) > 0 {
return true
}
}
return false
}
func isOpsNoAvailableAccountMessage(message string) bool {
msg := strings.ToLower(message)
return strings.Contains(msg, opsErrNoAvailableAccounts) ||
strings.Contains(msg, "no available account") ||
strings.Contains(msg, "no available gemini accounts") ||
strings.Contains(msg, "no available openai accounts") ||
strings.Contains(msg, "no available compatible accounts")
}
func classifyOpsErrorOwner(phase string, message string) string {
// Standardized owners: client|provider|platform
switch phase {

View File

@ -44,49 +44,6 @@ func resetOpsErrorLoggerStateForTest(t *testing.T) {
opsErrorLogDrained.Store(false)
}
func TestAttachOpsRequestBodyToEntry_SanitizeAndTrim(t *testing.T) {
resetOpsErrorLoggerStateForTest(t)
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
raw := []byte(`{"access_token":"secret-token","messages":[{"role":"user","content":"hello"}]}`)
setOpsRequestContext(c, "claude-3", false, raw)
entry := &service.OpsInsertErrorLogInput{}
attachOpsRequestBodyToEntry(c, entry)
require.NotNil(t, entry.RequestBodyBytes)
require.Equal(t, len(raw), *entry.RequestBodyBytes)
require.NotNil(t, entry.RequestBodyJSON)
require.NotContains(t, *entry.RequestBodyJSON, "secret-token")
require.Contains(t, *entry.RequestBodyJSON, "[REDACTED]")
require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
}
func TestAttachOpsRequestBodyToEntry_InvalidJSONKeepsSize(t *testing.T) {
resetOpsErrorLoggerStateForTest(t)
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
raw := []byte("not-json")
setOpsRequestContext(c, "claude-3", false, raw)
entry := &service.OpsInsertErrorLogInput{}
attachOpsRequestBodyToEntry(c, entry)
require.Nil(t, entry.RequestBodyJSON)
require.NotNil(t, entry.RequestBodyBytes)
require.Equal(t, len(raw), *entry.RequestBodyBytes)
require.False(t, entry.RequestBodyTruncated)
require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
}
func TestEnqueueOpsErrorLog_QueueFullDrop(t *testing.T) {
resetOpsErrorLoggerStateForTest(t)
@ -108,39 +65,6 @@ func TestEnqueueOpsErrorLog_QueueFullDrop(t *testing.T) {
require.Equal(t, int64(1), OpsErrorLogQueueLength())
}
func TestAttachOpsRequestBodyToEntry_EarlyReturnBranches(t *testing.T) {
resetOpsErrorLoggerStateForTest(t)
gin.SetMode(gin.TestMode)
entry := &service.OpsInsertErrorLogInput{}
attachOpsRequestBodyToEntry(nil, entry)
attachOpsRequestBodyToEntry(&gin.Context{}, nil)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
// 无请求体 key
attachOpsRequestBodyToEntry(c, entry)
require.Nil(t, entry.RequestBodyJSON)
require.Nil(t, entry.RequestBodyBytes)
require.False(t, entry.RequestBodyTruncated)
// 错误类型
c.Set(opsRequestBodyKey, "not-bytes")
attachOpsRequestBodyToEntry(c, entry)
require.Nil(t, entry.RequestBodyJSON)
require.Nil(t, entry.RequestBodyBytes)
// 空 bytes
c.Set(opsRequestBodyKey, []byte{})
attachOpsRequestBodyToEntry(c, entry)
require.Nil(t, entry.RequestBodyJSON)
require.Nil(t, entry.RequestBodyBytes)
require.Equal(t, int64(0), OpsErrorLogSanitizedTotal())
}
func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) {
resetOpsErrorLoggerStateForTest(t)
@ -275,6 +199,218 @@ func TestNormalizeOpsErrorType(t *testing.T) {
}
}
func TestClassifyOpsNoAvailableAccountsExcludedFromSLA(t *testing.T) {
const message = "No available accounts"
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
markOpsRoutingCapacityLimited(c)
errType := normalizeOpsErrorType("api_error", "")
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, message, "", http.StatusServiceUnavailable)
require.Equal(t, "api_error", errType)
require.Equal(t, "routing", phase)
require.True(t, isBusinessLimited)
require.Equal(t, "platform", errorOwner)
require.Equal(t, "gateway", errorSource)
}
func TestClassifyOpsRoutingCapacityMarkerExcludesMaskedSelectionFailureFromSLA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
markOpsRoutingCapacityLimited(c)
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
c,
"api_error",
"Service temporarily unavailable",
"",
http.StatusServiceUnavailable,
)
require.Equal(t, "routing", phase)
require.True(t, isBusinessLimited)
require.Equal(t, "platform", errorOwner)
require.Equal(t, "gateway", errorSource)
}
func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) {
tests := []struct {
name string
errType string
message string
code string
status int
}{
{
name: "standard invalid API key",
errType: "api_error",
message: "Invalid API key",
code: "INVALID_API_KEY",
status: http.StatusUnauthorized,
},
{
name: "standard missing API key",
errType: "api_error",
message: "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header",
code: "API_KEY_REQUIRED",
status: http.StatusUnauthorized,
},
{
name: "google invalid API key",
errType: "api_error",
message: "Invalid API key",
code: "401",
status: http.StatusUnauthorized,
},
{
name: "google missing API key",
errType: "api_error",
message: "API key is required",
code: "401",
status: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
errType := normalizeOpsErrorType(tt.errType, tt.code)
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, tt.message, tt.code, tt.status)
require.Equal(t, "api_error", errType)
require.Equal(t, "auth", phase)
require.True(t, isBusinessLimited)
require.Equal(t, "client", errorOwner)
require.Equal(t, "client_request", errorSource)
})
}
}
func TestClassifyOpsIPRestrictionAccessDeniedExcludedFromSLA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonIPRestriction)
errType := normalizeOpsErrorType("api_error", "ACCESS_DENIED")
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, "Access denied", "ACCESS_DENIED", http.StatusForbidden)
require.Equal(t, "api_error", errType)
require.Equal(t, "auth", phase)
require.True(t, isBusinessLimited)
require.Equal(t, "client", errorOwner)
require.Equal(t, "client_request", errorSource)
}
func TestClassifyOpsOtherErrorsStillCountForSLA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
errType := normalizeOpsErrorType("api_error", "INTERNAL_ERROR")
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, "Failed to validate API key", "INTERNAL_ERROR", http.StatusInternalServerError)
require.Equal(t, "api_error", errType)
require.Equal(t, "internal", phase)
require.False(t, isBusinessLimited)
require.Equal(t, "platform", errorOwner)
require.Equal(t, "gateway", errorSource)
}
func TestClassifyOpsUnsupportedModelExcludedFromSLA(t *testing.T) {
tests := []string{
"No available accounts: no available accounts supporting model: made-up-model",
"No available accounts: no available OpenAI accounts supporting model: made-up-model",
"No available Gemini accounts: no available Gemini accounts supporting model: made-up-model",
"No available accounts: no available accounts supporting model: made-up-model (channel pricing restriction)",
}
for _, message := range tests {
t.Run(message, func(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
markOpsRoutingCapacityLimited(c)
errType := normalizeOpsErrorType("api_error", "")
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, message, "", http.StatusServiceUnavailable)
require.Equal(t, "api_error", errType)
require.Equal(t, "routing", phase)
require.True(t, isBusinessLimited)
require.Equal(t, "platform", errorOwner)
require.Equal(t, "gateway", errorSource)
})
}
}
func TestClassifyOpsUnmarkedNoAvailableTextStillCountsForSLA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
c,
"api_error",
"No available accounts",
"",
http.StatusServiceUnavailable,
)
require.Equal(t, "routing", phase)
require.False(t, isBusinessLimited)
require.Equal(t, "platform", errorOwner)
require.Equal(t, "gateway", errorSource)
}
func TestClassifyOpsUpstreamAuthTextStillCountsForSLA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
service.SetOpsUpstreamError(c, http.StatusUnauthorized, "Invalid API key", "")
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
c,
"api_error",
"Invalid API key",
"401",
http.StatusUnauthorized,
)
require.Equal(t, "upstream", phase)
require.False(t, isBusinessLimited)
require.Equal(t, "provider", errorOwner)
require.Equal(t, "upstream_http", errorSource)
}
func TestClassifyOpsUpstreamNoAvailableTextStillCountsForSLA(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
service.SetOpsUpstreamError(c, http.StatusServiceUnavailable, "No available accounts", "")
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
c,
"api_error",
"No available accounts",
"",
http.StatusServiceUnavailable,
)
require.Equal(t, "upstream", phase)
require.False(t, isBusinessLimited)
require.Equal(t, "provider", errorOwner)
require.Equal(t, "upstream_http", errorSource)
}
func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()

View File

@ -58,7 +58,7 @@ func TestResolvePageImagePath(t *testing.T) {
if !ok {
t.Fatal("expected direct image path to be accepted")
}
want := filepath.Join(base, "logo.png")
want := mustEvalSymlinks(t, filepath.Join(base, "logo.png"))
if got != want {
t.Fatalf("path = %q, want %q", got, want)
}
@ -67,7 +67,7 @@ func TestResolvePageImagePath(t *testing.T) {
if !ok {
t.Fatal("expected nested image path to be accepted")
}
want = filepath.Join(base, "images", "logo.png")
want = mustEvalSymlinks(t, filepath.Join(base, "images", "logo.png"))
if got != want {
t.Fatalf("path = %q, want %q", got, want)
}
@ -100,3 +100,13 @@ func TestResolvePageImagePathRejectsSymlinkEscape(t *testing.T) {
t.Fatalf("expected symlink escape to be rejected, got %q", got)
}
}
func mustEvalSymlinks(t *testing.T, path string) string {
t.Helper()
realPath, err := filepath.EvalSymlinks(path)
if err != nil {
t.Fatalf("eval symlinks for %q: %v", path, err)
}
return realPath
}

View File

@ -141,6 +141,7 @@ func (h *PaymentHandler) GetCheckoutInfo(c *gin.Context) {
HelpText: cfg.HelpText,
HelpImageURL: cfg.HelpImageURL,
StripePublishableKey: cfg.StripePublishableKey,
AlipayForceQRCode: cfg.AlipayForceQRCode,
})
}
@ -155,6 +156,7 @@ type checkoutInfoResponse struct {
HelpText string `json:"help_text"`
HelpImageURL string `json:"help_image_url"`
StripePublishableKey string `json:"stripe_publishable_key"`
AlipayForceQRCode bool `json:"alipay_force_qrcode"`
}
type checkoutPlan struct {

View File

@ -61,6 +61,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
TablePageSizeOptions: settings.TablePageSizeOptions,
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
DingTalkOAuthEnabled: settings.DingTalkOAuthEnabled,
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled,

View File

@ -67,6 +67,7 @@ type userProfileResponse struct {
LinuxDoBound bool `json:"linuxdo_bound"`
OIDCBound bool `json:"oidc_bound"`
WeChatBound bool `json:"wechat_bound"`
DingTalkBound bool `json:"dingtalk_bound"`
}
type userProfileSourceContext struct {
@ -528,15 +529,17 @@ func userProfileResponseFromService(user *service.User, identities service.UserI
LinuxDoBound: identities.LinuxDo.Bound,
OIDCBound: identities.OIDC.Bound,
WeChatBound: identities.WeChat.Bound,
DingTalkBound: identities.DingTalk.Bound,
}
}
func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string]service.UserIdentitySummary {
return map[string]service.UserIdentitySummary{
"email": identities.Email,
"linuxdo": identities.LinuxDo,
"oidc": identities.OIDC,
"wechat": identities.WeChat,
"email": identities.Email,
"linuxdo": identities.LinuxDo,
"oidc": identities.OIDC,
"wechat": identities.WeChat,
"dingtalk": identities.DingTalk,
}
}
@ -585,7 +588,7 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity
func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary {
out := make([]service.UserIdentitySummary, 0, 3)
for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} {
for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat, identities.DingTalk} {
if summary.Bound {
out = append(out, summary)
}

View File

@ -105,10 +105,16 @@ func (a *Alipay) MerchantIdentityMetadata() map[string]string {
// CreatePayment creates an Alipay payment using the following routing:
// - Mobile (H5): alipay.trade.wap.pay — browser redirect into Alipay.
// - Desktop: prefer alipay.trade.precreate to get a scan payload directly.
// - Desktop fallback: if precreate is unavailable for the merchant, fall back
// to alipay.trade.page.pay and expose both pay_url and qr_code so the
// frontend can render a QR while still allowing direct page open.
// - Desktop, default: prefer alipay.trade.precreate (FACE_TO_FACE_PAYMENT) to
// get a scannable QR payload. If precreate is unavailable for the merchant,
// fall back to alipay.trade.page.pay and expose pay_url only — the frontend
// opens the Alipay checkout in a new tab.
// - Desktop, paymentMode == "redirect": skip precreate and go straight to
// alipay.trade.page.pay so the frontend always opens the Alipay checkout
// in a new tab. Use this when the merchant has not enabled FACE_TO_FACE_PAYMENT.
//
// Note: alipay.trade.page.pay returns a checkout page URL, not a scannable
// payment QR. Never expose it via the QRCode field.
func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
client, err := a.getClient()
if err != nil {
@ -150,6 +156,13 @@ func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePayment
}
func (a *Alipay) createDesktopTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
// Explicit redirect mode: merchant opted into "always open the Alipay
// checkout page in a new tab" via the provider instance's payment_mode.
// Skip precreate to avoid a wasted API call.
if strings.EqualFold(strings.TrimSpace(a.config["paymentMode"]), "redirect") {
return a.createPagePayTrade(client, req, notifyURL, returnURL)
}
resp, precreateErr := a.createPrecreateTrade(ctx, client, req, notifyURL)
if precreateErr == nil {
return resp, nil
@ -204,10 +217,12 @@ func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePay
if err != nil {
return nil, fmt.Errorf("alipay TradePagePay: %w", err)
}
// Only PayURL is exposed: alipay.trade.page.pay returns a checkout page URL
// that must be opened in a browser, not a scannable payment QR. Setting it
// as QRCode would let the frontend render an unscannable image.
return &payment.CreatePaymentResponse{
TradeNo: req.OrderID,
PayURL: payURL.String(),
QRCode: payURL.String(),
}, nil
}

View File

@ -189,8 +189,63 @@ func TestCreateTradeUsesPagePayForDesktop(t *testing.T) {
if resp.PayURL == "" {
t.Fatal("expected pay_url for desktop page pay")
}
if resp.QRCode != resp.PayURL {
t.Fatalf("qr_code = %q, want same as pay_url %q", resp.QRCode, resp.PayURL)
// page.pay returns a checkout page URL, not a scannable QR payload —
// it must never be exposed via QRCode (the frontend would render an
// unscannable image from it).
if resp.QRCode != "" {
t.Fatalf("qr_code = %q, want empty for page pay", resp.QRCode)
}
}
// When the provider instance is configured with paymentMode == "redirect",
// the desktop flow must skip precreate and go straight to page.pay.
func TestCreateTradeRedirectModeSkipsPrecreate(t *testing.T) {
origPreCreate := alipayTradePreCreate
origPagePay := alipayTradePagePay
t.Cleanup(func() {
alipayTradePreCreate = origPreCreate
alipayTradePagePay = origPagePay
})
preCreateCalls := 0
pagePayCalls := 0
alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
preCreateCalls++
return &alipay.TradePreCreateRsp{
Error: alipay.Error{Code: alipay.CodeSuccess},
QRCode: "https://qr.alipay.example.com/precreate-token",
}, nil
}
alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
pagePayCalls++
if param.ProductCode != alipayProductCodePagePay {
t.Fatalf("product_code = %q, want %q", param.ProductCode, alipayProductCodePagePay)
}
return url.Parse("https://openapi.alipay.com/gateway.do?page-pay")
}
provider := &Alipay{
config: map[string]string{"paymentMode": "redirect"},
}
resp, err := provider.createDesktopTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{
OrderID: "sub2_103",
Amount: "12.00",
Subject: "Balance recharge",
}, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if preCreateCalls != 0 {
t.Fatalf("precreate calls = %d, want 0 (redirect mode must skip precreate)", preCreateCalls)
}
if pagePayCalls != 1 {
t.Fatalf("page pay calls = %d, want 1", pagePayCalls)
}
if resp.PayURL == "" {
t.Fatal("expected pay_url for redirect mode")
}
if resp.QRCode != "" {
t.Fatalf("qr_code = %q, want empty for redirect mode", resp.QRCode)
}
}

View File

@ -254,6 +254,8 @@ const (
proxyTLSHandshakeTimeout = 5 * time.Second
// clientTimeout 整体请求超时(含连接、发送、等待响应、读取 body
clientTimeout = 10 * time.Second
// fetchAvailableModelsBodyLimit limits model-list responses to avoid unbounded memory use.
fetchAvailableModelsBodyLimit int64 = 8 << 20
)
func NewClient(proxyURL string) (*Client, error) {
@ -655,6 +657,10 @@ type FetchAvailableModelsResponse struct {
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
// 支持 URL fallbacksandbox → daily → prod
func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) {
if c == nil || c.httpClient == nil {
return nil, nil, errors.New("antigravity client is not configured")
}
reqBody := FetchAvailableModelsRequest{Project: projectID}
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
@ -664,6 +670,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
// 固定顺序prod -> daily
availableURLs := BaseURLs
fetchClient := c.fetchAvailableModelsHTTPClient()
var lastErr error
for urlIdx, baseURL := range availableURLs {
apiURL := baseURL + "/v1internal:fetchAvailableModels"
@ -676,7 +683,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", GetUserAgentForContext(ctx))
resp, err := c.httpClient.Do(req)
resp, err := fetchClient.Do(req)
if err != nil {
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
@ -686,11 +693,14 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
return nil, nil, lastErr
}
respBodyBytes, err := io.ReadAll(resp.Body)
respBodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, fetchAvailableModelsBodyLimit+1))
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
if err != nil {
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
}
if int64(len(respBodyBytes)) > fetchAvailableModelsBodyLimit {
return nil, nil, fmt.Errorf("响应超过 %d 字节", fetchAvailableModelsBodyLimit)
}
// 检查是否需要 URL 降级
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
@ -726,6 +736,42 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
return nil, nil, lastErr
}
func (c *Client) fetchAvailableModelsHTTPClient() *http.Client {
fetchClient := *c.httpClient
fetchClient.CheckRedirect = checkFetchAvailableModelsRedirect
return &fetchClient
}
func checkFetchAvailableModelsRedirect(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return errors.New("stopped after 10 redirects")
}
if req == nil || req.URL == nil {
return errors.New("redirect url is nil")
}
if !isAllowedFetchAvailableModelsRedirectHost(req.URL.Hostname()) {
return fmt.Errorf("redirect to unsupported host: %s", req.URL.Hostname())
}
return nil
}
func isAllowedFetchAvailableModelsRedirectHost(host string) bool {
host = strings.ToLower(strings.TrimSpace(host))
if host == "" {
return false
}
for _, baseURL := range BaseURLs {
parsed, err := url.Parse(baseURL)
if err != nil {
continue
}
if strings.EqualFold(host, parsed.Hostname()) {
return true
}
}
return false
}
// ── Privacy API ──────────────────────────────────────────────────────
// privacyBaseURL 隐私设置 API 仅使用 daily 端点(与 Antigravity 客户端行为一致)

View File

@ -744,6 +744,10 @@ func TestStreamingReasoning(t *testing.T) {
assert.Equal(t, "content_block_start", events[0].Type)
assert.Equal(t, "thinking", events[0].ContentBlock.Type)
sse, err := ResponsesAnthropicEventToSSE(events[0])
require.NoError(t, err)
assert.Contains(t, sse, `"content_block":{"thinking":"","type":"thinking"}`)
// reasoning text delta
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
Type: "response.reasoning_summary_text.delta",
@ -1520,3 +1524,49 @@ func TestAnthropicToResponses_ToolWithNilSchema(t *testing.T) {
assert.JSONEq(t, `"object"`, string(params["type"]))
assert.JSONEq(t, `{}`, string(params["properties"]))
}
// ---------------------------------------------------------------------------
// isReasoningModel / temperature-stripping tests
// ---------------------------------------------------------------------------
func TestAnthropicToResponses_TemperatureStrippedForReasoningModel(t *testing.T) {
temp := 0.7
req := &AnthropicRequest{
Model: "gpt-5.2",
MaxTokens: 1024,
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
Temperature: &temp,
TopP: &temp,
}
resp, err := AnthropicToResponses(req)
require.NoError(t, err)
assert.Nil(t, resp.Temperature, "reasoning model: temperature must be stripped")
assert.Nil(t, resp.TopP, "reasoning model: top_p must be stripped")
// Verify the fields are absent from the serialised JSON.
b, err := json.Marshal(resp)
require.NoError(t, err)
assert.NotContains(t, string(b), `"temperature"`)
assert.NotContains(t, string(b), `"top_p"`)
}
func TestAnthropicToResponses_TemperatureStrippedForAllGpt5Variants(t *testing.T) {
temp := 1.0
models := []string{"gpt-5.2", "gpt-5.4", "gpt-5.4-mini", "gpt-5.3-codex", "gpt-5.5"}
for _, model := range models {
t.Run(model, func(t *testing.T) {
req := &AnthropicRequest{
Model: model,
MaxTokens: 1024,
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
Temperature: &temp,
TopP: &temp,
}
resp, err := AnthropicToResponses(req)
require.NoError(t, err)
assert.Nil(t, resp.Temperature, "model %s: temperature must be stripped", model)
assert.Nil(t, resp.TopP, "model %s: top_p must be stripped", model)
})
}
}

View File

@ -22,12 +22,19 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
}
out := &ResponsesRequest{
Model: req.Model,
Input: inputJSON,
Temperature: req.Temperature,
TopP: req.TopP,
Stream: req.Stream,
Include: []string{"reasoning.encrypted_content"},
Model: req.Model,
Input: inputJSON,
Stream: req.Stream,
Include: []string{"reasoning.encrypted_content"},
}
// Reasoning models (gpt-5.x) served via the Responses API do not accept
// sampling parameters. Sending temperature or top_p causes a 400
// "Unsupported parameter" error, so we only forward them for non-reasoning
// models.
if !isReasoningModel(req.Model) {
out.Temperature = req.Temperature
out.TopP = req.TopP
}
storeFalse := false
@ -437,6 +444,14 @@ func boolPtr(v bool) *bool {
return &v
}
// isReasoningModel reports whether model is a reasoning model that does not
// support sampling parameters (temperature, top_p) via the Responses API.
// All gpt-5.x models are reasoning-only; the Responses API returns
// "Unsupported parameter: temperature" if these fields are present.
func isReasoningModel(model string) bool {
return strings.HasPrefix(model, "gpt-5")
}
// normalizeToolParameters ensures the tool parameter schema is valid for
// OpenAI's Responses API, which requires "properties" on object schemas.
//

View File

@ -225,6 +225,41 @@ func TestChatCompletionsToResponses_WhitespaceOnlyBase64ImageURLSkipped(t *testi
assert.Equal(t, "Describe this", parts[0].Text)
}
func TestChatCompletionsToResponses_EmptyContentNeverNull(t *testing.T) {
// Regression for #2515: the upstream Responses API rejects an input item
// whose content field is JSON null. Any chat-completions message that
// yields no usable content parts must serialize content as a string.
cases := []struct {
name string
content json.RawMessage
}{
{"null content", json.RawMessage(`null`)},
{"empty array content", json.RawMessage(`[]`)},
{"only empty text part", json.RawMessage(`[{"type":"text","text":""}]`)},
{"only empty base64 image part", json.RawMessage(`[{"type":"image_url","image_url":{"url":"data:image/png;base64,"}}]`)},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-5.5",
Messages: []ChatMessage{
{Role: "user", Content: tc.content},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
assert.NotContains(t, string(resp.Input), `"content":null`,
"converted input must not contain a null content field")
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 1)
assert.Equal(t, `""`, string(items[0].Content),
"content must be an empty string, not null")
})
}
}
func TestChatCompletionsToResponses_SystemArrayContent(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
@ -296,6 +331,48 @@ func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
assert.Equal(t, "flex", resp.ServiceTier)
}
// ---------------------------------------------------------------------------
// temperature / top_p stripping for reasoning models
// ---------------------------------------------------------------------------
func TestChatCompletionsToResponses_TemperatureStrippedForReasoningModel(t *testing.T) {
temp := 0.7
req := &ChatCompletionsRequest{
Model: "gpt-5.2",
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
Temperature: &temp,
TopP: &temp,
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
assert.Nil(t, resp.Temperature, "reasoning model: temperature must be stripped")
assert.Nil(t, resp.TopP, "reasoning model: top_p must be stripped")
// Must not appear in the serialised request body sent to the upstream.
b, err := json.Marshal(resp)
require.NoError(t, err)
assert.NotContains(t, string(b), `"temperature"`)
assert.NotContains(t, string(b), `"top_p"`)
}
func TestChatCompletionsToResponses_TemperaturePreservedForNonReasoningModel(t *testing.T) {
temp := 0.7
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
Temperature: &temp,
TopP: &temp,
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
require.NotNil(t, resp.Temperature, "non-reasoning model: temperature must be preserved")
assert.InDelta(t, 0.7, *resp.Temperature, 1e-9)
require.NotNil(t, resp.TopP, "non-reasoning model: top_p must be preserved")
assert.InDelta(t, 0.7, *resp.TopP, 1e-9)
}
func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
@ -379,6 +456,34 @@ func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T)
assert.Contains(t, parts[0].Text, "final answer")
}
func TestChatCompletionsToResponses_AssistantReasoningContentPreserved(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(`"Hi"`)},
{
Role: "assistant",
ReasoningContent: "internal plan",
Content: json.RawMessage(`"final answer"`),
},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 2)
var parts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
require.Len(t, parts, 1)
assert.Equal(t, "output_text", parts[0].Type)
assert.Contains(t, parts[0].Text, "<thinking>internal plan</thinking>")
assert.Contains(t, parts[0].Text, "final answer")
}
// ---------------------------------------------------------------------------
// ResponsesToChatCompletions tests
// ---------------------------------------------------------------------------

View File

@ -30,13 +30,18 @@ func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest,
Model: req.Model,
Instructions: req.Instructions,
Input: inputJSON,
Temperature: req.Temperature,
TopP: req.TopP,
Stream: true, // upstream always streams
Include: []string{"reasoning.encrypted_content"},
ServiceTier: req.ServiceTier,
}
// Reasoning models (gpt-5.x) do not accept sampling parameters.
// See isReasoningModel in anthropic_to_responses.go.
if !isReasoningModel(req.Model) {
out.Temperature = req.Temperature
out.TopP = req.TopP
}
storeFalse := false
out.Store = &storeFalse
@ -150,6 +155,11 @@ func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
// empty/nil and there are tool_calls, only function_call items are emitted.
func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
var items []ResponsesInputItem
content := ""
if m.ReasoningContent != "" {
content = "<thinking>" + m.ReasoningContent + "</thinking>"
}
// Emit assistant message with output_text if content is non-empty.
if len(m.Content) > 0 {
@ -158,15 +168,22 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
return nil, err
}
if s != "" {
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
partsJSON, err := json.Marshal(parts)
if err != nil {
return nil, err
if content != "" {
content += "\n"
}
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
content += s
}
}
if content != "" {
parts := []ResponsesContentPart{{Type: "output_text", Text: content}}
partsJSON, err := json.Marshal(parts)
if err != nil {
return nil, err
}
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
}
// Emit one function_call item per tool_call.
for _, tc := range m.ToolCalls {
args := tc.Function.Arguments
@ -325,7 +342,14 @@ func marshalChatInputContent(content chatMessageContent) (json.RawMessage, error
if content.Text != nil {
return json.Marshal(*content.Text)
}
return json.Marshal(convertChatContentPartsToResponses(content.Parts))
parts := convertChatContentPartsToResponses(content.Parts)
if len(parts) == 0 {
// A nil slice marshals to JSON null, which the upstream Responses API
// rejects ("expected an array of objects or string, but got null").
// Fall back to an empty string when no usable parts remain.
return json.Marshal("")
}
return json.Marshal(parts)
}
func convertChatContentPartsToResponses(parts []ChatContentPart) []ResponsesContentPart {

View File

@ -75,6 +75,28 @@ type AnthropicContentBlock struct {
IsError bool `json:"is_error,omitempty"`
}
func (b AnthropicContentBlock) MarshalJSON() ([]byte, error) {
type anthropicContentBlock AnthropicContentBlock
base := struct {
anthropicContentBlock
}{anthropicContentBlock: anthropicContentBlock(b)}
switch b.Type {
case "text":
return json.Marshal(struct {
Text string `json:"text"`
anthropicContentBlock
}{Text: b.Text, anthropicContentBlock: anthropicContentBlock(b)})
case "thinking":
return json.Marshal(struct {
Thinking string `json:"thinking"`
anthropicContentBlock
}{Thinking: b.Thinking, anthropicContentBlock: anthropicContentBlock(b)})
default:
return json.Marshal(base)
}
}
// AnthropicImageSource describes the source data for an image content block.
type AnthropicImageSource struct {
Type string `json:"type"` // "base64"
@ -306,6 +328,37 @@ type ResponsesUsage struct {
OutputTokensDetails *ResponsesOutputTokensDetails `json:"output_tokens_details,omitempty"`
}
func (u *ResponsesUsage) UnmarshalJSON(data []byte) error {
type responsesUsageAlias ResponsesUsage
var aux struct {
responsesUsageAlias
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
PromptTokensDetails *ResponsesInputTokensDetails `json:"prompt_tokens_details,omitempty"`
CompletionTokensDetails *ResponsesOutputTokensDetails `json:"completion_tokens_details,omitempty"`
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
*u = ResponsesUsage(aux.responsesUsageAlias)
if u.InputTokens == 0 && aux.PromptTokens != 0 {
u.InputTokens = aux.PromptTokens
}
if u.OutputTokens == 0 && aux.CompletionTokens != 0 {
u.OutputTokens = aux.CompletionTokens
}
if u.InputTokensDetails == nil && aux.PromptTokensDetails != nil {
u.InputTokensDetails = aux.PromptTokensDetails
}
if u.OutputTokensDetails == nil && aux.CompletionTokensDetails != nil {
u.OutputTokensDetails = aux.CompletionTokensDetails
}
if u.TotalTokens == 0 && (u.InputTokens != 0 || u.OutputTokens != 0) {
u.TotalTokens = u.InputTokens + u.OutputTokens
}
return nil
}
// ResponsesInputTokensDetails breaks down input token usage.
type ResponsesInputTokensDetails struct {
CachedTokens int `json:"cached_tokens,omitempty"`

View File

@ -22,6 +22,7 @@ func DefaultModels() []Model {
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-flash-image", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-3.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},

View File

@ -15,6 +15,7 @@ var DefaultModels = []Model{
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
{ID: "gemini-2.5-flash-image", Type: "model", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: ""},
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-3.5-flash", Type: "model", DisplayName: "Gemini 3.5 Flash", CreatedAt: ""},
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},

View File

@ -17,7 +17,7 @@
// pensieve/short-term/maxims/preserve-existing-runtime-behavior-when-replacing-logic-in-stateful-systems
package openai_compat
// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的支持状态。
// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的有效支持状态。
//
// 仅用于 platform=openai + type=apikey 的账号;其他账号类型不应调用本包判定。
type AccountResponsesSupport int
@ -35,11 +35,43 @@ const (
ResponsesSupportNo
)
// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储探测结果的键名。
// ResponsesSupportMode 描述账号级 Responses API 路由覆盖模式。
type ResponsesSupportMode string
const (
// ResponsesSupportModeAuto 表示跟随自动探测结果。
ResponsesSupportModeAuto ResponsesSupportMode = "auto"
// ResponsesSupportModeForceResponses 强制使用 /v1/responses。
ResponsesSupportModeForceResponses ResponsesSupportMode = "force_responses"
// ResponsesSupportModeForceChatCompletions 强制使用 /v1/chat/completions。
ResponsesSupportModeForceChatCompletions ResponsesSupportMode = "force_chat_completions"
)
// ExtraKeyResponsesMode 是 accounts.extra JSON 中存储手动覆盖模式的键名。
// 值类型为 stringauto=跟随探测force_responses=强制 Responses
// force_chat_completions=强制 Chat Completions。
const ExtraKeyResponsesMode = "openai_responses_mode"
// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储自动探测结果的键名。
// 值类型为 booltrue=支持、false=不支持、键缺失=未探测。
const ExtraKeyResponsesSupported = "openai_responses_supported"
// ResolveResponsesSupport 从账号的 extra map 中读取探测标记。
// NormalizeResponsesSupportMode 归一化账号级 Responses API 路由覆盖模式。
// 缺失或非法值按 auto 处理,以保持存量行为。
func NormalizeResponsesSupportMode(mode string) ResponsesSupportMode {
switch ResponsesSupportMode(mode) {
case ResponsesSupportModeForceResponses:
return ResponsesSupportModeForceResponses
case ResponsesSupportModeForceChatCompletions:
return ResponsesSupportModeForceChatCompletions
default:
return ResponsesSupportModeAuto
}
}
// ResolveResponsesSupport 从账号的 extra map 中读取手动覆盖模式与探测标记。
//
// 标记缺失或类型不匹配时返回 ResponsesSupportUnknown——调用方应按
// "未探测=保留旧行为=走 Responses" 处理(参见 ShouldUseResponsesAPI
@ -47,6 +79,14 @@ func ResolveResponsesSupport(extra map[string]any) AccountResponsesSupport {
if extra == nil {
return ResponsesSupportUnknown
}
if mode, ok := extra[ExtraKeyResponsesMode].(string); ok {
switch NormalizeResponsesSupportMode(mode) {
case ResponsesSupportModeForceResponses:
return ResponsesSupportYes
case ResponsesSupportModeForceChatCompletions:
return ResponsesSupportNo
}
}
v, ok := extra[ExtraKeyResponsesSupported]
if !ok {
return ResponsesSupportUnknown

View File

@ -16,6 +16,12 @@ func TestResolveResponsesSupport(t *testing.T) {
{"value wrong type string", map[string]any{ExtraKeyResponsesSupported: "true"}, ResponsesSupportUnknown},
{"value wrong type number", map[string]any{ExtraKeyResponsesSupported: 1}, ResponsesSupportUnknown},
{"value nil", map[string]any{ExtraKeyResponsesSupported: nil}, ResponsesSupportUnknown},
{"force responses", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses)}, ResponsesSupportYes},
{"force chat completions", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions)}, ResponsesSupportNo},
{"auto follows probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeAuto), ExtraKeyResponsesSupported: false}, ResponsesSupportNo},
{"invalid mode follows probe", map[string]any{ExtraKeyResponsesMode: "bogus", ExtraKeyResponsesSupported: true}, ResponsesSupportYes},
{"force responses overrides probe false", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses), ExtraKeyResponsesSupported: false}, ResponsesSupportYes},
{"force chat completions overrides probe true", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions), ExtraKeyResponsesSupported: true}, ResponsesSupportNo},
}
for _, tc := range tests {
@ -42,6 +48,10 @@ func TestShouldUseResponsesAPI(t *testing.T) {
// 已探测:标记决定
{"explicitly supported", map[string]any{ExtraKeyResponsesSupported: true}, true},
{"explicitly unsupported", map[string]any{ExtraKeyResponsesSupported: false}, false},
// 手动覆盖:覆盖自动探测结果
{"force responses overrides unsupported probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses), ExtraKeyResponsesSupported: false}, true},
{"force chat completions overrides supported probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions), ExtraKeyResponsesSupported: true}, false},
}
for _, tc := range tests {
@ -53,3 +63,26 @@ func TestShouldUseResponsesAPI(t *testing.T) {
})
}
}
func TestNormalizeResponsesSupportMode(t *testing.T) {
tests := []struct {
name string
mode string
want ResponsesSupportMode
}{
{"empty", "", ResponsesSupportModeAuto},
{"auto", "auto", ResponsesSupportModeAuto},
{"force responses", "force_responses", ResponsesSupportModeForceResponses},
{"force chat completions", "force_chat_completions", ResponsesSupportModeForceChatCompletions},
{"invalid", "enabled", ResponsesSupportModeAuto},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := NormalizeResponsesSupportMode(tc.mode)
if got != tc.want {
t.Errorf("NormalizeResponsesSupportMode(%q) = %q, want %q", tc.mode, got, tc.want)
}
})
}
}

View File

@ -230,6 +230,20 @@ type UserDashboardStats struct {
// 性能指标
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
// 按"有效平台"维度拆分(与 ops 路径口径一致group.platform 优先,否则 account.platform
ByPlatform []PlatformDashboardStats `json:"by_platform,omitempty"`
}
// PlatformDashboardStats 单个平台的用量明细。
type PlatformDashboardStats struct {
Platform string `json:"platform"`
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
TotalActualCost float64 `json:"total_actual_cost"`
TodayRequests int64 `json:"today_requests"`
TodayTokens int64 `json:"today_tokens"`
TodayActualCost float64 `json:"today_actual_cost"`
}
// UsageLogFilters represents filters for usage log queries
@ -265,13 +279,22 @@ type UsageStats struct {
EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"`
}
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats struct {
UserID int64 `json:"user_id"`
// PlatformUsage 表示某用户/某 API key 在单个"有效平台"维度的用量明细。
// Platform 取值与 ops 路径口径一致:优先 groups.platform否则 accounts.platform。
type PlatformUsage struct {
Platform string `json:"platform"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats struct {
UserID int64 `json:"user_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
ByPlatform []PlatformUsage `json:"by_platform,omitempty"`
}
// BatchAPIKeyUsageStats represents usage stats for a single API key
type BatchAPIKeyUsageStats struct {
APIKeyID int64 `json:"api_key_id"`

Some files were not shown because too many files have changed in this diff Show More