chore: merge upstream v0.1.127 — keep omniroute customizations
Upstream highlights:
- v0.1.127 release (150 commits): channel-monitor 协议管理、OpenAI
Responses 路由配置、模型定价 LiteLLM 默认、payment 强制扫码、
钉钉 OAuth、用户用量按平台拆分、Ops 错误分类 SLA 调整、
Anthropic passthrough keepalive、Gemini chat completions 路由 ...
- 91da8159 feat(risk-control): 内容审计新增关键词拦截
- 3d22dd34 feat: gemini-3.5-flash 模型支持
Conflicts resolved:
- Dockerfile: keep pnpm pin to 9.15.9 (upstream pinned generic v9 floating).
- wire_gen.go: combine upstream NewSettingHandler(+userAttributeService)
with local NewOpsHandler(opsService, requestEventBus, opsLogBroadcaster).
Verified by re-running wire generate.
- scheduler_cache.go: keep both upstream openai_responses_{mode,supported}
keys and local model_rate_limits key in filterSchedulerExtra().
- gateway_service.go: keep local context-compression block; drop now-dead
setOpsUpstreamRequestBody call (upstream removed ops retry replay).
- docker-compose.yml: keep local windsurf-ls service profile and named
volumes; keep local healthcheck start_period values.
Test mock signatures bumped to match current constructors:
- gateway_models_test.go: add nil for RPMTokenBucketService.
- account_handler_available_models_test.go: add nil for windsurfChatService.
This commit is contained in:
commit
158785bfc9
@ -22,8 +22,8 @@ RUN sed -i 's#https://dl-cdn.alpinelinux.org/alpine#https://mirrors.aliyun.com/a
|
||||
|
||||
WORKDIR /app/frontend
|
||||
|
||||
# Install pnpm. Keep this on v9 so Docker builds do not fail on pnpm v10's
|
||||
# interactive build-script approval flow.
|
||||
# Install pnpm. Pin to a specific v9 patch to dodge pnpm v10's interactive
|
||||
# build-script approval flow and keep Docker builds reproducible.
|
||||
RUN corepack enable && corepack prepare pnpm@9.15.9 --activate
|
||||
|
||||
# Install dependencies first (better caching)
|
||||
|
||||
24
README.md
24
README.md
@ -62,13 +62,18 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://poixe.com/i/sub2api"><img src="assets/partners/logos/poixe.png" alt="PoixeAi" width="150"></a></td>
|
||||
<td>Thanks to Poixe Ai for sponsoring this project! Poixe AI provides reliable LLM API services. You can leverage the platform's API endpoints to seamlessly build AI-powered products. Additionally, you can become a vendor by providing AI API resources to the platform and earn revenue. Register through the exclusive <a href="https://poixe.com/i/sub2api">sub2api</a> referral link and receive a bonus of $5 USD on your first top-up.</td>
|
||||
<td width="180"><a href="https://ctok.ai"><img src="assets/partners/logos/ctok.png" alt="CTok" width="150"></a></td>
|
||||
<td>Thanks to CTok.ai for sponsoring this project! CTok.ai is dedicated to building a one-stop AI programming tool service platform. We offer professional Claude Code packages and technical community services, with support for Google Gemini and OpenAI Codex. Through carefully designed plans and a professional tech community, we provide developers with reliable service guarantees and continuous technical support, making AI-assisted programming a true productivity tool. Click <a href="https://ctok.ai">here</a> to register!</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://ctok.ai"><img src="assets/partners/logos/ctok.png" alt="CTok" width="150"></a></td>
|
||||
<td>Thanks to CTok.ai for sponsoring this project! CTok.ai is dedicated to building a one-stop AI programming tool service platform. We offer professional Claude Code packages and technical community services, with support for Google Gemini and OpenAI Codex. Through carefully designed plans and a professional tech community, we provide developers with reliable service guarantees and continuous technical support, making AI-assisted programming a true productivity tool. Click <a href="https://ctok.ai">here</a> to register!</td>
|
||||
<td width="180"><a href="https://aigocode.com/invite/SUB2API"><img src="assets/partners/logos/aigocode.png" alt="AIGoCode" width="150"></a></td>
|
||||
<td>Thanks to AIGoCode for sponsoring this project! AIGoCode is an all-in-one platform that integrates Claude Code, Codex, and the latest Gemini models, providing you with stable, efficient, and highly cost-effective AI coding services. The platform offers flexible subscription plans, zero risk of account suspension, direct access with no VPN required, and lightning-fast responses. AIGoCode has prepared a special benefit for sub2api users: if you register via <a href="https://aigocode.com/invite/SUB2API">this link</a>, you'll receive an extra 10% bonus credit on your first top-up!</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://apikey.fun/register?aff=SUB2API"><img src="assets/partners/logos/apikey-fun.png" alt="APIKEY.FUN" width="150"></a></td>
|
||||
<td>Thanks to APIKEY.FUN for sponsoring this project! <a href="https://apikey.fun/register?aff=SUB2API">APIKEY.FUN</a> is one of the core contributors to the sub2api open-source project, dedicated to providing open, stable, and cost-effective AI API access. The platform supports API relay services for Claude, OpenAI, Gemini, and other popular models, with pricing starting from as low as 7% of the original rate. Register via the exclusive link: <a href="https://apikey.fun/register?aff=SUB2API">APIKEY</a> to enjoy a permanent 5% discount on all recharges.</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
@ -86,11 +91,6 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
|
||||
<td>Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for sub2api users: register via <a href="https://www.aicodemirror.com/register?invitecode=KMVZQM">this link</a> to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off!</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://aigocode.com/invite/SUB2API"><img src="assets/partners/logos/aigocode.png" alt="AIGoCode" width="150"></a></td>
|
||||
<td>Thanks to AIGoCode for sponsoring this project! AIGoCode is an all-in-one platform that integrates Claude Code, Codex, and the latest Gemini models, providing you with stable, efficient, and highly cost-effective AI coding services. The platform offers flexible subscription plans, zero risk of account suspension, direct access with no VPN required, and lightning-fast responses. AIGoCode has prepared a special benefit for sub2api users: if you register via <a href="https://aigocode.com/invite/SUB2API">this link</a>, you'll receive an extra 10% bonus credit on your first top-up!</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="assets/partners/logos/bmoplus.jpg" alt="bmoplus" width="150"></a></td>
|
||||
<td>Huge thanks to BmoPlus for sponsoring this project! BmoPlus is a highly reliable AI account provider built strictly for heavy AI users and developers. They offer rock-solid, ready-to-use accounts and official top-up services for ChatGPT Plus / ChatGPT Pro (Full Warranty) / Claude Pro / Super Grok / Gemini Pro. By registering and ordering through <a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus - Premium AI Accounts & Top-ups</a>, users can unlock the mind-blowing rate of 10% of the official GPT subscription price (90% OFF)</td>
|
||||
@ -108,6 +108,12 @@ Enterprise-grade high concurrency is also supported, with a dedicated management
|
||||
Register now via <a href="https://pateway.ai/?ch=1tsfr51">this link</a> to receive $3 in trial credits. User top-ups start as low as 60% off, and referring friends earns both parties rewards — referral bonuses up to $150.</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://api.pptoken.org/register?promo=SUB2API"><img src="assets/partners/logos/pptoken.png" alt="pptoken" width="150"></a></td>
|
||||
<td>Thanks to PPToken.org for sponsoring this project! <a href="https://api.pptoken.org/register?promo=SUB2API">PPToken.org</a> specializes in GPT model API relay services, supporting Codex, Claude Code, OpenAI-compatible clients, and Gemini CLI integration. Top-ups are 1:1 (¥1 = $1 credit); GPT models start at 0.16x rate multiplier, with overall cost at roughly 2.2% of official pricing and first-token latency around 1 second — ideal for developers seeking low-cost, high-speed access to GPT model capabilities. Technical support: 24/7 real human responses (no bots), @tech in the group chat and get a reply within 10 minutes. Sponsor benefit: the first 200 users who register via the <a href="https://api.pptoken.org/register?promo=SUB2API">exclusive registration link</a> and enter promo code `SUB2API` can claim free Codex / Claude Code trial credits — no minimum spend, no card required.
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
|
||||
## Ecosystem
|
||||
|
||||
23
README_CN.md
23
README_CN.md
@ -61,13 +61,18 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://poixe.com/i/sub2api"><img src="assets/partners/logos/poixe.png" alt="PoixeAI" width="150"></a></td>
|
||||
<td>感谢 Poixe AI 赞助了本项目!Poixe AI 提供可靠的 AI 模型接口服务,您可以使用平台提供的 LLM API 接口轻松构建 AI 产品,同时也可以成为供应商,为平台提供大模型资源以赚取收益。通过 <a href="https://poixe.com/i/sub2api">此链接</a> 专属链接注册,充值额外赠送 $5 美金</td>
|
||||
<td width="180"><a href="https://ctok.ai"><img src="assets/partners/logos/ctok.png" alt="CTok" width="150"></a></td>
|
||||
<td>感谢 CTok.ai 赞助了本项目!CTok.ai 致力于打造一站式 AI 编程工具服务平台。我们提供 Claude Code 专业套餐及技术社群服务,同时支持 Google Gemini 和 OpenAI Codex。通过精心设计的套餐方案和专业的技术社群,为开发者提供稳定的服务保障和持续的技术支持,让 AI 辅助编程真正成为开发者的生产力工具。点击<a href="https://ctok.ai">这里</a>注册!</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://ctok.ai"><img src="assets/partners/logos/ctok.png" alt="CTok" width="150"></a></td>
|
||||
<td>感谢 CTok.ai 赞助了本项目!CTok.ai 致力于打造一站式 AI 编程工具服务平台。我们提供 Claude Code 专业套餐及技术社群服务,同时支持 Google Gemini 和 OpenAI Codex。通过精心设计的套餐方案和专业的技术社群,为开发者提供稳定的服务保障和持续的技术支持,让 AI 辅助编程真正成为开发者的生产力工具。点击<a href="https://ctok.ai">这里</a>注册!</td>
|
||||
<td width="180"><a href="https://aigocode.com/invite/SUB2API"><img src="assets/partners/logos/aigocode.png" alt="AIGoCode" width="150"></a></td>
|
||||
<td>感谢 AIGoCode 赞助了本项目!AIGoCode 是一站式集成 Claude Code、Codex 以及最新 Gemini 模型的综合平台,为您提供稳定、高效、高性价比的 AI 编程服务。平台提供灵活的订阅方案,零封号风险,免 VPN 直连,响应极速。AIGoCode 为 sub2api 用户准备了专属福利:通过<a href="https://aigocode.com/invite/SUB2API">此链接</a>注册,首次充值可额外获得 10% 赠送额度!</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://apikey.fun/register?aff=SUB2API"><img src="assets/partners/logos/apikey-fun.png" alt="APIKEY.FUN" width="150"></a></td>
|
||||
<td>感谢 APIKEY.FUN 赞助了本项目!<a href="https://apikey.fun/register?aff=SUB2API">APIKEY.FUN</a> 是 sub2api 开源项目的核心贡献者之一,致力于提供开放、稳定、高性价比的 AI API 接入服务。平台支持 Claude、OpenAI、Gemini 等热门模型的 API 中转服务,价格低至官方原价的 7%。通过专属链接 <a href="https://apikey.fun/register?aff=SUB2API">APIKEY</a> 注册,可享受所有充值永久 95 折优惠。</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
@ -85,11 +90,6 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
|
||||
<td>感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定性中转服务,企业级并发、快速开票、7×24 小时专属技术支持。Claude Code / Codex / Gemini 官方通道低至原价 38% / 2% / 9%,充值更享额外折扣!AICodeMirror 为 sub2api 用户提供专属福利:通过<a href="https://www.aicodemirror.com/register?invitecode=KMVZQM">此链接</a>注册,首次充值立享 8 折优惠,企业客户最高可享 75 折!</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://aigocode.com/invite/SUB2API"><img src="assets/partners/logos/aigocode.png" alt="AIGoCode" width="150"></a></td>
|
||||
<td>感谢 AIGoCode 赞助了本项目!AIGoCode 是一站式集成 Claude Code、Codex 以及最新 Gemini 模型的综合平台,为您提供稳定、高效、高性价比的 AI 编程服务。平台提供灵活的订阅方案,零封号风险,免 VPN 直连,响应极速。AIGoCode 为 sub2api 用户准备了专属福利:通过<a href="https://aigocode.com/invite/SUB2API">此链接</a>注册,首次充值可额外获得 10% 赠送额度!</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="assets/partners/logos/bmoplus.jpg" alt="bmoplus" width="150"></a></td>
|
||||
<td>感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AI成品号专卖/代充</a>注册下单的用户,可享GPT 官网订阅一折 的震撼价格!</td>
|
||||
@ -107,6 +107,11 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
|
||||
现在通过 <a href="https://pateway.ai/?ch=1tsfr51">此链接</a> 注册即送 $3 试用额度,用户充值低至 6 折,邀请好友双向赠送,邀请奖励可达 $150。</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://api.pptoken.org/register?promo=SUB2API"><img src="assets/partners/logos/pptoken.png" alt="pptoken" width="150"></a></td>
|
||||
<td>感谢 PPToken.org 赞助本项目! <a href="https://api.pptoken.org/register?promo=SUB2API">PPToken.org</a> 主打 GPT 系列模型 API 中转服务,支持 Codex、Claude Code、OpenAI 兼容客户端及 Gemini CLI 等工具接入。充值 1:1,1 元=1 美元额度;GPT 模型最低 0.16 倍倍率,综合成本约为官方价格的 0.22 折,最快首字 Token 约 1 秒,适合开发者低成本、高响应速度接入 GPT 模型能力。技术支持: 7×24 小时真人响应(不是机器人),群内@技术,10 分钟内有回复 。赞助商福利:前 200 名用户通过 <a href="https://api.pptoken.org/register?promo=SUB2API">[专属注册链接]</a> 注册,输入优惠码 `SUB2API`,可领取 Codex / Claude Code 免费试用额度,无门槛、不绑卡。
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## 生态项目
|
||||
|
||||
24
README_JA.md
24
README_JA.md
@ -61,13 +61,18 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://poixe.com/i/sub2api"><img src="assets/partners/logos/poixe.png" alt="PoixeAi" width="150"></a></td>
|
||||
<td>Poixe AI のご支援に感謝します!Poixe AI は信頼性の高い LLM API サービスを提供しています。プラットフォームの API エンドポイントを活用して、AI 搭載プロダクトをシームレスに構築できます。また、ベンダーとして AI API リソースをプラットフォームに提供し、収益を得ることも可能です。専用の <a href="https://poixe.com/i/sub2api">sub2api</a> 紹介リンクから登録すると、初回チャージ時に $5 USD のボーナスがもらえます。</td>
|
||||
<td width="180"><a href="https://ctok.ai"><img src="assets/partners/logos/ctok.png" alt="CTok" width="150"></a></td>
|
||||
<td>CTok.ai のご支援に感謝します!CTok.ai はワンストップ AI プログラミングツールサービスプラットフォームの構築に取り組んでいます。Claude Code の専用プランと技術コミュニティサービスを提供し、Google Gemini や OpenAI Codex もサポートしています。丁寧に設計されたプランと専門的な技術コミュニティを通じて、開発者に安定したサービス保証と継続的な技術サポートを提供し、AI アシスト プログラミングを真の生産性向上ツールにします。<a href="https://ctok.ai">こちら</a>から登録!</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://ctok.ai"><img src="assets/partners/logos/ctok.png" alt="CTok" width="150"></a></td>
|
||||
<td>CTok.ai のご支援に感謝します!CTok.ai はワンストップ AI プログラミングツールサービスプラットフォームの構築に取り組んでいます。Claude Code の専用プランと技術コミュニティサービスを提供し、Google Gemini や OpenAI Codex もサポートしています。丁寧に設計されたプランと専門的な技術コミュニティを通じて、開発者に安定したサービス保証と継続的な技術サポートを提供し、AI アシスト プログラミングを真の生産性向上ツールにします。<a href="https://ctok.ai">こちら</a>から登録!</td>
|
||||
<td width="180"><a href="https://aigocode.com/invite/SUB2API"><img src="assets/partners/logos/aigocode.png" alt="AIGoCode" width="150"></a></td>
|
||||
<td>AIGoCode のご支援に感謝します!AIGoCode は Claude Code、Codex、最新の Gemini モデルを統合したオールインワンプラットフォームで、安定的かつ効率的でコストパフォーマンスに優れた AI コーディングサービスを提供します。柔軟なサブスクリプションプラン、アカウント停止リスクゼロ、VPN 不要の直接アクセス、超高速レスポンスが特長です。AIGoCode は sub2api ユーザー向けに特別特典を用意しています:<a href="https://aigocode.com/invite/SUB2API">こちらのリンク</a>から登録すると、初回チャージ時に 10% のボーナスクレジットを追加プレゼント!</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://apikey.fun/register?aff=SUB2API"><img src="assets/partners/logos/apikey-fun.png" alt="APIKEY.FUN" width="150"></a></td>
|
||||
<td>APIKEY.FUN のご支援に感謝します!<a href="https://apikey.fun/register?aff=SUB2API">APIKEY.FUN</a> は sub2api オープンソースプロジェクトのコアコントリビューターの一つであり、オープンで安定した、コストパフォーマンスに優れた AI API アクセスサービスの提供に取り組んでいます。プラットフォームは Claude、OpenAI、Gemini など人気モデルの API 中継サービスをサポートし、価格は公式料金のわずか 7% から。専用リンク <a href="https://apikey.fun/register?aff=SUB2API">APIKEY</a> から登録すると、すべてのチャージで永久 5% 割引をご利用いただけます。</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
@ -85,11 +90,6 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
|
||||
<td>AICodeMirror のご支援に感謝します!AICodeMirror は Claude Code / Codex / Gemini CLI の公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時実行、迅速な請求書発行、24時間年中無休の専属テクニカルサポートを備えています。Claude Code / Codex / Gemini の公式チャネルを定価の 38% / 2% / 9% で利用可能、チャージ時にはさらに追加割引!AICodeMirror は sub2api ユーザー向けに特別特典を提供中:<a href="https://www.aicodemirror.com/register?invitecode=KMVZQM">こちらのリンク</a>から登録すると、初回チャージが 20% オフ、法人のお客様は最大 25% オフ!</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://aigocode.com/invite/SUB2API"><img src="assets/partners/logos/aigocode.png" alt="AIGoCode" width="150"></a></td>
|
||||
<td>AIGoCode のご支援に感謝します!AIGoCode は Claude Code、Codex、最新の Gemini モデルを統合したオールインワンプラットフォームで、安定的かつ効率的でコストパフォーマンスに優れた AI コーディングサービスを提供します。柔軟なサブスクリプションプラン、アカウント停止リスクゼロ、VPN 不要の直接アクセス、超高速レスポンスが特長です。AIGoCode は sub2api ユーザー向けに特別特典を用意しています:<a href="https://aigocode.com/invite/SUB2API">こちらのリンク</a>から登録すると、初回チャージ時に 10% のボーナスクレジットを追加プレゼント!</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="assets/partners/logos/bmoplus.jpg" alt="bmoplus" width="150"></a></td>
|
||||
<td>本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらの<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AIアカウント専門店/代行チャージ</a>経由でご登録・ご注文いただいたユーザー様は、GPTを 公式サイト価格の約1割(90% OFF) という驚異的な価格でご利用いただけます!</td>
|
||||
@ -107,6 +107,12 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
|
||||
<a href="https://pateway.ai/?ch=1tsfr51">こちらのリンク</a>から登録すると、$3 のトライアルクレジットがもらえます。チャージは最大40%オフ、友達紹介で双方にボーナス付与 — 紹介報酬は最大 $150。</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td width="180"><a href="https://api.pptoken.org/register?promo=SUB2API"><img src="assets/partners/logos/pptoken.png" alt="pptoken" width="150"></a></td>
|
||||
<td>PPToken.org のご支援に感謝します!<a href="https://api.pptoken.org/register?promo=SUB2API">PPToken.org</a> は GPT シリーズモデルの API 中継サービスを専門としており、Codex、Claude Code、OpenAI 互換クライアント、Gemini CLI などのツール接続をサポートしています。チャージは 1:1(1元=1ドル分のクレジット)、GPT モデルは最低 0.16 倍のレート倍率で、総合コストは公式価格の約 2.2% 、最速ファーストトークンは約1秒 — 開発者が低コスト・高速レスポンスで GPT モデル機能にアクセスするのに最適です。テクニカルサポート:24時間365日リアルな人間が対応(ボットではありません)、グループ内で @技術 すれば 10 分以内に返信。スポンサー特典:先着 200 名のユーザーが<a href="https://api.pptoken.org/register?promo=SUB2API">専用登録リンク</a>から登録し、プロモコード `SUB2API` を入力すると、Codex / Claude Code の無料トライアルクレジットを獲得できます — 最低利用額なし、カード登録不要。
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
|
||||
## エコシステム
|
||||
|
||||
BIN
assets/partners/logos/apikey-fun.png
Normal file
BIN
assets/partners/logos/apikey-fun.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 33 KiB |
BIN
assets/partners/logos/pptoken.png
Normal file
BIN
assets/partners/logos/pptoken.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 69 KiB |
@ -1 +1 @@
|
||||
0.1.126
|
||||
0.1.127
|
||||
|
||||
@ -81,7 +81,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
}
|
||||
totpCache := repository.NewTotpCache(redisClient)
|
||||
totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService)
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService)
|
||||
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
|
||||
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
|
||||
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
|
||||
authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService, userAttributeService)
|
||||
userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||
@ -202,7 +205,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
registry := payment.ProvideRegistry()
|
||||
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
|
||||
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
|
||||
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService)
|
||||
requestEventBus := service.NewRequestEventBus()
|
||||
opsLogBroadcaster := service.ProvideOpsLogBroadcaster()
|
||||
opsHandler := admin.NewOpsHandler(opsService, requestEventBus, opsLogBroadcaster)
|
||||
@ -217,9 +220,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
usageCleanupRepository := repository.NewUsageCleanupRepository(client, db)
|
||||
usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig)
|
||||
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService, usageCleanupService)
|
||||
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
|
||||
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
|
||||
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
|
||||
userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService)
|
||||
errorPassthroughRepository := repository.NewErrorPassthroughRepository(client)
|
||||
errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient)
|
||||
@ -231,7 +231,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
scheduledTestResultRepository := repository.NewScheduledTestResultRepository(db)
|
||||
scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository)
|
||||
scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService)
|
||||
channelHandler := admin.NewChannelHandler(channelService, billingService)
|
||||
channelHandler := admin.NewChannelHandler(channelService, billingService, pricingService)
|
||||
channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService)
|
||||
channelMonitorRequestTemplateRepository := repository.NewChannelMonitorRequestTemplateRepository(client, db)
|
||||
channelMonitorRequestTemplateService := service.NewChannelMonitorRequestTemplateService(channelMonitorRequestTemplateRepository)
|
||||
|
||||
@ -27,6 +27,8 @@ type ChannelMonitor struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
// Provider holds the value of the "provider" field.
|
||||
Provider channelmonitor.Provider `json:"provider,omitempty"`
|
||||
// OpenAI request protocol: chat_completions or responses; non-OpenAI uses chat_completions
|
||||
APIMode string `json:"api_mode,omitempty"`
|
||||
// Provider base origin, e.g. https://api.openai.com
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
// AES-256-GCM encrypted API key
|
||||
@ -112,7 +114,7 @@ func (*ChannelMonitor) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new(sql.NullBool)
|
||||
case channelmonitor.FieldID, channelmonitor.FieldIntervalSeconds, channelmonitor.FieldCreatedBy, channelmonitor.FieldTemplateID:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case channelmonitor.FieldName, channelmonitor.FieldProvider, channelmonitor.FieldEndpoint, channelmonitor.FieldAPIKeyEncrypted, channelmonitor.FieldPrimaryModel, channelmonitor.FieldGroupName, channelmonitor.FieldBodyOverrideMode:
|
||||
case channelmonitor.FieldName, channelmonitor.FieldProvider, channelmonitor.FieldAPIMode, channelmonitor.FieldEndpoint, channelmonitor.FieldAPIKeyEncrypted, channelmonitor.FieldPrimaryModel, channelmonitor.FieldGroupName, channelmonitor.FieldBodyOverrideMode:
|
||||
values[i] = new(sql.NullString)
|
||||
case channelmonitor.FieldCreatedAt, channelmonitor.FieldUpdatedAt, channelmonitor.FieldLastCheckedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@ -161,6 +163,12 @@ func (_m *ChannelMonitor) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.Provider = channelmonitor.Provider(value.String)
|
||||
}
|
||||
case channelmonitor.FieldAPIMode:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field api_mode", values[i])
|
||||
} else if value.Valid {
|
||||
_m.APIMode = value.String
|
||||
}
|
||||
case channelmonitor.FieldEndpoint:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field endpoint", values[i])
|
||||
@ -310,6 +318,9 @@ func (_m *ChannelMonitor) String() string {
|
||||
builder.WriteString("provider=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Provider))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("api_mode=")
|
||||
builder.WriteString(_m.APIMode)
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("endpoint=")
|
||||
builder.WriteString(_m.Endpoint)
|
||||
builder.WriteString(", ")
|
||||
|
||||
@ -23,6 +23,8 @@ const (
|
||||
FieldName = "name"
|
||||
// FieldProvider holds the string denoting the provider field in the database.
|
||||
FieldProvider = "provider"
|
||||
// FieldAPIMode holds the string denoting the api_mode field in the database.
|
||||
FieldAPIMode = "api_mode"
|
||||
// FieldEndpoint holds the string denoting the endpoint field in the database.
|
||||
FieldEndpoint = "endpoint"
|
||||
// FieldAPIKeyEncrypted holds the string denoting the api_key_encrypted field in the database.
|
||||
@ -87,6 +89,7 @@ var Columns = []string{
|
||||
FieldUpdatedAt,
|
||||
FieldName,
|
||||
FieldProvider,
|
||||
FieldAPIMode,
|
||||
FieldEndpoint,
|
||||
FieldAPIKeyEncrypted,
|
||||
FieldPrimaryModel,
|
||||
@ -121,6 +124,10 @@ var (
|
||||
UpdateDefaultUpdatedAt func() time.Time
|
||||
// NameValidator is a validator for the "name" field. It is called by the builders before save.
|
||||
NameValidator func(string) error
|
||||
// DefaultAPIMode holds the default value on creation for the "api_mode" field.
|
||||
DefaultAPIMode string
|
||||
// APIModeValidator is a validator for the "api_mode" field. It is called by the builders before save.
|
||||
APIModeValidator func(string) error
|
||||
// EndpointValidator is a validator for the "endpoint" field. It is called by the builders before save.
|
||||
EndpointValidator func(string) error
|
||||
// APIKeyEncryptedValidator is a validator for the "api_key_encrypted" field. It is called by the builders before save.
|
||||
@ -197,6 +204,11 @@ func ByProvider(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldProvider, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAPIMode orders the results by the api_mode field.
|
||||
func ByAPIMode(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldAPIMode, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByEndpoint orders the results by the endpoint field.
|
||||
func ByEndpoint(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldEndpoint, opts...).ToFunc()
|
||||
|
||||
@ -70,6 +70,11 @@ func Name(v string) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldEQ(FieldName, v))
|
||||
}
|
||||
|
||||
// APIMode applies equality check predicate on the "api_mode" field. It's identical to APIModeEQ.
|
||||
func APIMode(v string) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldEQ(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// Endpoint applies equality check predicate on the "endpoint" field. It's identical to EndpointEQ.
|
||||
func Endpoint(v string) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldEQ(FieldEndpoint, v))
|
||||
@ -285,6 +290,71 @@ func ProviderNotIn(vs ...Provider) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldNotIn(FieldProvider, vs...))
|
||||
}
|
||||
|
||||
// APIModeEQ applies the EQ predicate on the "api_mode" field.
|
||||
func APIModeEQ(v string) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldEQ(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeNEQ applies the NEQ predicate on the "api_mode" field.
|
||||
func APIModeNEQ(v string) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldNEQ(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeIn applies the In predicate on the "api_mode" field.
|
||||
func APIModeIn(vs ...string) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldIn(FieldAPIMode, vs...))
|
||||
}
|
||||
|
||||
// APIModeNotIn applies the NotIn predicate on the "api_mode" field.
|
||||
func APIModeNotIn(vs ...string) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldNotIn(FieldAPIMode, vs...))
|
||||
}
|
||||
|
||||
// APIModeGT applies the GT predicate on the "api_mode" field.
|
||||
func APIModeGT(v string) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldGT(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeGTE applies the GTE predicate on the "api_mode" field.
|
||||
func APIModeGTE(v string) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldGTE(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeLT applies the LT predicate on the "api_mode" field.
|
||||
func APIModeLT(v string) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldLT(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeLTE applies the LTE predicate on the "api_mode" field.
|
||||
func APIModeLTE(v string) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldLTE(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeContains applies the Contains predicate on the "api_mode" field.
|
||||
func APIModeContains(v string) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldContains(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeHasPrefix applies the HasPrefix predicate on the "api_mode" field.
|
||||
func APIModeHasPrefix(v string) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldHasPrefix(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeHasSuffix applies the HasSuffix predicate on the "api_mode" field.
|
||||
func APIModeHasSuffix(v string) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldHasSuffix(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeEqualFold applies the EqualFold predicate on the "api_mode" field.
|
||||
func APIModeEqualFold(v string) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldEqualFold(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeContainsFold applies the ContainsFold predicate on the "api_mode" field.
|
||||
func APIModeContainsFold(v string) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldContainsFold(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// EndpointEQ applies the EQ predicate on the "endpoint" field.
|
||||
func EndpointEQ(v string) predicate.ChannelMonitor {
|
||||
return predicate.ChannelMonitor(sql.FieldEQ(FieldEndpoint, v))
|
||||
|
||||
@ -65,6 +65,20 @@ func (_c *ChannelMonitorCreate) SetProvider(v channelmonitor.Provider) *ChannelM
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetAPIMode sets the "api_mode" field.
|
||||
func (_c *ChannelMonitorCreate) SetAPIMode(v string) *ChannelMonitorCreate {
|
||||
_c.mutation.SetAPIMode(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
|
||||
func (_c *ChannelMonitorCreate) SetNillableAPIMode(v *string) *ChannelMonitorCreate {
|
||||
if v != nil {
|
||||
_c.SetAPIMode(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetEndpoint sets the "endpoint" field.
|
||||
func (_c *ChannelMonitorCreate) SetEndpoint(v string) *ChannelMonitorCreate {
|
||||
_c.mutation.SetEndpoint(v)
|
||||
@ -275,6 +289,10 @@ func (_c *ChannelMonitorCreate) defaults() {
|
||||
v := channelmonitor.DefaultUpdatedAt()
|
||||
_c.mutation.SetUpdatedAt(v)
|
||||
}
|
||||
if _, ok := _c.mutation.APIMode(); !ok {
|
||||
v := channelmonitor.DefaultAPIMode
|
||||
_c.mutation.SetAPIMode(v)
|
||||
}
|
||||
if _, ok := _c.mutation.ExtraModels(); !ok {
|
||||
v := channelmonitor.DefaultExtraModels
|
||||
_c.mutation.SetExtraModels(v)
|
||||
@ -321,6 +339,14 @@ func (_c *ChannelMonitorCreate) check() error {
|
||||
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.APIMode(); !ok {
|
||||
return &ValidationError{Name: "api_mode", err: errors.New(`ent: missing required field "ChannelMonitor.api_mode"`)}
|
||||
}
|
||||
if v, ok := _c.mutation.APIMode(); ok {
|
||||
if err := channelmonitor.APIModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.Endpoint(); !ok {
|
||||
return &ValidationError{Name: "endpoint", err: errors.New(`ent: missing required field "ChannelMonitor.endpoint"`)}
|
||||
}
|
||||
@ -421,6 +447,10 @@ func (_c *ChannelMonitorCreate) createSpec() (*ChannelMonitor, *sqlgraph.CreateS
|
||||
_spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value)
|
||||
_node.Provider = value
|
||||
}
|
||||
if value, ok := _c.mutation.APIMode(); ok {
|
||||
_spec.SetField(channelmonitor.FieldAPIMode, field.TypeString, value)
|
||||
_node.APIMode = value
|
||||
}
|
||||
if value, ok := _c.mutation.Endpoint(); ok {
|
||||
_spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value)
|
||||
_node.Endpoint = value
|
||||
@ -606,6 +636,18 @@ func (u *ChannelMonitorUpsert) UpdateProvider() *ChannelMonitorUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetAPIMode sets the "api_mode" field.
|
||||
func (u *ChannelMonitorUpsert) SetAPIMode(v string) *ChannelMonitorUpsert {
|
||||
u.Set(channelmonitor.FieldAPIMode, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
|
||||
func (u *ChannelMonitorUpsert) UpdateAPIMode() *ChannelMonitorUpsert {
|
||||
u.SetExcluded(channelmonitor.FieldAPIMode)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetEndpoint sets the "endpoint" field.
|
||||
func (u *ChannelMonitorUpsert) SetEndpoint(v string) *ChannelMonitorUpsert {
|
||||
u.Set(channelmonitor.FieldEndpoint, v)
|
||||
@ -885,6 +927,20 @@ func (u *ChannelMonitorUpsertOne) UpdateProvider() *ChannelMonitorUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetAPIMode sets the "api_mode" field.
|
||||
func (u *ChannelMonitorUpsertOne) SetAPIMode(v string) *ChannelMonitorUpsertOne {
|
||||
return u.Update(func(s *ChannelMonitorUpsert) {
|
||||
s.SetAPIMode(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
|
||||
func (u *ChannelMonitorUpsertOne) UpdateAPIMode() *ChannelMonitorUpsertOne {
|
||||
return u.Update(func(s *ChannelMonitorUpsert) {
|
||||
s.UpdateAPIMode()
|
||||
})
|
||||
}
|
||||
|
||||
// SetEndpoint sets the "endpoint" field.
|
||||
func (u *ChannelMonitorUpsertOne) SetEndpoint(v string) *ChannelMonitorUpsertOne {
|
||||
return u.Update(func(s *ChannelMonitorUpsert) {
|
||||
@ -1362,6 +1418,20 @@ func (u *ChannelMonitorUpsertBulk) UpdateProvider() *ChannelMonitorUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetAPIMode sets the "api_mode" field.
|
||||
func (u *ChannelMonitorUpsertBulk) SetAPIMode(v string) *ChannelMonitorUpsertBulk {
|
||||
return u.Update(func(s *ChannelMonitorUpsert) {
|
||||
s.SetAPIMode(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
|
||||
func (u *ChannelMonitorUpsertBulk) UpdateAPIMode() *ChannelMonitorUpsertBulk {
|
||||
return u.Update(func(s *ChannelMonitorUpsert) {
|
||||
s.UpdateAPIMode()
|
||||
})
|
||||
}
|
||||
|
||||
// SetEndpoint sets the "endpoint" field.
|
||||
func (u *ChannelMonitorUpsertBulk) SetEndpoint(v string) *ChannelMonitorUpsertBulk {
|
||||
return u.Update(func(s *ChannelMonitorUpsert) {
|
||||
|
||||
@ -66,6 +66,20 @@ func (_u *ChannelMonitorUpdate) SetNillableProvider(v *channelmonitor.Provider)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAPIMode sets the "api_mode" field.
|
||||
func (_u *ChannelMonitorUpdate) SetAPIMode(v string) *ChannelMonitorUpdate {
|
||||
_u.mutation.SetAPIMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
|
||||
func (_u *ChannelMonitorUpdate) SetNillableAPIMode(v *string) *ChannelMonitorUpdate {
|
||||
if v != nil {
|
||||
_u.SetAPIMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetEndpoint sets the "endpoint" field.
|
||||
func (_u *ChannelMonitorUpdate) SetEndpoint(v string) *ChannelMonitorUpdate {
|
||||
_u.mutation.SetEndpoint(v)
|
||||
@ -418,6 +432,11 @@ func (_u *ChannelMonitorUpdate) check() error {
|
||||
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.APIMode(); ok {
|
||||
if err := channelmonitor.APIModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.Endpoint(); ok {
|
||||
if err := channelmonitor.EndpointValidator(v); err != nil {
|
||||
return &ValidationError{Name: "endpoint", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.endpoint": %w`, err)}
|
||||
@ -472,6 +491,9 @@ func (_u *ChannelMonitorUpdate) sqlSave(ctx context.Context) (_node int, err err
|
||||
if value, ok := _u.mutation.Provider(); ok {
|
||||
_spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value)
|
||||
}
|
||||
if value, ok := _u.mutation.APIMode(); ok {
|
||||
_spec.SetField(channelmonitor.FieldAPIMode, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Endpoint(); ok {
|
||||
_spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value)
|
||||
}
|
||||
@ -701,6 +723,20 @@ func (_u *ChannelMonitorUpdateOne) SetNillableProvider(v *channelmonitor.Provide
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAPIMode sets the "api_mode" field.
|
||||
func (_u *ChannelMonitorUpdateOne) SetAPIMode(v string) *ChannelMonitorUpdateOne {
|
||||
_u.mutation.SetAPIMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
|
||||
func (_u *ChannelMonitorUpdateOne) SetNillableAPIMode(v *string) *ChannelMonitorUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetAPIMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetEndpoint sets the "endpoint" field.
|
||||
func (_u *ChannelMonitorUpdateOne) SetEndpoint(v string) *ChannelMonitorUpdateOne {
|
||||
_u.mutation.SetEndpoint(v)
|
||||
@ -1066,6 +1102,11 @@ func (_u *ChannelMonitorUpdateOne) check() error {
|
||||
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.provider": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.APIMode(); ok {
|
||||
if err := channelmonitor.APIModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.api_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.Endpoint(); ok {
|
||||
if err := channelmonitor.EndpointValidator(v); err != nil {
|
||||
return &ValidationError{Name: "endpoint", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitor.endpoint": %w`, err)}
|
||||
@ -1137,6 +1178,9 @@ func (_u *ChannelMonitorUpdateOne) sqlSave(ctx context.Context) (_node *ChannelM
|
||||
if value, ok := _u.mutation.Provider(); ok {
|
||||
_spec.SetField(channelmonitor.FieldProvider, field.TypeEnum, value)
|
||||
}
|
||||
if value, ok := _u.mutation.APIMode(); ok {
|
||||
_spec.SetField(channelmonitor.FieldAPIMode, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Endpoint(); ok {
|
||||
_spec.SetField(channelmonitor.FieldEndpoint, field.TypeString, value)
|
||||
}
|
||||
|
||||
@ -26,6 +26,8 @@ type ChannelMonitorRequestTemplate struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
// Provider holds the value of the "provider" field.
|
||||
Provider channelmonitorrequesttemplate.Provider `json:"provider,omitempty"`
|
||||
// OpenAI request protocol: chat_completions or responses; non-OpenAI uses chat_completions
|
||||
APIMode string `json:"api_mode,omitempty"`
|
||||
// Description holds the value of the "description" field.
|
||||
Description string `json:"description,omitempty"`
|
||||
// ExtraHeaders holds the value of the "extra_headers" field.
|
||||
@ -67,7 +69,7 @@ func (*ChannelMonitorRequestTemplate) scanValues(columns []string) ([]any, error
|
||||
values[i] = new([]byte)
|
||||
case channelmonitorrequesttemplate.FieldID:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case channelmonitorrequesttemplate.FieldName, channelmonitorrequesttemplate.FieldProvider, channelmonitorrequesttemplate.FieldDescription, channelmonitorrequesttemplate.FieldBodyOverrideMode:
|
||||
case channelmonitorrequesttemplate.FieldName, channelmonitorrequesttemplate.FieldProvider, channelmonitorrequesttemplate.FieldAPIMode, channelmonitorrequesttemplate.FieldDescription, channelmonitorrequesttemplate.FieldBodyOverrideMode:
|
||||
values[i] = new(sql.NullString)
|
||||
case channelmonitorrequesttemplate.FieldCreatedAt, channelmonitorrequesttemplate.FieldUpdatedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@ -116,6 +118,12 @@ func (_m *ChannelMonitorRequestTemplate) assignValues(columns []string, values [
|
||||
} else if value.Valid {
|
||||
_m.Provider = channelmonitorrequesttemplate.Provider(value.String)
|
||||
}
|
||||
case channelmonitorrequesttemplate.FieldAPIMode:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field api_mode", values[i])
|
||||
} else if value.Valid {
|
||||
_m.APIMode = value.String
|
||||
}
|
||||
case channelmonitorrequesttemplate.FieldDescription:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field description", values[i])
|
||||
@ -197,6 +205,9 @@ func (_m *ChannelMonitorRequestTemplate) String() string {
|
||||
builder.WriteString("provider=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.Provider))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("api_mode=")
|
||||
builder.WriteString(_m.APIMode)
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("description=")
|
||||
builder.WriteString(_m.Description)
|
||||
builder.WriteString(", ")
|
||||
|
||||
@ -23,6 +23,8 @@ const (
|
||||
FieldName = "name"
|
||||
// FieldProvider holds the string denoting the provider field in the database.
|
||||
FieldProvider = "provider"
|
||||
// FieldAPIMode holds the string denoting the api_mode field in the database.
|
||||
FieldAPIMode = "api_mode"
|
||||
// FieldDescription holds the string denoting the description field in the database.
|
||||
FieldDescription = "description"
|
||||
// FieldExtraHeaders holds the string denoting the extra_headers field in the database.
|
||||
@ -51,6 +53,7 @@ var Columns = []string{
|
||||
FieldUpdatedAt,
|
||||
FieldName,
|
||||
FieldProvider,
|
||||
FieldAPIMode,
|
||||
FieldDescription,
|
||||
FieldExtraHeaders,
|
||||
FieldBodyOverrideMode,
|
||||
@ -76,6 +79,10 @@ var (
|
||||
UpdateDefaultUpdatedAt func() time.Time
|
||||
// NameValidator is a validator for the "name" field. It is called by the builders before save.
|
||||
NameValidator func(string) error
|
||||
// DefaultAPIMode holds the default value on creation for the "api_mode" field.
|
||||
DefaultAPIMode string
|
||||
// APIModeValidator is a validator for the "api_mode" field. It is called by the builders before save.
|
||||
APIModeValidator func(string) error
|
||||
// DefaultDescription holds the default value on creation for the "description" field.
|
||||
DefaultDescription string
|
||||
// DescriptionValidator is a validator for the "description" field. It is called by the builders before save.
|
||||
@ -140,6 +147,11 @@ func ByProvider(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldProvider, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByAPIMode orders the results by the api_mode field.
|
||||
func ByAPIMode(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldAPIMode, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByDescription orders the results by the description field.
|
||||
func ByDescription(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldDescription, opts...).ToFunc()
|
||||
|
||||
@ -70,6 +70,11 @@ func Name(v string) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldName, v))
|
||||
}
|
||||
|
||||
// APIMode applies equality check predicate on the "api_mode" field. It's identical to APIModeEQ.
|
||||
func APIMode(v string) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ.
|
||||
func Description(v string) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldDescription, v))
|
||||
@ -245,6 +250,71 @@ func ProviderNotIn(vs ...Provider) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldProvider, vs...))
|
||||
}
|
||||
|
||||
// APIModeEQ applies the EQ predicate on the "api_mode" field.
|
||||
func APIModeEQ(v string) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeNEQ applies the NEQ predicate on the "api_mode" field.
|
||||
func APIModeNEQ(v string) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldNEQ(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeIn applies the In predicate on the "api_mode" field.
|
||||
func APIModeIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldIn(FieldAPIMode, vs...))
|
||||
}
|
||||
|
||||
// APIModeNotIn applies the NotIn predicate on the "api_mode" field.
|
||||
func APIModeNotIn(vs ...string) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldNotIn(FieldAPIMode, vs...))
|
||||
}
|
||||
|
||||
// APIModeGT applies the GT predicate on the "api_mode" field.
|
||||
func APIModeGT(v string) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldGT(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeGTE applies the GTE predicate on the "api_mode" field.
|
||||
func APIModeGTE(v string) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldGTE(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeLT applies the LT predicate on the "api_mode" field.
|
||||
func APIModeLT(v string) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldLT(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeLTE applies the LTE predicate on the "api_mode" field.
|
||||
func APIModeLTE(v string) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldLTE(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeContains applies the Contains predicate on the "api_mode" field.
|
||||
func APIModeContains(v string) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldContains(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeHasPrefix applies the HasPrefix predicate on the "api_mode" field.
|
||||
func APIModeHasPrefix(v string) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldHasPrefix(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeHasSuffix applies the HasSuffix predicate on the "api_mode" field.
|
||||
func APIModeHasSuffix(v string) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldHasSuffix(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeEqualFold applies the EqualFold predicate on the "api_mode" field.
|
||||
func APIModeEqualFold(v string) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldEqualFold(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// APIModeContainsFold applies the ContainsFold predicate on the "api_mode" field.
|
||||
func APIModeContainsFold(v string) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldContainsFold(FieldAPIMode, v))
|
||||
}
|
||||
|
||||
// DescriptionEQ applies the EQ predicate on the "description" field.
|
||||
func DescriptionEQ(v string) predicate.ChannelMonitorRequestTemplate {
|
||||
return predicate.ChannelMonitorRequestTemplate(sql.FieldEQ(FieldDescription, v))
|
||||
|
||||
@ -63,6 +63,20 @@ func (_c *ChannelMonitorRequestTemplateCreate) SetProvider(v channelmonitorreque
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetAPIMode sets the "api_mode" field.
|
||||
func (_c *ChannelMonitorRequestTemplateCreate) SetAPIMode(v string) *ChannelMonitorRequestTemplateCreate {
|
||||
_c.mutation.SetAPIMode(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
|
||||
func (_c *ChannelMonitorRequestTemplateCreate) SetNillableAPIMode(v *string) *ChannelMonitorRequestTemplateCreate {
|
||||
if v != nil {
|
||||
_c.SetAPIMode(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (_c *ChannelMonitorRequestTemplateCreate) SetDescription(v string) *ChannelMonitorRequestTemplateCreate {
|
||||
_c.mutation.SetDescription(v)
|
||||
@ -161,6 +175,10 @@ func (_c *ChannelMonitorRequestTemplateCreate) defaults() {
|
||||
v := channelmonitorrequesttemplate.DefaultUpdatedAt()
|
||||
_c.mutation.SetUpdatedAt(v)
|
||||
}
|
||||
if _, ok := _c.mutation.APIMode(); !ok {
|
||||
v := channelmonitorrequesttemplate.DefaultAPIMode
|
||||
_c.mutation.SetAPIMode(v)
|
||||
}
|
||||
if _, ok := _c.mutation.Description(); !ok {
|
||||
v := channelmonitorrequesttemplate.DefaultDescription
|
||||
_c.mutation.SetDescription(v)
|
||||
@ -199,6 +217,14 @@ func (_c *ChannelMonitorRequestTemplateCreate) check() error {
|
||||
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.APIMode(); !ok {
|
||||
return &ValidationError{Name: "api_mode", err: errors.New(`ent: missing required field "ChannelMonitorRequestTemplate.api_mode"`)}
|
||||
}
|
||||
if v, ok := _c.mutation.APIMode(); ok {
|
||||
if err := channelmonitorrequesttemplate.APIModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.api_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.Description(); ok {
|
||||
if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
|
||||
return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
|
||||
@ -258,6 +284,10 @@ func (_c *ChannelMonitorRequestTemplateCreate) createSpec() (*ChannelMonitorRequ
|
||||
_spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
|
||||
_node.Provider = value
|
||||
}
|
||||
if value, ok := _c.mutation.APIMode(); ok {
|
||||
_spec.SetField(channelmonitorrequesttemplate.FieldAPIMode, field.TypeString, value)
|
||||
_node.APIMode = value
|
||||
}
|
||||
if value, ok := _c.mutation.Description(); ok {
|
||||
_spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
|
||||
_node.Description = value
|
||||
@ -378,6 +408,18 @@ func (u *ChannelMonitorRequestTemplateUpsert) UpdateProvider() *ChannelMonitorRe
|
||||
return u
|
||||
}
|
||||
|
||||
// SetAPIMode sets the "api_mode" field.
|
||||
func (u *ChannelMonitorRequestTemplateUpsert) SetAPIMode(v string) *ChannelMonitorRequestTemplateUpsert {
|
||||
u.Set(channelmonitorrequesttemplate.FieldAPIMode, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
|
||||
func (u *ChannelMonitorRequestTemplateUpsert) UpdateAPIMode() *ChannelMonitorRequestTemplateUpsert {
|
||||
u.SetExcluded(channelmonitorrequesttemplate.FieldAPIMode)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (u *ChannelMonitorRequestTemplateUpsert) SetDescription(v string) *ChannelMonitorRequestTemplateUpsert {
|
||||
u.Set(channelmonitorrequesttemplate.FieldDescription, v)
|
||||
@ -525,6 +567,20 @@ func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateProvider() *ChannelMonito
|
||||
})
|
||||
}
|
||||
|
||||
// SetAPIMode sets the "api_mode" field.
|
||||
func (u *ChannelMonitorRequestTemplateUpsertOne) SetAPIMode(v string) *ChannelMonitorRequestTemplateUpsertOne {
|
||||
return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
|
||||
s.SetAPIMode(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
|
||||
func (u *ChannelMonitorRequestTemplateUpsertOne) UpdateAPIMode() *ChannelMonitorRequestTemplateUpsertOne {
|
||||
return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
|
||||
s.UpdateAPIMode()
|
||||
})
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (u *ChannelMonitorRequestTemplateUpsertOne) SetDescription(v string) *ChannelMonitorRequestTemplateUpsertOne {
|
||||
return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
|
||||
@ -848,6 +904,20 @@ func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateProvider() *ChannelMonit
|
||||
})
|
||||
}
|
||||
|
||||
// SetAPIMode sets the "api_mode" field.
|
||||
func (u *ChannelMonitorRequestTemplateUpsertBulk) SetAPIMode(v string) *ChannelMonitorRequestTemplateUpsertBulk {
|
||||
return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
|
||||
s.SetAPIMode(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAPIMode sets the "api_mode" field to the value that was provided on create.
|
||||
func (u *ChannelMonitorRequestTemplateUpsertBulk) UpdateAPIMode() *ChannelMonitorRequestTemplateUpsertBulk {
|
||||
return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
|
||||
s.UpdateAPIMode()
|
||||
})
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (u *ChannelMonitorRequestTemplateUpsertBulk) SetDescription(v string) *ChannelMonitorRequestTemplateUpsertBulk {
|
||||
return u.Update(func(s *ChannelMonitorRequestTemplateUpsert) {
|
||||
|
||||
@ -63,6 +63,20 @@ func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableProvider(v *channelmon
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAPIMode sets the "api_mode" field.
|
||||
func (_u *ChannelMonitorRequestTemplateUpdate) SetAPIMode(v string) *ChannelMonitorRequestTemplateUpdate {
|
||||
_u.mutation.SetAPIMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
|
||||
func (_u *ChannelMonitorRequestTemplateUpdate) SetNillableAPIMode(v *string) *ChannelMonitorRequestTemplateUpdate {
|
||||
if v != nil {
|
||||
_u.SetAPIMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (_u *ChannelMonitorRequestTemplateUpdate) SetDescription(v string) *ChannelMonitorRequestTemplateUpdate {
|
||||
_u.mutation.SetDescription(v)
|
||||
@ -204,6 +218,11 @@ func (_u *ChannelMonitorRequestTemplateUpdate) check() error {
|
||||
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.APIMode(); ok {
|
||||
if err := channelmonitorrequesttemplate.APIModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.api_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.Description(); ok {
|
||||
if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
|
||||
return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
|
||||
@ -238,6 +257,9 @@ func (_u *ChannelMonitorRequestTemplateUpdate) sqlSave(ctx context.Context) (_no
|
||||
if value, ok := _u.mutation.Provider(); ok {
|
||||
_spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
|
||||
}
|
||||
if value, ok := _u.mutation.APIMode(); ok {
|
||||
_spec.SetField(channelmonitorrequesttemplate.FieldAPIMode, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Description(); ok {
|
||||
_spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
|
||||
}
|
||||
@ -355,6 +377,20 @@ func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableProvider(v *channel
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetAPIMode sets the "api_mode" field.
|
||||
func (_u *ChannelMonitorRequestTemplateUpdateOne) SetAPIMode(v string) *ChannelMonitorRequestTemplateUpdateOne {
|
||||
_u.mutation.SetAPIMode(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableAPIMode sets the "api_mode" field if the given value is not nil.
|
||||
func (_u *ChannelMonitorRequestTemplateUpdateOne) SetNillableAPIMode(v *string) *ChannelMonitorRequestTemplateUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetAPIMode(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (_u *ChannelMonitorRequestTemplateUpdateOne) SetDescription(v string) *ChannelMonitorRequestTemplateUpdateOne {
|
||||
_u.mutation.SetDescription(v)
|
||||
@ -509,6 +545,11 @@ func (_u *ChannelMonitorRequestTemplateUpdateOne) check() error {
|
||||
return &ValidationError{Name: "provider", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.provider": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.APIMode(); ok {
|
||||
if err := channelmonitorrequesttemplate.APIModeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "api_mode", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.api_mode": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.Description(); ok {
|
||||
if err := channelmonitorrequesttemplate.DescriptionValidator(v); err != nil {
|
||||
return &ValidationError{Name: "description", err: fmt.Errorf(`ent: validator failed for field "ChannelMonitorRequestTemplate.description": %w`, err)}
|
||||
@ -560,6 +601,9 @@ func (_u *ChannelMonitorRequestTemplateUpdateOne) sqlSave(ctx context.Context) (
|
||||
if value, ok := _u.mutation.Provider(); ok {
|
||||
_spec.SetField(channelmonitorrequesttemplate.FieldProvider, field.TypeEnum, value)
|
||||
}
|
||||
if value, ok := _u.mutation.APIMode(); ok {
|
||||
_spec.SetField(channelmonitorrequesttemplate.FieldAPIMode, field.TypeString, value)
|
||||
}
|
||||
if value, ok := _u.mutation.Description(); ok {
|
||||
_spec.SetField(channelmonitorrequesttemplate.FieldDescription, field.TypeString, value)
|
||||
}
|
||||
|
||||
@ -428,6 +428,7 @@ var (
|
||||
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "name", Type: field.TypeString, Size: 100},
|
||||
{Name: "provider", Type: field.TypeEnum, Enums: []string{"openai", "anthropic", "gemini"}},
|
||||
{Name: "api_mode", Type: field.TypeString, Size: 32, Default: "chat_completions"},
|
||||
{Name: "endpoint", Type: field.TypeString, Size: 500},
|
||||
{Name: "api_key_encrypted", Type: field.TypeString},
|
||||
{Name: "primary_model", Type: field.TypeString, Size: 200},
|
||||
@ -450,7 +451,7 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "channel_monitors_channel_monitor_request_templates_request_template",
|
||||
Columns: []*schema.Column{ChannelMonitorsColumns[17]},
|
||||
Columns: []*schema.Column{ChannelMonitorsColumns[18]},
|
||||
RefColumns: []*schema.Column{ChannelMonitorRequestTemplatesColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@ -459,22 +460,27 @@ var (
|
||||
{
|
||||
Name: "channelmonitor_enabled_last_checked_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{ChannelMonitorsColumns[10], ChannelMonitorsColumns[12]},
|
||||
Columns: []*schema.Column{ChannelMonitorsColumns[11], ChannelMonitorsColumns[13]},
|
||||
},
|
||||
{
|
||||
Name: "channelmonitor_provider",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{ChannelMonitorsColumns[4]},
|
||||
},
|
||||
{
|
||||
Name: "channelmonitor_provider_api_mode",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{ChannelMonitorsColumns[4], ChannelMonitorsColumns[5]},
|
||||
},
|
||||
{
|
||||
Name: "channelmonitor_group_name",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{ChannelMonitorsColumns[9]},
|
||||
Columns: []*schema.Column{ChannelMonitorsColumns[10]},
|
||||
},
|
||||
{
|
||||
Name: "channelmonitor_template_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{ChannelMonitorsColumns[17]},
|
||||
Columns: []*schema.Column{ChannelMonitorsColumns[18]},
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -566,6 +572,7 @@ var (
|
||||
{Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "name", Type: field.TypeString, Size: 100},
|
||||
{Name: "provider", Type: field.TypeEnum, Enums: []string{"openai", "anthropic", "gemini"}},
|
||||
{Name: "api_mode", Type: field.TypeString, Size: 32, Default: "chat_completions"},
|
||||
{Name: "description", Type: field.TypeString, Nullable: true, Size: 500, Default: ""},
|
||||
{Name: "extra_headers", Type: field.TypeJSON},
|
||||
{Name: "body_override_mode", Type: field.TypeString, Size: 10, Default: "off"},
|
||||
@ -582,6 +589,11 @@ var (
|
||||
Unique: true,
|
||||
Columns: []*schema.Column{ChannelMonitorRequestTemplatesColumns[4], ChannelMonitorRequestTemplatesColumns[3]},
|
||||
},
|
||||
{
|
||||
Name: "channelmonitorrequesttemplate_provider_api_mode",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{ChannelMonitorRequestTemplatesColumns[4], ChannelMonitorRequestTemplatesColumns[5]},
|
||||
},
|
||||
},
|
||||
}
|
||||
// ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table.
|
||||
@ -1120,6 +1132,7 @@ var (
|
||||
{Name: "used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "notes", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
|
||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "validity_days", Type: field.TypeInt, Default: 30},
|
||||
{Name: "group_id", Type: field.TypeInt64, Nullable: true},
|
||||
{Name: "used_by", Type: field.TypeInt64, Nullable: true},
|
||||
@ -1132,13 +1145,13 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "redeem_codes_groups_redeem_codes",
|
||||
Columns: []*schema.Column{RedeemCodesColumns[9]},
|
||||
Columns: []*schema.Column{RedeemCodesColumns[10]},
|
||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
{
|
||||
Symbol: "redeem_codes_users_redeem_codes",
|
||||
Columns: []*schema.Column{RedeemCodesColumns[10]},
|
||||
Columns: []*schema.Column{RedeemCodesColumns[11]},
|
||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@ -1152,12 +1165,17 @@ var (
|
||||
{
|
||||
Name: "redeemcode_used_by",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{RedeemCodesColumns[10]},
|
||||
Columns: []*schema.Column{RedeemCodesColumns[11]},
|
||||
},
|
||||
{
|
||||
Name: "redeemcode_group_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{RedeemCodesColumns[9]},
|
||||
Columns: []*schema.Column{RedeemCodesColumns[10]},
|
||||
},
|
||||
{
|
||||
Name: "redeemcode_expires_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{RedeemCodesColumns[8]},
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -1318,6 +1336,10 @@ var (
|
||||
{Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45},
|
||||
{Name: "image_count", Type: field.TypeInt, Default: 0},
|
||||
{Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10},
|
||||
{Name: "image_input_size", Type: field.TypeString, Nullable: true, Size: 32},
|
||||
{Name: "image_output_size", Type: field.TypeString, Nullable: true, Size: 32},
|
||||
{Name: "image_size_source", Type: field.TypeString, Nullable: true, Size: 16},
|
||||
{Name: "image_size_breakdown", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
|
||||
{Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false},
|
||||
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
|
||||
{Name: "api_key_id", Type: field.TypeInt64},
|
||||
@ -1334,31 +1356,31 @@ var (
|
||||
ForeignKeys: []*schema.ForeignKey{
|
||||
{
|
||||
Symbol: "usage_logs_api_keys_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[37]},
|
||||
RefColumns: []*schema.Column{APIKeysColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_accounts_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[34]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[38]},
|
||||
RefColumns: []*schema.Column{AccountsColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_groups_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[35]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[39]},
|
||||
RefColumns: []*schema.Column{GroupsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_users_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[36]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[40]},
|
||||
RefColumns: []*schema.Column{UsersColumns[0]},
|
||||
OnDelete: schema.NoAction,
|
||||
},
|
||||
{
|
||||
Symbol: "usage_logs_user_subscriptions_usage_logs",
|
||||
Columns: []*schema.Column{UsageLogsColumns[37]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[41]},
|
||||
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
|
||||
OnDelete: schema.SetNull,
|
||||
},
|
||||
@ -1367,32 +1389,32 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[36]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[40]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[33]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[37]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_account_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[34]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[38]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_group_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[35]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[39]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_subscription_id",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[37]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[41]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[32]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[36]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_model",
|
||||
@ -1412,17 +1434,17 @@ var (
|
||||
{
|
||||
Name: "usagelog_user_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[32]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[40], UsageLogsColumns[36]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_api_key_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[32]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[37], UsageLogsColumns[36]},
|
||||
},
|
||||
{
|
||||
Name: "usagelog_group_id_created_at",
|
||||
Unique: false,
|
||||
Columns: []*schema.Column{UsageLogsColumns[35], UsageLogsColumns[32]},
|
||||
Columns: []*schema.Column{UsageLogsColumns[39], UsageLogsColumns[36]},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -8752,6 +8752,7 @@ type ChannelMonitorMutation struct {
|
||||
updated_at *time.Time
|
||||
name *string
|
||||
provider *channelmonitor.Provider
|
||||
api_mode *string
|
||||
endpoint *string
|
||||
api_key_encrypted *string
|
||||
primary_model *string
|
||||
@ -9023,6 +9024,42 @@ func (m *ChannelMonitorMutation) ResetProvider() {
|
||||
m.provider = nil
|
||||
}
|
||||
|
||||
// SetAPIMode sets the "api_mode" field.
|
||||
func (m *ChannelMonitorMutation) SetAPIMode(s string) {
|
||||
m.api_mode = &s
|
||||
}
|
||||
|
||||
// APIMode returns the value of the "api_mode" field in the mutation.
|
||||
func (m *ChannelMonitorMutation) APIMode() (r string, exists bool) {
|
||||
v := m.api_mode
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldAPIMode returns the old "api_mode" field's value of the ChannelMonitor entity.
|
||||
// If the ChannelMonitor object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *ChannelMonitorMutation) OldAPIMode(ctx context.Context) (v string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldAPIMode is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldAPIMode requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldAPIMode: %w", err)
|
||||
}
|
||||
return oldValue.APIMode, nil
|
||||
}
|
||||
|
||||
// ResetAPIMode resets all changes to the "api_mode" field.
|
||||
func (m *ChannelMonitorMutation) ResetAPIMode() {
|
||||
m.api_mode = nil
|
||||
}
|
||||
|
||||
// SetEndpoint sets the "endpoint" field.
|
||||
func (m *ChannelMonitorMutation) SetEndpoint(s string) {
|
||||
m.endpoint = &s
|
||||
@ -9780,7 +9817,7 @@ func (m *ChannelMonitorMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *ChannelMonitorMutation) Fields() []string {
|
||||
fields := make([]string, 0, 17)
|
||||
fields := make([]string, 0, 18)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, channelmonitor.FieldCreatedAt)
|
||||
}
|
||||
@ -9793,6 +9830,9 @@ func (m *ChannelMonitorMutation) Fields() []string {
|
||||
if m.provider != nil {
|
||||
fields = append(fields, channelmonitor.FieldProvider)
|
||||
}
|
||||
if m.api_mode != nil {
|
||||
fields = append(fields, channelmonitor.FieldAPIMode)
|
||||
}
|
||||
if m.endpoint != nil {
|
||||
fields = append(fields, channelmonitor.FieldEndpoint)
|
||||
}
|
||||
@ -9848,6 +9888,8 @@ func (m *ChannelMonitorMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.Name()
|
||||
case channelmonitor.FieldProvider:
|
||||
return m.Provider()
|
||||
case channelmonitor.FieldAPIMode:
|
||||
return m.APIMode()
|
||||
case channelmonitor.FieldEndpoint:
|
||||
return m.Endpoint()
|
||||
case channelmonitor.FieldAPIKeyEncrypted:
|
||||
@ -9891,6 +9933,8 @@ func (m *ChannelMonitorMutation) OldField(ctx context.Context, name string) (ent
|
||||
return m.OldName(ctx)
|
||||
case channelmonitor.FieldProvider:
|
||||
return m.OldProvider(ctx)
|
||||
case channelmonitor.FieldAPIMode:
|
||||
return m.OldAPIMode(ctx)
|
||||
case channelmonitor.FieldEndpoint:
|
||||
return m.OldEndpoint(ctx)
|
||||
case channelmonitor.FieldAPIKeyEncrypted:
|
||||
@ -9954,6 +9998,13 @@ func (m *ChannelMonitorMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetProvider(v)
|
||||
return nil
|
||||
case channelmonitor.FieldAPIMode:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetAPIMode(v)
|
||||
return nil
|
||||
case channelmonitor.FieldEndpoint:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
@ -10160,6 +10211,9 @@ func (m *ChannelMonitorMutation) ResetField(name string) error {
|
||||
case channelmonitor.FieldProvider:
|
||||
m.ResetProvider()
|
||||
return nil
|
||||
case channelmonitor.FieldAPIMode:
|
||||
m.ResetAPIMode()
|
||||
return nil
|
||||
case channelmonitor.FieldEndpoint:
|
||||
m.ResetEndpoint()
|
||||
return nil
|
||||
@ -12591,6 +12645,7 @@ type ChannelMonitorRequestTemplateMutation struct {
|
||||
updated_at *time.Time
|
||||
name *string
|
||||
provider *channelmonitorrequesttemplate.Provider
|
||||
api_mode *string
|
||||
description *string
|
||||
extra_headers *map[string]string
|
||||
body_override_mode *string
|
||||
@ -12846,6 +12901,42 @@ func (m *ChannelMonitorRequestTemplateMutation) ResetProvider() {
|
||||
m.provider = nil
|
||||
}
|
||||
|
||||
// SetAPIMode sets the "api_mode" field.
|
||||
func (m *ChannelMonitorRequestTemplateMutation) SetAPIMode(s string) {
|
||||
m.api_mode = &s
|
||||
}
|
||||
|
||||
// APIMode returns the value of the "api_mode" field in the mutation.
|
||||
func (m *ChannelMonitorRequestTemplateMutation) APIMode() (r string, exists bool) {
|
||||
v := m.api_mode
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldAPIMode returns the old "api_mode" field's value of the ChannelMonitorRequestTemplate entity.
|
||||
// If the ChannelMonitorRequestTemplate object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *ChannelMonitorRequestTemplateMutation) OldAPIMode(ctx context.Context) (v string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldAPIMode is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldAPIMode requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldAPIMode: %w", err)
|
||||
}
|
||||
return oldValue.APIMode, nil
|
||||
}
|
||||
|
||||
// ResetAPIMode resets all changes to the "api_mode" field.
|
||||
func (m *ChannelMonitorRequestTemplateMutation) ResetAPIMode() {
|
||||
m.api_mode = nil
|
||||
}
|
||||
|
||||
// SetDescription sets the "description" field.
|
||||
func (m *ChannelMonitorRequestTemplateMutation) SetDescription(s string) {
|
||||
m.description = &s
|
||||
@ -13104,7 +13195,7 @@ func (m *ChannelMonitorRequestTemplateMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *ChannelMonitorRequestTemplateMutation) Fields() []string {
|
||||
fields := make([]string, 0, 8)
|
||||
fields := make([]string, 0, 9)
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, channelmonitorrequesttemplate.FieldCreatedAt)
|
||||
}
|
||||
@ -13117,6 +13208,9 @@ func (m *ChannelMonitorRequestTemplateMutation) Fields() []string {
|
||||
if m.provider != nil {
|
||||
fields = append(fields, channelmonitorrequesttemplate.FieldProvider)
|
||||
}
|
||||
if m.api_mode != nil {
|
||||
fields = append(fields, channelmonitorrequesttemplate.FieldAPIMode)
|
||||
}
|
||||
if m.description != nil {
|
||||
fields = append(fields, channelmonitorrequesttemplate.FieldDescription)
|
||||
}
|
||||
@ -13145,6 +13239,8 @@ func (m *ChannelMonitorRequestTemplateMutation) Field(name string) (ent.Value, b
|
||||
return m.Name()
|
||||
case channelmonitorrequesttemplate.FieldProvider:
|
||||
return m.Provider()
|
||||
case channelmonitorrequesttemplate.FieldAPIMode:
|
||||
return m.APIMode()
|
||||
case channelmonitorrequesttemplate.FieldDescription:
|
||||
return m.Description()
|
||||
case channelmonitorrequesttemplate.FieldExtraHeaders:
|
||||
@ -13170,6 +13266,8 @@ func (m *ChannelMonitorRequestTemplateMutation) OldField(ctx context.Context, na
|
||||
return m.OldName(ctx)
|
||||
case channelmonitorrequesttemplate.FieldProvider:
|
||||
return m.OldProvider(ctx)
|
||||
case channelmonitorrequesttemplate.FieldAPIMode:
|
||||
return m.OldAPIMode(ctx)
|
||||
case channelmonitorrequesttemplate.FieldDescription:
|
||||
return m.OldDescription(ctx)
|
||||
case channelmonitorrequesttemplate.FieldExtraHeaders:
|
||||
@ -13215,6 +13313,13 @@ func (m *ChannelMonitorRequestTemplateMutation) SetField(name string, value ent.
|
||||
}
|
||||
m.SetProvider(v)
|
||||
return nil
|
||||
case channelmonitorrequesttemplate.FieldAPIMode:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetAPIMode(v)
|
||||
return nil
|
||||
case channelmonitorrequesttemplate.FieldDescription:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
@ -13319,6 +13424,9 @@ func (m *ChannelMonitorRequestTemplateMutation) ResetField(name string) error {
|
||||
case channelmonitorrequesttemplate.FieldProvider:
|
||||
m.ResetProvider()
|
||||
return nil
|
||||
case channelmonitorrequesttemplate.FieldAPIMode:
|
||||
m.ResetAPIMode()
|
||||
return nil
|
||||
case channelmonitorrequesttemplate.FieldDescription:
|
||||
m.ResetDescription()
|
||||
return nil
|
||||
@ -28602,6 +28710,7 @@ type RedeemCodeMutation struct {
|
||||
used_at *time.Time
|
||||
notes *string
|
||||
created_at *time.Time
|
||||
expires_at *time.Time
|
||||
validity_days *int
|
||||
addvalidity_days *int
|
||||
clearedFields map[string]struct{}
|
||||
@ -29059,6 +29168,55 @@ func (m *RedeemCodeMutation) ResetCreatedAt() {
|
||||
m.created_at = nil
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (m *RedeemCodeMutation) SetExpiresAt(t time.Time) {
|
||||
m.expires_at = &t
|
||||
}
|
||||
|
||||
// ExpiresAt returns the value of the "expires_at" field in the mutation.
|
||||
func (m *RedeemCodeMutation) ExpiresAt() (r time.Time, exists bool) {
|
||||
v := m.expires_at
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldExpiresAt returns the old "expires_at" field's value of the RedeemCode entity.
|
||||
// If the RedeemCode object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *RedeemCodeMutation) OldExpiresAt(ctx context.Context) (v *time.Time, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldExpiresAt requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err)
|
||||
}
|
||||
return oldValue.ExpiresAt, nil
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (m *RedeemCodeMutation) ClearExpiresAt() {
|
||||
m.expires_at = nil
|
||||
m.clearedFields[redeemcode.FieldExpiresAt] = struct{}{}
|
||||
}
|
||||
|
||||
// ExpiresAtCleared returns if the "expires_at" field was cleared in this mutation.
|
||||
func (m *RedeemCodeMutation) ExpiresAtCleared() bool {
|
||||
_, ok := m.clearedFields[redeemcode.FieldExpiresAt]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetExpiresAt resets all changes to the "expires_at" field.
|
||||
func (m *RedeemCodeMutation) ResetExpiresAt() {
|
||||
m.expires_at = nil
|
||||
delete(m.clearedFields, redeemcode.FieldExpiresAt)
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (m *RedeemCodeMutation) SetGroupID(i int64) {
|
||||
m.group = &i
|
||||
@ -29265,7 +29423,7 @@ func (m *RedeemCodeMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *RedeemCodeMutation) Fields() []string {
|
||||
fields := make([]string, 0, 10)
|
||||
fields := make([]string, 0, 11)
|
||||
if m.code != nil {
|
||||
fields = append(fields, redeemcode.FieldCode)
|
||||
}
|
||||
@ -29290,6 +29448,9 @@ func (m *RedeemCodeMutation) Fields() []string {
|
||||
if m.created_at != nil {
|
||||
fields = append(fields, redeemcode.FieldCreatedAt)
|
||||
}
|
||||
if m.expires_at != nil {
|
||||
fields = append(fields, redeemcode.FieldExpiresAt)
|
||||
}
|
||||
if m.group != nil {
|
||||
fields = append(fields, redeemcode.FieldGroupID)
|
||||
}
|
||||
@ -29320,6 +29481,8 @@ func (m *RedeemCodeMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.Notes()
|
||||
case redeemcode.FieldCreatedAt:
|
||||
return m.CreatedAt()
|
||||
case redeemcode.FieldExpiresAt:
|
||||
return m.ExpiresAt()
|
||||
case redeemcode.FieldGroupID:
|
||||
return m.GroupID()
|
||||
case redeemcode.FieldValidityDays:
|
||||
@ -29349,6 +29512,8 @@ func (m *RedeemCodeMutation) OldField(ctx context.Context, name string) (ent.Val
|
||||
return m.OldNotes(ctx)
|
||||
case redeemcode.FieldCreatedAt:
|
||||
return m.OldCreatedAt(ctx)
|
||||
case redeemcode.FieldExpiresAt:
|
||||
return m.OldExpiresAt(ctx)
|
||||
case redeemcode.FieldGroupID:
|
||||
return m.OldGroupID(ctx)
|
||||
case redeemcode.FieldValidityDays:
|
||||
@ -29418,6 +29583,13 @@ func (m *RedeemCodeMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetCreatedAt(v)
|
||||
return nil
|
||||
case redeemcode.FieldExpiresAt:
|
||||
v, ok := value.(time.Time)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetExpiresAt(v)
|
||||
return nil
|
||||
case redeemcode.FieldGroupID:
|
||||
v, ok := value.(int64)
|
||||
if !ok {
|
||||
@ -29498,6 +29670,9 @@ func (m *RedeemCodeMutation) ClearedFields() []string {
|
||||
if m.FieldCleared(redeemcode.FieldNotes) {
|
||||
fields = append(fields, redeemcode.FieldNotes)
|
||||
}
|
||||
if m.FieldCleared(redeemcode.FieldExpiresAt) {
|
||||
fields = append(fields, redeemcode.FieldExpiresAt)
|
||||
}
|
||||
if m.FieldCleared(redeemcode.FieldGroupID) {
|
||||
fields = append(fields, redeemcode.FieldGroupID)
|
||||
}
|
||||
@ -29524,6 +29699,9 @@ func (m *RedeemCodeMutation) ClearField(name string) error {
|
||||
case redeemcode.FieldNotes:
|
||||
m.ClearNotes()
|
||||
return nil
|
||||
case redeemcode.FieldExpiresAt:
|
||||
m.ClearExpiresAt()
|
||||
return nil
|
||||
case redeemcode.FieldGroupID:
|
||||
m.ClearGroupID()
|
||||
return nil
|
||||
@ -29559,6 +29737,9 @@ func (m *RedeemCodeMutation) ResetField(name string) error {
|
||||
case redeemcode.FieldCreatedAt:
|
||||
m.ResetCreatedAt()
|
||||
return nil
|
||||
case redeemcode.FieldExpiresAt:
|
||||
m.ResetExpiresAt()
|
||||
return nil
|
||||
case redeemcode.FieldGroupID:
|
||||
m.ResetGroupID()
|
||||
return nil
|
||||
@ -34260,6 +34441,10 @@ type UsageLogMutation struct {
|
||||
image_count *int
|
||||
addimage_count *int
|
||||
image_size *string
|
||||
image_input_size *string
|
||||
image_output_size *string
|
||||
image_size_source *string
|
||||
image_size_breakdown *map[string]int
|
||||
cache_ttl_overridden *bool
|
||||
created_at *time.Time
|
||||
clearedFields map[string]struct{}
|
||||
@ -36202,6 +36387,202 @@ func (m *UsageLogMutation) ResetImageSize() {
|
||||
delete(m.clearedFields, usagelog.FieldImageSize)
|
||||
}
|
||||
|
||||
// SetImageInputSize sets the "image_input_size" field.
|
||||
func (m *UsageLogMutation) SetImageInputSize(s string) {
|
||||
m.image_input_size = &s
|
||||
}
|
||||
|
||||
// ImageInputSize returns the value of the "image_input_size" field in the mutation.
|
||||
func (m *UsageLogMutation) ImageInputSize() (r string, exists bool) {
|
||||
v := m.image_input_size
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldImageInputSize returns the old "image_input_size" field's value of the UsageLog entity.
|
||||
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UsageLogMutation) OldImageInputSize(ctx context.Context) (v *string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldImageInputSize is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldImageInputSize requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldImageInputSize: %w", err)
|
||||
}
|
||||
return oldValue.ImageInputSize, nil
|
||||
}
|
||||
|
||||
// ClearImageInputSize clears the value of the "image_input_size" field.
|
||||
func (m *UsageLogMutation) ClearImageInputSize() {
|
||||
m.image_input_size = nil
|
||||
m.clearedFields[usagelog.FieldImageInputSize] = struct{}{}
|
||||
}
|
||||
|
||||
// ImageInputSizeCleared returns if the "image_input_size" field was cleared in this mutation.
|
||||
func (m *UsageLogMutation) ImageInputSizeCleared() bool {
|
||||
_, ok := m.clearedFields[usagelog.FieldImageInputSize]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetImageInputSize resets all changes to the "image_input_size" field.
|
||||
func (m *UsageLogMutation) ResetImageInputSize() {
|
||||
m.image_input_size = nil
|
||||
delete(m.clearedFields, usagelog.FieldImageInputSize)
|
||||
}
|
||||
|
||||
// SetImageOutputSize sets the "image_output_size" field.
|
||||
func (m *UsageLogMutation) SetImageOutputSize(s string) {
|
||||
m.image_output_size = &s
|
||||
}
|
||||
|
||||
// ImageOutputSize returns the value of the "image_output_size" field in the mutation.
|
||||
func (m *UsageLogMutation) ImageOutputSize() (r string, exists bool) {
|
||||
v := m.image_output_size
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldImageOutputSize returns the old "image_output_size" field's value of the UsageLog entity.
|
||||
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UsageLogMutation) OldImageOutputSize(ctx context.Context) (v *string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldImageOutputSize is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldImageOutputSize requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldImageOutputSize: %w", err)
|
||||
}
|
||||
return oldValue.ImageOutputSize, nil
|
||||
}
|
||||
|
||||
// ClearImageOutputSize clears the value of the "image_output_size" field.
|
||||
func (m *UsageLogMutation) ClearImageOutputSize() {
|
||||
m.image_output_size = nil
|
||||
m.clearedFields[usagelog.FieldImageOutputSize] = struct{}{}
|
||||
}
|
||||
|
||||
// ImageOutputSizeCleared returns if the "image_output_size" field was cleared in this mutation.
|
||||
func (m *UsageLogMutation) ImageOutputSizeCleared() bool {
|
||||
_, ok := m.clearedFields[usagelog.FieldImageOutputSize]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetImageOutputSize resets all changes to the "image_output_size" field.
|
||||
func (m *UsageLogMutation) ResetImageOutputSize() {
|
||||
m.image_output_size = nil
|
||||
delete(m.clearedFields, usagelog.FieldImageOutputSize)
|
||||
}
|
||||
|
||||
// SetImageSizeSource sets the "image_size_source" field.
|
||||
func (m *UsageLogMutation) SetImageSizeSource(s string) {
|
||||
m.image_size_source = &s
|
||||
}
|
||||
|
||||
// ImageSizeSource returns the value of the "image_size_source" field in the mutation.
|
||||
func (m *UsageLogMutation) ImageSizeSource() (r string, exists bool) {
|
||||
v := m.image_size_source
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldImageSizeSource returns the old "image_size_source" field's value of the UsageLog entity.
|
||||
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UsageLogMutation) OldImageSizeSource(ctx context.Context) (v *string, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldImageSizeSource is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldImageSizeSource requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldImageSizeSource: %w", err)
|
||||
}
|
||||
return oldValue.ImageSizeSource, nil
|
||||
}
|
||||
|
||||
// ClearImageSizeSource clears the value of the "image_size_source" field.
|
||||
func (m *UsageLogMutation) ClearImageSizeSource() {
|
||||
m.image_size_source = nil
|
||||
m.clearedFields[usagelog.FieldImageSizeSource] = struct{}{}
|
||||
}
|
||||
|
||||
// ImageSizeSourceCleared returns if the "image_size_source" field was cleared in this mutation.
|
||||
func (m *UsageLogMutation) ImageSizeSourceCleared() bool {
|
||||
_, ok := m.clearedFields[usagelog.FieldImageSizeSource]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetImageSizeSource resets all changes to the "image_size_source" field.
|
||||
func (m *UsageLogMutation) ResetImageSizeSource() {
|
||||
m.image_size_source = nil
|
||||
delete(m.clearedFields, usagelog.FieldImageSizeSource)
|
||||
}
|
||||
|
||||
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
|
||||
func (m *UsageLogMutation) SetImageSizeBreakdown(value map[string]int) {
|
||||
m.image_size_breakdown = &value
|
||||
}
|
||||
|
||||
// ImageSizeBreakdown returns the value of the "image_size_breakdown" field in the mutation.
|
||||
func (m *UsageLogMutation) ImageSizeBreakdown() (r map[string]int, exists bool) {
|
||||
v := m.image_size_breakdown
|
||||
if v == nil {
|
||||
return
|
||||
}
|
||||
return *v, true
|
||||
}
|
||||
|
||||
// OldImageSizeBreakdown returns the old "image_size_breakdown" field's value of the UsageLog entity.
|
||||
// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
|
||||
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
|
||||
func (m *UsageLogMutation) OldImageSizeBreakdown(ctx context.Context) (v map[string]int, err error) {
|
||||
if !m.op.Is(OpUpdateOne) {
|
||||
return v, errors.New("OldImageSizeBreakdown is only allowed on UpdateOne operations")
|
||||
}
|
||||
if m.id == nil || m.oldValue == nil {
|
||||
return v, errors.New("OldImageSizeBreakdown requires an ID field in the mutation")
|
||||
}
|
||||
oldValue, err := m.oldValue(ctx)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("querying old value for OldImageSizeBreakdown: %w", err)
|
||||
}
|
||||
return oldValue.ImageSizeBreakdown, nil
|
||||
}
|
||||
|
||||
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
|
||||
func (m *UsageLogMutation) ClearImageSizeBreakdown() {
|
||||
m.image_size_breakdown = nil
|
||||
m.clearedFields[usagelog.FieldImageSizeBreakdown] = struct{}{}
|
||||
}
|
||||
|
||||
// ImageSizeBreakdownCleared returns if the "image_size_breakdown" field was cleared in this mutation.
|
||||
func (m *UsageLogMutation) ImageSizeBreakdownCleared() bool {
|
||||
_, ok := m.clearedFields[usagelog.FieldImageSizeBreakdown]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ResetImageSizeBreakdown resets all changes to the "image_size_breakdown" field.
|
||||
func (m *UsageLogMutation) ResetImageSizeBreakdown() {
|
||||
m.image_size_breakdown = nil
|
||||
delete(m.clearedFields, usagelog.FieldImageSizeBreakdown)
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) {
|
||||
m.cache_ttl_overridden = &b
|
||||
@ -36443,7 +36824,7 @@ func (m *UsageLogMutation) Type() string {
|
||||
// order to get all numeric fields that were incremented/decremented, call
|
||||
// AddedFields().
|
||||
func (m *UsageLogMutation) Fields() []string {
|
||||
fields := make([]string, 0, 37)
|
||||
fields := make([]string, 0, 41)
|
||||
if m.user != nil {
|
||||
fields = append(fields, usagelog.FieldUserID)
|
||||
}
|
||||
@ -36549,6 +36930,18 @@ func (m *UsageLogMutation) Fields() []string {
|
||||
if m.image_size != nil {
|
||||
fields = append(fields, usagelog.FieldImageSize)
|
||||
}
|
||||
if m.image_input_size != nil {
|
||||
fields = append(fields, usagelog.FieldImageInputSize)
|
||||
}
|
||||
if m.image_output_size != nil {
|
||||
fields = append(fields, usagelog.FieldImageOutputSize)
|
||||
}
|
||||
if m.image_size_source != nil {
|
||||
fields = append(fields, usagelog.FieldImageSizeSource)
|
||||
}
|
||||
if m.image_size_breakdown != nil {
|
||||
fields = append(fields, usagelog.FieldImageSizeBreakdown)
|
||||
}
|
||||
if m.cache_ttl_overridden != nil {
|
||||
fields = append(fields, usagelog.FieldCacheTTLOverridden)
|
||||
}
|
||||
@ -36633,6 +37026,14 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
|
||||
return m.ImageCount()
|
||||
case usagelog.FieldImageSize:
|
||||
return m.ImageSize()
|
||||
case usagelog.FieldImageInputSize:
|
||||
return m.ImageInputSize()
|
||||
case usagelog.FieldImageOutputSize:
|
||||
return m.ImageOutputSize()
|
||||
case usagelog.FieldImageSizeSource:
|
||||
return m.ImageSizeSource()
|
||||
case usagelog.FieldImageSizeBreakdown:
|
||||
return m.ImageSizeBreakdown()
|
||||
case usagelog.FieldCacheTTLOverridden:
|
||||
return m.CacheTTLOverridden()
|
||||
case usagelog.FieldCreatedAt:
|
||||
@ -36716,6 +37117,14 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
|
||||
return m.OldImageCount(ctx)
|
||||
case usagelog.FieldImageSize:
|
||||
return m.OldImageSize(ctx)
|
||||
case usagelog.FieldImageInputSize:
|
||||
return m.OldImageInputSize(ctx)
|
||||
case usagelog.FieldImageOutputSize:
|
||||
return m.OldImageOutputSize(ctx)
|
||||
case usagelog.FieldImageSizeSource:
|
||||
return m.OldImageSizeSource(ctx)
|
||||
case usagelog.FieldImageSizeBreakdown:
|
||||
return m.OldImageSizeBreakdown(ctx)
|
||||
case usagelog.FieldCacheTTLOverridden:
|
||||
return m.OldCacheTTLOverridden(ctx)
|
||||
case usagelog.FieldCreatedAt:
|
||||
@ -36974,6 +37383,34 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
|
||||
}
|
||||
m.SetImageSize(v)
|
||||
return nil
|
||||
case usagelog.FieldImageInputSize:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetImageInputSize(v)
|
||||
return nil
|
||||
case usagelog.FieldImageOutputSize:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetImageOutputSize(v)
|
||||
return nil
|
||||
case usagelog.FieldImageSizeSource:
|
||||
v, ok := value.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetImageSizeSource(v)
|
||||
return nil
|
||||
case usagelog.FieldImageSizeBreakdown:
|
||||
v, ok := value.(map[string]int)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected type %T for field %s", value, name)
|
||||
}
|
||||
m.SetImageSizeBreakdown(v)
|
||||
return nil
|
||||
case usagelog.FieldCacheTTLOverridden:
|
||||
v, ok := value.(bool)
|
||||
if !ok {
|
||||
@ -37291,6 +37728,18 @@ func (m *UsageLogMutation) ClearedFields() []string {
|
||||
if m.FieldCleared(usagelog.FieldImageSize) {
|
||||
fields = append(fields, usagelog.FieldImageSize)
|
||||
}
|
||||
if m.FieldCleared(usagelog.FieldImageInputSize) {
|
||||
fields = append(fields, usagelog.FieldImageInputSize)
|
||||
}
|
||||
if m.FieldCleared(usagelog.FieldImageOutputSize) {
|
||||
fields = append(fields, usagelog.FieldImageOutputSize)
|
||||
}
|
||||
if m.FieldCleared(usagelog.FieldImageSizeSource) {
|
||||
fields = append(fields, usagelog.FieldImageSizeSource)
|
||||
}
|
||||
if m.FieldCleared(usagelog.FieldImageSizeBreakdown) {
|
||||
fields = append(fields, usagelog.FieldImageSizeBreakdown)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
@ -37347,6 +37796,18 @@ func (m *UsageLogMutation) ClearField(name string) error {
|
||||
case usagelog.FieldImageSize:
|
||||
m.ClearImageSize()
|
||||
return nil
|
||||
case usagelog.FieldImageInputSize:
|
||||
m.ClearImageInputSize()
|
||||
return nil
|
||||
case usagelog.FieldImageOutputSize:
|
||||
m.ClearImageOutputSize()
|
||||
return nil
|
||||
case usagelog.FieldImageSizeSource:
|
||||
m.ClearImageSizeSource()
|
||||
return nil
|
||||
case usagelog.FieldImageSizeBreakdown:
|
||||
m.ClearImageSizeBreakdown()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unknown UsageLog nullable field %s", name)
|
||||
}
|
||||
@ -37460,6 +37921,18 @@ func (m *UsageLogMutation) ResetField(name string) error {
|
||||
case usagelog.FieldImageSize:
|
||||
m.ResetImageSize()
|
||||
return nil
|
||||
case usagelog.FieldImageInputSize:
|
||||
m.ResetImageInputSize()
|
||||
return nil
|
||||
case usagelog.FieldImageOutputSize:
|
||||
m.ResetImageOutputSize()
|
||||
return nil
|
||||
case usagelog.FieldImageSizeSource:
|
||||
m.ResetImageSizeSource()
|
||||
return nil
|
||||
case usagelog.FieldImageSizeBreakdown:
|
||||
m.ResetImageSizeBreakdown()
|
||||
return nil
|
||||
case usagelog.FieldCacheTTLOverridden:
|
||||
m.ResetCacheTTLOverridden()
|
||||
return nil
|
||||
|
||||
@ -35,6 +35,8 @@ type RedeemCode struct {
|
||||
Notes *string `json:"notes,omitempty"`
|
||||
// CreatedAt holds the value of the "created_at" field.
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
// ExpiresAt holds the value of the "expires_at" field.
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
// GroupID holds the value of the "group_id" field.
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
// ValidityDays holds the value of the "validity_days" field.
|
||||
@ -89,7 +91,7 @@ func (*RedeemCode) scanValues(columns []string) ([]any, error) {
|
||||
values[i] = new(sql.NullInt64)
|
||||
case redeemcode.FieldCode, redeemcode.FieldType, redeemcode.FieldStatus, redeemcode.FieldNotes:
|
||||
values[i] = new(sql.NullString)
|
||||
case redeemcode.FieldUsedAt, redeemcode.FieldCreatedAt:
|
||||
case redeemcode.FieldUsedAt, redeemcode.FieldCreatedAt, redeemcode.FieldExpiresAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
default:
|
||||
values[i] = new(sql.UnknownType)
|
||||
@ -163,6 +165,13 @@ func (_m *RedeemCode) assignValues(columns []string, values []any) error {
|
||||
} else if value.Valid {
|
||||
_m.CreatedAt = value.Time
|
||||
}
|
||||
case redeemcode.FieldExpiresAt:
|
||||
if value, ok := values[i].(*sql.NullTime); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field expires_at", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ExpiresAt = new(time.Time)
|
||||
*_m.ExpiresAt = value.Time
|
||||
}
|
||||
case redeemcode.FieldGroupID:
|
||||
if value, ok := values[i].(*sql.NullInt64); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field group_id", values[i])
|
||||
@ -252,6 +261,11 @@ func (_m *RedeemCode) String() string {
|
||||
builder.WriteString("created_at=")
|
||||
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ExpiresAt; v != nil {
|
||||
builder.WriteString("expires_at=")
|
||||
builder.WriteString(v.Format(time.ANSIC))
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.GroupID; v != nil {
|
||||
builder.WriteString("group_id=")
|
||||
builder.WriteString(fmt.Sprintf("%v", *v))
|
||||
|
||||
@ -30,6 +30,8 @@ const (
|
||||
FieldNotes = "notes"
|
||||
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||
FieldCreatedAt = "created_at"
|
||||
// FieldExpiresAt holds the string denoting the expires_at field in the database.
|
||||
FieldExpiresAt = "expires_at"
|
||||
// FieldGroupID holds the string denoting the group_id field in the database.
|
||||
FieldGroupID = "group_id"
|
||||
// FieldValidityDays holds the string denoting the validity_days field in the database.
|
||||
@ -67,6 +69,7 @@ var Columns = []string{
|
||||
FieldUsedAt,
|
||||
FieldNotes,
|
||||
FieldCreatedAt,
|
||||
FieldExpiresAt,
|
||||
FieldGroupID,
|
||||
FieldValidityDays,
|
||||
}
|
||||
@ -148,6 +151,11 @@ func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByExpiresAt orders the results by the expires_at field.
|
||||
func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldExpiresAt, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByGroupID orders the results by the group_id field.
|
||||
func ByGroupID(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldGroupID, opts...).ToFunc()
|
||||
|
||||
@ -95,6 +95,11 @@ func CreatedAt(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldEQ(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ.
|
||||
func ExpiresAt(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldEQ(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ.
|
||||
func GroupID(v int64) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldEQ(FieldGroupID, v))
|
||||
@ -535,6 +540,56 @@ func CreatedAtLTE(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldLTE(FieldCreatedAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtEQ applies the EQ predicate on the "expires_at" field.
|
||||
func ExpiresAtEQ(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldEQ(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field.
|
||||
func ExpiresAtNEQ(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldNEQ(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtIn applies the In predicate on the "expires_at" field.
|
||||
func ExpiresAtIn(vs ...time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldIn(FieldExpiresAt, vs...))
|
||||
}
|
||||
|
||||
// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field.
|
||||
func ExpiresAtNotIn(vs ...time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldNotIn(FieldExpiresAt, vs...))
|
||||
}
|
||||
|
||||
// ExpiresAtGT applies the GT predicate on the "expires_at" field.
|
||||
func ExpiresAtGT(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldGT(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtGTE applies the GTE predicate on the "expires_at" field.
|
||||
func ExpiresAtGTE(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldGTE(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtLT applies the LT predicate on the "expires_at" field.
|
||||
func ExpiresAtLT(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldLT(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtLTE applies the LTE predicate on the "expires_at" field.
|
||||
func ExpiresAtLTE(v time.Time) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldLTE(FieldExpiresAt, v))
|
||||
}
|
||||
|
||||
// ExpiresAtIsNil applies the IsNil predicate on the "expires_at" field.
|
||||
func ExpiresAtIsNil() predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldIsNull(FieldExpiresAt))
|
||||
}
|
||||
|
||||
// ExpiresAtNotNil applies the NotNil predicate on the "expires_at" field.
|
||||
func ExpiresAtNotNil() predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldNotNull(FieldExpiresAt))
|
||||
}
|
||||
|
||||
// GroupIDEQ applies the EQ predicate on the "group_id" field.
|
||||
func GroupIDEQ(v int64) predicate.RedeemCode {
|
||||
return predicate.RedeemCode(sql.FieldEQ(FieldGroupID, v))
|
||||
|
||||
@ -128,6 +128,20 @@ func (_c *RedeemCodeCreate) SetNillableCreatedAt(v *time.Time) *RedeemCodeCreate
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (_c *RedeemCodeCreate) SetExpiresAt(v time.Time) *RedeemCodeCreate {
|
||||
_c.mutation.SetExpiresAt(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
|
||||
func (_c *RedeemCodeCreate) SetNillableExpiresAt(v *time.Time) *RedeemCodeCreate {
|
||||
if v != nil {
|
||||
_c.SetExpiresAt(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (_c *RedeemCodeCreate) SetGroupID(v int64) *RedeemCodeCreate {
|
||||
_c.mutation.SetGroupID(v)
|
||||
@ -327,6 +341,10 @@ func (_c *RedeemCodeCreate) createSpec() (*RedeemCode, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(redeemcode.FieldCreatedAt, field.TypeTime, value)
|
||||
_node.CreatedAt = value
|
||||
}
|
||||
if value, ok := _c.mutation.ExpiresAt(); ok {
|
||||
_spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value)
|
||||
_node.ExpiresAt = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ValidityDays(); ok {
|
||||
_spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value)
|
||||
_node.ValidityDays = value
|
||||
@ -525,6 +543,24 @@ func (u *RedeemCodeUpsert) ClearNotes() *RedeemCodeUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (u *RedeemCodeUpsert) SetExpiresAt(v time.Time) *RedeemCodeUpsert {
|
||||
u.Set(redeemcode.FieldExpiresAt, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
|
||||
func (u *RedeemCodeUpsert) UpdateExpiresAt() *RedeemCodeUpsert {
|
||||
u.SetExcluded(redeemcode.FieldExpiresAt)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (u *RedeemCodeUpsert) ClearExpiresAt() *RedeemCodeUpsert {
|
||||
u.SetNull(redeemcode.FieldExpiresAt)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (u *RedeemCodeUpsert) SetGroupID(v int64) *RedeemCodeUpsert {
|
||||
u.Set(redeemcode.FieldGroupID, v)
|
||||
@ -732,6 +768,27 @@ func (u *RedeemCodeUpsertOne) ClearNotes() *RedeemCodeUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (u *RedeemCodeUpsertOne) SetExpiresAt(v time.Time) *RedeemCodeUpsertOne {
|
||||
return u.Update(func(s *RedeemCodeUpsert) {
|
||||
s.SetExpiresAt(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
|
||||
func (u *RedeemCodeUpsertOne) UpdateExpiresAt() *RedeemCodeUpsertOne {
|
||||
return u.Update(func(s *RedeemCodeUpsert) {
|
||||
s.UpdateExpiresAt()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (u *RedeemCodeUpsertOne) ClearExpiresAt() *RedeemCodeUpsertOne {
|
||||
return u.Update(func(s *RedeemCodeUpsert) {
|
||||
s.ClearExpiresAt()
|
||||
})
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (u *RedeemCodeUpsertOne) SetGroupID(v int64) *RedeemCodeUpsertOne {
|
||||
return u.Update(func(s *RedeemCodeUpsert) {
|
||||
@ -1111,6 +1168,27 @@ func (u *RedeemCodeUpsertBulk) ClearNotes() *RedeemCodeUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (u *RedeemCodeUpsertBulk) SetExpiresAt(v time.Time) *RedeemCodeUpsertBulk {
|
||||
return u.Update(func(s *RedeemCodeUpsert) {
|
||||
s.SetExpiresAt(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create.
|
||||
func (u *RedeemCodeUpsertBulk) UpdateExpiresAt() *RedeemCodeUpsertBulk {
|
||||
return u.Update(func(s *RedeemCodeUpsert) {
|
||||
s.UpdateExpiresAt()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (u *RedeemCodeUpsertBulk) ClearExpiresAt() *RedeemCodeUpsertBulk {
|
||||
return u.Update(func(s *RedeemCodeUpsert) {
|
||||
s.ClearExpiresAt()
|
||||
})
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (u *RedeemCodeUpsertBulk) SetGroupID(v int64) *RedeemCodeUpsertBulk {
|
||||
return u.Update(func(s *RedeemCodeUpsert) {
|
||||
|
||||
@ -153,6 +153,26 @@ func (_u *RedeemCodeUpdate) ClearNotes() *RedeemCodeUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (_u *RedeemCodeUpdate) SetExpiresAt(v time.Time) *RedeemCodeUpdate {
|
||||
_u.mutation.SetExpiresAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
|
||||
func (_u *RedeemCodeUpdate) SetNillableExpiresAt(v *time.Time) *RedeemCodeUpdate {
|
||||
if v != nil {
|
||||
_u.SetExpiresAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (_u *RedeemCodeUpdate) ClearExpiresAt() *RedeemCodeUpdate {
|
||||
_u.mutation.ClearExpiresAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (_u *RedeemCodeUpdate) SetGroupID(v int64) *RedeemCodeUpdate {
|
||||
_u.mutation.SetGroupID(v)
|
||||
@ -321,6 +341,12 @@ func (_u *RedeemCodeUpdate) sqlSave(ctx context.Context) (_node int, err error)
|
||||
if _u.mutation.NotesCleared() {
|
||||
_spec.ClearField(redeemcode.FieldNotes, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ExpiresAt(); ok {
|
||||
_spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.ExpiresAtCleared() {
|
||||
_spec.ClearField(redeemcode.FieldExpiresAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.ValidityDays(); ok {
|
||||
_spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value)
|
||||
}
|
||||
@ -528,6 +554,26 @@ func (_u *RedeemCodeUpdateOne) ClearNotes() *RedeemCodeUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetExpiresAt sets the "expires_at" field.
|
||||
func (_u *RedeemCodeUpdateOne) SetExpiresAt(v time.Time) *RedeemCodeUpdateOne {
|
||||
_u.mutation.SetExpiresAt(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil.
|
||||
func (_u *RedeemCodeUpdateOne) SetNillableExpiresAt(v *time.Time) *RedeemCodeUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetExpiresAt(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearExpiresAt clears the value of the "expires_at" field.
|
||||
func (_u *RedeemCodeUpdateOne) ClearExpiresAt() *RedeemCodeUpdateOne {
|
||||
_u.mutation.ClearExpiresAt()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetGroupID sets the "group_id" field.
|
||||
func (_u *RedeemCodeUpdateOne) SetGroupID(v int64) *RedeemCodeUpdateOne {
|
||||
_u.mutation.SetGroupID(v)
|
||||
@ -726,6 +772,12 @@ func (_u *RedeemCodeUpdateOne) sqlSave(ctx context.Context) (_node *RedeemCode,
|
||||
if _u.mutation.NotesCleared() {
|
||||
_spec.ClearField(redeemcode.FieldNotes, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ExpiresAt(); ok {
|
||||
_spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value)
|
||||
}
|
||||
if _u.mutation.ExpiresAtCleared() {
|
||||
_spec.ClearField(redeemcode.FieldExpiresAt, field.TypeTime)
|
||||
}
|
||||
if value, ok := _u.mutation.ValidityDays(); ok {
|
||||
_spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value)
|
||||
}
|
||||
|
||||
@ -464,8 +464,14 @@ func init() {
|
||||
return nil
|
||||
}
|
||||
}()
|
||||
// channelmonitorDescAPIMode is the schema descriptor for api_mode field.
|
||||
channelmonitorDescAPIMode := channelmonitorFields[2].Descriptor()
|
||||
// channelmonitor.DefaultAPIMode holds the default value on creation for the api_mode field.
|
||||
channelmonitor.DefaultAPIMode = channelmonitorDescAPIMode.Default.(string)
|
||||
// channelmonitor.APIModeValidator is a validator for the "api_mode" field. It is called by the builders before save.
|
||||
channelmonitor.APIModeValidator = channelmonitorDescAPIMode.Validators[0].(func(string) error)
|
||||
// channelmonitorDescEndpoint is the schema descriptor for endpoint field.
|
||||
channelmonitorDescEndpoint := channelmonitorFields[2].Descriptor()
|
||||
channelmonitorDescEndpoint := channelmonitorFields[3].Descriptor()
|
||||
// channelmonitor.EndpointValidator is a validator for the "endpoint" field. It is called by the builders before save.
|
||||
channelmonitor.EndpointValidator = func() func(string) error {
|
||||
validators := channelmonitorDescEndpoint.Validators
|
||||
@ -483,11 +489,11 @@ func init() {
|
||||
}
|
||||
}()
|
||||
// channelmonitorDescAPIKeyEncrypted is the schema descriptor for api_key_encrypted field.
|
||||
channelmonitorDescAPIKeyEncrypted := channelmonitorFields[3].Descriptor()
|
||||
channelmonitorDescAPIKeyEncrypted := channelmonitorFields[4].Descriptor()
|
||||
// channelmonitor.APIKeyEncryptedValidator is a validator for the "api_key_encrypted" field. It is called by the builders before save.
|
||||
channelmonitor.APIKeyEncryptedValidator = channelmonitorDescAPIKeyEncrypted.Validators[0].(func(string) error)
|
||||
// channelmonitorDescPrimaryModel is the schema descriptor for primary_model field.
|
||||
channelmonitorDescPrimaryModel := channelmonitorFields[4].Descriptor()
|
||||
channelmonitorDescPrimaryModel := channelmonitorFields[5].Descriptor()
|
||||
// channelmonitor.PrimaryModelValidator is a validator for the "primary_model" field. It is called by the builders before save.
|
||||
channelmonitor.PrimaryModelValidator = func() func(string) error {
|
||||
validators := channelmonitorDescPrimaryModel.Validators
|
||||
@ -505,29 +511,29 @@ func init() {
|
||||
}
|
||||
}()
|
||||
// channelmonitorDescExtraModels is the schema descriptor for extra_models field.
|
||||
channelmonitorDescExtraModels := channelmonitorFields[5].Descriptor()
|
||||
channelmonitorDescExtraModels := channelmonitorFields[6].Descriptor()
|
||||
// channelmonitor.DefaultExtraModels holds the default value on creation for the extra_models field.
|
||||
channelmonitor.DefaultExtraModels = channelmonitorDescExtraModels.Default.([]string)
|
||||
// channelmonitorDescGroupName is the schema descriptor for group_name field.
|
||||
channelmonitorDescGroupName := channelmonitorFields[6].Descriptor()
|
||||
channelmonitorDescGroupName := channelmonitorFields[7].Descriptor()
|
||||
// channelmonitor.DefaultGroupName holds the default value on creation for the group_name field.
|
||||
channelmonitor.DefaultGroupName = channelmonitorDescGroupName.Default.(string)
|
||||
// channelmonitor.GroupNameValidator is a validator for the "group_name" field. It is called by the builders before save.
|
||||
channelmonitor.GroupNameValidator = channelmonitorDescGroupName.Validators[0].(func(string) error)
|
||||
// channelmonitorDescEnabled is the schema descriptor for enabled field.
|
||||
channelmonitorDescEnabled := channelmonitorFields[7].Descriptor()
|
||||
channelmonitorDescEnabled := channelmonitorFields[8].Descriptor()
|
||||
// channelmonitor.DefaultEnabled holds the default value on creation for the enabled field.
|
||||
channelmonitor.DefaultEnabled = channelmonitorDescEnabled.Default.(bool)
|
||||
// channelmonitorDescIntervalSeconds is the schema descriptor for interval_seconds field.
|
||||
channelmonitorDescIntervalSeconds := channelmonitorFields[8].Descriptor()
|
||||
channelmonitorDescIntervalSeconds := channelmonitorFields[9].Descriptor()
|
||||
// channelmonitor.IntervalSecondsValidator is a validator for the "interval_seconds" field. It is called by the builders before save.
|
||||
channelmonitor.IntervalSecondsValidator = channelmonitorDescIntervalSeconds.Validators[0].(func(int) error)
|
||||
// channelmonitorDescExtraHeaders is the schema descriptor for extra_headers field.
|
||||
channelmonitorDescExtraHeaders := channelmonitorFields[12].Descriptor()
|
||||
channelmonitorDescExtraHeaders := channelmonitorFields[13].Descriptor()
|
||||
// channelmonitor.DefaultExtraHeaders holds the default value on creation for the extra_headers field.
|
||||
channelmonitor.DefaultExtraHeaders = channelmonitorDescExtraHeaders.Default.(map[string]string)
|
||||
// channelmonitorDescBodyOverrideMode is the schema descriptor for body_override_mode field.
|
||||
channelmonitorDescBodyOverrideMode := channelmonitorFields[13].Descriptor()
|
||||
channelmonitorDescBodyOverrideMode := channelmonitorFields[14].Descriptor()
|
||||
// channelmonitor.DefaultBodyOverrideMode holds the default value on creation for the body_override_mode field.
|
||||
channelmonitor.DefaultBodyOverrideMode = channelmonitorDescBodyOverrideMode.Default.(string)
|
||||
// channelmonitor.BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
|
||||
@ -661,18 +667,24 @@ func init() {
|
||||
return nil
|
||||
}
|
||||
}()
|
||||
// channelmonitorrequesttemplateDescAPIMode is the schema descriptor for api_mode field.
|
||||
channelmonitorrequesttemplateDescAPIMode := channelmonitorrequesttemplateFields[2].Descriptor()
|
||||
// channelmonitorrequesttemplate.DefaultAPIMode holds the default value on creation for the api_mode field.
|
||||
channelmonitorrequesttemplate.DefaultAPIMode = channelmonitorrequesttemplateDescAPIMode.Default.(string)
|
||||
// channelmonitorrequesttemplate.APIModeValidator is a validator for the "api_mode" field. It is called by the builders before save.
|
||||
channelmonitorrequesttemplate.APIModeValidator = channelmonitorrequesttemplateDescAPIMode.Validators[0].(func(string) error)
|
||||
// channelmonitorrequesttemplateDescDescription is the schema descriptor for description field.
|
||||
channelmonitorrequesttemplateDescDescription := channelmonitorrequesttemplateFields[2].Descriptor()
|
||||
channelmonitorrequesttemplateDescDescription := channelmonitorrequesttemplateFields[3].Descriptor()
|
||||
// channelmonitorrequesttemplate.DefaultDescription holds the default value on creation for the description field.
|
||||
channelmonitorrequesttemplate.DefaultDescription = channelmonitorrequesttemplateDescDescription.Default.(string)
|
||||
// channelmonitorrequesttemplate.DescriptionValidator is a validator for the "description" field. It is called by the builders before save.
|
||||
channelmonitorrequesttemplate.DescriptionValidator = channelmonitorrequesttemplateDescDescription.Validators[0].(func(string) error)
|
||||
// channelmonitorrequesttemplateDescExtraHeaders is the schema descriptor for extra_headers field.
|
||||
channelmonitorrequesttemplateDescExtraHeaders := channelmonitorrequesttemplateFields[3].Descriptor()
|
||||
channelmonitorrequesttemplateDescExtraHeaders := channelmonitorrequesttemplateFields[4].Descriptor()
|
||||
// channelmonitorrequesttemplate.DefaultExtraHeaders holds the default value on creation for the extra_headers field.
|
||||
channelmonitorrequesttemplate.DefaultExtraHeaders = channelmonitorrequesttemplateDescExtraHeaders.Default.(map[string]string)
|
||||
// channelmonitorrequesttemplateDescBodyOverrideMode is the schema descriptor for body_override_mode field.
|
||||
channelmonitorrequesttemplateDescBodyOverrideMode := channelmonitorrequesttemplateFields[4].Descriptor()
|
||||
channelmonitorrequesttemplateDescBodyOverrideMode := channelmonitorrequesttemplateFields[5].Descriptor()
|
||||
// channelmonitorrequesttemplate.DefaultBodyOverrideMode holds the default value on creation for the body_override_mode field.
|
||||
channelmonitorrequesttemplate.DefaultBodyOverrideMode = channelmonitorrequesttemplateDescBodyOverrideMode.Default.(string)
|
||||
// channelmonitorrequesttemplate.BodyOverrideModeValidator is a validator for the "body_override_mode" field. It is called by the builders before save.
|
||||
@ -1386,7 +1398,7 @@ func init() {
|
||||
// redeemcode.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
redeemcode.DefaultCreatedAt = redeemcodeDescCreatedAt.Default.(func() time.Time)
|
||||
// redeemcodeDescValidityDays is the schema descriptor for validity_days field.
|
||||
redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor()
|
||||
redeemcodeDescValidityDays := redeemcodeFields[10].Descriptor()
|
||||
// redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field.
|
||||
redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int)
|
||||
securitysecretMixin := schema.SecuritySecret{}.Mixin()
|
||||
@ -1722,12 +1734,24 @@ func init() {
|
||||
usagelogDescImageSize := usagelogFields[34].Descriptor()
|
||||
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
|
||||
// usagelogDescImageInputSize is the schema descriptor for image_input_size field.
|
||||
usagelogDescImageInputSize := usagelogFields[35].Descriptor()
|
||||
// usagelog.ImageInputSizeValidator is a validator for the "image_input_size" field. It is called by the builders before save.
|
||||
usagelog.ImageInputSizeValidator = usagelogDescImageInputSize.Validators[0].(func(string) error)
|
||||
// usagelogDescImageOutputSize is the schema descriptor for image_output_size field.
|
||||
usagelogDescImageOutputSize := usagelogFields[36].Descriptor()
|
||||
// usagelog.ImageOutputSizeValidator is a validator for the "image_output_size" field. It is called by the builders before save.
|
||||
usagelog.ImageOutputSizeValidator = usagelogDescImageOutputSize.Validators[0].(func(string) error)
|
||||
// usagelogDescImageSizeSource is the schema descriptor for image_size_source field.
|
||||
usagelogDescImageSizeSource := usagelogFields[37].Descriptor()
|
||||
// usagelog.ImageSizeSourceValidator is a validator for the "image_size_source" field. It is called by the builders before save.
|
||||
usagelog.ImageSizeSourceValidator = usagelogDescImageSizeSource.Validators[0].(func(string) error)
|
||||
// usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field.
|
||||
usagelogDescCacheTTLOverridden := usagelogFields[35].Descriptor()
|
||||
usagelogDescCacheTTLOverridden := usagelogFields[39].Descriptor()
|
||||
// usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field.
|
||||
usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool)
|
||||
// usagelogDescCreatedAt is the schema descriptor for created_at field.
|
||||
usagelogDescCreatedAt := usagelogFields[36].Descriptor()
|
||||
usagelogDescCreatedAt := usagelogFields[40].Descriptor()
|
||||
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
|
||||
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
|
||||
userMixin := schema.User{}.Mixin()
|
||||
|
||||
@ -15,12 +15,13 @@ import (
|
||||
)
|
||||
|
||||
var authProviderTypes = map[string]struct{}{
|
||||
"email": {},
|
||||
"github": {},
|
||||
"google": {},
|
||||
"linuxdo": {},
|
||||
"oidc": {},
|
||||
"wechat": {},
|
||||
"email": {},
|
||||
"github": {},
|
||||
"google": {},
|
||||
"linuxdo": {},
|
||||
"oidc": {},
|
||||
"wechat": {},
|
||||
"dingtalk": {},
|
||||
}
|
||||
|
||||
func validateAuthProviderType(value string) error {
|
||||
|
||||
@ -83,7 +83,7 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) {
|
||||
require.Equal(t, 1, signupSource.Validators)
|
||||
|
||||
validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source")
|
||||
for _, value := range []string{"email", "linuxdo", "wechat", "oidc", "github", "google"} {
|
||||
for _, value := range []string{"email", "linuxdo", "wechat", "oidc", "github", "google", "dingtalk"} {
|
||||
require.NoError(t, validator(value))
|
||||
}
|
||||
require.Error(t, validator("unknown"))
|
||||
|
||||
@ -36,6 +36,10 @@ func (ChannelMonitor) Fields() []ent.Field {
|
||||
MaxLen(100),
|
||||
field.Enum("provider").
|
||||
Values("openai", "anthropic", "gemini"),
|
||||
field.String("api_mode").
|
||||
Default("chat_completions").
|
||||
MaxLen(32).
|
||||
Comment("OpenAI request protocol: chat_completions or responses; non-OpenAI uses chat_completions"),
|
||||
field.String("endpoint").
|
||||
NotEmpty().
|
||||
MaxLen(500).
|
||||
@ -104,6 +108,7 @@ func (ChannelMonitor) Indexes() []ent.Index {
|
||||
return []ent.Index{
|
||||
index.Fields("enabled", "last_checked_at"),
|
||||
index.Fields("provider"),
|
||||
index.Fields("provider", "api_mode"),
|
||||
index.Fields("group_name"),
|
||||
index.Fields("template_id"),
|
||||
}
|
||||
|
||||
@ -40,6 +40,10 @@ func (ChannelMonitorRequestTemplate) Fields() []ent.Field {
|
||||
MaxLen(100),
|
||||
field.Enum("provider").
|
||||
Values("openai", "anthropic", "gemini"),
|
||||
field.String("api_mode").
|
||||
Default("chat_completions").
|
||||
MaxLen(32).
|
||||
Comment("OpenAI request protocol: chat_completions or responses; non-OpenAI uses chat_completions"),
|
||||
field.String("description").
|
||||
Optional().
|
||||
Default("").
|
||||
@ -76,5 +80,6 @@ func (ChannelMonitorRequestTemplate) Indexes() []ent.Index {
|
||||
return []ent.Index{
|
||||
// 同一 provider 内 name 唯一:允许 Anthropic + OpenAI 重名 "伪装官方客户端"。
|
||||
index.Fields("provider", "name").Unique(),
|
||||
index.Fields("provider", "api_mode"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -63,6 +63,10 @@ func (RedeemCode) Fields() []ent.Field {
|
||||
Immutable().
|
||||
Default(time.Now).
|
||||
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
|
||||
field.Time("expires_at").
|
||||
Optional().
|
||||
Nillable().
|
||||
SchemaType(map[string]string{dialect.Postgres: "timestamptz"}),
|
||||
field.Int64("group_id").
|
||||
Optional().
|
||||
Nillable(),
|
||||
@ -90,5 +94,6 @@ func (RedeemCode) Indexes() []ent.Index {
|
||||
index.Fields("status"),
|
||||
index.Fields("used_by"),
|
||||
index.Fields("group_id"),
|
||||
index.Fields("expires_at"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -134,6 +134,21 @@ func (UsageLog) Fields() []ent.Field {
|
||||
MaxLen(10).
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.String("image_input_size").
|
||||
MaxLen(32).
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.String("image_output_size").
|
||||
MaxLen(32).
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.String("image_size_source").
|
||||
MaxLen(16).
|
||||
Optional().
|
||||
Nillable(),
|
||||
field.JSON("image_size_breakdown", map[string]int{}).
|
||||
Optional().
|
||||
SchemaType(map[string]string{dialect.Postgres: "jsonb"}),
|
||||
// Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费)
|
||||
field.Bool("cache_ttl_overridden").
|
||||
Default(false),
|
||||
|
||||
@ -77,10 +77,10 @@ func (User) Fields() []ent.Field {
|
||||
field.String("signup_source").
|
||||
Validate(func(value string) error {
|
||||
switch value {
|
||||
case "email", "linuxdo", "wechat", "oidc", "github", "google":
|
||||
case "email", "linuxdo", "wechat", "oidc", "github", "google", "dingtalk":
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc, github, google")
|
||||
return fmt.Errorf("must be one of email, linuxdo, wechat, oidc, github, google, dingtalk")
|
||||
}
|
||||
}).
|
||||
Default("email"),
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
package ent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@ -92,6 +93,14 @@ type UsageLog struct {
|
||||
ImageCount int `json:"image_count,omitempty"`
|
||||
// ImageSize holds the value of the "image_size" field.
|
||||
ImageSize *string `json:"image_size,omitempty"`
|
||||
// ImageInputSize holds the value of the "image_input_size" field.
|
||||
ImageInputSize *string `json:"image_input_size,omitempty"`
|
||||
// ImageOutputSize holds the value of the "image_output_size" field.
|
||||
ImageOutputSize *string `json:"image_output_size,omitempty"`
|
||||
// ImageSizeSource holds the value of the "image_size_source" field.
|
||||
ImageSizeSource *string `json:"image_size_source,omitempty"`
|
||||
// ImageSizeBreakdown holds the value of the "image_size_breakdown" field.
|
||||
ImageSizeBreakdown map[string]int `json:"image_size_breakdown,omitempty"`
|
||||
// CacheTTLOverridden holds the value of the "cache_ttl_overridden" field.
|
||||
CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"`
|
||||
// CreatedAt holds the value of the "created_at" field.
|
||||
@ -179,13 +188,15 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
|
||||
values := make([]any, len(columns))
|
||||
for i := range columns {
|
||||
switch columns[i] {
|
||||
case usagelog.FieldImageSizeBreakdown:
|
||||
values[i] = new([]byte)
|
||||
case usagelog.FieldStream, usagelog.FieldCacheTTLOverridden:
|
||||
values[i] = new(sql.NullBool)
|
||||
case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
|
||||
values[i] = new(sql.NullFloat64)
|
||||
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
|
||||
values[i] = new(sql.NullInt64)
|
||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize:
|
||||
case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldImageInputSize, usagelog.FieldImageOutputSize, usagelog.FieldImageSizeSource:
|
||||
values[i] = new(sql.NullString)
|
||||
case usagelog.FieldCreatedAt:
|
||||
values[i] = new(sql.NullTime)
|
||||
@ -434,6 +445,35 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
|
||||
_m.ImageSize = new(string)
|
||||
*_m.ImageSize = value.String
|
||||
}
|
||||
case usagelog.FieldImageInputSize:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_input_size", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ImageInputSize = new(string)
|
||||
*_m.ImageInputSize = value.String
|
||||
}
|
||||
case usagelog.FieldImageOutputSize:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_output_size", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ImageOutputSize = new(string)
|
||||
*_m.ImageOutputSize = value.String
|
||||
}
|
||||
case usagelog.FieldImageSizeSource:
|
||||
if value, ok := values[i].(*sql.NullString); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_size_source", values[i])
|
||||
} else if value.Valid {
|
||||
_m.ImageSizeSource = new(string)
|
||||
*_m.ImageSizeSource = value.String
|
||||
}
|
||||
case usagelog.FieldImageSizeBreakdown:
|
||||
if value, ok := values[i].(*[]byte); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field image_size_breakdown", values[i])
|
||||
} else if value != nil && len(*value) > 0 {
|
||||
if err := json.Unmarshal(*value, &_m.ImageSizeBreakdown); err != nil {
|
||||
return fmt.Errorf("unmarshal field image_size_breakdown: %w", err)
|
||||
}
|
||||
}
|
||||
case usagelog.FieldCacheTTLOverridden:
|
||||
if value, ok := values[i].(*sql.NullBool); !ok {
|
||||
return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i])
|
||||
@ -640,6 +680,24 @@ func (_m *UsageLog) String() string {
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ImageInputSize; v != nil {
|
||||
builder.WriteString("image_input_size=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ImageOutputSize; v != nil {
|
||||
builder.WriteString("image_output_size=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
if v := _m.ImageSizeSource; v != nil {
|
||||
builder.WriteString("image_size_source=")
|
||||
builder.WriteString(*v)
|
||||
}
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("image_size_breakdown=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.ImageSizeBreakdown))
|
||||
builder.WriteString(", ")
|
||||
builder.WriteString("cache_ttl_overridden=")
|
||||
builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden))
|
||||
builder.WriteString(", ")
|
||||
|
||||
@ -84,6 +84,14 @@ const (
|
||||
FieldImageCount = "image_count"
|
||||
// FieldImageSize holds the string denoting the image_size field in the database.
|
||||
FieldImageSize = "image_size"
|
||||
// FieldImageInputSize holds the string denoting the image_input_size field in the database.
|
||||
FieldImageInputSize = "image_input_size"
|
||||
// FieldImageOutputSize holds the string denoting the image_output_size field in the database.
|
||||
FieldImageOutputSize = "image_output_size"
|
||||
// FieldImageSizeSource holds the string denoting the image_size_source field in the database.
|
||||
FieldImageSizeSource = "image_size_source"
|
||||
// FieldImageSizeBreakdown holds the string denoting the image_size_breakdown field in the database.
|
||||
FieldImageSizeBreakdown = "image_size_breakdown"
|
||||
// FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database.
|
||||
FieldCacheTTLOverridden = "cache_ttl_overridden"
|
||||
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||
@ -175,6 +183,10 @@ var Columns = []string{
|
||||
FieldIPAddress,
|
||||
FieldImageCount,
|
||||
FieldImageSize,
|
||||
FieldImageInputSize,
|
||||
FieldImageOutputSize,
|
||||
FieldImageSizeSource,
|
||||
FieldImageSizeBreakdown,
|
||||
FieldCacheTTLOverridden,
|
||||
FieldCreatedAt,
|
||||
}
|
||||
@ -242,6 +254,12 @@ var (
|
||||
DefaultImageCount int
|
||||
// ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
|
||||
ImageSizeValidator func(string) error
|
||||
// ImageInputSizeValidator is a validator for the "image_input_size" field. It is called by the builders before save.
|
||||
ImageInputSizeValidator func(string) error
|
||||
// ImageOutputSizeValidator is a validator for the "image_output_size" field. It is called by the builders before save.
|
||||
ImageOutputSizeValidator func(string) error
|
||||
// ImageSizeSourceValidator is a validator for the "image_size_source" field. It is called by the builders before save.
|
||||
ImageSizeSourceValidator func(string) error
|
||||
// DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field.
|
||||
DefaultCacheTTLOverridden bool
|
||||
// DefaultCreatedAt holds the default value on creation for the "created_at" field.
|
||||
@ -431,6 +449,21 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImageSize, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImageInputSize orders the results by the image_input_size field.
|
||||
func ByImageInputSize(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImageInputSize, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImageOutputSize orders the results by the image_output_size field.
|
||||
func ByImageOutputSize(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImageOutputSize, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByImageSizeSource orders the results by the image_size_source field.
|
||||
func ByImageSizeSource(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldImageSizeSource, opts...).ToFunc()
|
||||
}
|
||||
|
||||
// ByCacheTTLOverridden orders the results by the cache_ttl_overridden field.
|
||||
func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption {
|
||||
return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc()
|
||||
|
||||
@ -230,6 +230,21 @@ func ImageSize(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSize applies equality check predicate on the "image_input_size" field. It's identical to ImageInputSizeEQ.
|
||||
func ImageInputSize(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSize applies equality check predicate on the "image_output_size" field. It's identical to ImageOutputSizeEQ.
|
||||
func ImageOutputSize(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageSizeSource applies equality check predicate on the "image_size_source" field. It's identical to ImageSizeSourceEQ.
|
||||
func ImageSizeSource(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ.
|
||||
func CacheTTLOverridden(v bool) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
|
||||
@ -1900,6 +1915,241 @@ func ImageSizeContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeEQ applies the EQ predicate on the "image_input_size" field.
|
||||
func ImageInputSizeEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeNEQ applies the NEQ predicate on the "image_input_size" field.
|
||||
func ImageInputSizeNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeIn applies the In predicate on the "image_input_size" field.
|
||||
func ImageInputSizeIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldImageInputSize, vs...))
|
||||
}
|
||||
|
||||
// ImageInputSizeNotIn applies the NotIn predicate on the "image_input_size" field.
|
||||
func ImageInputSizeNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldImageInputSize, vs...))
|
||||
}
|
||||
|
||||
// ImageInputSizeGT applies the GT predicate on the "image_input_size" field.
|
||||
func ImageInputSizeGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeGTE applies the GTE predicate on the "image_input_size" field.
|
||||
func ImageInputSizeGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeLT applies the LT predicate on the "image_input_size" field.
|
||||
func ImageInputSizeLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeLTE applies the LTE predicate on the "image_input_size" field.
|
||||
func ImageInputSizeLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeContains applies the Contains predicate on the "image_input_size" field.
|
||||
func ImageInputSizeContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeHasPrefix applies the HasPrefix predicate on the "image_input_size" field.
|
||||
func ImageInputSizeHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeHasSuffix applies the HasSuffix predicate on the "image_input_size" field.
|
||||
func ImageInputSizeHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeIsNil applies the IsNil predicate on the "image_input_size" field.
|
||||
func ImageInputSizeIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldImageInputSize))
|
||||
}
|
||||
|
||||
// ImageInputSizeNotNil applies the NotNil predicate on the "image_input_size" field.
|
||||
func ImageInputSizeNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldImageInputSize))
|
||||
}
|
||||
|
||||
// ImageInputSizeEqualFold applies the EqualFold predicate on the "image_input_size" field.
|
||||
func ImageInputSizeEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageInputSizeContainsFold applies the ContainsFold predicate on the "image_input_size" field.
|
||||
func ImageInputSizeContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldImageInputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeEQ applies the EQ predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeNEQ applies the NEQ predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeIn applies the In predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldImageOutputSize, vs...))
|
||||
}
|
||||
|
||||
// ImageOutputSizeNotIn applies the NotIn predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldImageOutputSize, vs...))
|
||||
}
|
||||
|
||||
// ImageOutputSizeGT applies the GT predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeGTE applies the GTE predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeLT applies the LT predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeLTE applies the LTE predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeContains applies the Contains predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeHasPrefix applies the HasPrefix predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeHasSuffix applies the HasSuffix predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeIsNil applies the IsNil predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldImageOutputSize))
|
||||
}
|
||||
|
||||
// ImageOutputSizeNotNil applies the NotNil predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldImageOutputSize))
|
||||
}
|
||||
|
||||
// ImageOutputSizeEqualFold applies the EqualFold predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageOutputSizeContainsFold applies the ContainsFold predicate on the "image_output_size" field.
|
||||
func ImageOutputSizeContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldImageOutputSize, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceEQ applies the EQ predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceNEQ applies the NEQ predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceNEQ(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNEQ(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceIn applies the In predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIn(FieldImageSizeSource, vs...))
|
||||
}
|
||||
|
||||
// ImageSizeSourceNotIn applies the NotIn predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceNotIn(vs ...string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotIn(FieldImageSizeSource, vs...))
|
||||
}
|
||||
|
||||
// ImageSizeSourceGT applies the GT predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceGT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGT(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceGTE applies the GTE predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceGTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldGTE(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceLT applies the LT predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceLT(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLT(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceLTE applies the LTE predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceLTE(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldLTE(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceContains applies the Contains predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceContains(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContains(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceHasPrefix applies the HasPrefix predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceHasPrefix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasPrefix(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceHasSuffix applies the HasSuffix predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceHasSuffix(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldHasSuffix(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceIsNil applies the IsNil predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldImageSizeSource))
|
||||
}
|
||||
|
||||
// ImageSizeSourceNotNil applies the NotNil predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldImageSizeSource))
|
||||
}
|
||||
|
||||
// ImageSizeSourceEqualFold applies the EqualFold predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceEqualFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEqualFold(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeSourceContainsFold applies the ContainsFold predicate on the "image_size_source" field.
|
||||
func ImageSizeSourceContainsFold(v string) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldContainsFold(FieldImageSizeSource, v))
|
||||
}
|
||||
|
||||
// ImageSizeBreakdownIsNil applies the IsNil predicate on the "image_size_breakdown" field.
|
||||
func ImageSizeBreakdownIsNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldIsNull(FieldImageSizeBreakdown))
|
||||
}
|
||||
|
||||
// ImageSizeBreakdownNotNil applies the NotNil predicate on the "image_size_breakdown" field.
|
||||
func ImageSizeBreakdownNotNil() predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldNotNull(FieldImageSizeBreakdown))
|
||||
}
|
||||
|
||||
// CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field.
|
||||
func CacheTTLOverriddenEQ(v bool) predicate.UsageLog {
|
||||
return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v))
|
||||
|
||||
@ -477,6 +477,54 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImageInputSize sets the "image_input_size" field.
|
||||
func (_c *UsageLogCreate) SetImageInputSize(v string) *UsageLogCreate {
|
||||
_c.mutation.SetImageInputSize(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableImageInputSize(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetImageInputSize(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImageOutputSize sets the "image_output_size" field.
|
||||
func (_c *UsageLogCreate) SetImageOutputSize(v string) *UsageLogCreate {
|
||||
_c.mutation.SetImageOutputSize(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableImageOutputSize(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetImageOutputSize(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImageSizeSource sets the "image_size_source" field.
|
||||
func (_c *UsageLogCreate) SetImageSizeSource(v string) *UsageLogCreate {
|
||||
_c.mutation.SetImageSizeSource(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil.
|
||||
func (_c *UsageLogCreate) SetNillableImageSizeSource(v *string) *UsageLogCreate {
|
||||
if v != nil {
|
||||
_c.SetImageSizeSource(*v)
|
||||
}
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
|
||||
func (_c *UsageLogCreate) SetImageSizeBreakdown(v map[string]int) *UsageLogCreate {
|
||||
_c.mutation.SetImageSizeBreakdown(v)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate {
|
||||
_c.mutation.SetCacheTTLOverridden(v)
|
||||
@ -754,6 +802,21 @@ func (_c *UsageLogCreate) check() error {
|
||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.ImageInputSize(); ok {
|
||||
if err := usagelog.ImageInputSizeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.ImageOutputSize(); ok {
|
||||
if err := usagelog.ImageOutputSizeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _c.mutation.ImageSizeSource(); ok {
|
||||
if err := usagelog.ImageSizeSourceValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _, ok := _c.mutation.CacheTTLOverridden(); !ok {
|
||||
return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)}
|
||||
}
|
||||
@ -916,6 +979,22 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
|
||||
_spec.SetField(usagelog.FieldImageSize, field.TypeString, value)
|
||||
_node.ImageSize = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ImageInputSize(); ok {
|
||||
_spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value)
|
||||
_node.ImageInputSize = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ImageOutputSize(); ok {
|
||||
_spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value)
|
||||
_node.ImageOutputSize = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ImageSizeSource(); ok {
|
||||
_spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value)
|
||||
_node.ImageSizeSource = &value
|
||||
}
|
||||
if value, ok := _c.mutation.ImageSizeBreakdown(); ok {
|
||||
_spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value)
|
||||
_node.ImageSizeBreakdown = value
|
||||
}
|
||||
if value, ok := _c.mutation.CacheTTLOverridden(); ok {
|
||||
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||
_node.CacheTTLOverridden = value
|
||||
@ -1679,6 +1758,78 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert {
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImageInputSize sets the "image_input_size" field.
|
||||
func (u *UsageLogUpsert) SetImageInputSize(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldImageInputSize, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateImageInputSize() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldImageInputSize)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearImageInputSize clears the value of the "image_input_size" field.
|
||||
func (u *UsageLogUpsert) ClearImageInputSize() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldImageInputSize)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImageOutputSize sets the "image_output_size" field.
|
||||
func (u *UsageLogUpsert) SetImageOutputSize(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldImageOutputSize, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateImageOutputSize() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldImageOutputSize)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearImageOutputSize clears the value of the "image_output_size" field.
|
||||
func (u *UsageLogUpsert) ClearImageOutputSize() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldImageOutputSize)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImageSizeSource sets the "image_size_source" field.
|
||||
func (u *UsageLogUpsert) SetImageSizeSource(v string) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldImageSizeSource, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateImageSizeSource() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldImageSizeSource)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearImageSizeSource clears the value of the "image_size_source" field.
|
||||
func (u *UsageLogUpsert) ClearImageSizeSource() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldImageSizeSource)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
|
||||
func (u *UsageLogUpsert) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldImageSizeBreakdown, v)
|
||||
return u
|
||||
}
|
||||
|
||||
// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsert) UpdateImageSizeBreakdown() *UsageLogUpsert {
|
||||
u.SetExcluded(usagelog.FieldImageSizeBreakdown)
|
||||
return u
|
||||
}
|
||||
|
||||
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
|
||||
func (u *UsageLogUpsert) ClearImageSizeBreakdown() *UsageLogUpsert {
|
||||
u.SetNull(usagelog.FieldImageSizeBreakdown)
|
||||
return u
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert {
|
||||
u.Set(usagelog.FieldCacheTTLOverridden, v)
|
||||
@ -2457,6 +2608,90 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne {
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageInputSize sets the "image_input_size" field.
|
||||
func (u *UsageLogUpsertOne) SetImageInputSize(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageInputSize(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateImageInputSize() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageInputSize()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageInputSize clears the value of the "image_input_size" field.
|
||||
func (u *UsageLogUpsertOne) ClearImageInputSize() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageInputSize()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageOutputSize sets the "image_output_size" field.
|
||||
func (u *UsageLogUpsertOne) SetImageOutputSize(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageOutputSize(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateImageOutputSize() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageOutputSize()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageOutputSize clears the value of the "image_output_size" field.
|
||||
func (u *UsageLogUpsertOne) ClearImageOutputSize() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageOutputSize()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageSizeSource sets the "image_size_source" field.
|
||||
func (u *UsageLogUpsertOne) SetImageSizeSource(v string) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageSizeSource(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateImageSizeSource() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageSizeSource()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageSizeSource clears the value of the "image_size_source" field.
|
||||
func (u *UsageLogUpsertOne) ClearImageSizeSource() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageSizeSource()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
|
||||
func (u *UsageLogUpsertOne) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageSizeBreakdown(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertOne) UpdateImageSizeBreakdown() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageSizeBreakdown()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
|
||||
func (u *UsageLogUpsertOne) ClearImageSizeBreakdown() *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageSizeBreakdown()
|
||||
})
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
@ -3403,6 +3638,90 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk {
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageInputSize sets the "image_input_size" field.
|
||||
func (u *UsageLogUpsertBulk) SetImageInputSize(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageInputSize(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateImageInputSize() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageInputSize()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageInputSize clears the value of the "image_input_size" field.
|
||||
func (u *UsageLogUpsertBulk) ClearImageInputSize() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageInputSize()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageOutputSize sets the "image_output_size" field.
|
||||
func (u *UsageLogUpsertBulk) SetImageOutputSize(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageOutputSize(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateImageOutputSize() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageOutputSize()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageOutputSize clears the value of the "image_output_size" field.
|
||||
func (u *UsageLogUpsertBulk) ClearImageOutputSize() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageOutputSize()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageSizeSource sets the "image_size_source" field.
|
||||
func (u *UsageLogUpsertBulk) SetImageSizeSource(v string) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageSizeSource(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateImageSizeSource() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageSizeSource()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageSizeSource clears the value of the "image_size_source" field.
|
||||
func (u *UsageLogUpsertBulk) ClearImageSizeSource() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageSizeSource()
|
||||
})
|
||||
}
|
||||
|
||||
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
|
||||
func (u *UsageLogUpsertBulk) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.SetImageSizeBreakdown(v)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create.
|
||||
func (u *UsageLogUpsertBulk) UpdateImageSizeBreakdown() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.UpdateImageSizeBreakdown()
|
||||
})
|
||||
}
|
||||
|
||||
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
|
||||
func (u *UsageLogUpsertBulk) ClearImageSizeBreakdown() *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
s.ClearImageSizeBreakdown()
|
||||
})
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk {
|
||||
return u.Update(func(s *UsageLogUpsert) {
|
||||
|
||||
@ -739,6 +739,78 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageInputSize sets the "image_input_size" field.
|
||||
func (_u *UsageLogUpdate) SetImageInputSize(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetImageInputSize(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableImageInputSize(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetImageInputSize(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageInputSize clears the value of the "image_input_size" field.
|
||||
func (_u *UsageLogUpdate) ClearImageInputSize() *UsageLogUpdate {
|
||||
_u.mutation.ClearImageInputSize()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageOutputSize sets the "image_output_size" field.
|
||||
func (_u *UsageLogUpdate) SetImageOutputSize(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetImageOutputSize(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableImageOutputSize(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetImageOutputSize(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageOutputSize clears the value of the "image_output_size" field.
|
||||
func (_u *UsageLogUpdate) ClearImageOutputSize() *UsageLogUpdate {
|
||||
_u.mutation.ClearImageOutputSize()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageSizeSource sets the "image_size_source" field.
|
||||
func (_u *UsageLogUpdate) SetImageSizeSource(v string) *UsageLogUpdate {
|
||||
_u.mutation.SetImageSizeSource(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdate) SetNillableImageSizeSource(v *string) *UsageLogUpdate {
|
||||
if v != nil {
|
||||
_u.SetImageSizeSource(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageSizeSource clears the value of the "image_size_source" field.
|
||||
func (_u *UsageLogUpdate) ClearImageSizeSource() *UsageLogUpdate {
|
||||
_u.mutation.ClearImageSizeSource()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
|
||||
func (_u *UsageLogUpdate) SetImageSizeBreakdown(v map[string]int) *UsageLogUpdate {
|
||||
_u.mutation.SetImageSizeBreakdown(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
|
||||
func (_u *UsageLogUpdate) ClearImageSizeBreakdown() *UsageLogUpdate {
|
||||
_u.mutation.ClearImageSizeBreakdown()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate {
|
||||
_u.mutation.SetCacheTTLOverridden(v)
|
||||
@ -892,6 +964,21 @@ func (_u *UsageLogUpdate) check() error {
|
||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ImageInputSize(); ok {
|
||||
if err := usagelog.ImageInputSizeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ImageOutputSize(); ok {
|
||||
if err := usagelog.ImageOutputSizeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ImageSizeSource(); ok {
|
||||
if err := usagelog.ImageSizeSourceValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
|
||||
}
|
||||
@ -1099,6 +1186,30 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
|
||||
if _u.mutation.ImageSizeCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageInputSize(); ok {
|
||||
_spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ImageInputSizeCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageInputSize, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageOutputSize(); ok {
|
||||
_spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ImageOutputSizeCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageOutputSize, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageSizeSource(); ok {
|
||||
_spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ImageSizeSourceCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageSizeSource, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageSizeBreakdown(); ok {
|
||||
_spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value)
|
||||
}
|
||||
if _u.mutation.ImageSizeBreakdownCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageSizeBreakdown, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
|
||||
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||
}
|
||||
@ -1974,6 +2085,78 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne {
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageInputSize sets the "image_input_size" field.
|
||||
func (_u *UsageLogUpdateOne) SetImageInputSize(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetImageInputSize(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableImageInputSize(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetImageInputSize(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageInputSize clears the value of the "image_input_size" field.
|
||||
func (_u *UsageLogUpdateOne) ClearImageInputSize() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearImageInputSize()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageOutputSize sets the "image_output_size" field.
|
||||
func (_u *UsageLogUpdateOne) SetImageOutputSize(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetImageOutputSize(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableImageOutputSize(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetImageOutputSize(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageOutputSize clears the value of the "image_output_size" field.
|
||||
func (_u *UsageLogUpdateOne) ClearImageOutputSize() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearImageOutputSize()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageSizeSource sets the "image_size_source" field.
|
||||
func (_u *UsageLogUpdateOne) SetImageSizeSource(v string) *UsageLogUpdateOne {
|
||||
_u.mutation.SetImageSizeSource(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil.
|
||||
func (_u *UsageLogUpdateOne) SetNillableImageSizeSource(v *string) *UsageLogUpdateOne {
|
||||
if v != nil {
|
||||
_u.SetImageSizeSource(*v)
|
||||
}
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageSizeSource clears the value of the "image_size_source" field.
|
||||
func (_u *UsageLogUpdateOne) ClearImageSizeSource() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearImageSizeSource()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetImageSizeBreakdown sets the "image_size_breakdown" field.
|
||||
func (_u *UsageLogUpdateOne) SetImageSizeBreakdown(v map[string]int) *UsageLogUpdateOne {
|
||||
_u.mutation.SetImageSizeBreakdown(v)
|
||||
return _u
|
||||
}
|
||||
|
||||
// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field.
|
||||
func (_u *UsageLogUpdateOne) ClearImageSizeBreakdown() *UsageLogUpdateOne {
|
||||
_u.mutation.ClearImageSizeBreakdown()
|
||||
return _u
|
||||
}
|
||||
|
||||
// SetCacheTTLOverridden sets the "cache_ttl_overridden" field.
|
||||
func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne {
|
||||
_u.mutation.SetCacheTTLOverridden(v)
|
||||
@ -2140,6 +2323,21 @@ func (_u *UsageLogUpdateOne) check() error {
|
||||
return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ImageInputSize(); ok {
|
||||
if err := usagelog.ImageInputSizeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ImageOutputSize(); ok {
|
||||
if err := usagelog.ImageOutputSizeValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)}
|
||||
}
|
||||
}
|
||||
if v, ok := _u.mutation.ImageSizeSource(); ok {
|
||||
if err := usagelog.ImageSizeSourceValidator(v); err != nil {
|
||||
return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)}
|
||||
}
|
||||
}
|
||||
if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 {
|
||||
return errors.New(`ent: clearing a required unique edge "UsageLog.user"`)
|
||||
}
|
||||
@ -2364,6 +2562,30 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
|
||||
if _u.mutation.ImageSizeCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageSize, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageInputSize(); ok {
|
||||
_spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ImageInputSizeCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageInputSize, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageOutputSize(); ok {
|
||||
_spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ImageOutputSizeCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageOutputSize, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageSizeSource(); ok {
|
||||
_spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value)
|
||||
}
|
||||
if _u.mutation.ImageSizeSourceCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageSizeSource, field.TypeString)
|
||||
}
|
||||
if value, ok := _u.mutation.ImageSizeBreakdown(); ok {
|
||||
_spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value)
|
||||
}
|
||||
if _u.mutation.ImageSizeBreakdownCleared() {
|
||||
_spec.ClearField(usagelog.FieldImageSizeBreakdown, field.TypeJSON)
|
||||
}
|
||||
if value, ok := _u.mutation.CacheTTLOverridden(); ok {
|
||||
_spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value)
|
||||
}
|
||||
|
||||
@ -222,6 +222,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
|
||||
@ -255,6 +257,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
@ -284,6 +288,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv
|
||||
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
|
||||
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
@ -316,6 +322,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
|
||||
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
|
||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
|
||||
|
||||
@ -79,6 +79,7 @@ type Config struct {
|
||||
LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"`
|
||||
WeChat WeChatConnectConfig `mapstructure:"wechat_connect"`
|
||||
OIDC OIDCConnectConfig `mapstructure:"oidc_connect"`
|
||||
DingTalk DingTalkConnectConfig `mapstructure:"dingtalk_connect"`
|
||||
GitHubOAuth EmailOAuthProviderConfig `mapstructure:"github_oauth"`
|
||||
GoogleOAuth EmailOAuthProviderConfig `mapstructure:"google_oauth"`
|
||||
Default DefaultConfig `mapstructure:"default"`
|
||||
@ -250,6 +251,47 @@ type OIDCConnectConfig struct {
|
||||
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
|
||||
}
|
||||
|
||||
type DingTalkConnectConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
ClientID string `mapstructure:"client_id"`
|
||||
ClientSecret string `mapstructure:"client_secret"`
|
||||
AuthorizeURL string `mapstructure:"authorize_url"`
|
||||
TokenURL string `mapstructure:"token_url"`
|
||||
UserInfoURL string `mapstructure:"userinfo_url"`
|
||||
Scopes string `mapstructure:"scopes"`
|
||||
RedirectURL string `mapstructure:"redirect_url"`
|
||||
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
|
||||
|
||||
// 平台底座 + 业务行为
|
||||
DingTalkAppKind string `mapstructure:"dingtalk_app_kind"` // 仅 "internal_app"(V4 fail-closed)
|
||||
AppType string `mapstructure:"app_type"` // "public" (default) | "internal"
|
||||
|
||||
// Corp 限定(none | internal_only)
|
||||
CorpRestrictionPolicy string `mapstructure:"corp_restriction_policy"`
|
||||
InternalCorpID string `mapstructure:"internal_corp_id"`
|
||||
BypassRegistration bool `mapstructure:"bypass_registration"`
|
||||
SyncCorpEmail bool `mapstructure:"sync_corp_email"`
|
||||
SyncDisplayName bool `mapstructure:"sync_display_name"`
|
||||
SyncDept bool `mapstructure:"sync_dept"`
|
||||
SyncCorpEmailAttrKey string `mapstructure:"sync_corp_email_attr_key"`
|
||||
SyncDisplayNameAttrKey string `mapstructure:"sync_display_name_attr_key"`
|
||||
SyncDeptAttrKey string `mapstructure:"sync_dept_attr_key"`
|
||||
SyncCorpEmailAttrName string `mapstructure:"sync_corp_email_attr_name"`
|
||||
SyncDisplayNameAttrName string `mapstructure:"sync_display_name_attr_name"`
|
||||
SyncDeptAttrName string `mapstructure:"sync_dept_attr_name"`
|
||||
|
||||
// 邮箱 + Username
|
||||
RequireEmail bool `mapstructure:"require_email"`
|
||||
UsernameOverwritePolicy string `mapstructure:"username_overwrite_policy"`
|
||||
|
||||
// Attribute(私有版扩展点;开源版仅声明)
|
||||
UsernameAttributeKey string `mapstructure:"username_attribute_key"`
|
||||
EnableAttributeMatching bool `mapstructure:"enable_attribute_matching"`
|
||||
EnableAttributeSync bool `mapstructure:"enable_attribute_sync"`
|
||||
AttributeSyncFields []string `mapstructure:"attribute_sync_fields"`
|
||||
AttributeSyncOverwritePolicy string `mapstructure:"attribute_sync_overwrite_policy"`
|
||||
}
|
||||
|
||||
type EmailOAuthProviderConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
ClientID string `mapstructure:"client_id"`
|
||||
@ -1639,6 +1681,19 @@ func setDefaults() {
|
||||
viper.SetDefault("oidc_connect.userinfo_id_path", "")
|
||||
viper.SetDefault("oidc_connect.userinfo_username_path", "")
|
||||
|
||||
// DingTalk Connect OAuth 登录
|
||||
viper.SetDefault("dingtalk_connect.enabled", false)
|
||||
viper.SetDefault("dingtalk_connect.authorize_url", "https://login.dingtalk.com/oauth2/auth")
|
||||
viper.SetDefault("dingtalk_connect.token_url", "https://api.dingtalk.com/v1.0/oauth2/userAccessToken")
|
||||
viper.SetDefault("dingtalk_connect.userinfo_url", "https://api.dingtalk.com/v1.0/contact/users/me")
|
||||
viper.SetDefault("dingtalk_connect.scopes", "openid")
|
||||
viper.SetDefault("dingtalk_connect.frontend_redirect_url", "/auth/dingtalk/callback")
|
||||
viper.SetDefault("dingtalk_connect.dingtalk_app_kind", "internal_app")
|
||||
viper.SetDefault("dingtalk_connect.app_type", "public")
|
||||
viper.SetDefault("dingtalk_connect.corp_restriction_policy", "none")
|
||||
viper.SetDefault("dingtalk_connect.require_email", true)
|
||||
viper.SetDefault("dingtalk_connect.username_overwrite_policy", "if_empty")
|
||||
|
||||
// Database
|
||||
viper.SetDefault("database.host", "localhost")
|
||||
viper.SetDefault("database.port", 5432)
|
||||
@ -2767,6 +2822,9 @@ func (c *Config) Validate() error {
|
||||
if c.Concurrency.PingInterval < 5 || c.Concurrency.PingInterval > 30 {
|
||||
return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds")
|
||||
}
|
||||
if err := ValidateDingTalkConfig(c.DingTalk); err != nil {
|
||||
return fmt.Errorf("dingtalk_connect: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
30
backend/internal/config/validate_dingtalk.go
Normal file
30
backend/internal/config/validate_dingtalk.go
Normal file
@ -0,0 +1,30 @@
|
||||
// Package config 包含钉钉连接配置的校验逻辑。
|
||||
//
|
||||
// internal_only 模式安全模型(方案 A):
|
||||
// 不再要求 admin 填写 InternalCorpID 做二次 corpID 比对。
|
||||
// 安全边界由钉钉"企业内部应用"类型本身保证——只有应用所属企业的员工才能完成 OAuth,
|
||||
// 因此 ValidateDingTalkConfig 只要求 app_type=internal(V1),不再要求 InternalCorpID 非空(原 V3 已删除)。
|
||||
// InternalCorpID 字段保留,admin 可选填;若填写,checkDingTalkCorpAllowed 不会使用它做约束。
|
||||
package config
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrDingTalkV1AppTypeMismatch = errors.New("dingtalk: internal_only requires app_type=internal")
|
||||
ErrDingTalkV4InvalidAppKind = errors.New("dingtalk: dingtalk_app_kind must be internal_app")
|
||||
)
|
||||
|
||||
func ValidateDingTalkConfig(cfg DingTalkConnectConfig) error {
|
||||
if !cfg.Enabled {
|
||||
return nil
|
||||
}
|
||||
if cfg.DingTalkAppKind != "internal_app" {
|
||||
return ErrDingTalkV4InvalidAppKind
|
||||
}
|
||||
if cfg.CorpRestrictionPolicy == "internal_only" {
|
||||
if cfg.AppType != "internal" {
|
||||
return ErrDingTalkV1AppTypeMismatch
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
53
backend/internal/config/validate_dingtalk_test.go
Normal file
53
backend/internal/config/validate_dingtalk_test.go
Normal file
@ -0,0 +1,53 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateDingTalkConfig_Disabled_Skip(t *testing.T) {
|
||||
require.NoError(t, ValidateDingTalkConfig(DingTalkConnectConfig{Enabled: false}))
|
||||
}
|
||||
|
||||
func TestValidateDingTalkConfig_V4_DingTalkAppKind(t *testing.T) {
|
||||
err := ValidateDingTalkConfig(DingTalkConnectConfig{
|
||||
Enabled: true,
|
||||
DingTalkAppKind: "third_party_enterprise_app",
|
||||
CorpRestrictionPolicy: "none",
|
||||
})
|
||||
require.ErrorIs(t, err, ErrDingTalkV4InvalidAppKind)
|
||||
}
|
||||
|
||||
func TestValidateDingTalkConfig_V1_InternalOnlyRequiresInternalAppType(t *testing.T) {
|
||||
err := ValidateDingTalkConfig(DingTalkConnectConfig{
|
||||
Enabled: true,
|
||||
DingTalkAppKind: "internal_app",
|
||||
AppType: "public",
|
||||
CorpRestrictionPolicy: "internal_only",
|
||||
InternalCorpID: "dingABC",
|
||||
})
|
||||
require.ErrorIs(t, err, ErrDingTalkV1AppTypeMismatch)
|
||||
}
|
||||
|
||||
// TestValidateDingTalkConfig_V3_InternalOnlyAllowsEmptyCorpID 验证方案 A:
|
||||
// internal_only 策略下,InternalCorpID="" 应通过校验(企业隔离由钉钉 AppType=internal 保证)。
|
||||
func TestValidateDingTalkConfig_V3_InternalOnlyAllowsEmptyCorpID(t *testing.T) {
|
||||
err := ValidateDingTalkConfig(DingTalkConnectConfig{
|
||||
Enabled: true,
|
||||
DingTalkAppKind: "internal_app",
|
||||
AppType: "internal",
|
||||
CorpRestrictionPolicy: "internal_only",
|
||||
InternalCorpID: "",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestValidateDingTalkConfig_HappyPath_None(t *testing.T) {
|
||||
require.NoError(t, ValidateDingTalkConfig(DingTalkConnectConfig{
|
||||
Enabled: true,
|
||||
DingTalkAppKind: "internal_app",
|
||||
AppType: "public",
|
||||
CorpRestrictionPolicy: "none",
|
||||
}))
|
||||
}
|
||||
@ -44,6 +44,9 @@ type DataProxy struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// DataAccount 是管理员显式备份导出使用的账号结构,故意不走 dto.Account 的脱敏路径,
|
||||
// Credentials 原文返回。这是"管理员备份"这一显式行为的一部分;如未来需要导出脱敏版本,
|
||||
// 应新增独立结构而非修改这里。
|
||||
type DataAccount struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes,omitempty"`
|
||||
|
||||
@ -1656,7 +1656,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) {
|
||||
}
|
||||
|
||||
// GetUsage handles getting account usage information
|
||||
// GET /api/v1/admin/accounts/:id/usage?source=passive|active
|
||||
// GET /api/v1/admin/accounts/:id/usage?source=passive|active&force=true
|
||||
func (h *AccountHandler) GetUsage(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
@ -1665,12 +1665,13 @@ func (h *AccountHandler) GetUsage(c *gin.Context) {
|
||||
}
|
||||
|
||||
source := c.DefaultQuery("source", "active")
|
||||
force := c.Query("force") == "true"
|
||||
|
||||
var usage *service.UsageInfo
|
||||
if source == "passive" {
|
||||
usage, err = h.accountUsageService.GetPassiveUsage(c.Request.Context(), accountID)
|
||||
} else {
|
||||
usage, err = h.accountUsageService.GetUsage(c.Request.Context(), accountID)
|
||||
usage, err = h.accountUsageService.GetUsage(c.Request.Context(), accountID, force)
|
||||
}
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@ -2022,6 +2023,48 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) {
|
||||
response.Success(c, models)
|
||||
}
|
||||
|
||||
// SyncUpstreamModels handles syncing live supported models from an account's upstream.
|
||||
// POST /api/v1/admin/accounts/:id/models/sync-upstream
|
||||
func (h *AccountHandler) SyncUpstreamModels(c *gin.Context) {
|
||||
accountID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid account ID")
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.adminService.GetAccount(c.Request.Context(), accountID)
|
||||
if err != nil {
|
||||
response.NotFound(c, "Account not found")
|
||||
return
|
||||
}
|
||||
|
||||
if h.accountTestService == nil {
|
||||
response.InternalError(c, "Account test service is not configured")
|
||||
return
|
||||
}
|
||||
|
||||
models, err := h.accountTestService.FetchUpstreamSupportedModels(c.Request.Context(), account)
|
||||
if err != nil {
|
||||
var syncErr *service.UpstreamModelSyncError
|
||||
if errors.As(err, &syncErr) {
|
||||
switch syncErr.Kind {
|
||||
case service.UpstreamModelSyncErrorConfiguration, service.UpstreamModelSyncErrorUnsupported:
|
||||
response.BadRequest(c, syncErr.SafeMessage())
|
||||
default:
|
||||
slog.Warn("sync_upstream_models_failed", "account_id", accountID, "kind", syncErr.Kind)
|
||||
response.Error(c, http.StatusBadGateway, syncErr.SafeMessage())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
slog.Warn("sync_upstream_models_failed", "account_id", accountID)
|
||||
response.Error(c, http.StatusBadGateway, "Failed to sync upstream models from upstream")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"models": models})
|
||||
}
|
||||
|
||||
// SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account
|
||||
// POST /api/v1/admin/accounts/:id/set-privacy
|
||||
func (h *AccountHandler) SetPrivacy(c *gin.Context) {
|
||||
|
||||
@ -3,10 +3,14 @@ package admin
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -33,6 +37,40 @@ func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine {
|
||||
return router
|
||||
}
|
||||
|
||||
type syncUpstreamHTTPUpstream struct {
|
||||
resp *http.Response
|
||||
err error
|
||||
}
|
||||
|
||||
func (u *syncUpstreamHTTPUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
if u.err != nil {
|
||||
return nil, u.err
|
||||
}
|
||||
return u.resp, nil
|
||||
}
|
||||
|
||||
func (u *syncUpstreamHTTPUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) {
|
||||
return u.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
func setupSyncUpstreamModelsRouter(adminSvc service.AdminService, upstream service.HTTPUpstream) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
accountTestSvc := service.NewAccountTestService(
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
upstream,
|
||||
&config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}},
|
||||
nil,
|
||||
)
|
||||
handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, accountTestSvc, nil, nil, nil, nil, nil)
|
||||
router.POST("/api/v1/admin/accounts/:id/models/sync-upstream", handler.SyncUpstreamModels)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
@ -103,3 +141,58 @@ func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefau
|
||||
require.NotEmpty(t, resp.Data)
|
||||
require.NotEqual(t, "gpt-5", resp.Data[0].ID)
|
||||
}
|
||||
|
||||
func TestAccountHandlerSyncUpstreamModels_ConfigErrorReturnsBadRequest(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 44,
|
||||
Name: "openai-apikey-missing-key",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"base_url": "https://openai.example.com/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
router := setupSyncUpstreamModelsRouter(svc, &syncUpstreamHTTPUpstream{})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/44/models/sync-upstream", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "No OpenAI API key is available")
|
||||
}
|
||||
|
||||
func TestAccountHandlerSyncUpstreamModels_UpstreamErrorDoesNotExposeBody(t *testing.T) {
|
||||
svc := &availableModelsAdminService{
|
||||
stubAdminService: newStubAdminService(),
|
||||
account: service.Account{
|
||||
ID: 45,
|
||||
Name: "openai-apikey-upstream-error",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "openai-key",
|
||||
"base_url": "https://openai.example.com/v1",
|
||||
},
|
||||
},
|
||||
}
|
||||
upstream := &syncUpstreamHTTPUpstream{resp: &http.Response{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":"SECRET_TOKEN should not be exposed"}`)),
|
||||
}}
|
||||
router := setupSyncUpstreamModelsRouter(svc, upstream)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/45/models/sync-upstream", nil)
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusBadGateway, rec.Code)
|
||||
require.Contains(t, rec.Body.String(), "Upstream model list request failed with HTTP 502")
|
||||
require.NotContains(t, rec.Body.String(), "SECRET_TOKEN")
|
||||
}
|
||||
|
||||
@ -17,11 +17,12 @@ import (
|
||||
type ChannelHandler struct {
|
||||
channelService *service.ChannelService
|
||||
billingService *service.BillingService
|
||||
pricingService *service.PricingService
|
||||
}
|
||||
|
||||
// NewChannelHandler creates a new admin channel handler
|
||||
func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService) *ChannelHandler {
|
||||
return &ChannelHandler{channelService: channelService, billingService: billingService}
|
||||
func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService, pricingService *service.PricingService) *ChannelHandler {
|
||||
return &ChannelHandler{channelService: channelService, billingService: billingService, pricingService: pricingService}
|
||||
}
|
||||
|
||||
// --- Request / Response types ---
|
||||
@ -500,3 +501,34 @@ func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) {
|
||||
"image_output_price": pricing.ImageOutputPricePerToken,
|
||||
})
|
||||
}
|
||||
|
||||
// platformToLiteLLMProvider maps a channel platform name to the corresponding
|
||||
// LiteLLM provider string used as the key in the pricing catalog.
|
||||
var platformToLiteLLMProvider = map[string]string{
|
||||
service.PlatformAnthropic: "anthropic",
|
||||
service.PlatformOpenAI: "openai",
|
||||
service.PlatformGemini: "google",
|
||||
service.PlatformAntigravity: "anthropic",
|
||||
}
|
||||
|
||||
// SyncPricingModels 返回 LiteLLM 定价目录中指定平台的最新模型列表
|
||||
// GET /api/v1/admin/channels/pricing/sync-models?platform=anthropic
|
||||
func (h *ChannelHandler) SyncPricingModels(c *gin.Context) {
|
||||
platform := strings.ToLower(strings.TrimSpace(c.Query("platform")))
|
||||
if platform == "" {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("MISSING_PARAMETER", "platform parameter is required").
|
||||
WithMetadata(map[string]string{"param": "platform"}))
|
||||
return
|
||||
}
|
||||
|
||||
provider, ok := platformToLiteLLMProvider[platform]
|
||||
if !ok {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("UNSUPPORTED_PLATFORM",
|
||||
fmt.Sprintf("unsupported platform: %s", platform)).
|
||||
WithMetadata(map[string]string{"param": "platform"}))
|
||||
return
|
||||
}
|
||||
|
||||
models := h.pricingService.ListModelNamesByProvider(provider)
|
||||
response.Success(c, gin.H{"models": models})
|
||||
}
|
||||
|
||||
@ -3,10 +3,14 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -416,3 +420,58 @@ func TestPricingRequestToService_NilPriceFields(t *testing.T) {
|
||||
require.Nil(t, r.ImageOutputPrice)
|
||||
require.Nil(t, r.PerRequestPrice)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 3. SyncPricingModels handler
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func setupSyncPricingModelsRouter(pricingSvc *service.PricingService) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
h := &ChannelHandler{pricingService: pricingSvc}
|
||||
router.GET("/channels/pricing/sync-models", h.SyncPricingModels)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestSyncPricingModels_MissingPlatform(t *testing.T) {
|
||||
svc := service.NewPricingService(nil, nil)
|
||||
router := setupSyncPricingModelsRouter(svc)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func TestSyncPricingModels_UnsupportedPlatform(t *testing.T) {
|
||||
svc := service.NewPricingService(nil, nil)
|
||||
router := setupSyncPricingModelsRouter(svc)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models?platform=unknown", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
func TestSyncPricingModels_ValidPlatform_EmptyService(t *testing.T) {
|
||||
svc := service.NewPricingService(nil, nil)
|
||||
router := setupSyncPricingModelsRouter(svc)
|
||||
|
||||
for _, platform := range []string{"anthropic", "openai", "gemini", "antigravity"} {
|
||||
req := httptest.NewRequest(http.MethodGet, "/channels/pricing/sync-models?platform="+platform, nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code, "platform=%s", platform)
|
||||
|
||||
var body struct {
|
||||
Data struct {
|
||||
Models []string `json:"models"`
|
||||
} `json:"data"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||
require.NotNil(t, body.Data.Models, "models must not be null for platform=%s", platform)
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,6 +38,7 @@ func NewChannelMonitorHandler(monitorService *service.ChannelMonitorService) *Ch
|
||||
type channelMonitorCreateRequest struct {
|
||||
Name string `json:"name" binding:"required,max=100"`
|
||||
Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"`
|
||||
APIMode string `json:"api_mode" binding:"omitempty,oneof=chat_completions responses"`
|
||||
Endpoint string `json:"endpoint" binding:"required,max=500"`
|
||||
APIKey string `json:"api_key" binding:"required,max=2000"`
|
||||
PrimaryModel string `json:"primary_model" binding:"required,max=200"`
|
||||
@ -54,6 +55,7 @@ type channelMonitorCreateRequest struct {
|
||||
type channelMonitorUpdateRequest struct {
|
||||
Name *string `json:"name" binding:"omitempty,max=100"`
|
||||
Provider *string `json:"provider" binding:"omitempty,oneof=openai anthropic gemini"`
|
||||
APIMode *string `json:"api_mode" binding:"omitempty,oneof=chat_completions responses"`
|
||||
Endpoint *string `json:"endpoint" binding:"omitempty,max=500"`
|
||||
APIKey *string `json:"api_key" binding:"omitempty,max=2000"`
|
||||
PrimaryModel *string `json:"primary_model" binding:"omitempty,max=200"`
|
||||
@ -72,6 +74,7 @@ type channelMonitorResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
APIMode string `json:"api_mode"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
APIKeyMasked string `json:"api_key_masked"`
|
||||
APIKeyDecryptFailed bool `json:"api_key_decrypt_failed"`
|
||||
@ -138,6 +141,7 @@ func channelMonitorToResponse(m *service.ChannelMonitor) *channelMonitorResponse
|
||||
ID: m.ID,
|
||||
Name: m.Name,
|
||||
Provider: m.Provider,
|
||||
APIMode: m.APIMode,
|
||||
Endpoint: m.Endpoint,
|
||||
APIKeyMasked: maskAPIKey(m.APIKey),
|
||||
APIKeyDecryptFailed: m.APIKeyDecryptFailed,
|
||||
@ -303,6 +307,7 @@ func (h *ChannelMonitorHandler) Create(c *gin.Context) {
|
||||
m, err := h.monitorService.Create(c.Request.Context(), service.ChannelMonitorCreateParams{
|
||||
Name: req.Name,
|
||||
Provider: req.Provider,
|
||||
APIMode: req.APIMode,
|
||||
Endpoint: req.Endpoint,
|
||||
APIKey: req.APIKey,
|
||||
PrimaryModel: req.PrimaryModel,
|
||||
@ -338,6 +343,7 @@ func (h *ChannelMonitorHandler) Update(c *gin.Context) {
|
||||
m, err := h.monitorService.Update(c.Request.Context(), id, service.ChannelMonitorUpdateParams{
|
||||
Name: req.Name,
|
||||
Provider: req.Provider,
|
||||
APIMode: req.APIMode,
|
||||
Endpoint: req.Endpoint,
|
||||
APIKey: req.APIKey,
|
||||
PrimaryModel: req.PrimaryModel,
|
||||
|
||||
@ -27,6 +27,7 @@ func NewChannelMonitorRequestTemplateHandler(templateService *service.ChannelMon
|
||||
type channelMonitorTemplateCreateRequest struct {
|
||||
Name string `json:"name" binding:"required,max=100"`
|
||||
Provider string `json:"provider" binding:"required,oneof=openai anthropic gemini"`
|
||||
APIMode string `json:"api_mode" binding:"omitempty,oneof=chat_completions responses"`
|
||||
Description string `json:"description" binding:"max=500"`
|
||||
ExtraHeaders map[string]string `json:"extra_headers"`
|
||||
BodyOverrideMode string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
|
||||
@ -35,6 +36,7 @@ type channelMonitorTemplateCreateRequest struct {
|
||||
|
||||
type channelMonitorTemplateUpdateRequest struct {
|
||||
Name *string `json:"name" binding:"omitempty,max=100"`
|
||||
APIMode *string `json:"api_mode" binding:"omitempty,oneof=chat_completions responses"`
|
||||
Description *string `json:"description" binding:"omitempty,max=500"`
|
||||
ExtraHeaders *map[string]string `json:"extra_headers"`
|
||||
BodyOverrideMode *string `json:"body_override_mode" binding:"omitempty,oneof=off merge replace"`
|
||||
@ -45,6 +47,7 @@ type channelMonitorTemplateResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
APIMode string `json:"api_mode"`
|
||||
Description string `json:"description"`
|
||||
ExtraHeaders map[string]string `json:"extra_headers"`
|
||||
BodyOverrideMode string `json:"body_override_mode"`
|
||||
@ -67,6 +70,7 @@ func (h *ChannelMonitorRequestTemplateHandler) toResponse(c *gin.Context, t *ser
|
||||
ID: t.ID,
|
||||
Name: t.Name,
|
||||
Provider: t.Provider,
|
||||
APIMode: t.APIMode,
|
||||
Description: t.Description,
|
||||
ExtraHeaders: headers,
|
||||
BodyOverrideMode: t.BodyOverrideMode,
|
||||
@ -93,6 +97,7 @@ func parseTemplateID(c *gin.Context) (int64, bool) {
|
||||
func (h *ChannelMonitorRequestTemplateHandler) List(c *gin.Context) {
|
||||
items, err := h.templateService.List(c.Request.Context(), service.ChannelMonitorRequestTemplateListParams{
|
||||
Provider: strings.TrimSpace(c.Query("provider")),
|
||||
APIMode: strings.TrimSpace(c.Query("api_mode")),
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@ -129,6 +134,7 @@ func (h *ChannelMonitorRequestTemplateHandler) Create(c *gin.Context) {
|
||||
t, err := h.templateService.Create(c.Request.Context(), service.ChannelMonitorRequestTemplateCreateParams{
|
||||
Name: req.Name,
|
||||
Provider: req.Provider,
|
||||
APIMode: req.APIMode,
|
||||
Description: req.Description,
|
||||
ExtraHeaders: req.ExtraHeaders,
|
||||
BodyOverrideMode: req.BodyOverrideMode,
|
||||
@ -154,6 +160,7 @@ func (h *ChannelMonitorRequestTemplateHandler) Update(c *gin.Context) {
|
||||
}
|
||||
t, err := h.templateService.Update(c.Request.Context(), id, service.ChannelMonitorRequestTemplateUpdateParams{
|
||||
Name: req.Name,
|
||||
APIMode: req.APIMode,
|
||||
Description: req.Description,
|
||||
ExtraHeaders: req.ExtraHeaders,
|
||||
BodyOverrideMode: req.BodyOverrideMode,
|
||||
@ -209,6 +216,7 @@ type associatedMonitorBriefResponse struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Provider string `json:"provider"`
|
||||
APIMode string `json:"api_mode"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
@ -227,7 +235,7 @@ func (h *ChannelMonitorRequestTemplateHandler) AssociatedMonitors(c *gin.Context
|
||||
out := make([]associatedMonitorBriefResponse, 0, len(items))
|
||||
for _, m := range items {
|
||||
out = append(out, associatedMonitorBriefResponse{
|
||||
ID: m.ID, Name: m.Name, Provider: m.Provider, Enabled: m.Enabled,
|
||||
ID: m.ID, Name: m.Name, Provider: m.Provider, APIMode: m.APIMode, Enabled: m.Enabled,
|
||||
})
|
||||
}
|
||||
response.Success(c, gin.H{"items": out})
|
||||
|
||||
@ -46,6 +46,8 @@ type contentModerationConfigRequest struct {
|
||||
HitRetentionDays *int `json:"hit_retention_days"`
|
||||
NonHitRetentionDays *int `json:"non_hit_retention_days"`
|
||||
PreHashCheckEnabled *bool `json:"pre_hash_check_enabled"`
|
||||
BlockedKeywords *[]string `json:"blocked_keywords"`
|
||||
KeywordBlockingMode *string `json:"keyword_blocking_mode"`
|
||||
}
|
||||
|
||||
type contentModerationAPIKeyTestRequest struct {
|
||||
@ -103,6 +105,8 @@ func (h *ContentModerationHandler) UpdateConfig(c *gin.Context) {
|
||||
HitRetentionDays: req.HitRetentionDays,
|
||||
NonHitRetentionDays: req.NonHitRetentionDays,
|
||||
PreHashCheckEnabled: req.PreHashCheckEnabled,
|
||||
BlockedKeywords: req.BlockedKeywords,
|
||||
KeywordBlockingMode: req.KeywordBlockingMode,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
|
||||
@ -546,9 +546,14 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// cacheKey 必须包含当日日期,否则跨午夜后 30s 内会复用昨天的 "today_*" 结果。
|
||||
keyRaw, _ := json.Marshal(struct {
|
||||
V int `json:"v"`
|
||||
Day string `json:"day"`
|
||||
UserIDs []int64 `json:"user_ids"`
|
||||
}{
|
||||
V: 2, // bump 当响应结构变化(如加入 by_platform 时)
|
||||
Day: timezone.Today().Format("2006-01-02"),
|
||||
UserIDs: userIDs,
|
||||
})
|
||||
cacheKey := string(keyRaw)
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -386,79 +384,6 @@ func (h *OpsHandler) ListRequestErrorUpstreamErrors(c *gin.Context) {
|
||||
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
|
||||
}
|
||||
|
||||
// RetryRequestErrorClient retries the client request based on stored request body.
|
||||
// POST /api/v1/admin/ops/request-errors/:id/retry-client
|
||||
func (h *OpsHandler) RetryRequestErrorClient(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
idStr := strings.TrimSpace(c.Param("id"))
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid error id")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeClient, nil)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// RetryRequestErrorUpstreamEvent retries a specific upstream attempt using captured upstream_request_body.
|
||||
// POST /api/v1/admin/ops/request-errors/:id/upstream-errors/:idx/retry
|
||||
func (h *OpsHandler) RetryRequestErrorUpstreamEvent(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
idStr := strings.TrimSpace(c.Param("id"))
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid error id")
|
||||
return
|
||||
}
|
||||
|
||||
idxStr := strings.TrimSpace(c.Param("idx"))
|
||||
idx, err := strconv.Atoi(idxStr)
|
||||
if err != nil || idx < 0 {
|
||||
response.BadRequest(c, "Invalid upstream idx")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.opsService.RetryUpstreamEvent(c.Request.Context(), subject.UserID, id, idx)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ResolveRequestError toggles resolved status.
|
||||
// PUT /api/v1/admin/ops/request-errors/:id/resolve
|
||||
func (h *OpsHandler) ResolveRequestError(c *gin.Context) {
|
||||
@ -566,39 +491,6 @@ func (h *OpsHandler) GetUpstreamError(c *gin.Context) {
|
||||
h.GetErrorLogByID(c)
|
||||
}
|
||||
|
||||
// RetryUpstreamError retries upstream error using the original account_id.
|
||||
// POST /api/v1/admin/ops/upstream-errors/:id/retry
|
||||
func (h *OpsHandler) RetryUpstreamError(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
idStr := strings.TrimSpace(c.Param("id"))
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid error id")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeUpstream, nil)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ResolveUpstreamError toggles resolved status.
|
||||
// PUT /api/v1/admin/ops/upstream-errors/:id/resolve
|
||||
func (h *OpsHandler) ResolveUpstreamError(c *gin.Context) {
|
||||
@ -708,106 +600,10 @@ func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
|
||||
response.Paginated(c, out.Items, out.Total, out.Page, out.PageSize)
|
||||
}
|
||||
|
||||
type opsRetryRequest struct {
|
||||
Mode string `json:"mode"`
|
||||
PinnedAccountID *int64 `json:"pinned_account_id"`
|
||||
Force bool `json:"force"`
|
||||
}
|
||||
|
||||
type opsResolveRequest struct {
|
||||
Resolved bool `json:"resolved"`
|
||||
}
|
||||
|
||||
// RetryErrorRequest retries a failed request using stored request_body.
|
||||
// POST /api/v1/admin/ops/errors/:id/retry
|
||||
func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware.GetAuthSubjectFromContext(c)
|
||||
if !ok || subject.UserID <= 0 {
|
||||
response.Error(c, http.StatusUnauthorized, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
idStr := strings.TrimSpace(c.Param("id"))
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid error id")
|
||||
return
|
||||
}
|
||||
|
||||
req := opsRetryRequest{Mode: service.OpsRetryModeClient}
|
||||
if err := c.ShouldBindJSON(&req); err != nil && !errors.Is(err, io.EOF) {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Mode) == "" {
|
||||
req.Mode = service.OpsRetryModeClient
|
||||
}
|
||||
|
||||
// Force flag is currently a UI-level acknowledgement. Server may still enforce safety constraints.
|
||||
_ = req.Force
|
||||
|
||||
// Legacy endpoint safety: only allow retrying the client request here.
|
||||
// Upstream retries must go through the split endpoints.
|
||||
if strings.EqualFold(strings.TrimSpace(req.Mode), service.OpsRetryModeUpstream) {
|
||||
response.BadRequest(c, "upstream retry is not supported on this endpoint")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, req.Mode, req.PinnedAccountID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ListRetryAttempts lists retry attempts for an error log.
|
||||
// GET /api/v1/admin/ops/errors/:id/retries
|
||||
func (h *OpsHandler) ListRetryAttempts(c *gin.Context) {
|
||||
if h.opsService == nil {
|
||||
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
|
||||
return
|
||||
}
|
||||
if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
idStr := strings.TrimSpace(c.Param("id"))
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
response.BadRequest(c, "Invalid error id")
|
||||
return
|
||||
}
|
||||
|
||||
limit := 50
|
||||
if v := strings.TrimSpace(c.Query("limit")); v != "" {
|
||||
n, err := strconv.Atoi(v)
|
||||
if err != nil || n <= 0 {
|
||||
response.BadRequest(c, "Invalid limit")
|
||||
return
|
||||
}
|
||||
limit = n
|
||||
}
|
||||
|
||||
items, err := h.opsService.ListRetryAttemptsByErrorID(c.Request.Context(), id, limit)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
response.Success(c, items)
|
||||
}
|
||||
|
||||
// UpdateErrorResolution allows manual resolve/unresolve.
|
||||
// PUT /api/v1/admin/ops/errors/:id/resolve
|
||||
func (h *OpsHandler) UpdateErrorResolution(c *gin.Context) {
|
||||
@ -839,7 +635,7 @@ func (h *OpsHandler) UpdateErrorResolution(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
uid := subject.UserID
|
||||
if err := h.opsService.UpdateErrorResolution(c.Request.Context(), id, req.Resolved, &uid, nil); err != nil {
|
||||
if err := h.opsService.UpdateErrorResolution(c.Request.Context(), id, req.Resolved, &uid); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
@ -33,23 +34,51 @@ func NewRedeemHandler(adminService service.AdminService, redeemService *service.
|
||||
|
||||
// GenerateRedeemCodesRequest represents generate redeem codes request
|
||||
type GenerateRedeemCodesRequest struct {
|
||||
Count int `json:"count" binding:"required,min=1,max=100"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
||||
Value float64 `json:"value"`
|
||||
GroupID *int64 `json:"group_id"` // 订阅类型必填
|
||||
ValidityDays int `json:"validity_days"` // 订阅类型使用,正数增加/负数退款扣减
|
||||
Count int `json:"count" binding:"required,min=1,max=100"`
|
||||
Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
|
||||
Value float64 `json:"value"`
|
||||
GroupID *int64 `json:"group_id"` // 订阅类型必填
|
||||
ValidityDays int `json:"validity_days"` // 订阅类型使用,正数增加/负数退款扣减
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
ExpiresInDays *int `json:"expires_in_days" binding:"omitempty,min=1,max=3650"`
|
||||
}
|
||||
|
||||
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
|
||||
// Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance)。
|
||||
type CreateAndRedeemCodeRequest struct {
|
||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容)
|
||||
Value float64 `json:"value" binding:"required"`
|
||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||
GroupID *int64 `json:"group_id"` // subscription 类型必填
|
||||
ValidityDays int `json:"validity_days"` // subscription 类型:正数增加,负数退款扣减
|
||||
Notes string `json:"notes"`
|
||||
Code string `json:"code" binding:"required,min=3,max=128"`
|
||||
Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容)
|
||||
Value float64 `json:"value" binding:"required"`
|
||||
UserID int64 `json:"user_id" binding:"required,gt=0"`
|
||||
GroupID *int64 `json:"group_id"` // subscription 类型必填
|
||||
ValidityDays int `json:"validity_days"` // subscription 类型:正数增加,负数退款扣减
|
||||
Notes string `json:"notes"`
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
ExpiresInDays *int `json:"expires_in_days" binding:"omitempty,min=1,max=3650"`
|
||||
}
|
||||
|
||||
func resolveRedeemCodeExpiresAt(expiresAt *time.Time, expiresInDays *int) (*time.Time, error) {
|
||||
if expiresAt != nil && expiresInDays != nil {
|
||||
return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRY_CONFLICT", "expires_at and expires_in_days cannot both be set")
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if expiresInDays != nil {
|
||||
if *expiresInDays <= 0 {
|
||||
return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRES_IN_DAYS_INVALID", "expires_in_days must be greater than zero")
|
||||
}
|
||||
expires := now.AddDate(0, 0, *expiresInDays)
|
||||
return &expires, nil
|
||||
}
|
||||
if expiresAt == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
expires := expiresAt.UTC()
|
||||
if !expires.After(now) {
|
||||
return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRES_AT_INVALID", "expires_at must be in the future")
|
||||
}
|
||||
return &expires, nil
|
||||
}
|
||||
|
||||
// List handles listing all redeem codes with pagination
|
||||
@ -107,6 +136,12 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt, err := resolveRedeemCodeExpiresAt(req.ExpiresAt, req.ExpiresInDays)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{
|
||||
Count: req.Count,
|
||||
@ -114,6 +149,7 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
|
||||
Value: req.Value,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
ExpiresAt: expiresAt,
|
||||
})
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
@ -158,6 +194,12 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
expiresAt, err := resolveRedeemCodeExpiresAt(req.ExpiresAt, req.ExpiresInDays)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
|
||||
existing, err := h.redeemService.GetByCode(ctx, req.Code)
|
||||
if err == nil {
|
||||
@ -175,6 +217,7 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
|
||||
Notes: req.Notes,
|
||||
GroupID: req.GroupID,
|
||||
ValidityDays: req.ValidityDays,
|
||||
ExpiresAt: expiresAt,
|
||||
})
|
||||
if createErr != nil {
|
||||
// Unique code race: if code now exists, use idempotent semantics by used_by.
|
||||
@ -199,6 +242,9 @@ func (h *RedeemHandler) resolveCreateAndRedeemExisting(ctx context.Context, exis
|
||||
}
|
||||
|
||||
// If previous run created the code but crashed before redeem, redeem it now.
|
||||
if existing.IsExpired() {
|
||||
return nil, service.ErrRedeemCodeExpired
|
||||
}
|
||||
if existing.CanUse() {
|
||||
redeemed, err := h.redeemService.Redeem(ctx, userID, existing.Code)
|
||||
if err == nil {
|
||||
@ -321,7 +367,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
writer := csv.NewWriter(&buf)
|
||||
|
||||
// Write header
|
||||
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "created_at"}); err != nil {
|
||||
if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "expires_at", "created_at"}); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
return
|
||||
}
|
||||
@ -340,6 +386,10 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
if code.UsedAt != nil {
|
||||
usedAt = code.UsedAt.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
expiresAt := ""
|
||||
if code.ExpiresAt != nil {
|
||||
expiresAt = code.ExpiresAt.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
if err := writer.Write([]string{
|
||||
fmt.Sprintf("%d", code.ID),
|
||||
code.Code,
|
||||
@ -349,6 +399,7 @@ func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
usedBy,
|
||||
usedByEmail,
|
||||
usedAt,
|
||||
expiresAt,
|
||||
code.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
}); err != nil {
|
||||
response.InternalError(c, "Failed to export redeem codes: "+err.Error())
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
@ -139,3 +140,33 @@ func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) {
|
||||
assert.NotEqual(t, http.StatusBadRequest, code,
|
||||
"balance type should not require group_id or validity_days")
|
||||
}
|
||||
|
||||
func TestResolveRedeemCodeExpiresAt_FromDays(t *testing.T) {
|
||||
days := 3
|
||||
expiresAt, err := resolveRedeemCodeExpiresAt(nil, &days)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, expiresAt)
|
||||
require.WithinDuration(t, time.Now().UTC().AddDate(0, 0, days), *expiresAt, 2*time.Second)
|
||||
}
|
||||
|
||||
func TestResolveRedeemCodeExpiresAt_RejectsPastAbsoluteTime(t *testing.T) {
|
||||
past := time.Now().UTC().Add(-time.Minute)
|
||||
expiresAt, err := resolveRedeemCodeExpiresAt(&past, nil)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, expiresAt)
|
||||
}
|
||||
|
||||
func TestResolveRedeemCodeExpiresAt_RejectsNonPositiveDays(t *testing.T) {
|
||||
days := 0
|
||||
expiresAt, err := resolveRedeemCodeExpiresAt(nil, &days)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, expiresAt)
|
||||
}
|
||||
|
||||
func TestResolveRedeemCodeExpiresAt_RejectsConflictingInputs(t *testing.T) {
|
||||
future := time.Now().UTC().Add(time.Hour)
|
||||
days := 3
|
||||
expiresAt, err := resolveRedeemCodeExpiresAt(&future, &days)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, expiresAt)
|
||||
}
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@ -60,10 +62,11 @@ type SettingHandler struct {
|
||||
opsService *service.OpsService
|
||||
paymentConfigService *service.PaymentConfigService
|
||||
paymentService *service.PaymentService
|
||||
userAttributeService *service.UserAttributeService
|
||||
}
|
||||
|
||||
// NewSettingHandler 创建系统设置处理器
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService) *SettingHandler {
|
||||
func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService, userAttributeService *service.UserAttributeService) *SettingHandler {
|
||||
return &SettingHandler{
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
@ -71,6 +74,7 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser
|
||||
opsService: opsService,
|
||||
paymentConfigService: paymentConfigService,
|
||||
paymentService: paymentService,
|
||||
userAttributeService: userAttributeService,
|
||||
}
|
||||
}
|
||||
|
||||
@ -135,6 +139,22 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
|
||||
LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
|
||||
DingTalkConnectEnabled: settings.DingTalkConnectEnabled,
|
||||
DingTalkConnectClientID: settings.DingTalkConnectClientID,
|
||||
DingTalkConnectClientSecretConfigured: settings.DingTalkConnectClientSecretConfigured,
|
||||
DingTalkConnectRedirectURL: settings.DingTalkConnectRedirectURL,
|
||||
DingTalkConnectCorpRestrictionPolicy: settings.DingTalkConnectCorpRestrictionPolicy,
|
||||
DingTalkConnectInternalCorpID: settings.DingTalkConnectInternalCorpID,
|
||||
DingTalkConnectBypassRegistration: settings.DingTalkConnectBypassRegistration,
|
||||
DingTalkConnectSyncCorpEmail: settings.DingTalkConnectSyncCorpEmail,
|
||||
DingTalkConnectSyncDisplayName: settings.DingTalkConnectSyncDisplayName,
|
||||
DingTalkConnectSyncDept: settings.DingTalkConnectSyncDept,
|
||||
DingTalkConnectSyncCorpEmailAttrKey: settings.DingTalkConnectSyncCorpEmailAttrKey,
|
||||
DingTalkConnectSyncDisplayNameAttrKey: settings.DingTalkConnectSyncDisplayNameAttrKey,
|
||||
DingTalkConnectSyncDeptAttrKey: settings.DingTalkConnectSyncDeptAttrKey,
|
||||
DingTalkConnectSyncCorpEmailAttrName: settings.DingTalkConnectSyncCorpEmailAttrName,
|
||||
DingTalkConnectSyncDisplayNameAttrName: settings.DingTalkConnectSyncDisplayNameAttrName,
|
||||
DingTalkConnectSyncDeptAttrName: settings.DingTalkConnectSyncDeptAttrName,
|
||||
WeChatConnectEnabled: settings.WeChatConnectEnabled,
|
||||
WeChatConnectAppID: settings.WeChatConnectAppID,
|
||||
WeChatConnectAppSecretConfigured: settings.WeChatConnectAppSecretConfigured,
|
||||
@ -258,6 +278,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow,
|
||||
PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit,
|
||||
PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode,
|
||||
PaymentAlipayForceQRCode: paymentCfg.AlipayForceQRCode,
|
||||
|
||||
ChannelMonitorEnabled: settings.ChannelMonitorEnabled,
|
||||
ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds,
|
||||
@ -376,6 +397,24 @@ type UpdateSettingsRequest struct {
|
||||
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
// DingTalk Connect OAuth 登录
|
||||
DingTalkConnectEnabled bool `json:"dingtalk_connect_enabled"`
|
||||
DingTalkConnectClientID string `json:"dingtalk_connect_client_id"`
|
||||
DingTalkConnectClientSecret string `json:"dingtalk_connect_client_secret"`
|
||||
DingTalkConnectRedirectURL string `json:"dingtalk_connect_redirect_url"`
|
||||
DingTalkConnectCorpRestrictionPolicy string `json:"dingtalk_connect_corp_restriction_policy"`
|
||||
DingTalkConnectInternalCorpID string `json:"dingtalk_connect_internal_corp_id"`
|
||||
DingTalkConnectBypassRegistration bool `json:"dingtalk_connect_bypass_registration"`
|
||||
DingTalkConnectSyncCorpEmail bool `json:"dingtalk_connect_sync_corp_email"`
|
||||
DingTalkConnectSyncDisplayName bool `json:"dingtalk_connect_sync_display_name"`
|
||||
DingTalkConnectSyncDept bool `json:"dingtalk_connect_sync_dept"`
|
||||
DingTalkConnectSyncCorpEmailAttrKey string `json:"dingtalk_connect_sync_corp_email_attr_key"`
|
||||
DingTalkConnectSyncDisplayNameAttrKey string `json:"dingtalk_connect_sync_display_name_attr_key"`
|
||||
DingTalkConnectSyncDeptAttrKey string `json:"dingtalk_connect_sync_dept_attr_key"`
|
||||
DingTalkConnectSyncCorpEmailAttrName string `json:"dingtalk_connect_sync_corp_email_attr_name"`
|
||||
DingTalkConnectSyncDisplayNameAttrName string `json:"dingtalk_connect_sync_display_name_attr_name"`
|
||||
DingTalkConnectSyncDeptAttrName string `json:"dingtalk_connect_sync_dept_attr_name"`
|
||||
|
||||
// WeChat Connect OAuth 登录
|
||||
WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
|
||||
WeChatConnectAppID string `json:"wechat_connect_app_id"`
|
||||
@ -446,45 +485,50 @@ type UpdateSettingsRequest struct {
|
||||
CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
|
||||
AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
|
||||
AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
|
||||
AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
|
||||
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
|
||||
AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
|
||||
AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
|
||||
AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
|
||||
AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
|
||||
AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
|
||||
AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
|
||||
AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
|
||||
AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
|
||||
AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
|
||||
AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
|
||||
AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
|
||||
AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
|
||||
AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
|
||||
AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
|
||||
AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
|
||||
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
|
||||
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
|
||||
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
|
||||
AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"`
|
||||
AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"`
|
||||
AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"`
|
||||
AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"`
|
||||
AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"`
|
||||
AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"`
|
||||
AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"`
|
||||
AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"`
|
||||
AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"`
|
||||
AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"`
|
||||
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"`
|
||||
AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"`
|
||||
AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"`
|
||||
AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"`
|
||||
DefaultUserRPMLimit int `json:"default_user_rpm_limit"`
|
||||
DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
|
||||
AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"`
|
||||
AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"`
|
||||
AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"`
|
||||
AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"`
|
||||
AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"`
|
||||
AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"`
|
||||
AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"`
|
||||
AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"`
|
||||
AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"`
|
||||
AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"`
|
||||
AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"`
|
||||
AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"`
|
||||
AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"`
|
||||
AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"`
|
||||
AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"`
|
||||
AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"`
|
||||
AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"`
|
||||
AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"`
|
||||
AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"`
|
||||
AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"`
|
||||
AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"`
|
||||
AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"`
|
||||
AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"`
|
||||
AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"`
|
||||
AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"`
|
||||
AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"`
|
||||
AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"`
|
||||
AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"`
|
||||
AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"`
|
||||
AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"`
|
||||
AuthSourceDefaultDingTalkBalance *float64 `json:"auth_source_default_dingtalk_balance"`
|
||||
AuthSourceDefaultDingTalkConcurrency *int `json:"auth_source_default_dingtalk_concurrency"`
|
||||
AuthSourceDefaultDingTalkSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_dingtalk_subscriptions"`
|
||||
AuthSourceDefaultDingTalkGrantOnSignup *bool `json:"auth_source_default_dingtalk_grant_on_signup"`
|
||||
AuthSourceDefaultDingTalkGrantOnFirstBind *bool `json:"auth_source_default_dingtalk_grant_on_first_bind"`
|
||||
ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"`
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
@ -560,6 +604,9 @@ type UpdateSettingsRequest struct {
|
||||
PaymentCancelRateLimitUnit *string `json:"payment_cancel_rate_limit_unit"`
|
||||
PaymentCancelRateLimitMode *string `json:"payment_cancel_rate_limit_window_mode"`
|
||||
|
||||
// Force Alipay mobile clients to use QR code payment instead of mobile redirect
|
||||
PaymentAlipayForceQRCode *bool `json:"payment_alipay_force_qrcode"`
|
||||
|
||||
// Channel Monitor feature switch
|
||||
ChannelMonitorEnabled *bool `json:"channel_monitor_enabled"`
|
||||
ChannelMonitorDefaultIntervalSeconds *int `json:"channel_monitor_default_interval_seconds"`
|
||||
@ -661,6 +708,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions)
|
||||
req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions)
|
||||
req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions)
|
||||
req.AuthSourceDefaultDingTalkSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultDingTalkSubscriptions)
|
||||
|
||||
// SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置
|
||||
// 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置
|
||||
@ -777,6 +825,100 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// DingTalk Connect 参数验证
|
||||
// 防御性:任何写入路径上把已废弃的 corp_restriction_policy=whitelist 入参 coerce 为 none,
|
||||
// 避免任何直连 admin API 的客户端把死值写回 DB(前端 UI 已无此选项)。
|
||||
req.DingTalkConnectCorpRestrictionPolicy = service.CoerceDingTalkCorpPolicyForWrite(req.DingTalkConnectCorpRestrictionPolicy)
|
||||
|
||||
if req.DingTalkConnectEnabled {
|
||||
req.DingTalkConnectClientID = strings.TrimSpace(req.DingTalkConnectClientID)
|
||||
req.DingTalkConnectClientSecret = strings.TrimSpace(req.DingTalkConnectClientSecret)
|
||||
req.DingTalkConnectRedirectURL = strings.TrimSpace(req.DingTalkConnectRedirectURL)
|
||||
req.DingTalkConnectCorpRestrictionPolicy = strings.TrimSpace(req.DingTalkConnectCorpRestrictionPolicy)
|
||||
req.DingTalkConnectInternalCorpID = strings.TrimSpace(req.DingTalkConnectInternalCorpID)
|
||||
|
||||
if req.DingTalkConnectClientID == "" {
|
||||
response.BadRequest(c, "DingTalk Client ID is required when enabled")
|
||||
return
|
||||
}
|
||||
if req.DingTalkConnectRedirectURL == "" {
|
||||
response.BadRequest(c, "DingTalk Redirect URL is required when enabled")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(req.DingTalkConnectRedirectURL); err != nil {
|
||||
response.BadRequest(c, "DingTalk Redirect URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
|
||||
// 如果未提供 client_secret,则保留现有值(如有)。
|
||||
if req.DingTalkConnectClientSecret == "" {
|
||||
if previousSettings.DingTalkConnectClientSecret == "" {
|
||||
response.BadRequest(c, "DingTalk Client Secret is required when enabled")
|
||||
return
|
||||
}
|
||||
req.DingTalkConnectClientSecret = previousSettings.DingTalkConnectClientSecret
|
||||
}
|
||||
|
||||
// Corp 策略校验(V1/V4 fail-closed)
|
||||
dingTalkCfg := config.DingTalkConnectConfig{
|
||||
Enabled: true,
|
||||
DingTalkAppKind: "internal_app", // 硬编码:settings 层仅支持 internal_app
|
||||
AppType: "internal", // 对于 internal_only 策略的默认值
|
||||
CorpRestrictionPolicy: req.DingTalkConnectCorpRestrictionPolicy,
|
||||
InternalCorpID: req.DingTalkConnectInternalCorpID,
|
||||
}
|
||||
// 若未填 corp_restriction_policy,保留已有配置
|
||||
if dingTalkCfg.CorpRestrictionPolicy == "" {
|
||||
dingTalkCfg.CorpRestrictionPolicy = previousSettings.DingTalkConnectCorpRestrictionPolicy
|
||||
}
|
||||
// 对于 internal_only 策略,app_type 必须为 internal(V1 校验)
|
||||
if dingTalkCfg.CorpRestrictionPolicy == "internal_only" {
|
||||
dingTalkCfg.AppType = "internal"
|
||||
} else {
|
||||
dingTalkCfg.AppType = "public"
|
||||
}
|
||||
if err := config.ValidateDingTalkConfig(dingTalkCfg); err != nil {
|
||||
response.ErrorWithDetails(c, http.StatusBadRequest, err.Error(), mapDingTalkValidateError(err), nil)
|
||||
return
|
||||
}
|
||||
|
||||
// bypass_registration 仅在 internal_only 模式下有意义;其它策略下强制为 false,
|
||||
// 防止 admin 在切换 policy 时把 bypass 残留在 DB 中(前端 UI 也已隐藏该开关)。
|
||||
if dingTalkCfg.CorpRestrictionPolicy != "internal_only" {
|
||||
req.DingTalkConnectBypassRegistration = false
|
||||
// 身份同步三开关同理:仅 internal_only 模式下有意义,其它策略强制 false。
|
||||
req.DingTalkConnectSyncCorpEmail = false
|
||||
req.DingTalkConnectSyncDisplayName = false
|
||||
req.DingTalkConnectSyncDept = false
|
||||
}
|
||||
// 身份同步目标 attr key:trimSpace + 空值 fallback 到默认值
|
||||
req.DingTalkConnectSyncCorpEmailAttrKey = strings.TrimSpace(req.DingTalkConnectSyncCorpEmailAttrKey)
|
||||
if req.DingTalkConnectSyncCorpEmailAttrKey == "" {
|
||||
req.DingTalkConnectSyncCorpEmailAttrKey = "dingtalk_email"
|
||||
}
|
||||
req.DingTalkConnectSyncDisplayNameAttrKey = strings.TrimSpace(req.DingTalkConnectSyncDisplayNameAttrKey)
|
||||
if req.DingTalkConnectSyncDisplayNameAttrKey == "" {
|
||||
req.DingTalkConnectSyncDisplayNameAttrKey = "dingtalk_name"
|
||||
}
|
||||
req.DingTalkConnectSyncDeptAttrKey = strings.TrimSpace(req.DingTalkConnectSyncDeptAttrKey)
|
||||
if req.DingTalkConnectSyncDeptAttrKey == "" {
|
||||
req.DingTalkConnectSyncDeptAttrKey = "dingtalk_department"
|
||||
}
|
||||
// 身份同步目标 attr 显示名称:trim + 空值 fallback 到默认中文名
|
||||
req.DingTalkConnectSyncCorpEmailAttrName = strings.TrimSpace(req.DingTalkConnectSyncCorpEmailAttrName)
|
||||
if req.DingTalkConnectSyncCorpEmailAttrName == "" {
|
||||
req.DingTalkConnectSyncCorpEmailAttrName = "钉钉企业邮箱"
|
||||
}
|
||||
req.DingTalkConnectSyncDisplayNameAttrName = strings.TrimSpace(req.DingTalkConnectSyncDisplayNameAttrName)
|
||||
if req.DingTalkConnectSyncDisplayNameAttrName == "" {
|
||||
req.DingTalkConnectSyncDisplayNameAttrName = "钉钉姓名"
|
||||
}
|
||||
req.DingTalkConnectSyncDeptAttrName = strings.TrimSpace(req.DingTalkConnectSyncDeptAttrName)
|
||||
if req.DingTalkConnectSyncDeptAttrName == "" {
|
||||
req.DingTalkConnectSyncDeptAttrName = "钉钉部门"
|
||||
}
|
||||
}
|
||||
|
||||
if req.WeChatConnectEnabled {
|
||||
req.WeChatConnectAppID = strings.TrimSpace(req.WeChatConnectAppID)
|
||||
req.WeChatConnectAppSecret = strings.TrimSpace(req.WeChatConnectAppSecret)
|
||||
@ -1272,113 +1414,129 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
FrontendURL: req.FrontendURL,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
LoginAgreementEnabled: req.LoginAgreementEnabled,
|
||||
LoginAgreementMode: loginAgreementMode,
|
||||
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
|
||||
LoginAgreementDocuments: loginAgreementDocuments,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
WeChatConnectEnabled: req.WeChatConnectEnabled,
|
||||
WeChatConnectAppID: req.WeChatConnectAppID,
|
||||
WeChatConnectAppSecret: req.WeChatConnectAppSecret,
|
||||
WeChatConnectOpenAppID: req.WeChatConnectOpenAppID,
|
||||
WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret,
|
||||
WeChatConnectMPAppID: req.WeChatConnectMPAppID,
|
||||
WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret,
|
||||
WeChatConnectMobileAppID: req.WeChatConnectMobileAppID,
|
||||
WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret,
|
||||
WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled,
|
||||
WeChatConnectMPEnabled: req.WeChatConnectMPEnabled,
|
||||
WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled,
|
||||
WeChatConnectMode: req.WeChatConnectMode,
|
||||
WeChatConnectScopes: req.WeChatConnectScopes,
|
||||
WeChatConnectRedirectURL: req.WeChatConnectRedirectURL,
|
||||
WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL,
|
||||
OIDCConnectEnabled: req.OIDCConnectEnabled,
|
||||
OIDCConnectProviderName: req.OIDCConnectProviderName,
|
||||
OIDCConnectClientID: req.OIDCConnectClientID,
|
||||
OIDCConnectClientSecret: req.OIDCConnectClientSecret,
|
||||
OIDCConnectIssuerURL: req.OIDCConnectIssuerURL,
|
||||
OIDCConnectDiscoveryURL: req.OIDCConnectDiscoveryURL,
|
||||
OIDCConnectAuthorizeURL: req.OIDCConnectAuthorizeURL,
|
||||
OIDCConnectTokenURL: req.OIDCConnectTokenURL,
|
||||
OIDCConnectUserInfoURL: req.OIDCConnectUserInfoURL,
|
||||
OIDCConnectJWKSURL: req.OIDCConnectJWKSURL,
|
||||
OIDCConnectScopes: req.OIDCConnectScopes,
|
||||
OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
|
||||
OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
|
||||
OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
|
||||
OIDCConnectUsePKCE: oidcUsePKCE,
|
||||
OIDCConnectValidateIDToken: oidcValidateIDToken,
|
||||
OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
|
||||
OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
|
||||
OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,
|
||||
OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
|
||||
OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
|
||||
OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
|
||||
GitHubOAuthEnabled: req.GitHubOAuthEnabled,
|
||||
GitHubOAuthClientID: req.GitHubOAuthClientID,
|
||||
GitHubOAuthClientSecret: req.GitHubOAuthClientSecret,
|
||||
GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL,
|
||||
GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL,
|
||||
GoogleOAuthEnabled: req.GoogleOAuthEnabled,
|
||||
GoogleOAuthClientID: req.GoogleOAuthClientID,
|
||||
GoogleOAuthClientSecret: req.GoogleOAuthClientSecret,
|
||||
GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL,
|
||||
GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
TableDefaultPageSize: req.TableDefaultPageSize,
|
||||
TablePageSizeOptions: req.TablePageSizeOptions,
|
||||
CustomMenuItems: customMenuJSON,
|
||||
CustomEndpoints: customEndpointsJSON,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
AffiliateRebateRate: affiliateRebateRate,
|
||||
AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: affiliateRebateDurationDays,
|
||||
AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
|
||||
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||
MaxClaudeCodeVersion: req.MaxClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: req.BackendModeEnabled,
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
FrontendURL: req.FrontendURL,
|
||||
InvitationCodeEnabled: req.InvitationCodeEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
LoginAgreementEnabled: req.LoginAgreementEnabled,
|
||||
LoginAgreementMode: loginAgreementMode,
|
||||
LoginAgreementUpdatedAt: loginAgreementUpdatedAt,
|
||||
LoginAgreementDocuments: loginAgreementDocuments,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
DingTalkConnectEnabled: req.DingTalkConnectEnabled,
|
||||
DingTalkConnectClientID: req.DingTalkConnectClientID,
|
||||
DingTalkConnectClientSecret: req.DingTalkConnectClientSecret,
|
||||
DingTalkConnectRedirectURL: req.DingTalkConnectRedirectURL,
|
||||
DingTalkConnectCorpRestrictionPolicy: req.DingTalkConnectCorpRestrictionPolicy,
|
||||
DingTalkConnectInternalCorpID: req.DingTalkConnectInternalCorpID,
|
||||
DingTalkConnectBypassRegistration: req.DingTalkConnectBypassRegistration,
|
||||
DingTalkConnectSyncCorpEmail: req.DingTalkConnectSyncCorpEmail,
|
||||
DingTalkConnectSyncDisplayName: req.DingTalkConnectSyncDisplayName,
|
||||
DingTalkConnectSyncDept: req.DingTalkConnectSyncDept,
|
||||
DingTalkConnectSyncCorpEmailAttrKey: req.DingTalkConnectSyncCorpEmailAttrKey,
|
||||
DingTalkConnectSyncDisplayNameAttrKey: req.DingTalkConnectSyncDisplayNameAttrKey,
|
||||
DingTalkConnectSyncDeptAttrKey: req.DingTalkConnectSyncDeptAttrKey,
|
||||
DingTalkConnectSyncCorpEmailAttrName: req.DingTalkConnectSyncCorpEmailAttrName,
|
||||
DingTalkConnectSyncDisplayNameAttrName: req.DingTalkConnectSyncDisplayNameAttrName,
|
||||
DingTalkConnectSyncDeptAttrName: req.DingTalkConnectSyncDeptAttrName,
|
||||
WeChatConnectEnabled: req.WeChatConnectEnabled,
|
||||
WeChatConnectAppID: req.WeChatConnectAppID,
|
||||
WeChatConnectAppSecret: req.WeChatConnectAppSecret,
|
||||
WeChatConnectOpenAppID: req.WeChatConnectOpenAppID,
|
||||
WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret,
|
||||
WeChatConnectMPAppID: req.WeChatConnectMPAppID,
|
||||
WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret,
|
||||
WeChatConnectMobileAppID: req.WeChatConnectMobileAppID,
|
||||
WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret,
|
||||
WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled,
|
||||
WeChatConnectMPEnabled: req.WeChatConnectMPEnabled,
|
||||
WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled,
|
||||
WeChatConnectMode: req.WeChatConnectMode,
|
||||
WeChatConnectScopes: req.WeChatConnectScopes,
|
||||
WeChatConnectRedirectURL: req.WeChatConnectRedirectURL,
|
||||
WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL,
|
||||
OIDCConnectEnabled: req.OIDCConnectEnabled,
|
||||
OIDCConnectProviderName: req.OIDCConnectProviderName,
|
||||
OIDCConnectClientID: req.OIDCConnectClientID,
|
||||
OIDCConnectClientSecret: req.OIDCConnectClientSecret,
|
||||
OIDCConnectIssuerURL: req.OIDCConnectIssuerURL,
|
||||
OIDCConnectDiscoveryURL: req.OIDCConnectDiscoveryURL,
|
||||
OIDCConnectAuthorizeURL: req.OIDCConnectAuthorizeURL,
|
||||
OIDCConnectTokenURL: req.OIDCConnectTokenURL,
|
||||
OIDCConnectUserInfoURL: req.OIDCConnectUserInfoURL,
|
||||
OIDCConnectJWKSURL: req.OIDCConnectJWKSURL,
|
||||
OIDCConnectScopes: req.OIDCConnectScopes,
|
||||
OIDCConnectRedirectURL: req.OIDCConnectRedirectURL,
|
||||
OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL,
|
||||
OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod,
|
||||
OIDCConnectUsePKCE: oidcUsePKCE,
|
||||
OIDCConnectValidateIDToken: oidcValidateIDToken,
|
||||
OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs,
|
||||
OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds,
|
||||
OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified,
|
||||
OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath,
|
||||
OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath,
|
||||
OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath,
|
||||
GitHubOAuthEnabled: req.GitHubOAuthEnabled,
|
||||
GitHubOAuthClientID: req.GitHubOAuthClientID,
|
||||
GitHubOAuthClientSecret: req.GitHubOAuthClientSecret,
|
||||
GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL,
|
||||
GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL,
|
||||
GoogleOAuthEnabled: req.GoogleOAuthEnabled,
|
||||
GoogleOAuthClientID: req.GoogleOAuthClientID,
|
||||
GoogleOAuthClientSecret: req.GoogleOAuthClientSecret,
|
||||
GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL,
|
||||
GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
HideCcsImportButton: req.HideCcsImportButton,
|
||||
PurchaseSubscriptionEnabled: purchaseEnabled,
|
||||
PurchaseSubscriptionURL: purchaseURL,
|
||||
TableDefaultPageSize: req.TableDefaultPageSize,
|
||||
TablePageSizeOptions: req.TablePageSizeOptions,
|
||||
CustomMenuItems: customMenuJSON,
|
||||
CustomEndpoints: customEndpointsJSON,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
AffiliateRebateRate: affiliateRebateRate,
|
||||
AffiliateRebateFreezeHours: affiliateRebateFreezeHours,
|
||||
AffiliateRebateDurationDays: affiliateRebateDurationDays,
|
||||
AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap,
|
||||
DefaultUserRPMLimit: req.DefaultUserRPMLimit,
|
||||
DefaultSubscriptions: defaultSubscriptions,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
MinClaudeCodeVersion: req.MinClaudeCodeVersion,
|
||||
MaxClaudeCodeVersion: req.MaxClaudeCodeVersion,
|
||||
AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling,
|
||||
BackendModeEnabled: req.BackendModeEnabled,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@ -1574,6 +1732,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind),
|
||||
},
|
||||
DingTalk: service.ProviderDefaultGrantSettings{
|
||||
Balance: float64ValueOrDefault(req.AuthSourceDefaultDingTalkBalance, previousAuthSourceDefaults.DingTalk.Balance),
|
||||
Concurrency: intValueOrDefault(req.AuthSourceDefaultDingTalkConcurrency, previousAuthSourceDefaults.DingTalk.Concurrency),
|
||||
Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultDingTalkSubscriptions, previousAuthSourceDefaults.DingTalk.Subscriptions),
|
||||
GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnSignup, previousAuthSourceDefaults.DingTalk.GrantOnSignup),
|
||||
GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnFirstBind, previousAuthSourceDefaults.DingTalk.GrantOnFirstBind),
|
||||
},
|
||||
ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup),
|
||||
}
|
||||
if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil {
|
||||
@ -1613,6 +1778,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
CancelRateLimitWindow: req.PaymentCancelRateLimitWindow,
|
||||
CancelRateLimitUnit: req.PaymentCancelRateLimitUnit,
|
||||
CancelRateLimitMode: req.PaymentCancelRateLimitMode,
|
||||
AlipayForceQRCode: req.PaymentAlipayForceQRCode,
|
||||
}
|
||||
if err := h.paymentConfigService.UpdatePaymentConfig(c.Request.Context(), paymentReq); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@ -1632,6 +1798,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
h.ensureDingTalkSyncAttributes(c.Request.Context(), updatedSettings)
|
||||
updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context())
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@ -1682,6 +1849,22 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
|
||||
LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
|
||||
DingTalkConnectEnabled: updatedSettings.DingTalkConnectEnabled,
|
||||
DingTalkConnectClientID: updatedSettings.DingTalkConnectClientID,
|
||||
DingTalkConnectClientSecretConfigured: updatedSettings.DingTalkConnectClientSecretConfigured,
|
||||
DingTalkConnectRedirectURL: updatedSettings.DingTalkConnectRedirectURL,
|
||||
DingTalkConnectCorpRestrictionPolicy: updatedSettings.DingTalkConnectCorpRestrictionPolicy,
|
||||
DingTalkConnectInternalCorpID: updatedSettings.DingTalkConnectInternalCorpID,
|
||||
DingTalkConnectBypassRegistration: updatedSettings.DingTalkConnectBypassRegistration,
|
||||
DingTalkConnectSyncCorpEmail: updatedSettings.DingTalkConnectSyncCorpEmail,
|
||||
DingTalkConnectSyncDisplayName: updatedSettings.DingTalkConnectSyncDisplayName,
|
||||
DingTalkConnectSyncDept: updatedSettings.DingTalkConnectSyncDept,
|
||||
DingTalkConnectSyncCorpEmailAttrKey: updatedSettings.DingTalkConnectSyncCorpEmailAttrKey,
|
||||
DingTalkConnectSyncDisplayNameAttrKey: updatedSettings.DingTalkConnectSyncDisplayNameAttrKey,
|
||||
DingTalkConnectSyncDeptAttrKey: updatedSettings.DingTalkConnectSyncDeptAttrKey,
|
||||
DingTalkConnectSyncCorpEmailAttrName: updatedSettings.DingTalkConnectSyncCorpEmailAttrName,
|
||||
DingTalkConnectSyncDisplayNameAttrName: updatedSettings.DingTalkConnectSyncDisplayNameAttrName,
|
||||
DingTalkConnectSyncDeptAttrName: updatedSettings.DingTalkConnectSyncDeptAttrName,
|
||||
WeChatConnectEnabled: updatedSettings.WeChatConnectEnabled,
|
||||
WeChatConnectAppID: updatedSettings.WeChatConnectAppID,
|
||||
WeChatConnectAppSecretConfigured: updatedSettings.WeChatConnectAppSecretConfigured,
|
||||
@ -1803,6 +1986,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow,
|
||||
PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit,
|
||||
PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode,
|
||||
PaymentAlipayForceQRCode: updatedPaymentCfg.AlipayForceQRCode,
|
||||
|
||||
ChannelMonitorEnabled: updatedSettings.ChannelMonitorEnabled,
|
||||
ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds,
|
||||
@ -1822,6 +2006,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
// hasPaymentFields returns true if any payment-related field was explicitly provided.
|
||||
// mapDingTalkValidateError maps ValidateDingTalkConfig errors to machine-readable reason codes.
|
||||
func mapDingTalkValidateError(err error) string {
|
||||
switch {
|
||||
case errors.Is(err, config.ErrDingTalkV1AppTypeMismatch):
|
||||
return "dingtalk_apptype_mismatch"
|
||||
case errors.Is(err, config.ErrDingTalkV4InvalidAppKind):
|
||||
return "dingtalk_app_kind_invalid"
|
||||
default:
|
||||
return "dingtalk_corp_config_invalid"
|
||||
}
|
||||
}
|
||||
|
||||
func hasPaymentFields(req UpdateSettingsRequest) bool {
|
||||
return req.PaymentEnabled != nil || req.PaymentMinAmount != nil ||
|
||||
req.PaymentMaxAmount != nil || req.PaymentDailyLimit != nil ||
|
||||
@ -1832,7 +2028,8 @@ func hasPaymentFields(req UpdateSettingsRequest) bool {
|
||||
req.PaymentProductNameSuffix != nil || req.PaymentHelpImageURL != nil ||
|
||||
req.PaymentHelpText != nil || req.PaymentCancelRateLimitEnabled != nil ||
|
||||
req.PaymentCancelRateLimitMax != nil || req.PaymentCancelRateLimitWindow != nil ||
|
||||
req.PaymentCancelRateLimitUnit != nil || req.PaymentCancelRateLimitMode != nil
|
||||
req.PaymentCancelRateLimitUnit != nil || req.PaymentCancelRateLimitMode != nil ||
|
||||
req.PaymentAlipayForceQRCode != nil
|
||||
}
|
||||
|
||||
func (h *SettingHandler) auditSettingsUpdate(c *gin.Context, before *service.SystemSettings, after *service.SystemSettings, beforeAuthSourceDefaults *service.AuthSourceDefaultSettings, afterAuthSourceDefaults *service.AuthSourceDefaultSettings, req UpdateSettingsRequest) {
|
||||
@ -1935,6 +2132,45 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
|
||||
changed = append(changed, "linuxdo_connect_redirect_url")
|
||||
}
|
||||
if before.DingTalkConnectEnabled != after.DingTalkConnectEnabled {
|
||||
changed = append(changed, "dingtalk_connect_enabled")
|
||||
}
|
||||
if before.DingTalkConnectClientID != after.DingTalkConnectClientID {
|
||||
changed = append(changed, "dingtalk_connect_client_id")
|
||||
}
|
||||
if req.DingTalkConnectClientSecret != "" {
|
||||
changed = append(changed, "dingtalk_connect_client_secret")
|
||||
}
|
||||
if before.DingTalkConnectRedirectURL != after.DingTalkConnectRedirectURL {
|
||||
changed = append(changed, "dingtalk_connect_redirect_url")
|
||||
}
|
||||
if before.DingTalkConnectCorpRestrictionPolicy != after.DingTalkConnectCorpRestrictionPolicy {
|
||||
changed = append(changed, "dingtalk_connect_corp_restriction_policy")
|
||||
}
|
||||
if before.DingTalkConnectInternalCorpID != after.DingTalkConnectInternalCorpID {
|
||||
changed = append(changed, "dingtalk_connect_internal_corp_id")
|
||||
}
|
||||
if before.DingTalkConnectBypassRegistration != after.DingTalkConnectBypassRegistration {
|
||||
changed = append(changed, "dingtalk_connect_bypass_registration")
|
||||
}
|
||||
if before.DingTalkConnectSyncCorpEmail != after.DingTalkConnectSyncCorpEmail {
|
||||
changed = append(changed, "dingtalk_connect_sync_corp_email")
|
||||
}
|
||||
if before.DingTalkConnectSyncDisplayName != after.DingTalkConnectSyncDisplayName {
|
||||
changed = append(changed, "dingtalk_connect_sync_display_name")
|
||||
}
|
||||
if before.DingTalkConnectSyncDept != after.DingTalkConnectSyncDept {
|
||||
changed = append(changed, "dingtalk_connect_sync_dept")
|
||||
}
|
||||
if before.DingTalkConnectSyncCorpEmailAttrKey != after.DingTalkConnectSyncCorpEmailAttrKey {
|
||||
changed = append(changed, "dingtalk_connect_sync_corp_email_attr_key")
|
||||
}
|
||||
if before.DingTalkConnectSyncDisplayNameAttrKey != after.DingTalkConnectSyncDisplayNameAttrKey {
|
||||
changed = append(changed, "dingtalk_connect_sync_display_name_attr_key")
|
||||
}
|
||||
if before.DingTalkConnectSyncDeptAttrKey != after.DingTalkConnectSyncDeptAttrKey {
|
||||
changed = append(changed, "dingtalk_connect_sync_dept_attr_key")
|
||||
}
|
||||
if before.WeChatConnectEnabled != after.WeChatConnectEnabled {
|
||||
changed = append(changed, "wechat_connect_enabled")
|
||||
}
|
||||
@ -2246,6 +2482,7 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource
|
||||
{name: "wechat", before: before.WeChat, after: after.WeChat},
|
||||
{name: "github", before: before.GitHub, after: after.GitHub},
|
||||
{name: "google", before: before.Google, after: after.Google},
|
||||
{name: "dingtalk", before: before.DingTalk, after: after.DingTalk},
|
||||
}
|
||||
for _, field := range fields {
|
||||
if field.before.Balance != field.after.Balance {
|
||||
@ -2350,6 +2587,11 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults
|
||||
data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions
|
||||
data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup
|
||||
data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind
|
||||
data["auth_source_default_dingtalk_balance"] = authSourceDefaults.DingTalk.Balance
|
||||
data["auth_source_default_dingtalk_concurrency"] = authSourceDefaults.DingTalk.Concurrency
|
||||
data["auth_source_default_dingtalk_subscriptions"] = authSourceDefaults.DingTalk.Subscriptions
|
||||
data["auth_source_default_dingtalk_grant_on_signup"] = authSourceDefaults.DingTalk.GrantOnSignup
|
||||
data["auth_source_default_dingtalk_grant_on_first_bind"] = authSourceDefaults.DingTalk.GrantOnFirstBind
|
||||
data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance
|
||||
data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency
|
||||
data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions
|
||||
@ -3044,3 +3286,56 @@ func (h *SettingHandler) TestWebSearchEmulation(c *gin.Context) {
|
||||
}
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// ensureDingTalkSyncAttributes 在保存 settings 后,按 admin 配置的 (attr key, attr name)
|
||||
// 兜底 upsert 对应 user attribute definition:不存在则创建;存在但 name 不同则更新 name
|
||||
// (type/options/required 不变)。仅 internal_only + 对应 sync 开关开启时执行。
|
||||
// 失败仅记录日志,不阻塞 settings 保存。
|
||||
func (h *SettingHandler) ensureDingTalkSyncAttributes(ctx context.Context, settings *service.SystemSettings) {
|
||||
if h.userAttributeService == nil || settings == nil {
|
||||
return
|
||||
}
|
||||
if settings.DingTalkConnectCorpRestrictionPolicy != "internal_only" {
|
||||
return
|
||||
}
|
||||
if settings.DingTalkConnectSyncDisplayName {
|
||||
h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncDisplayNameAttrKey, settings.DingTalkConnectSyncDisplayNameAttrName, "钉钉 internal_only 登录时同步的钉钉姓名", service.AttributeTypeText)
|
||||
}
|
||||
if settings.DingTalkConnectSyncCorpEmail {
|
||||
h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncCorpEmailAttrKey, settings.DingTalkConnectSyncCorpEmailAttrName, "钉钉 internal_only 登录时同步的企业邮箱", service.AttributeTypeEmail)
|
||||
}
|
||||
if settings.DingTalkConnectSyncDept {
|
||||
h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncDeptAttrKey, settings.DingTalkConnectSyncDeptAttrName, "钉钉 internal_only 登录时同步的完整部门路径(如:公司/研发部)", service.AttributeTypeText)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *SettingHandler) ensureUserAttributeDefinition(ctx context.Context, key, name, description string, attrType service.UserAttributeType) {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
existing, err := h.userAttributeService.GetDefinitionByKey(ctx, key)
|
||||
if err == nil && existing != nil {
|
||||
if strings.TrimSpace(name) != "" && existing.Name != name {
|
||||
if _, err := h.userAttributeService.UpdateDefinition(ctx, existing.ID, service.UpdateAttributeDefinitionInput{
|
||||
Name: &name,
|
||||
}); err != nil {
|
||||
slog.Warn("dingtalk: update user attribute definition name failed", "key", key, "err", err.Error())
|
||||
return
|
||||
}
|
||||
slog.Info("dingtalk: updated user attribute definition name", "key", key, "name", name)
|
||||
}
|
||||
return
|
||||
}
|
||||
if _, err := h.userAttributeService.CreateDefinition(ctx, service.CreateAttributeDefinitionInput{
|
||||
Key: key,
|
||||
Name: name,
|
||||
Description: description,
|
||||
Type: attrType,
|
||||
Enabled: true,
|
||||
}); err != nil {
|
||||
slog.Warn("dingtalk: ensure user attribute definition failed", "key", key, "err", err.Error())
|
||||
return
|
||||
}
|
||||
slog.Info("dingtalk: created user attribute definition", "key", key, "name", name, "type", attrType)
|
||||
}
|
||||
|
||||
@ -137,7 +137,7 @@ func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) {
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
@ -174,7 +174,7 @@ func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *tes
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
body := map[string]any{
|
||||
"registration_enabled": true,
|
||||
@ -214,7 +214,7 @@ func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedS
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
body := map[string]any{
|
||||
"promo_code_enabled": true,
|
||||
@ -264,7 +264,7 @@ func TestSettingHandler_UpdateSettings_PreservesLegacyBlankPaymentVisibleMethodS
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
body := map[string]any{
|
||||
"promo_code_enabled": false,
|
||||
@ -309,7 +309,7 @@ func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFla
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
body := map[string]any{
|
||||
"promo_code_enabled": true,
|
||||
@ -388,7 +388,7 @@ func TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaul
|
||||
ClockSkewSeconds: 120,
|
||||
},
|
||||
})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
body := map[string]any{
|
||||
"promo_code_enabled": true,
|
||||
@ -417,7 +417,7 @@ func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource(
|
||||
},
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
body := map[string]any{
|
||||
"promo_code_enabled": true,
|
||||
@ -450,7 +450,7 @@ func TestSettingHandler_UpdateSettings_DoesNotPersistPartialSystemSettingsWhenAu
|
||||
err: errors.New("write auth source defaults failed"),
|
||||
}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil)
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
body := map[string]any{
|
||||
"registration_enabled": true,
|
||||
|
||||
319
backend/internal/handler/admin/setting_handler_dingtalk_test.go
Normal file
319
backend/internal/handler/admin/setting_handler_dingtalk_test.go
Normal file
@ -0,0 +1,319 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// dingtalkSettingsRepoStub 复用 settingHandlerRepoStub(已在 setting_handler_auth_source_defaults_test.go 定义)
|
||||
|
||||
func newDingTalkSettingsHandler() (*SettingHandler, *settingHandlerRepoStub) {
|
||||
repo := &settingHandlerRepoStub{values: map[string]string{}}
|
||||
svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}})
|
||||
handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil)
|
||||
return handler, repo
|
||||
}
|
||||
|
||||
// baseValidDingTalkBody 返回一个可以通过所有校验的最小合法 body。
|
||||
func baseValidDingTalkBody() map[string]any {
|
||||
return map[string]any{
|
||||
"dingtalk_connect_enabled": true,
|
||||
"dingtalk_connect_client_id": "test-client-id",
|
||||
"dingtalk_connect_client_secret": "test-client-secret",
|
||||
"dingtalk_connect_redirect_url": "https://example.com/auth/dingtalk/callback",
|
||||
"dingtalk_connect_corp_restriction_policy": "none",
|
||||
}
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_V3_InternalOnlyAllowsEmptyCorpID 验证方案 A:
|
||||
// internal_only + internal_corp_id="" 应通过校验(→ 200),不再是 400。
|
||||
func TestSettingsPUT_DingTalk_V3_InternalOnlyAllowsEmptyCorpID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
handler, _ := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
|
||||
body["dingtalk_connect_internal_corp_id"] = "" // 空值现在合法
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_HappyPath_None 验证 none policy → 200
|
||||
func TestSettingsPUT_DingTalk_HappyPath_None(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
handler, _ := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "none"
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp response.Response
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
data, ok := resp.Data.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, true, data["dingtalk_connect_enabled"])
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_HappyPath_InternalOnly_WithCorpID 验证 internal_only + corp_id → 200
|
||||
func TestSettingsPUT_DingTalk_HappyPath_InternalOnly_WithCorpID(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
handler, _ := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
|
||||
body["dingtalk_connect_internal_corp_id"] = "ding-corp-123"
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_BypassRegistration_RoundTrip 验证 bypass_registration 字段 save+load。
|
||||
// 必须用 policy=internal_only:bypass 仅在该 policy 下生效,其它 policy 写入层会 coerce 为 false。
|
||||
func TestSettingsPUT_DingTalk_BypassRegistration_RoundTrip(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
handler, _ := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
|
||||
body["dingtalk_connect_bypass_registration"] = true
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp response.Response
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
data, ok := resp.Data.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, true, data["dingtalk_connect_bypass_registration"])
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_Disabled_SkipsValidation 验证 disabled 时跳过 corp 校验 → 200。
|
||||
// 用 enabled=true 时必然触发"Client ID is required when enabled"的空 client_id 作为
|
||||
// 哨兵——只要 enabled=false 仍能 200 就证明跳过了。
|
||||
func TestSettingsPUT_DingTalk_Disabled_SkipsValidation(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
handler, _ := newDingTalkSettingsHandler()
|
||||
|
||||
body := map[string]any{
|
||||
"dingtalk_connect_enabled": false,
|
||||
"dingtalk_connect_client_id": "", // 这种空值在 enabled=true 时会被 400 拒绝
|
||||
"dingtalk_connect_corp_restriction_policy": "internal_only",
|
||||
}
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_SyncFlags_InternalOnly_RoundTrip 验证三个 sync 开关在 internal_only 下可正常 save+load。
|
||||
func TestSettingsPUT_DingTalk_SyncFlags_InternalOnly_RoundTrip(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
handler, _ := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
|
||||
body["dingtalk_connect_sync_corp_email"] = true
|
||||
body["dingtalk_connect_sync_display_name"] = true
|
||||
body["dingtalk_connect_sync_dept"] = true
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp response.Response
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
data, ok := resp.Data.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, true, data["dingtalk_connect_sync_corp_email"], "sync_corp_email should be true for internal_only")
|
||||
require.Equal(t, true, data["dingtalk_connect_sync_display_name"], "sync_display_name should be true for internal_only")
|
||||
require.Equal(t, true, data["dingtalk_connect_sync_dept"], "sync_dept should be true for internal_only")
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_SyncFlags_PolicyNone_CoercedToFalse 验证 policy=none 时三个 sync 开关被 coerce 为 false。
|
||||
func TestSettingsPUT_DingTalk_SyncFlags_PolicyNone_CoercedToFalse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
handler, _ := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "none"
|
||||
body["dingtalk_connect_sync_corp_email"] = true
|
||||
body["dingtalk_connect_sync_display_name"] = true
|
||||
body["dingtalk_connect_sync_dept"] = true
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
var resp response.Response
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
data, ok := resp.Data.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, false, data["dingtalk_connect_sync_corp_email"], "sync_corp_email must be coerced to false when policy=none")
|
||||
require.Equal(t, false, data["dingtalk_connect_sync_display_name"], "sync_display_name must be coerced to false when policy=none")
|
||||
require.Equal(t, false, data["dingtalk_connect_sync_dept"], "sync_dept must be coerced to false when policy=none")
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_StaleWhitelist_CoercedToNone 验证升级兼容:
|
||||
// admin 直接把 corp_restriction_policy=whitelist 提交(前端 UI 已无此选项,但 API 仍可命中)
|
||||
// 不应导致 400 失败,应该被静默 coerce 为 none 后通过校验。
|
||||
func TestSettingsPUT_DingTalk_StaleWhitelist_CoercedToNone(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
handler, repo := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "whitelist"
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
require.Equal(t, "none", repo.values[service.SettingKeyDingTalkConnectCorpRestrictionPolicy],
|
||||
"stale whitelist 应在写入路径被 coerce 为 none")
|
||||
}
|
||||
|
||||
// TestSettingsPUT_DingTalk_SyncAttrKey_RoundTrip 验证 3 个 attr key 字段 save+load + 空值 fallback 到默认值。
|
||||
func TestSettingsPUT_DingTalk_SyncAttrKey_RoundTrip(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
t.Run("custom_attr_keys_saved", func(t *testing.T) {
|
||||
handler, repo := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
|
||||
body["dingtalk_connect_sync_corp_email"] = true
|
||||
body["dingtalk_connect_sync_display_name"] = true
|
||||
body["dingtalk_connect_sync_dept"] = true
|
||||
body["dingtalk_connect_sync_corp_email_attr_key"] = "my_email_attr"
|
||||
body["dingtalk_connect_sync_display_name_attr_key"] = "my_name_attr"
|
||||
body["dingtalk_connect_sync_dept_attr_key"] = "my_dept_attr"
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 验证写入 DB 的 key
|
||||
require.Equal(t, "my_email_attr", repo.values[service.SettingKeyDingTalkConnectSyncCorpEmailAttrKey])
|
||||
require.Equal(t, "my_name_attr", repo.values[service.SettingKeyDingTalkConnectSyncDisplayNameAttrKey])
|
||||
require.Equal(t, "my_dept_attr", repo.values[service.SettingKeyDingTalkConnectSyncDeptAttrKey])
|
||||
|
||||
// 验证响应中的 attr key
|
||||
var resp response.Response
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||
data, ok := resp.Data.(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "my_email_attr", data["dingtalk_connect_sync_corp_email_attr_key"])
|
||||
require.Equal(t, "my_name_attr", data["dingtalk_connect_sync_display_name_attr_key"])
|
||||
require.Equal(t, "my_dept_attr", data["dingtalk_connect_sync_dept_attr_key"])
|
||||
})
|
||||
|
||||
t.Run("empty_attr_keys_fallback_to_defaults", func(t *testing.T) {
|
||||
handler, repo := newDingTalkSettingsHandler()
|
||||
|
||||
body := baseValidDingTalkBody()
|
||||
body["dingtalk_connect_corp_restriction_policy"] = "internal_only"
|
||||
// 不传 attr key → 写入层 fallback 到默认值
|
||||
body["dingtalk_connect_sync_corp_email_attr_key"] = ""
|
||||
body["dingtalk_connect_sync_display_name_attr_key"] = ""
|
||||
body["dingtalk_connect_sync_dept_attr_key"] = ""
|
||||
|
||||
rawBody, err := json.Marshal(body)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
handler.UpdateSettings(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// 空值应 fallback 到默认值并持久化
|
||||
require.Equal(t, "dingtalk_email", repo.values[service.SettingKeyDingTalkConnectSyncCorpEmailAttrKey])
|
||||
require.Equal(t, "dingtalk_name", repo.values[service.SettingKeyDingTalkConnectSyncDisplayNameAttrKey])
|
||||
require.Equal(t, "dingtalk_department", repo.values[service.SettingKeyDingTalkConnectSyncDeptAttrKey])
|
||||
})
|
||||
}
|
||||
398
backend/internal/handler/auth_dingtalk_client.go
Normal file
398
backend/internal/handler/auth_dingtalk_client.go
Normal file
@ -0,0 +1,398 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// dingTalkClientConfig 是 DingTalkClient 需要的最小配置子集
|
||||
type dingTalkClientConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
TokenURL string
|
||||
UserInfoURL string
|
||||
}
|
||||
|
||||
type DingTalkClient struct {
|
||||
cfg dingTalkClientConfig
|
||||
appToken string
|
||||
appTokenExp time.Time // 钉钉 7200s,留 200s 余量 → 7000s
|
||||
mu sync.Mutex
|
||||
httpClient *http.Client
|
||||
// TODO(multi-instance): Redis 集中缓存 appToken
|
||||
}
|
||||
|
||||
type DingTalkUserTokenResp struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpireIn int64 `json:"expireIn"`
|
||||
CorpID string `json:"corpId"`
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) ExchangeCodeForUserToken(ctx context.Context, code string) (*DingTalkUserTokenResp, error) {
|
||||
body := map[string]string{
|
||||
"clientId": c.cfg.ClientID,
|
||||
"clientSecret": c.cfg.ClientSecret,
|
||||
"code": code,
|
||||
"grantType": "authorization_code",
|
||||
}
|
||||
payload, _ := json.Marshal(body)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.TokenURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
var out DingTalkUserTokenResp
|
||||
if err := json.Unmarshal(raw, &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(out.AccessToken) == "" {
|
||||
return nil, parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
type DingTalkAPIError struct {
|
||||
Code string
|
||||
Message string
|
||||
HTTP int
|
||||
}
|
||||
|
||||
func (e *DingTalkAPIError) Error() string {
|
||||
return fmt.Sprintf("dingtalk api error code=%s msg=%s http=%d", e.Code, e.Message, e.HTTP)
|
||||
}
|
||||
|
||||
func parseDingTalkErr(raw []byte, status int) error {
|
||||
var v struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
_ = json.Unmarshal(raw, &v)
|
||||
code := v.Code
|
||||
if code == "" && v.ErrCode != 0 {
|
||||
code = fmt.Sprintf("%d", v.ErrCode)
|
||||
}
|
||||
msg := v.Message
|
||||
if msg == "" {
|
||||
msg = v.ErrMsg
|
||||
}
|
||||
return &DingTalkAPIError{Code: code, Message: msg, HTTP: status}
|
||||
}
|
||||
|
||||
// GetUnionIdByUserToken 调用 /v1.0/contact/users/me 返回 unionId 与用户自设昵称 nick。
|
||||
// nick 来自钉钉新版 OIDC 接口(用户在 App 个人资料填的昵称),与旧版 user/get.nickname 不同源。
|
||||
func (c *DingTalkClient) GetUnionIdByUserToken(ctx context.Context, userToken string) (unionID string, nick string, err error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.cfg.UserInfoURL, nil)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
req.Header.Set("x-acs-dingtalk-access-token", userToken)
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", "", parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
var v struct {
|
||||
UnionID string `json:"unionId"`
|
||||
Nick string `json:"nick"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &v); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
if strings.TrimSpace(v.UnionID) == "" {
|
||||
return "", "", parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
return v.UnionID, v.Nick, nil
|
||||
}
|
||||
|
||||
type DingTalkStaffInfo struct {
|
||||
UserID string
|
||||
Name string // 企业内真实姓名(钉钉企业管理后台配置)
|
||||
Nickname string // 钉钉个人昵称(用户自己设置)
|
||||
Email string
|
||||
DeptIDs []int64
|
||||
// CorpID 不来自 staff 接口,来自 userToken;不在此 struct
|
||||
}
|
||||
|
||||
// dingTalkOAPIBase 推导钉钉旧版 OAPI base URL(host: api.dingtalk.com → oapi.dingtalk.com)。
|
||||
// getbyunionid 与 topapi/v2/user/get 仅在旧版 OAPI 提供,不在 v1.0 OpenAPI。
|
||||
func (c *DingTalkClient) dingTalkOAPIBase() string {
|
||||
u, err := url.Parse(c.cfg.UserInfoURL)
|
||||
if err != nil || u.Scheme == "" || u.Host == "" {
|
||||
return "https://oapi.dingtalk.com"
|
||||
}
|
||||
host := u.Host
|
||||
if strings.HasPrefix(host, "api.") {
|
||||
host = "oapi." + strings.TrimPrefix(host, "api.")
|
||||
}
|
||||
return u.Scheme + "://" + host
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) GetAppToken(ctx context.Context) (string, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.appToken != "" && time.Now().Before(c.appTokenExp) {
|
||||
return c.appToken, nil
|
||||
}
|
||||
body := map[string]string{"appKey": c.cfg.ClientID, "appSecret": c.cfg.ClientSecret}
|
||||
payload, _ := json.Marshal(body)
|
||||
// 钉钉新版 v1.0 企业内部应用 access_token: POST /v1.0/oauth2/accessToken
|
||||
// 此 token 也可作为旧版 OAPI 的 access_token 使用(钉钉文档已说明)
|
||||
appTokenURL := strings.Replace(c.cfg.TokenURL, "/oauth2/userAccessToken", "/oauth2/accessToken", 1)
|
||||
if !strings.Contains(appTokenURL, "accessToken") && !strings.Contains(appTokenURL, "gettoken") {
|
||||
appTokenURL = c.cfg.TokenURL // fallback for test stub
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, appTokenURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
var v struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
ExpireIn int64 `json:"expireIn"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &v); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if v.AccessToken == "" {
|
||||
return "", parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
c.appToken = v.AccessToken
|
||||
ttl := v.ExpireIn
|
||||
if ttl > 200 {
|
||||
ttl -= 200
|
||||
}
|
||||
c.appTokenExp = time.Now().Add(time.Duration(ttl) * time.Second)
|
||||
return c.appToken, nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) GetUserIdByUnionId(ctx context.Context, unionID string) (string, error) {
|
||||
appToken, err := c.GetAppToken(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
body := map[string]string{"unionid": unionID}
|
||||
payload, _ := json.Marshal(body)
|
||||
// 钉钉旧版 OAPI: POST https://oapi.dingtalk.com/topapi/user/getbyunionid?access_token=XXX
|
||||
// access_token 通过 query string 传递(不是 header)
|
||||
var targetURL string
|
||||
if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") {
|
||||
targetURL = c.dingTalkOAPIBase() + "/topapi/user/getbyunionid?access_token=" + url.QueryEscape(appToken)
|
||||
} else {
|
||||
targetURL = c.cfg.UserInfoURL // fallback for test stub
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
var v struct {
|
||||
Result struct {
|
||||
UserID string `json:"userid"`
|
||||
} `json:"result"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &v); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if v.ErrCode != 0 {
|
||||
return "", parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
if strings.TrimSpace(v.Result.UserID) == "" {
|
||||
return "", parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
return v.Result.UserID, nil
|
||||
}
|
||||
|
||||
// DingTalkDeptInfo 部门信息(topapi/v2/department/get 返回子集)
|
||||
type DingTalkDeptInfo struct {
|
||||
DeptID int64
|
||||
Name string
|
||||
ParentID int64
|
||||
}
|
||||
|
||||
// GetDeptInfo 查询单个部门信息(用于递归拼部门路径)。
|
||||
// 调用钉钉旧版 OAPI: POST /topapi/v2/department/get?access_token=XXX
|
||||
func (c *DingTalkClient) GetDeptInfo(ctx context.Context, deptID int64) (*DingTalkDeptInfo, error) {
|
||||
appToken, err := c.GetAppToken(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body := map[string]any{"dept_id": deptID, "language": "zh_CN"}
|
||||
payload, _ := json.Marshal(body)
|
||||
var targetURL string
|
||||
if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") {
|
||||
targetURL = c.dingTalkOAPIBase() + "/topapi/v2/department/get?access_token=" + url.QueryEscape(appToken)
|
||||
} else {
|
||||
targetURL = c.cfg.UserInfoURL // test stub fallback
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
var v struct {
|
||||
Result struct {
|
||||
DeptID int64 `json:"dept_id"`
|
||||
Name string `json:"name"`
|
||||
ParentID int64 `json:"parent_id"`
|
||||
} `json:"result"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v.ErrCode != 0 {
|
||||
return nil, parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
return &DingTalkDeptInfo{
|
||||
DeptID: v.Result.DeptID,
|
||||
Name: v.Result.Name,
|
||||
ParentID: v.Result.ParentID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) GetStaffInfoByUserId(ctx context.Context, userID string) (*DingTalkStaffInfo, error) {
|
||||
appToken, err := c.GetAppToken(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body := map[string]string{"userid": userID}
|
||||
payload, _ := json.Marshal(body)
|
||||
// 钉钉旧版 OAPI: POST https://oapi.dingtalk.com/topapi/v2/user/get?access_token=XXX
|
||||
var targetURL string
|
||||
if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") {
|
||||
targetURL = c.dingTalkOAPIBase() + "/topapi/v2/user/get?access_token=" + url.QueryEscape(appToken)
|
||||
} else {
|
||||
targetURL = c.cfg.UserInfoURL // fallback for test stub
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
var v struct {
|
||||
Result struct {
|
||||
UserID string `json:"userid"`
|
||||
Name string `json:"name"`
|
||||
Nickname string `json:"nickname"`
|
||||
Email string `json:"email"`
|
||||
OrgEmail string `json:"org_email"`
|
||||
Extension string `json:"extension"`
|
||||
DeptID []int64 `json:"dept_id_list"`
|
||||
} `json:"result"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v.ErrCode != 0 {
|
||||
return nil, parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
if strings.TrimSpace(v.Result.UserID) == "" {
|
||||
return nil, parseDingTalkErr(raw, resp.StatusCode)
|
||||
}
|
||||
// 邮箱三级 fallback:org_email > email > extension["企业邮箱"](钉钉自定义扩展字段,JSON string)
|
||||
email := strings.TrimSpace(v.Result.OrgEmail)
|
||||
emailSource := "org_email"
|
||||
if email == "" {
|
||||
email = strings.TrimSpace(v.Result.Email)
|
||||
emailSource = "email"
|
||||
}
|
||||
extensionParsed := false
|
||||
if email == "" && strings.TrimSpace(v.Result.Extension) != "" {
|
||||
var ext map[string]string
|
||||
if err := json.Unmarshal([]byte(v.Result.Extension), &ext); err == nil {
|
||||
extensionParsed = true
|
||||
if v, ok := ext["企业邮箱"]; ok {
|
||||
email = strings.TrimSpace(v)
|
||||
emailSource = "extension.企业邮箱"
|
||||
}
|
||||
}
|
||||
}
|
||||
if email == "" {
|
||||
emailSource = "none"
|
||||
}
|
||||
slog.Info("dingtalk staff fetched",
|
||||
"userid", v.Result.UserID,
|
||||
"name_present", v.Result.Name != "",
|
||||
"nickname_present", v.Result.Nickname != "",
|
||||
"name_eq_nickname", v.Result.Name != "" && v.Result.Name == v.Result.Nickname,
|
||||
"email_present", v.Result.Email != "",
|
||||
"org_email_present", v.Result.OrgEmail != "",
|
||||
"extension_present", v.Result.Extension != "",
|
||||
"extension_parsed", extensionParsed,
|
||||
"email_source", emailSource,
|
||||
"dept_count", len(v.Result.DeptID),
|
||||
)
|
||||
return &DingTalkStaffInfo{
|
||||
UserID: v.Result.UserID,
|
||||
Name: v.Result.Name,
|
||||
Nickname: v.Result.Nickname,
|
||||
Email: email,
|
||||
DeptIDs: v.Result.DeptID,
|
||||
}, nil
|
||||
}
|
||||
143
backend/internal/handler/auth_dingtalk_client_test.go
Normal file
143
backend/internal/handler/auth_dingtalk_client_test.go
Normal file
@ -0,0 +1,143 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDingTalkClient_ExchangeCodeForUserToken_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "POST", r.Method)
|
||||
require.Equal(t, "/v1.0/oauth2/userAccessToken", r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"accessToken":"USER_TOKEN_X","expireIn":7200,"refreshToken":"R","corpId":"dingABC"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cli := &DingTalkClient{
|
||||
cfg: dingTalkClientConfig{
|
||||
ClientID: "k", ClientSecret: "s",
|
||||
TokenURL: server.URL + "/v1.0/oauth2/userAccessToken",
|
||||
},
|
||||
httpClient: server.Client(),
|
||||
}
|
||||
resp, err := cli.ExchangeCodeForUserToken(context.Background(), "AUTH_CODE")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "USER_TOKEN_X", resp.AccessToken)
|
||||
require.Equal(t, "dingABC", resp.CorpID)
|
||||
}
|
||||
|
||||
func TestDingTalkClient_GetUnionIdByUserToken_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "USER_TOKEN_X", r.Header.Get("x-acs-dingtalk-access-token"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"nick":"张三","unionId":"UID_AAA","openId":"OPEN","avatarUrl":"http://x"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cli := &DingTalkClient{
|
||||
cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/v1.0/contact/users/me"},
|
||||
httpClient: server.Client(),
|
||||
}
|
||||
unionID, nick, err := cli.GetUnionIdByUserToken(context.Background(), "USER_TOKEN_X")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "UID_AAA", unionID)
|
||||
require.Equal(t, "张三", nick)
|
||||
}
|
||||
|
||||
func TestDingTalkClient_GetAppToken_Cached(t *testing.T) {
|
||||
callCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
_, _ = w.Write([]byte(`{"accessToken":"APP_TKN","expireIn":7200}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cli := &DingTalkClient{
|
||||
cfg: dingTalkClientConfig{ClientID: "k", ClientSecret: "s", TokenURL: server.URL + "/gettoken"},
|
||||
httpClient: server.Client(),
|
||||
}
|
||||
t1, err := cli.GetAppToken(context.Background())
|
||||
require.NoError(t, err)
|
||||
t2, err := cli.GetAppToken(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, t1, t2)
|
||||
require.Equal(t, 1, callCount, "second call should hit cache")
|
||||
}
|
||||
|
||||
func TestDingTalkClient_GetUserIdByUnionId_60011(t *testing.T) {
|
||||
appTokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte(`{"accessToken":"APP_TKN","expireIn":7200}`))
|
||||
}))
|
||||
defer appTokenServer.Close()
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"errcode":60011,"errmsg":"not in directory"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cli := &DingTalkClient{
|
||||
cfg: dingTalkClientConfig{TokenURL: appTokenServer.URL + "/gettoken"},
|
||||
httpClient: server.Client(),
|
||||
}
|
||||
cli.appToken = "APP_TKN"
|
||||
cli.appTokenExp = time.Now().Add(time.Hour)
|
||||
cli.cfg.UserInfoURL = server.URL + "/v1.0/contact/users/byUnionId"
|
||||
|
||||
_, err := cli.GetUserIdByUnionId(context.Background(), "UID_AAA")
|
||||
require.Error(t, err)
|
||||
apiErr, ok := err.(*DingTalkAPIError)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "60011", apiErr.Code)
|
||||
}
|
||||
|
||||
// TestDingTalkClient_GetDeptInfo_Success 验证 GetDeptInfo 正常情况返回部门信息。
|
||||
func TestDingTalkClient_GetDeptInfo_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"errcode":0,"errmsg":"ok","result":{"dept_id":42,"name":"AI数据","parent_id":1}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cli := &DingTalkClient{
|
||||
cfg: dingTalkClientConfig{
|
||||
UserInfoURL: server.URL + "/stub", // 不含 /contact/users/me,走 test stub 路径
|
||||
},
|
||||
httpClient: server.Client(),
|
||||
}
|
||||
cli.appToken = "APP_TKN"
|
||||
cli.appTokenExp = time.Now().Add(time.Hour)
|
||||
|
||||
info, err := cli.GetDeptInfo(context.Background(), 42)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(42), info.DeptID)
|
||||
require.Equal(t, "AI数据", info.Name)
|
||||
require.Equal(t, int64(1), info.ParentID)
|
||||
}
|
||||
|
||||
// TestDingTalkClient_GetDeptInfo_ErrCode60003 验证 errcode=60003(部门不存在)时返回错误。
|
||||
func TestDingTalkClient_GetDeptInfo_ErrCode60003(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"dept not found"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cli := &DingTalkClient{
|
||||
cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"},
|
||||
httpClient: server.Client(),
|
||||
}
|
||||
cli.appToken = "APP_TKN"
|
||||
cli.appTokenExp = time.Now().Add(time.Hour)
|
||||
|
||||
_, err := cli.GetDeptInfo(context.Background(), 999)
|
||||
require.Error(t, err)
|
||||
apiErr, ok := err.(*DingTalkAPIError)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "60003", apiErr.Code)
|
||||
}
|
||||
1066
backend/internal/handler/auth_dingtalk_oauth.go
Normal file
1066
backend/internal/handler/auth_dingtalk_oauth.go
Normal file
File diff suppressed because it is too large
Load Diff
391
backend/internal/handler/auth_dingtalk_oauth_test.go
Normal file
391
backend/internal/handler/auth_dingtalk_oauth_test.go
Normal file
@ -0,0 +1,391 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestDingTalkOAuthStart_Disabled は sentinel テスト。
|
||||
// TODO(task-1.10): newTestAuthHandlerWithDingTalk helper が追加されたら t.Skip を外す。
|
||||
func TestDingTalkOAuthStart_Disabled(t *testing.T) {
|
||||
t.Skip("helper newTestAuthHandlerWithDingTalk added in Task 1.10; sentinel only")
|
||||
}
|
||||
|
||||
// TestBuildDingTalkSyntheticEmail_UsesUnionID 验证合成邮箱种子使用 unionID。
|
||||
func TestBuildDingTalkSyntheticEmail_UsesUnionID(t *testing.T) {
|
||||
unionID := "union_AbCdEf123"
|
||||
email := buildDingTalkSyntheticEmail(unionID)
|
||||
|
||||
want := "dingtalk-union_abcdef123@dingtalk-connect.invalid"
|
||||
require.Equal(t, want, email)
|
||||
|
||||
// 确保结果都是小写(邮箱大小写不敏感,统一小写)
|
||||
require.True(t, strings.ToLower(email) == email, "synthetic email should be all lowercase")
|
||||
|
||||
// 确保前缀正确
|
||||
require.True(t, strings.HasPrefix(email, "dingtalk-"), "should have dingtalk- prefix")
|
||||
|
||||
// 确保后缀是合成邮箱域名
|
||||
require.True(t, strings.HasSuffix(email, "@dingtalk-connect.invalid"), "should have reserved domain suffix")
|
||||
}
|
||||
|
||||
// TestBuildDingTalkSyntheticEmail_TrimsSpace 验证 unionID 空白被修剪。
|
||||
func TestBuildDingTalkSyntheticEmail_TrimsSpace(t *testing.T) {
|
||||
email := buildDingTalkSyntheticEmail(" UID_XYZ ")
|
||||
require.Equal(t, "dingtalk-uid_xyz@dingtalk-connect.invalid", email)
|
||||
}
|
||||
|
||||
// TestBuildDingTalkUpstreamClaims_EmptyStaff 验证 staff 为空 struct(跨组织降级路径)时:
|
||||
// - subject 等于 unionID(与 identityKey.ProviderSubject 一致)
|
||||
// - corp_user_id 为空字符串(跨组织时拿不到企业 userid)
|
||||
// - email/username 为空字符串
|
||||
// B/C: Step 3/4 失败降级时 staff = &DingTalkStaffInfo{},claims 不应有 nil。
|
||||
func TestBuildDingTalkUpstreamClaims_EmptyStaff(t *testing.T) {
|
||||
staff := &DingTalkStaffInfo{}
|
||||
claims := buildDingTalkUpstreamClaims(staff, "UNION_AAA", "CORP_X")
|
||||
|
||||
require.Equal(t, "", claims["email"])
|
||||
require.Equal(t, "", claims["username"])
|
||||
// 重构后 subject = unionID(与 identityKey.ProviderSubject 保持一致)
|
||||
require.Equal(t, "UNION_AAA", claims["subject"])
|
||||
require.Equal(t, "", claims["corp_user_id"]) // 企业 userid 跨组织时为空
|
||||
require.Equal(t, "UNION_AAA", claims["union_id"])
|
||||
require.Equal(t, "CORP_X", claims["corp_id"])
|
||||
}
|
||||
|
||||
// TestCheckDingTalkCorpAllowed_CrossOrgPolicy 验证 policy=none 时允许任意 corp。
|
||||
// D: corp 校验提前后逻辑不变。
|
||||
func TestCheckDingTalkCorpAllowed_CrossOrgPolicy(t *testing.T) {
|
||||
cfg := config.DingTalkConnectConfig{CorpRestrictionPolicy: "none"}
|
||||
|
||||
assert.True(t, checkDingTalkCorpAllowed(cfg, "dingABC"), "policy=none should allow any corp")
|
||||
assert.True(t, checkDingTalkCorpAllowed(cfg, ""), "policy=none should allow empty corp")
|
||||
assert.True(t, checkDingTalkCorpAllowed(cfg, "foreign_corp"), "policy=none should allow foreign corp")
|
||||
}
|
||||
|
||||
// TestCheckDingTalkCorpAllowed_InternalOnly 验证 policy=internal_only 时的 corp 校验语义(方案 A 修订)。
|
||||
// 钉钉 userAccessToken 在部分授权场景(扫码登录、非企业工作台入口)不返回 corpId 字段,
|
||||
// 因此 checkDingTalkCorpAllowed 完全不校验 corpID,由 step 3 GetUserIdByUnionId 做真实判定
|
||||
// (跨企业用户会被钉钉错误码 60011/60121 拒绝,mapDingTalkErrorCode 映射回 corp_rejected)。
|
||||
func TestCheckDingTalkCorpAllowed_InternalOnly(t *testing.T) {
|
||||
cfgWithCorpID := config.DingTalkConnectConfig{
|
||||
CorpRestrictionPolicy: "internal_only",
|
||||
InternalCorpID: "dingInternal",
|
||||
}
|
||||
assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, "dingInternal"), "internal_only: matching corpID allowed")
|
||||
assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, "foreign_corp"), "internal_only: corpID 字段不再用于决策,step 3 兜底")
|
||||
assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, ""), "internal_only: 空 corpID 也通过(钉钉部分授权场景不返回 corpId)")
|
||||
|
||||
cfgNoCorpID := config.DingTalkConnectConfig{
|
||||
CorpRestrictionPolicy: "internal_only",
|
||||
InternalCorpID: "",
|
||||
}
|
||||
assert.True(t, checkDingTalkCorpAllowed(cfgNoCorpID, "dingAnyNonEmpty"), "internal_only + no InternalCorpID: 非空 corpID 通过")
|
||||
assert.True(t, checkDingTalkCorpAllowed(cfgNoCorpID, ""), "internal_only + no InternalCorpID: 空 corpID 也通过")
|
||||
}
|
||||
|
||||
// TestDecideDingTalkStep34Strategy_PolicyNone 验证 policy=none 时
|
||||
// Step 3/4 失败应降级(shouldFallback=true, isFatal=false)。
|
||||
func TestDecideDingTalkStep34Strategy_PolicyNone(t *testing.T) {
|
||||
step3Err := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403}
|
||||
|
||||
shouldFallback, isFatal := decideDingTalkStep34Strategy("none", step3Err)
|
||||
|
||||
require.True(t, shouldFallback, "policy=none: step3 failure should trigger fallback")
|
||||
require.False(t, isFatal, "policy=none: step3 failure should NOT be fatal")
|
||||
}
|
||||
|
||||
// TestDecideDingTalkStep34Strategy_PolicyNoneEmpty 验证 policy="" 时行为与 "none" 相同。
|
||||
func TestDecideDingTalkStep34Strategy_PolicyNoneEmpty(t *testing.T) {
|
||||
stepErr := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403}
|
||||
|
||||
shouldFallback, isFatal := decideDingTalkStep34Strategy("", stepErr)
|
||||
|
||||
require.True(t, shouldFallback, "policy='': step failure should trigger fallback")
|
||||
require.False(t, isFatal, "policy='': step failure should NOT be fatal")
|
||||
}
|
||||
|
||||
// TestDecideDingTalkStep34Strategy_PolicyInternalOnly 验证 policy=internal_only 时
|
||||
// Step 3/4 失败应 hard fail(isFatal=true)。
|
||||
func TestDecideDingTalkStep34Strategy_PolicyInternalOnly(t *testing.T) {
|
||||
step3Err := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403}
|
||||
|
||||
shouldFallback, isFatal := decideDingTalkStep34Strategy("internal_only", step3Err)
|
||||
|
||||
require.False(t, shouldFallback, "policy=internal_only: should NOT fallback on step3 error")
|
||||
require.True(t, isFatal, "policy=internal_only: step3 failure should be fatal")
|
||||
}
|
||||
|
||||
// TestDecideDingTalkStep34Strategy_NoError 验证 stepErr=nil 时两个返回值均为 false。
|
||||
func TestDecideDingTalkStep34Strategy_NoError(t *testing.T) {
|
||||
for _, policy := range []string{"none", "internal_only", ""} {
|
||||
shouldFallback, isFatal := decideDingTalkStep34Strategy(policy, nil)
|
||||
require.False(t, shouldFallback, "no error should not trigger fallback (policy=%q)", policy)
|
||||
require.False(t, isFatal, "no error should not be fatal (policy=%q)", policy)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompleteDingTalkRegistration_UsernameFromEmailLocalPart 验证 username 为空时
|
||||
// 退到 email local part(@ 之前的部分)。
|
||||
// E: CompleteDingTalkOAuthRegistration username fallback。
|
||||
func TestCompleteDingTalkRegistration_UsernameFromEmailLocalPart(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
username string
|
||||
wantUser string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "username empty, normal email → local part",
|
||||
email: "dingtalk-uid123@dingtalk-connect.invalid",
|
||||
username: "",
|
||||
wantUser: "dingtalk-uid123",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "username already set → keep original",
|
||||
email: "user@example.com",
|
||||
username: "张三",
|
||||
wantUser: "张三",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "username empty, no @ in email → use whole email",
|
||||
email: "noemail",
|
||||
username: "",
|
||||
wantUser: "noemail",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "both empty → invalid",
|
||||
email: "",
|
||||
username: "",
|
||||
wantUser: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
username := tc.username
|
||||
email := tc.email
|
||||
|
||||
// 模拟 CompleteDingTalkOAuthRegistration 中的 fallback 逻辑
|
||||
if username == "" {
|
||||
if at := strings.Index(email, "@"); at > 0 {
|
||||
username = email[:at]
|
||||
} else {
|
||||
username = email
|
||||
}
|
||||
}
|
||||
|
||||
isValid := email != "" && username != ""
|
||||
require.Equal(t, tc.wantUser, username, fmt.Sprintf("username for email=%q", tc.email))
|
||||
require.Equal(t, tc.wantValid, isValid, "validity check")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildDingTalkUpstreamClaims_SubjectEqualsUnionID 验证重构后 subject = unionID
|
||||
// 而非 staff.UserID,与 identityKey.ProviderSubject 保持一致。
|
||||
// §4.2: buildDingTalkUpstreamClaims subject 字段修正。
|
||||
func TestBuildDingTalkUpstreamClaims_SubjectEqualsUnionID(t *testing.T) {
|
||||
staff := &DingTalkStaffInfo{UserID: "user123", Name: "张三", Email: "zhangsan@corp.com"}
|
||||
claims := buildDingTalkUpstreamClaims(staff, "union456", "dingcorp789")
|
||||
|
||||
// 重构后 subject = unionID(全局唯一,与 identityKey.ProviderSubject 一致)
|
||||
require.Equal(t, "union456", claims["subject"], "subject should equal unionID after refactor")
|
||||
// 企业 userid 保留为独立字段,供 audit/debug 使用
|
||||
require.Equal(t, "user123", claims["corp_user_id"], "corp_user_id should be staff.UserID")
|
||||
// union_id 字段与 subject 相同(冗余保留,便于读取)
|
||||
require.Equal(t, "union456", claims["union_id"])
|
||||
require.Equal(t, "dingcorp789", claims["corp_id"])
|
||||
require.Equal(t, "张三", claims["username"])
|
||||
require.Equal(t, "zhangsan@corp.com", claims["email"])
|
||||
}
|
||||
|
||||
// TestBuildDingTalkUpstreamClaims_CrossOrgEmptyCorpUserID 验证跨组织降级时
|
||||
// corp_user_id 为空字符串(跨组织拿不到企业 userid),subject 仍为 unionID。
|
||||
func TestBuildDingTalkUpstreamClaims_CrossOrgEmptyCorpUserID(t *testing.T) {
|
||||
// 跨组织降级路径:staff = &DingTalkStaffInfo{}(所有字段为零值)
|
||||
staff := &DingTalkStaffInfo{}
|
||||
claims := buildDingTalkUpstreamClaims(staff, "union_cross_org", "foreign_corp")
|
||||
|
||||
require.Equal(t, "union_cross_org", claims["subject"], "subject should still be unionID for cross-org users")
|
||||
require.Equal(t, "", claims["corp_user_id"], "corp_user_id should be empty for cross-org fallback")
|
||||
require.Equal(t, "", claims["email"])
|
||||
require.Equal(t, "", claims["username"])
|
||||
}
|
||||
|
||||
// TestBuildDingTalkUpstreamClaims_PrimaryDeptIDInClaims 验证首个 dept_id 被存入 claims。
|
||||
func TestBuildDingTalkUpstreamClaims_PrimaryDeptIDInClaims(t *testing.T) {
|
||||
staff := &DingTalkStaffInfo{UserID: "u1", Name: "张三", Email: "a@b.com", DeptIDs: []int64{42, 99}}
|
||||
claims := buildDingTalkUpstreamClaims(staff, "uid1", "corpX")
|
||||
|
||||
// 只取首个 dept_id
|
||||
require.Equal(t, int64(42), claims["primary_dept_id"], "primary_dept_id should be the first dept_id")
|
||||
}
|
||||
|
||||
// TestBuildDingTalkUpstreamClaims_NoDeptIDs 验证无部门时 primary_dept_id=0。
|
||||
func TestBuildDingTalkUpstreamClaims_NoDeptIDs(t *testing.T) {
|
||||
staff := &DingTalkStaffInfo{UserID: "u2", Name: "李四"}
|
||||
claims := buildDingTalkUpstreamClaims(staff, "uid2", "corpY")
|
||||
|
||||
require.Equal(t, int64(0), claims["primary_dept_id"], "primary_dept_id should be 0 when no depts")
|
||||
}
|
||||
|
||||
// TestDingTalkStaffFromClaims_RoundTrip 验证 dingTalkStaffFromClaims 能从 claims 恢复 staff 信息。
|
||||
func TestDingTalkStaffFromClaims_RoundTrip(t *testing.T) {
|
||||
staff := &DingTalkStaffInfo{UserID: "u3", Name: "王五", Email: "ww@corp.com", DeptIDs: []int64{55}}
|
||||
claims := buildDingTalkUpstreamClaims(staff, "uid3", "corpZ")
|
||||
|
||||
recovered := dingTalkStaffFromClaims(claims)
|
||||
require.Equal(t, "王五", recovered.Name)
|
||||
require.Equal(t, "ww@corp.com", recovered.Email)
|
||||
require.Equal(t, "u3", recovered.UserID)
|
||||
require.Equal(t, []int64{55}, recovered.DeptIDs)
|
||||
}
|
||||
|
||||
// TestResolveDingTalkDeptPath_SingleLevel 验证单层部门(parent_id=1)返回部门名。
|
||||
func TestResolveDingTalkDeptPath_SingleLevel(t *testing.T) {
|
||||
handler := &AuthHandler{}
|
||||
callCount := 0
|
||||
responses := map[string]string{
|
||||
"42": `{"errcode":0,"result":{"dept_id":42,"name":"研发部","parent_id":1}}`,
|
||||
"1": `{"errcode":0,"result":{"dept_id":1,"name":"公司","parent_id":0}}`,
|
||||
}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
var req struct {
|
||||
DeptID int64 `json:"dept_id"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if resp, ok := responses[fmt.Sprintf("%d", req.DeptID)]; ok {
|
||||
_, _ = w.Write([]byte(resp))
|
||||
} else {
|
||||
_, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"not found"}`))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cli := &DingTalkClient{
|
||||
cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"},
|
||||
httpClient: server.Client(),
|
||||
}
|
||||
cli.appToken = "tok"
|
||||
cli.appTokenExp = time.Now().Add(time.Hour)
|
||||
|
||||
path, err := handler.resolveDingTalkDeptPath(context.Background(), cli, 42)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "研发部", path)
|
||||
require.Equal(t, 2, callCount)
|
||||
}
|
||||
|
||||
// TestSyncDingTalkIdentity_UsesCfgAttrKeys 验证 syncDingTalkIdentity 使用 cfg 中配置的 attr key
|
||||
// 而不是硬编码值。通过 userAttributeService=nil 使同步路径走 warn 跳过,但在此之前先验证
|
||||
// syncField 构建逻辑(即 attr key 从 cfg 读取)。
|
||||
// 间接验证:通过构造定制 cfg,确认不同 attr key 可以正确传入(编译时保证类型正确,运行时不 panic)。
|
||||
func TestSyncDingTalkIdentity_UsesCfgAttrKeys_NoopWithNilService(t *testing.T) {
|
||||
handler := &AuthHandler{
|
||||
userAttributeService: nil, // nil → 触发 warn 跳过,但不 panic
|
||||
}
|
||||
|
||||
cfg := config.DingTalkConnectConfig{
|
||||
CorpRestrictionPolicy: "internal_only",
|
||||
SyncCorpEmail: true,
|
||||
SyncDisplayName: true,
|
||||
SyncDept: true,
|
||||
// 自定义 attr key(非默认值)
|
||||
SyncCorpEmailAttrKey: "custom_email_key",
|
||||
SyncDisplayNameAttrKey: "custom_name_key",
|
||||
SyncDeptAttrKey: "custom_dept_key",
|
||||
}
|
||||
|
||||
staff := &DingTalkStaffInfo{
|
||||
Name: "张三",
|
||||
Email: "zhangsan@example.com",
|
||||
}
|
||||
|
||||
// 调用不应 panic(userAttributeService 为 nil 时走 warn 跳过路径)
|
||||
require.NotPanics(t, func() {
|
||||
handler.syncDingTalkIdentity(context.Background(), cfg, nil, 42, staff, false)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSyncDingTalkIdentity_DefaultAttrKeys_NoopWithNilService 验证 cfg 默认 attr key 为空时
|
||||
// 使用 fallback 默认值(dingtalk_email / dingtalk_name / dingtalk_department)。
|
||||
// 此测试主要验证调用路径不 panic;实际 key 赋值默认值的逻辑在 GetDingTalkConnectOAuthConfig 层。
|
||||
func TestSyncDingTalkIdentity_DefaultAttrKeys_NoopWithNilService(t *testing.T) {
|
||||
handler := &AuthHandler{
|
||||
userAttributeService: nil,
|
||||
}
|
||||
|
||||
cfg := config.DingTalkConnectConfig{
|
||||
CorpRestrictionPolicy: "internal_only",
|
||||
SyncCorpEmail: true,
|
||||
SyncDisplayName: true,
|
||||
SyncDept: false,
|
||||
// 不设置 attr key(等同于 GetDingTalkConnectOAuthConfig 未设置时 fallback 后的默认值已在调用前填充)
|
||||
SyncCorpEmailAttrKey: "dingtalk_email",
|
||||
SyncDisplayNameAttrKey: "dingtalk_name",
|
||||
SyncDeptAttrKey: "dingtalk_department",
|
||||
}
|
||||
|
||||
staff := &DingTalkStaffInfo{
|
||||
Name: "李四",
|
||||
Email: "lisi@corp.com",
|
||||
}
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
handler.syncDingTalkIdentity(context.Background(), cfg, nil, 99, staff, false)
|
||||
})
|
||||
}
|
||||
|
||||
// TestResolveDingTalkDeptPath_MultiLevel 验证多层部门路径拼接。
|
||||
func TestResolveDingTalkDeptPath_MultiLevel(t *testing.T) {
|
||||
handler := &AuthHandler{}
|
||||
// 模拟:42(AI研发) → parent=10(研发部) → parent=1(根)
|
||||
responses := map[string]string{
|
||||
"42": `{"errcode":0,"result":{"dept_id":42,"name":"AI研发","parent_id":10}}`,
|
||||
"10": `{"errcode":0,"result":{"dept_id":10,"name":"研发部","parent_id":1}}`,
|
||||
"1": `{"errcode":0,"result":{"dept_id":1,"name":"公司","parent_id":0}}`,
|
||||
}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 解析请求 body 拿到 dept_id
|
||||
var req struct {
|
||||
DeptID int64 `json:"dept_id"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
key := fmt.Sprintf("%d", req.DeptID)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if resp, ok := responses[key]; ok {
|
||||
_, _ = w.Write([]byte(resp))
|
||||
} else {
|
||||
_, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"not found"}`))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cli := &DingTalkClient{
|
||||
cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"},
|
||||
httpClient: server.Client(),
|
||||
}
|
||||
cli.appToken = "tok"
|
||||
cli.appTokenExp = time.Now().Add(time.Hour)
|
||||
|
||||
path, err := handler.resolveDingTalkDeptPath(context.Background(), cli, 42)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "研发部/AI研发", path)
|
||||
}
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
@ -18,25 +19,30 @@ import (
|
||||
|
||||
// AuthHandler handles authentication-related requests
|
||||
type AuthHandler struct {
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
settingSvc *service.SettingService
|
||||
promoService *service.PromoService
|
||||
redeemService *service.RedeemService
|
||||
totpService *service.TotpService
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
settingSvc *service.SettingService
|
||||
promoService *service.PromoService
|
||||
redeemService *service.RedeemService
|
||||
totpService *service.TotpService
|
||||
userAttributeService *service.UserAttributeService
|
||||
|
||||
dingTalkClientInstance *DingTalkClient
|
||||
dingTalkClientMu sync.Mutex
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService) *AuthHandler {
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService, userAttributeService *service.UserAttributeService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
settingSvc: settingService,
|
||||
promoService: promoService,
|
||||
redeemService: redeemService,
|
||||
totpService: totpService,
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
settingSvc: settingService,
|
||||
promoService: promoService,
|
||||
redeemService: redeemService,
|
||||
totpService: totpService,
|
||||
userAttributeService: userAttributeService,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -350,7 +350,8 @@ func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email stri
|
||||
if email == "" ||
|
||||
strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
|
||||
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(email, service.DingTalkConnectSyntheticEmailDomain) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@ -519,7 +520,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "linuxdo")
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@ -195,6 +196,14 @@ func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPen
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("pending auth session create failed",
|
||||
"intent", strings.TrimSpace(payload.Intent),
|
||||
"provider_type", strings.TrimSpace(payload.Identity.ProviderType),
|
||||
"provider_key", strings.TrimSpace(payload.Identity.ProviderKey),
|
||||
"provider_subject_len", len(strings.TrimSpace(payload.Identity.ProviderSubject)),
|
||||
"resolved_email_len", len(strings.TrimSpace(payload.ResolvedEmail)),
|
||||
"has_target_user", payload.TargetUserID != nil,
|
||||
"error", err.Error())
|
||||
return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err)
|
||||
}
|
||||
|
||||
@ -266,6 +275,22 @@ func pendingSessionWantsInvitation(payload map[string]any) bool {
|
||||
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required")
|
||||
}
|
||||
|
||||
// pendingSessionRequiresEmailCompletion 判断 callback 写入的 completion payload 是否处于"补邮箱"状态。
|
||||
// 钉钉跨组织/staff 邮箱缺失时进入此状态:前端跳到补邮箱页,exchange 不应走 adoption apply。
|
||||
func pendingSessionRequiresEmailCompletion(payload map[string]any) bool {
|
||||
if v, ok := payload["requires_email_completion"].(bool); ok && v {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "email_completion")
|
||||
}
|
||||
|
||||
// pendingSessionRequiresBindLogin 判断 callback 写入的 completion payload 是否处于"必须绑定已有账户"状态。
|
||||
// 钉钉 signupBlocked=true(注册关 + 钉钉企业豁免关)时进入此状态:前端渲染 bind_login 表单,
|
||||
// exchange 不应消费 session,否则后续 /pending/bind-login 找不到 session。
|
||||
func pendingSessionRequiresBindLogin(payload map[string]any) bool {
|
||||
return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "bind_login_required")
|
||||
}
|
||||
|
||||
func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool {
|
||||
if session == nil {
|
||||
return false
|
||||
@ -1467,8 +1492,10 @@ func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string]
|
||||
delete(normalized, key)
|
||||
}
|
||||
step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step")))
|
||||
// 把多种 choice 别名归一为 oauthPendingChoiceStep;bind_login_required 是独立终态
|
||||
// (前端渲染 needsBindLogin 而非 needsChooser),故不能并入归一化列表。
|
||||
switch step {
|
||||
case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required":
|
||||
case "choice", "choose_account_action", "choose_account", "choose", "email_required":
|
||||
normalized["step"] = oauthPendingChoiceStep
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(normalized, "step")), oauthPendingChoiceStep) {
|
||||
@ -1594,6 +1621,8 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) {
|
||||
}
|
||||
|
||||
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
||||
// bindPendingOAuthLogin = 绑定已有账户登录,不动 users.username(用户已有自己的名字)
|
||||
h.maybeSyncDingTalkAfterLogin(c.Request.Context(), session, user.ID)
|
||||
tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "")
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate token pair")
|
||||
@ -1792,6 +1821,8 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string)
|
||||
}
|
||||
|
||||
h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID)
|
||||
// createPendingOAuthAccount = 注册新账户,需要把钉钉昵称同步到 users.username 作为初始值
|
||||
h.maybeSyncDingTalkAfterRegistration(c.Request.Context(), session, user.ID)
|
||||
clearCookies()
|
||||
writeOAuthTokenPairResponse(c, tokenPair)
|
||||
}
|
||||
@ -1893,6 +1924,14 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) {
|
||||
response.Success(c, payload)
|
||||
return
|
||||
}
|
||||
if pendingSessionRequiresEmailCompletion(payload) {
|
||||
response.Success(c, payload)
|
||||
return
|
||||
}
|
||||
if pendingSessionRequiresBindLogin(payload) {
|
||||
response.Success(c, payload)
|
||||
return
|
||||
}
|
||||
if !adoptionDecision.hasDecision() {
|
||||
adoptionRequired, _ := payload["adoption_required"].(bool)
|
||||
if adoptionRequired {
|
||||
|
||||
@ -502,7 +502,8 @@ func (h *AuthHandler) findOIDCCompatEmailUser(ctx context.Context, email string)
|
||||
if email == "" ||
|
||||
strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) {
|
||||
strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) ||
|
||||
strings.HasSuffix(email, service.DingTalkConnectSyntheticEmailDomain) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@ -666,7 +667,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "oidc")
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@ -548,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode)
|
||||
tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "wechat")
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
67
backend/internal/handler/dto/account_mapper_redact_test.go
Normal file
67
backend/internal/handler/dto/account_mapper_redact_test.go
Normal file
@ -0,0 +1,67 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func TestAccountFromServiceShallow_RedactsSensitiveCredentials(t *testing.T) {
|
||||
src := &service.Account{
|
||||
ID: 42,
|
||||
Name: "demo",
|
||||
Platform: "anthropic",
|
||||
Type: "oauth",
|
||||
Credentials: map[string]any{
|
||||
"access_token": "at-secret",
|
||||
"refresh_token": "rt-secret",
|
||||
"id_token": "id-secret",
|
||||
"api_key": "sk-secret",
|
||||
"base_url": "https://api.example.com",
|
||||
"model_mapping": map[string]any{"foo": "bar"},
|
||||
},
|
||||
}
|
||||
|
||||
got := AccountFromServiceShallow(src)
|
||||
require.NotNil(t, got)
|
||||
|
||||
// 敏感键不在 Credentials 里
|
||||
require.NotContains(t, got.Credentials, "access_token")
|
||||
require.NotContains(t, got.Credentials, "refresh_token")
|
||||
require.NotContains(t, got.Credentials, "id_token")
|
||||
require.NotContains(t, got.Credentials, "api_key")
|
||||
// 非敏感键保留
|
||||
require.Equal(t, "https://api.example.com", got.Credentials["base_url"])
|
||||
require.Equal(t, map[string]any{"foo": "bar"}, got.Credentials["model_mapping"])
|
||||
|
||||
// 状态 map 标记敏感键存在
|
||||
require.True(t, got.CredentialsStatus["has_access_token"])
|
||||
require.True(t, got.CredentialsStatus["has_refresh_token"])
|
||||
require.True(t, got.CredentialsStatus["has_id_token"])
|
||||
require.True(t, got.CredentialsStatus["has_api_key"])
|
||||
|
||||
// JSON 序列化校验:响应体里不会出现敏感子串
|
||||
raw, err := json.Marshal(got)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(raw), "rt-secret")
|
||||
require.NotContains(t, string(raw), "at-secret")
|
||||
require.NotContains(t, string(raw), "sk-secret")
|
||||
require.NotContains(t, string(raw), "id-secret")
|
||||
// 状态标识应序列化进 JSON
|
||||
require.Contains(t, string(raw), "credentials_status")
|
||||
require.Contains(t, string(raw), "has_refresh_token")
|
||||
|
||||
// 原始 service.Account 不应被改动
|
||||
require.Equal(t, "rt-secret", src.Credentials["refresh_token"])
|
||||
}
|
||||
|
||||
func TestAccountFromServiceShallow_NilCredentialsOmitsStatus(t *testing.T) {
|
||||
src := &service.Account{ID: 1, Name: "n", Platform: "anthropic", Type: "oauth"}
|
||||
got := AccountFromServiceShallow(src)
|
||||
require.NotNil(t, got)
|
||||
require.Nil(t, got.Credentials)
|
||||
require.Nil(t, got.CredentialsStatus)
|
||||
}
|
||||
44
backend/internal/handler/dto/credentials_redact.go
Normal file
44
backend/internal/handler/dto/credentials_redact.go
Normal file
@ -0,0 +1,44 @@
|
||||
// Package dto provides data transfer objects for HTTP handlers.
|
||||
package dto
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
// RedactCredentials 复制一份 in,剥离 service.SensitiveCredentialKeys 列出的所有敏感子键,
|
||||
// 并产出一个 has_<key> 状态 map 表示哪些敏感键存在且非零值。
|
||||
//
|
||||
// 输入 nil 时返回 nil, nil(避免响应里出现空对象)。
|
||||
// 不修改入参;调用方拿到的 out 可安全序列化进 JSON 返回前端。
|
||||
func RedactCredentials(in map[string]any) (out map[string]any, status map[string]bool) {
|
||||
if in == nil {
|
||||
return nil, nil
|
||||
}
|
||||
out = make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
if service.IsSensitiveCredentialKey(k) {
|
||||
if isCredentialValuePresent(v) {
|
||||
if status == nil {
|
||||
status = make(map[string]bool, 4)
|
||||
}
|
||||
status["has_"+k] = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
out[k] = v
|
||||
}
|
||||
return out, status
|
||||
}
|
||||
|
||||
// isCredentialValuePresent 判断值是否"存在且非零"。空字符串、nil、false 均视为未配置;
|
||||
// 其余非零类型(数字、对象、字符串等)视为已配置。
|
||||
func isCredentialValuePresent(v any) bool {
|
||||
switch x := v.(type) {
|
||||
case nil:
|
||||
return false
|
||||
case string:
|
||||
return x != ""
|
||||
case bool:
|
||||
return x
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
97
backend/internal/handler/dto/credentials_redact_test.go
Normal file
97
backend/internal/handler/dto/credentials_redact_test.go
Normal file
@ -0,0 +1,97 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRedactCredentials_NilInput(t *testing.T) {
|
||||
out, status := RedactCredentials(nil)
|
||||
require.Nil(t, out)
|
||||
require.Nil(t, status)
|
||||
}
|
||||
|
||||
func TestRedactCredentials_StripsSensitiveKeysAndReportsStatus(t *testing.T) {
|
||||
in := map[string]any{
|
||||
"refresh_token": "rt-secret",
|
||||
"access_token": "at-secret",
|
||||
"api_key": "sk-secret",
|
||||
"aws_secret_access_key": "aws-secret",
|
||||
"service_account_json": map[string]any{"private_key": "..."},
|
||||
"private_key": "raw-key",
|
||||
// 非敏感
|
||||
"base_url": "https://api.example.com",
|
||||
"model_mapping": map[string]any{"foo": "bar"},
|
||||
"project_id": "proj-1",
|
||||
"expires_at": int64(123456),
|
||||
}
|
||||
|
||||
out, status := RedactCredentials(in)
|
||||
|
||||
require.NotContains(t, out, "refresh_token")
|
||||
require.NotContains(t, out, "access_token")
|
||||
require.NotContains(t, out, "api_key")
|
||||
require.NotContains(t, out, "aws_secret_access_key")
|
||||
require.NotContains(t, out, "service_account_json")
|
||||
require.NotContains(t, out, "private_key")
|
||||
|
||||
require.Equal(t, "https://api.example.com", out["base_url"])
|
||||
require.Equal(t, map[string]any{"foo": "bar"}, out["model_mapping"])
|
||||
require.Equal(t, "proj-1", out["project_id"])
|
||||
require.Equal(t, int64(123456), out["expires_at"])
|
||||
|
||||
require.True(t, status["has_refresh_token"])
|
||||
require.True(t, status["has_access_token"])
|
||||
require.True(t, status["has_api_key"])
|
||||
require.True(t, status["has_aws_secret_access_key"])
|
||||
require.True(t, status["has_service_account_json"])
|
||||
require.True(t, status["has_private_key"])
|
||||
|
||||
// 状态 map 不应携带非敏感键的 has_*
|
||||
require.NotContains(t, status, "has_base_url")
|
||||
require.NotContains(t, status, "has_project_id")
|
||||
}
|
||||
|
||||
func TestRedactCredentials_EmptyValuesNotMarkedPresent(t *testing.T) {
|
||||
in := map[string]any{
|
||||
"refresh_token": "",
|
||||
"access_token": nil,
|
||||
"api_key": false,
|
||||
"id_token": "actual-id",
|
||||
}
|
||||
out, status := RedactCredentials(in)
|
||||
require.Empty(t, out, "敏感键即使为空也不应出现在 redacted output")
|
||||
require.False(t, status["has_refresh_token"])
|
||||
require.False(t, status["has_access_token"])
|
||||
require.False(t, status["has_api_key"])
|
||||
require.True(t, status["has_id_token"])
|
||||
}
|
||||
|
||||
func TestRedactCredentials_DoesNotMutateInput(t *testing.T) {
|
||||
in := map[string]any{
|
||||
"refresh_token": "secret",
|
||||
"base_url": "x",
|
||||
}
|
||||
_, _ = RedactCredentials(in)
|
||||
require.Equal(t, "secret", in["refresh_token"], "原始 map 不应被修改")
|
||||
require.Equal(t, "x", in["base_url"])
|
||||
}
|
||||
|
||||
func TestRedactCredentials_AllKnownSensitiveKeys(t *testing.T) {
|
||||
keys := []string{
|
||||
"access_token", "refresh_token", "id_token",
|
||||
"api_key", "session_key", "cookie",
|
||||
"aws_secret_access_key", "aws_session_token",
|
||||
"service_account_json", "service_account", "private_key",
|
||||
}
|
||||
in := make(map[string]any, len(keys))
|
||||
for _, k := range keys {
|
||||
in[k] = "filled"
|
||||
}
|
||||
out, status := RedactCredentials(in)
|
||||
require.Empty(t, out)
|
||||
for _, k := range keys {
|
||||
require.True(t, status["has_"+k], "key %s 应在 status 中标记为已配置", k)
|
||||
}
|
||||
}
|
||||
@ -198,13 +198,15 @@ func AccountFromServiceShallow(a *service.Account) *Account {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
redactedCreds, credsStatus := RedactCredentials(a.Credentials)
|
||||
out := &Account{
|
||||
ID: a.ID,
|
||||
Name: a.Name,
|
||||
Notes: a.Notes,
|
||||
Platform: a.Platform,
|
||||
Type: a.Type,
|
||||
Credentials: a.Credentials,
|
||||
Credentials: redactedCreds,
|
||||
CredentialsStatus: credsStatus,
|
||||
Extra: a.Extra,
|
||||
ProxyID: a.ProxyID,
|
||||
Concurrency: a.Concurrency,
|
||||
@ -531,11 +533,15 @@ func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode {
|
||||
UsedBy: rc.UsedBy,
|
||||
UsedAt: rc.UsedAt,
|
||||
CreatedAt: rc.CreatedAt,
|
||||
ExpiresAt: rc.ExpiresAt,
|
||||
GroupID: rc.GroupID,
|
||||
ValidityDays: rc.ValidityDays,
|
||||
User: UserFromServiceShallow(rc.User),
|
||||
Group: GroupFromServiceShallow(rc.Group),
|
||||
}
|
||||
if rc.IsExpired() {
|
||||
out.Status = service.StatusExpired
|
||||
}
|
||||
|
||||
// For admin_balance/admin_concurrency types, include notes so users can see
|
||||
// why they were charged or credited by admin
|
||||
@ -600,6 +606,10 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
|
||||
FirstTokenMs: l.FirstTokenMs,
|
||||
ImageCount: l.ImageCount,
|
||||
ImageSize: l.ImageSize,
|
||||
ImageInputSize: l.ImageInputSize,
|
||||
ImageOutputSize: l.ImageOutputSize,
|
||||
ImageSizeSource: l.ImageSizeSource,
|
||||
ImageSizeBreakdown: l.ImageSizeBreakdown,
|
||||
MediaType: l.MediaType,
|
||||
UserAgent: l.UserAgent,
|
||||
CacheTTLOverridden: l.CacheTTLOverridden,
|
||||
|
||||
@ -148,6 +148,65 @@ func TestUsageLogFromService_FallsBackToLegacyModelWhenRequestedModelMissing(t *
|
||||
require.Equal(t, "claude-3", adminDTO.Model)
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_IncludesImageBillingMetadataForUserAndAdmin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
imageSize := "4K"
|
||||
inputSize := "1024x1024"
|
||||
outputSize := "3840x2160"
|
||||
source := "output"
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_image_metadata",
|
||||
Model: "gpt-image-2",
|
||||
ImageCount: 2,
|
||||
ImageSize: &imageSize,
|
||||
ImageInputSize: &inputSize,
|
||||
ImageOutputSize: &outputSize,
|
||||
ImageSizeSource: &source,
|
||||
ImageSizeBreakdown: map[string]int{"4K": 2},
|
||||
}
|
||||
|
||||
userDTO := UsageLogFromService(log)
|
||||
adminDTO := UsageLogFromServiceAdmin(log)
|
||||
|
||||
for _, got := range []*UsageLog{userDTO, &adminDTO.UsageLog} {
|
||||
require.Equal(t, 2, got.ImageCount)
|
||||
require.NotNil(t, got.ImageSize)
|
||||
require.Equal(t, imageSize, *got.ImageSize)
|
||||
require.NotNil(t, got.ImageInputSize)
|
||||
require.Equal(t, inputSize, *got.ImageInputSize)
|
||||
require.NotNil(t, got.ImageOutputSize)
|
||||
require.Equal(t, outputSize, *got.ImageOutputSize)
|
||||
require.NotNil(t, got.ImageSizeSource)
|
||||
require.Equal(t, source, *got.ImageSizeSource)
|
||||
require.Equal(t, map[string]int{"4K": 2}, got.ImageSizeBreakdown)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageLogFromService_PreservesHistoricalMissingImageSize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := &service.UsageLog{
|
||||
RequestID: "req_legacy_image_missing_size",
|
||||
Model: "gpt-image-2",
|
||||
ImageCount: 1,
|
||||
ImageSize: nil,
|
||||
}
|
||||
|
||||
dto := UsageLogFromService(log)
|
||||
require.Equal(t, 1, dto.ImageCount)
|
||||
require.Nil(t, dto.ImageSize)
|
||||
require.Nil(t, dto.ImageInputSize)
|
||||
require.Nil(t, dto.ImageOutputSize)
|
||||
require.Nil(t, dto.ImageSizeSource)
|
||||
require.Nil(t, dto.ImageSizeBreakdown)
|
||||
|
||||
body, err := json.Marshal(dto)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(body), `"image_size":null`)
|
||||
require.NotContains(t, string(body), `"image_size":"2K"`)
|
||||
}
|
||||
|
||||
func f64Ptr(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
@ -56,6 +56,23 @@ type SystemSettings struct {
|
||||
LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"`
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
DingTalkConnectEnabled bool `json:"dingtalk_connect_enabled"`
|
||||
DingTalkConnectClientID string `json:"dingtalk_connect_client_id"`
|
||||
DingTalkConnectClientSecretConfigured bool `json:"dingtalk_connect_client_secret_configured"`
|
||||
DingTalkConnectRedirectURL string `json:"dingtalk_connect_redirect_url"`
|
||||
DingTalkConnectCorpRestrictionPolicy string `json:"dingtalk_connect_corp_restriction_policy"`
|
||||
DingTalkConnectInternalCorpID string `json:"dingtalk_connect_internal_corp_id"`
|
||||
DingTalkConnectBypassRegistration bool `json:"dingtalk_connect_bypass_registration"`
|
||||
DingTalkConnectSyncCorpEmail bool `json:"dingtalk_connect_sync_corp_email"`
|
||||
DingTalkConnectSyncDisplayName bool `json:"dingtalk_connect_sync_display_name"`
|
||||
DingTalkConnectSyncDept bool `json:"dingtalk_connect_sync_dept"`
|
||||
DingTalkConnectSyncCorpEmailAttrKey string `json:"dingtalk_connect_sync_corp_email_attr_key"`
|
||||
DingTalkConnectSyncDisplayNameAttrKey string `json:"dingtalk_connect_sync_display_name_attr_key"`
|
||||
DingTalkConnectSyncDeptAttrKey string `json:"dingtalk_connect_sync_dept_attr_key"`
|
||||
DingTalkConnectSyncCorpEmailAttrName string `json:"dingtalk_connect_sync_corp_email_attr_name"`
|
||||
DingTalkConnectSyncDisplayNameAttrName string `json:"dingtalk_connect_sync_display_name_attr_name"`
|
||||
DingTalkConnectSyncDeptAttrName string `json:"dingtalk_connect_sync_dept_attr_name"`
|
||||
|
||||
WeChatConnectEnabled bool `json:"wechat_connect_enabled"`
|
||||
WeChatConnectAppID string `json:"wechat_connect_app_id"`
|
||||
WeChatConnectAppSecretConfigured bool `json:"wechat_connect_app_secret_configured"`
|
||||
@ -201,6 +218,9 @@ type SystemSettings struct {
|
||||
PaymentCancelRateLimitUnit string `json:"payment_cancel_rate_limit_unit"`
|
||||
PaymentCancelRateLimitMode string `json:"payment_cancel_rate_limit_window_mode"`
|
||||
|
||||
// Force Alipay mobile clients to use QR code payment instead of mobile redirect
|
||||
PaymentAlipayForceQRCode bool `json:"payment_alipay_force_qrcode"`
|
||||
|
||||
// Balance low notification
|
||||
BalanceLowNotifyEnabled bool `json:"balance_low_notify_enabled"`
|
||||
BalanceLowNotifyThreshold float64 `json:"balance_low_notify_threshold"`
|
||||
@ -260,6 +280,7 @@ type PublicSettings struct {
|
||||
TablePageSizeOptions []int `json:"table_page_size_options"`
|
||||
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
|
||||
CustomEndpoints []CustomEndpoint `json:"custom_endpoints"`
|
||||
DingTalkOAuthEnabled bool `json:"dingtalk_oauth_enabled"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"`
|
||||
WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"`
|
||||
|
||||
@ -149,25 +149,28 @@ type AdminGroup struct {
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
LoadFactor *int `json:"load_factor,omitempty"`
|
||||
Priority int `json:"priority"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
Status string `json:"status"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired bool `json:"auto_pause_on_expired"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
// Credentials 经 RedactCredentials 处理后只含非敏感子键;敏感 token / api_key / 私钥
|
||||
// 的存在性通过 CredentialsStatus(has_<key>)暴露,原始值不返回前端。
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
CredentialsStatus map[string]bool `json:"credentials_status,omitempty"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
LoadFactor *int `json:"load_factor,omitempty"`
|
||||
Priority int `json:"priority"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
Status string `json:"status"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
AutoPauseOnExpired bool `json:"auto_pause_on_expired"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
Schedulable bool `json:"schedulable"`
|
||||
|
||||
@ -335,6 +338,7 @@ type RedeemCode struct {
|
||||
UsedBy *int64 `json:"used_by"`
|
||||
UsedAt *time.Time `json:"used_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
|
||||
GroupID *int64 `json:"group_id"`
|
||||
ValidityDays int `json:"validity_days"`
|
||||
@ -400,9 +404,13 @@ type UsageLog struct {
|
||||
FirstTokenMs *int `json:"first_token_ms"`
|
||||
|
||||
// 图片生成字段
|
||||
ImageCount int `json:"image_count"`
|
||||
ImageSize *string `json:"image_size"`
|
||||
MediaType *string `json:"media_type"`
|
||||
ImageCount int `json:"image_count"`
|
||||
ImageSize *string `json:"image_size"`
|
||||
ImageInputSize *string `json:"image_input_size"`
|
||||
ImageOutputSize *string `json:"image_output_size"`
|
||||
ImageSizeSource *string `json:"image_size_source"`
|
||||
ImageSizeBreakdown map[string]int `json:"image_size_breakdown"`
|
||||
MediaType *string `json:"media_type"`
|
||||
|
||||
// User-Agent
|
||||
UserAgent *string `json:"user_agent"`
|
||||
|
||||
@ -18,6 +18,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
@ -157,7 +158,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
setOpsRequestContext(c, "", false)
|
||||
|
||||
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||
if err != nil {
|
||||
@ -209,7 +210,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
|
||||
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsRequestContext(c, reqModel, reqStream)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 验证 model 必填
|
||||
@ -351,6 +352,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
reqLog.Warn("gateway.select_account_no_available",
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64p("group_id", apiKey.GroupID),
|
||||
@ -400,6 +402,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("model", reqModel),
|
||||
@ -621,6 +624,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID)
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
reqLog.Warn("gateway.select_account_no_available",
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64p("group_id", currentAPIKey.GroupID),
|
||||
@ -681,6 +685,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
reqLog.Warn("gateway.select_account_no_slot_no_wait_plan",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("model", reqModel),
|
||||
@ -1041,8 +1046,8 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
platform = forcedPlatform
|
||||
}
|
||||
|
||||
// Get available models from account configurations (without platform filter)
|
||||
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
|
||||
// Get available models from account configurations for the selected group platform.
|
||||
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, platform)
|
||||
|
||||
if len(availableModels) > 0 {
|
||||
// Build model list from whitelist
|
||||
@ -1063,7 +1068,7 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Fallback to default models
|
||||
if platform == "openai" {
|
||||
if platform == service.PlatformOpenAI {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": openai.DefaultModels,
|
||||
@ -1071,6 +1076,14 @@ func (h *GatewayHandler) Models(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if platform == service.PlatformGemini {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": geminicli.DefaultModels,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": claude.DefaultModels,
|
||||
@ -1407,6 +1420,11 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT
|
||||
func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) {
|
||||
statusCode := failoverErr.StatusCode
|
||||
responseBody := failoverErr.ResponseBody
|
||||
if service.IsOpenAISilentRefusalErrorBody(responseBody) {
|
||||
service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage(), streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 先检查透传规则
|
||||
if h.errorPassthroughService != nil && len(responseBody) > 0 {
|
||||
@ -1611,7 +1629,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
setOpsRequestContext(c, "", false)
|
||||
|
||||
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
|
||||
if err != nil {
|
||||
@ -1630,7 +1648,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body)
|
||||
setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsedReq.Stream, false)))
|
||||
|
||||
// 获取订阅信息(可能为nil)
|
||||
@ -1659,6 +1677,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
|
||||
if err != nil {
|
||||
reqLog.Warn("gateway.count_tokens_select_account_failed", zap.Error(err))
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
@ -60,7 +60,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
setOpsRequestContext(c, "", false)
|
||||
|
||||
// Validate JSON
|
||||
if !gjson.ValidBytes(body) {
|
||||
@ -78,7 +78,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsRequestContext(c, reqModel, reqStream)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
@ -161,14 +161,26 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
APIKeyID: apiKey.ID,
|
||||
}
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
groupPlatform := ""
|
||||
if apiKey.Group != nil {
|
||||
groupPlatform = apiKey.Group.Platform
|
||||
}
|
||||
selectionSessionHash := sessionHash
|
||||
if groupPlatform == service.PlatformGemini && selectionSessionHash != "" {
|
||||
selectionSessionHash = "gemini:" + selectionSessionHash
|
||||
}
|
||||
|
||||
// 3. Account selection + failover loop
|
||||
fs := NewFailoverState(h.maxAccountSwitches, false)
|
||||
if groupPlatform == service.PlatformGemini {
|
||||
fs = NewFailoverState(h.maxAccountSwitchesGemini, false)
|
||||
}
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, selectionSessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
@ -194,6 +206,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
||||
return
|
||||
}
|
||||
@ -213,13 +226,33 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
}
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
if groupPlatform == service.PlatformGemini && account.Platform != service.PlatformGemini {
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
fs.FailedAccountIDs[account.ID] = struct{}{}
|
||||
continue
|
||||
}
|
||||
|
||||
// 5. Forward request
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
forwardBody := body
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
|
||||
var result *service.ForwardResult
|
||||
if account.Platform == service.PlatformGemini {
|
||||
if h.geminiCompatService == nil {
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "upstream_error", "Gemini compatibility service is not configured")
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
return
|
||||
}
|
||||
result, err = h.geminiCompatService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody)
|
||||
} else {
|
||||
result, err = h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
|
||||
}
|
||||
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
@ -302,5 +335,10 @@ func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *serv
|
||||
if lastErr != nil && lastErr.StatusCode > 0 {
|
||||
statusCode = lastErr.StatusCode
|
||||
}
|
||||
if lastErr != nil && service.IsOpenAISilentRefusalErrorBody(lastErr.ResponseBody) {
|
||||
service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
|
||||
h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage())
|
||||
return
|
||||
}
|
||||
h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
|
||||
}
|
||||
|
||||
@ -60,7 +60,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
setOpsRequestContext(c, "", false)
|
||||
|
||||
// Validate JSON
|
||||
if !gjson.ValidBytes(body) {
|
||||
@ -78,7 +78,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsRequestContext(c, reqModel, reqStream)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
// 解析渠道级模型映射
|
||||
@ -174,6 +174,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0))
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
@ -199,6 +200,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
||||
return
|
||||
}
|
||||
@ -308,5 +310,10 @@ func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastEr
|
||||
if lastErr != nil && lastErr.StatusCode > 0 {
|
||||
statusCode = lastErr.StatusCode
|
||||
}
|
||||
if lastErr != nil && service.IsOpenAISilentRefusalErrorBody(lastErr.ResponseBody) {
|
||||
service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
|
||||
h.responsesErrorResponse(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage())
|
||||
return
|
||||
}
|
||||
h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted")
|
||||
}
|
||||
|
||||
136
backend/internal/handler/gateway_models_test.go
Normal file
136
backend/internal/handler/gateway_models_test.go
Normal file
@ -0,0 +1,136 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type gatewayModelsAccountRepoStub struct {
|
||||
service.AccountRepository
|
||||
|
||||
byGroup map[int64][]service.Account
|
||||
}
|
||||
|
||||
type gatewayModelsResponseForTest struct {
|
||||
Object string `json:"object"`
|
||||
Data []gatewayModelItemForTest `json:"data"`
|
||||
}
|
||||
|
||||
type gatewayModelItemForTest struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func (s *gatewayModelsAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) {
|
||||
accounts, ok := s.byGroup[groupID]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
out := make([]service.Account, len(accounts))
|
||||
copy(out, accounts)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func newGatewayModelsHandlerForTest(repo service.AccountRepository) *GatewayHandler {
|
||||
return &GatewayHandler{
|
||||
gatewayService: service.NewGatewayService(
|
||||
repo,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayModels_GeminiGroupFallsBackToGeminiModels(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(20)
|
||||
h := newGatewayModelsHandlerForTest(
|
||||
&gatewayModelsAccountRepoStub{
|
||||
byGroup: map[int64][]service.Account{
|
||||
groupID: {
|
||||
{ID: 1, Platform: service.PlatformGemini},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
Group: &service.Group{ID: groupID, Platform: service.PlatformGemini},
|
||||
})
|
||||
|
||||
h.Models(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var got gatewayModelsResponseForTest
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
require.Equal(t, "list", got.Object)
|
||||
require.Contains(t, modelIDsForTest(got.Data), "gemini-2.5-flash")
|
||||
require.NotContains(t, modelIDsForTest(got.Data), "claude-sonnet-4-6")
|
||||
}
|
||||
|
||||
func TestGatewayModels_GeminiGroupFiltersMappedModelsByPlatform(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
groupID := int64(21)
|
||||
h := newGatewayModelsHandlerForTest(
|
||||
&gatewayModelsAccountRepoStub{
|
||||
byGroup: map[int64][]service.Account{
|
||||
groupID: {
|
||||
{
|
||||
ID: 1,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-6": "claude-sonnet-4-6",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Platform: service.PlatformGemini,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gemini-2.5-flash": "gemini-2.5-flash",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
|
||||
Group: &service.Group{ID: groupID, Platform: service.PlatformGemini},
|
||||
})
|
||||
|
||||
h.Models(c)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var got gatewayModelsResponseForTest
|
||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
|
||||
require.Equal(t, []string{"gemini-2.5-flash"}, modelIDsForTest(got.Data))
|
||||
}
|
||||
|
||||
func modelIDsForTest(models []gatewayModelItemForTest) []string {
|
||||
ids := make([]string, 0, len(models))
|
||||
for _, model := range models {
|
||||
ids = append(ids, model.ID)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
@ -61,6 +61,7 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gemini.FallbackModelsList())
|
||||
return
|
||||
}
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
@ -113,6 +114,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gemini.FallbackModel(modelName))
|
||||
return
|
||||
}
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
@ -182,7 +184,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, modelName, stream, body)
|
||||
setOpsRequestContext(c, modelName, stream)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, authSubject, service.ContentModerationProtocolGemini, modelName, body); decision != nil && decision.Blocked {
|
||||
@ -372,6 +374,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(fs.FailedAccountIDs) == 0 {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
|
||||
return
|
||||
}
|
||||
@ -419,6 +422,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts")
|
||||
return
|
||||
}
|
||||
|
||||
@ -78,7 +78,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsRequestContext(c, reqModel, reqStream)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIChat, reqModel, body); decision != nil && decision.Blocked {
|
||||
@ -143,6 +143,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
} else {
|
||||
@ -155,6 +156,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
@ -176,6 +178,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "")
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
@ -201,6 +204,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
} else {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleFailoverExhausted(c, failoverErr, true)
|
||||
return
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// Pool mode: retry on the same account
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
@ -292,7 +299,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
|
||||
// resolveRawCCUpstreamEndpoint returns the actual upstream endpoint for
|
||||
// OpenAI Chat Completions requests. For APIKey accounts whose upstream
|
||||
// has been probed to not support the Responses API, the request is
|
||||
// is forced or probed to not support the Responses API, the request is
|
||||
// forwarded directly to /v1/chat/completions — not through the default
|
||||
// CC→Responses conversion path.
|
||||
func resolveRawCCUpstreamEndpoint(c *gin.Context, account *service.Account) string {
|
||||
|
||||
@ -130,7 +130,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
setOpsRequestContext(c, "", false)
|
||||
sessionHashBody := body
|
||||
if service.IsOpenAIResponsesCompactPathForTest(c) {
|
||||
if compactSeed := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()); compactSeed != "" {
|
||||
@ -189,7 +189,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsRequestContext(c, reqModel, reqStream)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, body); decision != nil && decision.Blocked {
|
||||
@ -282,6 +282,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
if errors.Is(err, service.ErrNoAvailableCompactAccounts) {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "compact_not_supported", "No available OpenAI accounts support /responses/compact", streamStarted)
|
||||
return
|
||||
@ -297,6 +298,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
@ -330,6 +332,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
if channelMapping.Mapped {
|
||||
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||
}
|
||||
writerSizeBeforeForward := c.Writer.Size()
|
||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody)
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
@ -354,6 +357,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
} else {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
if c.Writer.Size() != writerSizeBeforeForward {
|
||||
h.handleFailoverExhausted(c, failoverErr, true)
|
||||
return
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// 池模式:同账号重试
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
@ -604,7 +611,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
setOpsRequestContext(c, reqModel, reqStream)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolAnthropicMessages, reqModel, body); decision != nil && decision.Blocked {
|
||||
@ -677,6 +684,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
if err != nil {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
@ -690,6 +698,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
@ -992,6 +1001,7 @@ func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
|
||||
reqLog *zap.Logger,
|
||||
) (func(), bool) {
|
||||
if selection == nil || selection.Account == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
@ -1002,6 +1012,7 @@ func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
|
||||
return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true
|
||||
}
|
||||
if selection.WaitPlan == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
@ -1163,7 +1174,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
zap.Bool("has_previous_response_id", previousResponseID != ""),
|
||||
zap.String("previous_response_id_kind", previousResponseIDKind),
|
||||
)
|
||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||
setOpsRequestContext(c, reqModel, true)
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2))
|
||||
|
||||
if decision := h.checkContentModeration(c, reqLog, apiKey, subject, service.ContentModerationProtocolOpenAIResponses, reqModel, firstMessage); decision != nil && decision.Blocked {
|
||||
@ -1598,6 +1609,11 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error,
|
||||
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
|
||||
statusCode := failoverErr.StatusCode
|
||||
responseBody := failoverErr.ResponseBody
|
||||
if service.IsOpenAISilentRefusalErrorBody(responseBody) {
|
||||
service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "")
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage(), streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 先检查透传规则
|
||||
if h.errorPassthroughService != nil && len(responseBody) > 0 {
|
||||
|
||||
@ -63,9 +63,9 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
}
|
||||
|
||||
if isMultipartImagesContentType(c.GetHeader("Content-Type")) {
|
||||
setOpsRequestContext(c, "", false, nil)
|
||||
setOpsRequestContext(c, "", false)
|
||||
} else {
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
setOpsRequestContext(c, "", false)
|
||||
}
|
||||
|
||||
parsed, err := h.gatewayService.ParseOpenAIImagesRequest(c, body)
|
||||
@ -98,9 +98,9 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
}
|
||||
|
||||
if parsed.Multipart {
|
||||
setOpsRequestContext(c, parsed.Model, parsed.Stream, nil)
|
||||
setOpsRequestContext(c, parsed.Model, parsed.Stream)
|
||||
} else {
|
||||
setOpsRequestContext(c, parsed.Model, parsed.Stream, body)
|
||||
setOpsRequestContext(c, parsed.Model, parsed.Stream)
|
||||
}
|
||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(parsed.Stream, false)))
|
||||
|
||||
@ -157,6 +157,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
markOpsRoutingCapacityLimitedIfNoAvailable(c, err)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
@ -168,6 +169,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
@ -22,10 +23,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
opsModelKey = "ops_model"
|
||||
opsStreamKey = "ops_stream"
|
||||
opsRequestBodyKey = "ops_request_body"
|
||||
opsAccountIDKey = "ops_account_id"
|
||||
opsModelKey = "ops_model"
|
||||
opsStreamKey = "ops_stream"
|
||||
opsAccountIDKey = "ops_account_id"
|
||||
opsRoutingCapacityLimitedKey = "ops_routing_capacity_limited"
|
||||
|
||||
opsUpstreamModelKey = "ops_upstream_model"
|
||||
opsRequestTypeKey = "ops_request_type"
|
||||
@ -45,6 +46,8 @@ const (
|
||||
opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND"
|
||||
opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID"
|
||||
opsCodeUserInactive = "USER_INACTIVE"
|
||||
opsCodeInvalidAPIKey = "INVALID_API_KEY"
|
||||
opsCodeAPIKeyRequired = "API_KEY_REQUIRED"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -332,16 +335,13 @@ func opsErrorLogConfig() (workerCount int, queueSize int) {
|
||||
return workerCount, queueSize
|
||||
}
|
||||
|
||||
func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody []byte) {
|
||||
func setOpsRequestContext(c *gin.Context, model string, stream bool) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
model = strings.TrimSpace(model)
|
||||
c.Set(opsModelKey, model)
|
||||
c.Set(opsStreamKey, stream)
|
||||
if len(requestBody) > 0 {
|
||||
c.Set(opsRequestBodyKey, requestBody)
|
||||
}
|
||||
if c.Request != nil && model != "" {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.Model, model)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
@ -360,22 +360,6 @@ func setOpsEndpointContext(c *gin.Context, upstreamModel string, requestType int
|
||||
c.Set(opsRequestTypeKey, requestType)
|
||||
}
|
||||
|
||||
func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
|
||||
if c == nil || entry == nil {
|
||||
return
|
||||
}
|
||||
v, ok := c.Get(opsRequestBodyKey)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
raw, ok := v.([]byte)
|
||||
if !ok || len(raw) == 0 {
|
||||
return
|
||||
}
|
||||
entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = service.PrepareOpsRequestBodyForQueue(raw)
|
||||
opsErrorLogSanitized.Add(1)
|
||||
}
|
||||
|
||||
func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) {
|
||||
if c == nil || accountID <= 0 {
|
||||
return
|
||||
@ -393,6 +377,42 @@ func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string)
|
||||
}
|
||||
}
|
||||
|
||||
func markOpsRoutingCapacityLimited(c *gin.Context) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.Set(opsRoutingCapacityLimitedKey, true)
|
||||
}
|
||||
|
||||
func markOpsRoutingCapacityLimitedIfNoAvailable(c *gin.Context, err error) {
|
||||
if !isOpsNoAvailableAccountError(err) {
|
||||
return
|
||||
}
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
}
|
||||
|
||||
func isOpsRoutingCapacityLimited(c *gin.Context) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
v, ok := c.Get(opsRoutingCapacityLimitedKey)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
marked, _ := v.(bool)
|
||||
return marked
|
||||
}
|
||||
|
||||
func isOpsNoAvailableAccountError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, service.ErrNoAvailableAccounts) || errors.Is(err, service.ErrNoAvailableCompactAccounts) {
|
||||
return true
|
||||
}
|
||||
return isOpsNoAvailableAccountMessage(err.Error())
|
||||
}
|
||||
|
||||
type opsCaptureWriter struct {
|
||||
gin.ResponseWriter
|
||||
limit int
|
||||
@ -671,7 +691,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
|
||||
ErrorPhase: "upstream",
|
||||
ErrorType: "upstream_error",
|
||||
// Severity/retryability should reflect the upstream failure, not the final client status (200).
|
||||
// Severity should reflect the upstream failure, not the final client status (200).
|
||||
Severity: classifyOpsSeverity("upstream_error", effectiveUpstreamStatus),
|
||||
StatusCode: status,
|
||||
IsBusinessLimited: false,
|
||||
@ -688,9 +708,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
UpstreamErrorDetail: upstreamErrorDetail,
|
||||
UpstreamErrors: events,
|
||||
|
||||
IsRetryable: classifyOpsIsRetryable("upstream_error", effectiveUpstreamStatus),
|
||||
RetryCount: 0,
|
||||
CreatedAt: time.Now(),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
applyOpsLatencyFieldsFromContext(c, entry)
|
||||
|
||||
@ -714,10 +732,6 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
entry.ClientIP = &clientIP
|
||||
}
|
||||
|
||||
// Store request headers/body only when an upstream error occurred to keep overhead minimal.
|
||||
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
|
||||
// Skip logging if a passthrough rule with skip_monitoring=true matched.
|
||||
if v, ok := c.Get(service.OpsSkipPassthroughKey); ok {
|
||||
if skip, _ := v.(bool); skip {
|
||||
@ -775,11 +789,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
|
||||
normalizedType := normalizeOpsErrorType(parsed.ErrorType, parsed.Code)
|
||||
|
||||
phase := classifyOpsPhase(normalizedType, parsed.Message, parsed.Code)
|
||||
isBusinessLimited := classifyOpsIsBusinessLimited(normalizedType, phase, parsed.Code, status, parsed.Message)
|
||||
|
||||
errorOwner := classifyOpsErrorOwner(phase, parsed.Message)
|
||||
errorSource := classifyOpsErrorSource(phase, parsed.Message)
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, normalizedType, parsed.Message, parsed.Code, status)
|
||||
|
||||
entry := &service.OpsInsertErrorLogInput{
|
||||
RequestID: requestID,
|
||||
@ -834,9 +844,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
ErrorSource: errorSource,
|
||||
ErrorOwner: errorOwner,
|
||||
|
||||
IsRetryable: classifyOpsIsRetryable(normalizedType, status),
|
||||
RetryCount: 0,
|
||||
CreatedAt: time.Now(),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
applyOpsLatencyFieldsFromContext(c, entry)
|
||||
|
||||
@ -914,20 +922,10 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
||||
entry.ClientIP = &clientIP
|
||||
}
|
||||
|
||||
// Persist only a minimal, whitelisted set of request headers to improve retry fidelity.
|
||||
// Do NOT store Authorization/Cookie/etc.
|
||||
entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c)
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
|
||||
enqueueOpsErrorLog(ops, entry)
|
||||
}
|
||||
}
|
||||
|
||||
var opsRetryRequestHeaderAllowlist = []string{
|
||||
"anthropic-beta",
|
||||
"anthropic-version",
|
||||
}
|
||||
|
||||
// isCountTokensRequest checks if the request is a count_tokens request
|
||||
func isCountTokensRequest(c *gin.Context) bool {
|
||||
if c == nil || c.Request == nil || c.Request.URL == nil {
|
||||
@ -936,32 +934,6 @@ func isCountTokensRequest(c *gin.Context) bool {
|
||||
return strings.Contains(c.Request.URL.Path, "/count_tokens")
|
||||
}
|
||||
|
||||
func extractOpsRetryRequestHeaders(c *gin.Context) *string {
|
||||
if c == nil || c.Request == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
headers := make(map[string]string, 4)
|
||||
for _, key := range opsRetryRequestHeaderAllowlist {
|
||||
v := strings.TrimSpace(c.GetHeader(key))
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
// Keep headers small even if a client sends something unexpected.
|
||||
headers[key] = truncateString(v, 512)
|
||||
}
|
||||
if len(headers) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
raw, err := json.Marshal(headers)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
s := string(raw)
|
||||
return &s
|
||||
}
|
||||
|
||||
func applyOpsLatencyFieldsFromContext(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
|
||||
if c == nil || entry == nil {
|
||||
return
|
||||
@ -1116,6 +1088,9 @@ func classifyOpsPhase(errType, message, code string) string {
|
||||
msg := strings.ToLower(message)
|
||||
// Standardized phases: request|auth|routing|upstream|network|internal
|
||||
// Map billing/concurrency/response => request; scheduling => routing.
|
||||
if isOpsClientAuthError(code, msg) {
|
||||
return "auth"
|
||||
}
|
||||
switch strings.TrimSpace(code) {
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid:
|
||||
return "request"
|
||||
@ -1136,7 +1111,7 @@ func classifyOpsPhase(errType, message, code string) string {
|
||||
case "upstream_error", "overloaded_error":
|
||||
return "upstream"
|
||||
case "api_error":
|
||||
if strings.Contains(msg, opsErrNoAvailableAccounts) {
|
||||
if isOpsNoAvailableAccountMessage(msg) {
|
||||
return "routing"
|
||||
}
|
||||
return "internal"
|
||||
@ -1162,25 +1137,31 @@ func classifyOpsSeverity(errType string, status int) string {
|
||||
return "P3"
|
||||
}
|
||||
|
||||
func classifyOpsIsRetryable(errType string, statusCode int) bool {
|
||||
switch errType {
|
||||
case "authentication_error", "invalid_request_error":
|
||||
return false
|
||||
case "timeout_error":
|
||||
return true
|
||||
case "rate_limit_error":
|
||||
// May be transient (upstream or queue); retry can help.
|
||||
return true
|
||||
case "billing_error", "subscription_error":
|
||||
return false
|
||||
case "upstream_error", "overloaded_error":
|
||||
return statusCode >= 500 || statusCode == 429 || statusCode == 529
|
||||
default:
|
||||
return statusCode >= 500
|
||||
func classifyOpsErrorLog(c *gin.Context, errType, message, code string, status int) (phase string, isBusinessLimited bool, errorOwner string, errorSource string) {
|
||||
phase = classifyOpsPhase(errType, message, code)
|
||||
routingCapacityLimited := isOpsRoutingCapacityLimited(c)
|
||||
clientBusinessLimited := service.HasOpsClientBusinessLimited(c)
|
||||
upstreamError := hasOpsUpstreamErrorContext(c)
|
||||
if upstreamError && !routingCapacityLimited {
|
||||
phase = "upstream"
|
||||
}
|
||||
if clientBusinessLimited && !upstreamError && !routingCapacityLimited {
|
||||
phase = "auth"
|
||||
}
|
||||
if routingCapacityLimited {
|
||||
phase = "routing"
|
||||
}
|
||||
localClientAuthError := !upstreamError && phase == "auth" && isOpsClientAuthError(code, strings.ToLower(message))
|
||||
isBusinessLimited = routingCapacityLimited || clientBusinessLimited || classifyOpsIsBusinessLimited(errType, phase, code, status, message, localClientAuthError)
|
||||
errorOwner = classifyOpsErrorOwner(phase, message)
|
||||
errorSource = classifyOpsErrorSource(phase, message)
|
||||
return phase, isBusinessLimited, errorOwner, errorSource
|
||||
}
|
||||
|
||||
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool {
|
||||
func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string, localClientAuthError ...bool) bool {
|
||||
if len(localClientAuthError) > 0 && localClientAuthError[0] {
|
||||
return true
|
||||
}
|
||||
switch strings.TrimSpace(code) {
|
||||
case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive:
|
||||
return true
|
||||
@ -1197,6 +1178,47 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa
|
||||
return false
|
||||
}
|
||||
|
||||
func isOpsClientAuthError(code string, msg string) bool {
|
||||
switch strings.TrimSpace(code) {
|
||||
case opsCodeInvalidAPIKey, opsCodeAPIKeyRequired:
|
||||
return true
|
||||
}
|
||||
return strings.Contains(msg, "invalid api key") || strings.Contains(msg, "api key is required")
|
||||
}
|
||||
|
||||
func hasOpsUpstreamErrorContext(c *gin.Context) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := c.Get(service.OpsUpstreamStatusCodeKey); ok {
|
||||
switch code := v.(type) {
|
||||
case int:
|
||||
if code > 0 {
|
||||
return true
|
||||
}
|
||||
case int64:
|
||||
if code > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
if v, ok := c.Get(service.OpsUpstreamErrorsKey); ok {
|
||||
if events, ok := v.([]*service.OpsUpstreamErrorEvent); ok && len(events) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isOpsNoAvailableAccountMessage(message string) bool {
|
||||
msg := strings.ToLower(message)
|
||||
return strings.Contains(msg, opsErrNoAvailableAccounts) ||
|
||||
strings.Contains(msg, "no available account") ||
|
||||
strings.Contains(msg, "no available gemini accounts") ||
|
||||
strings.Contains(msg, "no available openai accounts") ||
|
||||
strings.Contains(msg, "no available compatible accounts")
|
||||
}
|
||||
|
||||
func classifyOpsErrorOwner(phase string, message string) string {
|
||||
// Standardized owners: client|provider|platform
|
||||
switch phase {
|
||||
|
||||
@ -44,49 +44,6 @@ func resetOpsErrorLoggerStateForTest(t *testing.T) {
|
||||
opsErrorLogDrained.Store(false)
|
||||
}
|
||||
|
||||
func TestAttachOpsRequestBodyToEntry_SanitizeAndTrim(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
raw := []byte(`{"access_token":"secret-token","messages":[{"role":"user","content":"hello"}]}`)
|
||||
setOpsRequestContext(c, "claude-3", false, raw)
|
||||
|
||||
entry := &service.OpsInsertErrorLogInput{}
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
|
||||
require.NotNil(t, entry.RequestBodyBytes)
|
||||
require.Equal(t, len(raw), *entry.RequestBodyBytes)
|
||||
require.NotNil(t, entry.RequestBodyJSON)
|
||||
require.NotContains(t, *entry.RequestBodyJSON, "secret-token")
|
||||
require.Contains(t, *entry.RequestBodyJSON, "[REDACTED]")
|
||||
require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
|
||||
}
|
||||
|
||||
func TestAttachOpsRequestBodyToEntry_InvalidJSONKeepsSize(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
raw := []byte("not-json")
|
||||
setOpsRequestContext(c, "claude-3", false, raw)
|
||||
|
||||
entry := &service.OpsInsertErrorLogInput{}
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
|
||||
require.Nil(t, entry.RequestBodyJSON)
|
||||
require.NotNil(t, entry.RequestBodyBytes)
|
||||
require.Equal(t, len(raw), *entry.RequestBodyBytes)
|
||||
require.False(t, entry.RequestBodyTruncated)
|
||||
require.Equal(t, int64(1), OpsErrorLogSanitizedTotal())
|
||||
}
|
||||
|
||||
func TestEnqueueOpsErrorLog_QueueFullDrop(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
|
||||
@ -108,39 +65,6 @@ func TestEnqueueOpsErrorLog_QueueFullDrop(t *testing.T) {
|
||||
require.Equal(t, int64(1), OpsErrorLogQueueLength())
|
||||
}
|
||||
|
||||
func TestAttachOpsRequestBodyToEntry_EarlyReturnBranches(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
entry := &service.OpsInsertErrorLogInput{}
|
||||
attachOpsRequestBodyToEntry(nil, entry)
|
||||
attachOpsRequestBodyToEntry(&gin.Context{}, nil)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
// 无请求体 key
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
require.Nil(t, entry.RequestBodyJSON)
|
||||
require.Nil(t, entry.RequestBodyBytes)
|
||||
require.False(t, entry.RequestBodyTruncated)
|
||||
|
||||
// 错误类型
|
||||
c.Set(opsRequestBodyKey, "not-bytes")
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
require.Nil(t, entry.RequestBodyJSON)
|
||||
require.Nil(t, entry.RequestBodyBytes)
|
||||
|
||||
// 空 bytes
|
||||
c.Set(opsRequestBodyKey, []byte{})
|
||||
attachOpsRequestBodyToEntry(c, entry)
|
||||
require.Nil(t, entry.RequestBodyJSON)
|
||||
require.Nil(t, entry.RequestBodyBytes)
|
||||
|
||||
require.Equal(t, int64(0), OpsErrorLogSanitizedTotal())
|
||||
}
|
||||
|
||||
func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) {
|
||||
resetOpsErrorLoggerStateForTest(t)
|
||||
|
||||
@ -275,6 +199,218 @@ func TestNormalizeOpsErrorType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyOpsNoAvailableAccountsExcludedFromSLA(t *testing.T) {
|
||||
const message = "No available accounts"
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
|
||||
errType := normalizeOpsErrorType("api_error", "")
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, message, "", http.StatusServiceUnavailable)
|
||||
|
||||
require.Equal(t, "api_error", errType)
|
||||
require.Equal(t, "routing", phase)
|
||||
require.True(t, isBusinessLimited)
|
||||
require.Equal(t, "platform", errorOwner)
|
||||
require.Equal(t, "gateway", errorSource)
|
||||
}
|
||||
|
||||
func TestClassifyOpsRoutingCapacityMarkerExcludesMaskedSelectionFailureFromSLA(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
|
||||
c,
|
||||
"api_error",
|
||||
"Service temporarily unavailable",
|
||||
"",
|
||||
http.StatusServiceUnavailable,
|
||||
)
|
||||
|
||||
require.Equal(t, "routing", phase)
|
||||
require.True(t, isBusinessLimited)
|
||||
require.Equal(t, "platform", errorOwner)
|
||||
require.Equal(t, "gateway", errorSource)
|
||||
}
|
||||
|
||||
func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
message string
|
||||
code string
|
||||
status int
|
||||
}{
|
||||
{
|
||||
name: "standard invalid API key",
|
||||
errType: "api_error",
|
||||
message: "Invalid API key",
|
||||
code: "INVALID_API_KEY",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "standard missing API key",
|
||||
errType: "api_error",
|
||||
message: "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header",
|
||||
code: "API_KEY_REQUIRED",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "google invalid API key",
|
||||
errType: "api_error",
|
||||
message: "Invalid API key",
|
||||
code: "401",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "google missing API key",
|
||||
errType: "api_error",
|
||||
message: "API key is required",
|
||||
code: "401",
|
||||
status: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
errType := normalizeOpsErrorType(tt.errType, tt.code)
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, tt.message, tt.code, tt.status)
|
||||
|
||||
require.Equal(t, "api_error", errType)
|
||||
require.Equal(t, "auth", phase)
|
||||
require.True(t, isBusinessLimited)
|
||||
require.Equal(t, "client", errorOwner)
|
||||
require.Equal(t, "client_request", errorSource)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyOpsIPRestrictionAccessDeniedExcludedFromSLA(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonIPRestriction)
|
||||
|
||||
errType := normalizeOpsErrorType("api_error", "ACCESS_DENIED")
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, "Access denied", "ACCESS_DENIED", http.StatusForbidden)
|
||||
|
||||
require.Equal(t, "api_error", errType)
|
||||
require.Equal(t, "auth", phase)
|
||||
require.True(t, isBusinessLimited)
|
||||
require.Equal(t, "client", errorOwner)
|
||||
require.Equal(t, "client_request", errorSource)
|
||||
}
|
||||
|
||||
func TestClassifyOpsOtherErrorsStillCountForSLA(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
errType := normalizeOpsErrorType("api_error", "INTERNAL_ERROR")
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, "Failed to validate API key", "INTERNAL_ERROR", http.StatusInternalServerError)
|
||||
|
||||
require.Equal(t, "api_error", errType)
|
||||
require.Equal(t, "internal", phase)
|
||||
require.False(t, isBusinessLimited)
|
||||
require.Equal(t, "platform", errorOwner)
|
||||
require.Equal(t, "gateway", errorSource)
|
||||
}
|
||||
|
||||
func TestClassifyOpsUnsupportedModelExcludedFromSLA(t *testing.T) {
|
||||
tests := []string{
|
||||
"No available accounts: no available accounts supporting model: made-up-model",
|
||||
"No available accounts: no available OpenAI accounts supporting model: made-up-model",
|
||||
"No available Gemini accounts: no available Gemini accounts supporting model: made-up-model",
|
||||
"No available accounts: no available accounts supporting model: made-up-model (channel pricing restriction)",
|
||||
}
|
||||
|
||||
for _, message := range tests {
|
||||
t.Run(message, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
markOpsRoutingCapacityLimited(c)
|
||||
|
||||
errType := normalizeOpsErrorType("api_error", "")
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, message, "", http.StatusServiceUnavailable)
|
||||
|
||||
require.Equal(t, "api_error", errType)
|
||||
require.Equal(t, "routing", phase)
|
||||
require.True(t, isBusinessLimited)
|
||||
require.Equal(t, "platform", errorOwner)
|
||||
require.Equal(t, "gateway", errorSource)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyOpsUnmarkedNoAvailableTextStillCountsForSLA(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
|
||||
c,
|
||||
"api_error",
|
||||
"No available accounts",
|
||||
"",
|
||||
http.StatusServiceUnavailable,
|
||||
)
|
||||
|
||||
require.Equal(t, "routing", phase)
|
||||
require.False(t, isBusinessLimited)
|
||||
require.Equal(t, "platform", errorOwner)
|
||||
require.Equal(t, "gateway", errorSource)
|
||||
}
|
||||
|
||||
func TestClassifyOpsUpstreamAuthTextStillCountsForSLA(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
service.SetOpsUpstreamError(c, http.StatusUnauthorized, "Invalid API key", "")
|
||||
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
|
||||
c,
|
||||
"api_error",
|
||||
"Invalid API key",
|
||||
"401",
|
||||
http.StatusUnauthorized,
|
||||
)
|
||||
|
||||
require.Equal(t, "upstream", phase)
|
||||
require.False(t, isBusinessLimited)
|
||||
require.Equal(t, "provider", errorOwner)
|
||||
require.Equal(t, "upstream_http", errorSource)
|
||||
}
|
||||
|
||||
func TestClassifyOpsUpstreamNoAvailableTextStillCountsForSLA(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
service.SetOpsUpstreamError(c, http.StatusServiceUnavailable, "No available accounts", "")
|
||||
|
||||
phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(
|
||||
c,
|
||||
"api_error",
|
||||
"No available accounts",
|
||||
"",
|
||||
http.StatusServiceUnavailable,
|
||||
)
|
||||
|
||||
require.Equal(t, "upstream", phase)
|
||||
require.False(t, isBusinessLimited)
|
||||
require.Equal(t, "provider", errorOwner)
|
||||
require.Equal(t, "upstream_http", errorSource)
|
||||
}
|
||||
|
||||
func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@ -58,7 +58,7 @@ func TestResolvePageImagePath(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatal("expected direct image path to be accepted")
|
||||
}
|
||||
want := filepath.Join(base, "logo.png")
|
||||
want := mustEvalSymlinks(t, filepath.Join(base, "logo.png"))
|
||||
if got != want {
|
||||
t.Fatalf("path = %q, want %q", got, want)
|
||||
}
|
||||
@ -67,7 +67,7 @@ func TestResolvePageImagePath(t *testing.T) {
|
||||
if !ok {
|
||||
t.Fatal("expected nested image path to be accepted")
|
||||
}
|
||||
want = filepath.Join(base, "images", "logo.png")
|
||||
want = mustEvalSymlinks(t, filepath.Join(base, "images", "logo.png"))
|
||||
if got != want {
|
||||
t.Fatalf("path = %q, want %q", got, want)
|
||||
}
|
||||
@ -100,3 +100,13 @@ func TestResolvePageImagePathRejectsSymlinkEscape(t *testing.T) {
|
||||
t.Fatalf("expected symlink escape to be rejected, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func mustEvalSymlinks(t *testing.T, path string) string {
|
||||
t.Helper()
|
||||
|
||||
realPath, err := filepath.EvalSymlinks(path)
|
||||
if err != nil {
|
||||
t.Fatalf("eval symlinks for %q: %v", path, err)
|
||||
}
|
||||
return realPath
|
||||
}
|
||||
|
||||
@ -141,6 +141,7 @@ func (h *PaymentHandler) GetCheckoutInfo(c *gin.Context) {
|
||||
HelpText: cfg.HelpText,
|
||||
HelpImageURL: cfg.HelpImageURL,
|
||||
StripePublishableKey: cfg.StripePublishableKey,
|
||||
AlipayForceQRCode: cfg.AlipayForceQRCode,
|
||||
})
|
||||
}
|
||||
|
||||
@ -155,6 +156,7 @@ type checkoutInfoResponse struct {
|
||||
HelpText string `json:"help_text"`
|
||||
HelpImageURL string `json:"help_image_url"`
|
||||
StripePublishableKey string `json:"stripe_publishable_key"`
|
||||
AlipayForceQRCode bool `json:"alipay_force_qrcode"`
|
||||
}
|
||||
|
||||
type checkoutPlan struct {
|
||||
|
||||
@ -61,6 +61,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
TablePageSizeOptions: settings.TablePageSizeOptions,
|
||||
CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems),
|
||||
CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints),
|
||||
DingTalkOAuthEnabled: settings.DingTalkOAuthEnabled,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
WeChatOAuthEnabled: settings.WeChatOAuthEnabled,
|
||||
WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled,
|
||||
|
||||
@ -67,6 +67,7 @@ type userProfileResponse struct {
|
||||
LinuxDoBound bool `json:"linuxdo_bound"`
|
||||
OIDCBound bool `json:"oidc_bound"`
|
||||
WeChatBound bool `json:"wechat_bound"`
|
||||
DingTalkBound bool `json:"dingtalk_bound"`
|
||||
}
|
||||
|
||||
type userProfileSourceContext struct {
|
||||
@ -528,15 +529,17 @@ func userProfileResponseFromService(user *service.User, identities service.UserI
|
||||
LinuxDoBound: identities.LinuxDo.Bound,
|
||||
OIDCBound: identities.OIDC.Bound,
|
||||
WeChatBound: identities.WeChat.Bound,
|
||||
DingTalkBound: identities.DingTalk.Bound,
|
||||
}
|
||||
}
|
||||
|
||||
func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string]service.UserIdentitySummary {
|
||||
return map[string]service.UserIdentitySummary{
|
||||
"email": identities.Email,
|
||||
"linuxdo": identities.LinuxDo,
|
||||
"oidc": identities.OIDC,
|
||||
"wechat": identities.WeChat,
|
||||
"email": identities.Email,
|
||||
"linuxdo": identities.LinuxDo,
|
||||
"oidc": identities.OIDC,
|
||||
"wechat": identities.WeChat,
|
||||
"dingtalk": identities.DingTalk,
|
||||
}
|
||||
}
|
||||
|
||||
@ -585,7 +588,7 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity
|
||||
|
||||
func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary {
|
||||
out := make([]service.UserIdentitySummary, 0, 3)
|
||||
for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} {
|
||||
for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat, identities.DingTalk} {
|
||||
if summary.Bound {
|
||||
out = append(out, summary)
|
||||
}
|
||||
|
||||
@ -105,10 +105,16 @@ func (a *Alipay) MerchantIdentityMetadata() map[string]string {
|
||||
|
||||
// CreatePayment creates an Alipay payment using the following routing:
|
||||
// - Mobile (H5): alipay.trade.wap.pay — browser redirect into Alipay.
|
||||
// - Desktop: prefer alipay.trade.precreate to get a scan payload directly.
|
||||
// - Desktop fallback: if precreate is unavailable for the merchant, fall back
|
||||
// to alipay.trade.page.pay and expose both pay_url and qr_code so the
|
||||
// frontend can render a QR while still allowing direct page open.
|
||||
// - Desktop, default: prefer alipay.trade.precreate (FACE_TO_FACE_PAYMENT) to
|
||||
// get a scannable QR payload. If precreate is unavailable for the merchant,
|
||||
// fall back to alipay.trade.page.pay and expose pay_url only — the frontend
|
||||
// opens the Alipay checkout in a new tab.
|
||||
// - Desktop, paymentMode == "redirect": skip precreate and go straight to
|
||||
// alipay.trade.page.pay so the frontend always opens the Alipay checkout
|
||||
// in a new tab. Use this when the merchant has not enabled FACE_TO_FACE_PAYMENT.
|
||||
//
|
||||
// Note: alipay.trade.page.pay returns a checkout page URL, not a scannable
|
||||
// payment QR. Never expose it via the QRCode field.
|
||||
func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) {
|
||||
client, err := a.getClient()
|
||||
if err != nil {
|
||||
@ -150,6 +156,13 @@ func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePayment
|
||||
}
|
||||
|
||||
func (a *Alipay) createDesktopTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) {
|
||||
// Explicit redirect mode: merchant opted into "always open the Alipay
|
||||
// checkout page in a new tab" via the provider instance's payment_mode.
|
||||
// Skip precreate to avoid a wasted API call.
|
||||
if strings.EqualFold(strings.TrimSpace(a.config["paymentMode"]), "redirect") {
|
||||
return a.createPagePayTrade(client, req, notifyURL, returnURL)
|
||||
}
|
||||
|
||||
resp, precreateErr := a.createPrecreateTrade(ctx, client, req, notifyURL)
|
||||
if precreateErr == nil {
|
||||
return resp, nil
|
||||
@ -204,10 +217,12 @@ func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePay
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("alipay TradePagePay: %w", err)
|
||||
}
|
||||
// Only PayURL is exposed: alipay.trade.page.pay returns a checkout page URL
|
||||
// that must be opened in a browser, not a scannable payment QR. Setting it
|
||||
// as QRCode would let the frontend render an unscannable image.
|
||||
return &payment.CreatePaymentResponse{
|
||||
TradeNo: req.OrderID,
|
||||
PayURL: payURL.String(),
|
||||
QRCode: payURL.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@ -189,8 +189,63 @@ func TestCreateTradeUsesPagePayForDesktop(t *testing.T) {
|
||||
if resp.PayURL == "" {
|
||||
t.Fatal("expected pay_url for desktop page pay")
|
||||
}
|
||||
if resp.QRCode != resp.PayURL {
|
||||
t.Fatalf("qr_code = %q, want same as pay_url %q", resp.QRCode, resp.PayURL)
|
||||
// page.pay returns a checkout page URL, not a scannable QR payload —
|
||||
// it must never be exposed via QRCode (the frontend would render an
|
||||
// unscannable image from it).
|
||||
if resp.QRCode != "" {
|
||||
t.Fatalf("qr_code = %q, want empty for page pay", resp.QRCode)
|
||||
}
|
||||
}
|
||||
|
||||
// When the provider instance is configured with paymentMode == "redirect",
|
||||
// the desktop flow must skip precreate and go straight to page.pay.
|
||||
func TestCreateTradeRedirectModeSkipsPrecreate(t *testing.T) {
|
||||
origPreCreate := alipayTradePreCreate
|
||||
origPagePay := alipayTradePagePay
|
||||
t.Cleanup(func() {
|
||||
alipayTradePreCreate = origPreCreate
|
||||
alipayTradePagePay = origPagePay
|
||||
})
|
||||
|
||||
preCreateCalls := 0
|
||||
pagePayCalls := 0
|
||||
alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) {
|
||||
preCreateCalls++
|
||||
return &alipay.TradePreCreateRsp{
|
||||
Error: alipay.Error{Code: alipay.CodeSuccess},
|
||||
QRCode: "https://qr.alipay.example.com/precreate-token",
|
||||
}, nil
|
||||
}
|
||||
alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) {
|
||||
pagePayCalls++
|
||||
if param.ProductCode != alipayProductCodePagePay {
|
||||
t.Fatalf("product_code = %q, want %q", param.ProductCode, alipayProductCodePagePay)
|
||||
}
|
||||
return url.Parse("https://openapi.alipay.com/gateway.do?page-pay")
|
||||
}
|
||||
|
||||
provider := &Alipay{
|
||||
config: map[string]string{"paymentMode": "redirect"},
|
||||
}
|
||||
resp, err := provider.createDesktopTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{
|
||||
OrderID: "sub2_103",
|
||||
Amount: "12.00",
|
||||
Subject: "Balance recharge",
|
||||
}, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if preCreateCalls != 0 {
|
||||
t.Fatalf("precreate calls = %d, want 0 (redirect mode must skip precreate)", preCreateCalls)
|
||||
}
|
||||
if pagePayCalls != 1 {
|
||||
t.Fatalf("page pay calls = %d, want 1", pagePayCalls)
|
||||
}
|
||||
if resp.PayURL == "" {
|
||||
t.Fatal("expected pay_url for redirect mode")
|
||||
}
|
||||
if resp.QRCode != "" {
|
||||
t.Fatalf("qr_code = %q, want empty for redirect mode", resp.QRCode)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -254,6 +254,8 @@ const (
|
||||
proxyTLSHandshakeTimeout = 5 * time.Second
|
||||
// clientTimeout 整体请求超时(含连接、发送、等待响应、读取 body)
|
||||
clientTimeout = 10 * time.Second
|
||||
// fetchAvailableModelsBodyLimit limits model-list responses to avoid unbounded memory use.
|
||||
fetchAvailableModelsBodyLimit int64 = 8 << 20
|
||||
)
|
||||
|
||||
func NewClient(proxyURL string) (*Client, error) {
|
||||
@ -655,6 +657,10 @@ type FetchAvailableModelsResponse struct {
|
||||
// FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON
|
||||
// 支持 URL fallback:sandbox → daily → prod
|
||||
func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) {
|
||||
if c == nil || c.httpClient == nil {
|
||||
return nil, nil, errors.New("antigravity client is not configured")
|
||||
}
|
||||
|
||||
reqBody := FetchAvailableModelsRequest{Project: projectID}
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
@ -664,6 +670,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
// 固定顺序:prod -> daily
|
||||
availableURLs := BaseURLs
|
||||
|
||||
fetchClient := c.fetchAvailableModelsHTTPClient()
|
||||
var lastErr error
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
apiURL := baseURL + "/v1internal:fetchAvailableModels"
|
||||
@ -676,7 +683,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", GetUserAgentForContext(ctx))
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
resp, err := fetchClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err)
|
||||
if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 {
|
||||
@ -686,11 +693,14 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
respBodyBytes, err := io.ReadAll(resp.Body)
|
||||
respBodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, fetchAvailableModelsBodyLimit+1))
|
||||
_ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
if int64(len(respBodyBytes)) > fetchAvailableModelsBodyLimit {
|
||||
return nil, nil, fmt.Errorf("响应超过 %d 字节", fetchAvailableModelsBodyLimit)
|
||||
}
|
||||
|
||||
// 检查是否需要 URL 降级
|
||||
if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 {
|
||||
@ -726,6 +736,42 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI
|
||||
return nil, nil, lastErr
|
||||
}
|
||||
|
||||
func (c *Client) fetchAvailableModelsHTTPClient() *http.Client {
|
||||
fetchClient := *c.httpClient
|
||||
fetchClient.CheckRedirect = checkFetchAvailableModelsRedirect
|
||||
return &fetchClient
|
||||
}
|
||||
|
||||
func checkFetchAvailableModelsRedirect(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 10 {
|
||||
return errors.New("stopped after 10 redirects")
|
||||
}
|
||||
if req == nil || req.URL == nil {
|
||||
return errors.New("redirect url is nil")
|
||||
}
|
||||
if !isAllowedFetchAvailableModelsRedirectHost(req.URL.Hostname()) {
|
||||
return fmt.Errorf("redirect to unsupported host: %s", req.URL.Hostname())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isAllowedFetchAvailableModelsRedirectHost(host string) bool {
|
||||
host = strings.ToLower(strings.TrimSpace(host))
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
for _, baseURL := range BaseURLs {
|
||||
parsed, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(host, parsed.Hostname()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ── Privacy API ──────────────────────────────────────────────────────
|
||||
|
||||
// privacyBaseURL 隐私设置 API 仅使用 daily 端点(与 Antigravity 客户端行为一致)
|
||||
|
||||
@ -744,6 +744,10 @@ func TestStreamingReasoning(t *testing.T) {
|
||||
assert.Equal(t, "content_block_start", events[0].Type)
|
||||
assert.Equal(t, "thinking", events[0].ContentBlock.Type)
|
||||
|
||||
sse, err := ResponsesAnthropicEventToSSE(events[0])
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, sse, `"content_block":{"thinking":"","type":"thinking"}`)
|
||||
|
||||
// reasoning text delta
|
||||
events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{
|
||||
Type: "response.reasoning_summary_text.delta",
|
||||
@ -1520,3 +1524,49 @@ func TestAnthropicToResponses_ToolWithNilSchema(t *testing.T) {
|
||||
assert.JSONEq(t, `"object"`, string(params["type"]))
|
||||
assert.JSONEq(t, `{}`, string(params["properties"]))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isReasoningModel / temperature-stripping tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAnthropicToResponses_TemperatureStrippedForReasoningModel(t *testing.T) {
|
||||
temp := 0.7
|
||||
req := &AnthropicRequest{
|
||||
Model: "gpt-5.2",
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
Temperature: &temp,
|
||||
TopP: &temp,
|
||||
}
|
||||
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resp.Temperature, "reasoning model: temperature must be stripped")
|
||||
assert.Nil(t, resp.TopP, "reasoning model: top_p must be stripped")
|
||||
|
||||
// Verify the fields are absent from the serialised JSON.
|
||||
b, err := json.Marshal(resp)
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, string(b), `"temperature"`)
|
||||
assert.NotContains(t, string(b), `"top_p"`)
|
||||
}
|
||||
|
||||
func TestAnthropicToResponses_TemperatureStrippedForAllGpt5Variants(t *testing.T) {
|
||||
temp := 1.0
|
||||
models := []string{"gpt-5.2", "gpt-5.4", "gpt-5.4-mini", "gpt-5.3-codex", "gpt-5.5"}
|
||||
for _, model := range models {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
req := &AnthropicRequest{
|
||||
Model: model,
|
||||
MaxTokens: 1024,
|
||||
Messages: []AnthropicMessage{{Role: "user", Content: json.RawMessage(`"Hello"`)}},
|
||||
Temperature: &temp,
|
||||
TopP: &temp,
|
||||
}
|
||||
resp, err := AnthropicToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resp.Temperature, "model %s: temperature must be stripped", model)
|
||||
assert.Nil(t, resp.TopP, "model %s: top_p must be stripped", model)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -22,12 +22,19 @@ func AnthropicToResponses(req *AnthropicRequest) (*ResponsesRequest, error) {
|
||||
}
|
||||
|
||||
out := &ResponsesRequest{
|
||||
Model: req.Model,
|
||||
Input: inputJSON,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
Stream: req.Stream,
|
||||
Include: []string{"reasoning.encrypted_content"},
|
||||
Model: req.Model,
|
||||
Input: inputJSON,
|
||||
Stream: req.Stream,
|
||||
Include: []string{"reasoning.encrypted_content"},
|
||||
}
|
||||
|
||||
// Reasoning models (gpt-5.x) served via the Responses API do not accept
|
||||
// sampling parameters. Sending temperature or top_p causes a 400
|
||||
// "Unsupported parameter" error, so we only forward them for non-reasoning
|
||||
// models.
|
||||
if !isReasoningModel(req.Model) {
|
||||
out.Temperature = req.Temperature
|
||||
out.TopP = req.TopP
|
||||
}
|
||||
|
||||
storeFalse := false
|
||||
@ -437,6 +444,14 @@ func boolPtr(v bool) *bool {
|
||||
return &v
|
||||
}
|
||||
|
||||
// isReasoningModel reports whether model is a reasoning model that does not
|
||||
// support sampling parameters (temperature, top_p) via the Responses API.
|
||||
// All gpt-5.x models are reasoning-only; the Responses API returns
|
||||
// "Unsupported parameter: temperature" if these fields are present.
|
||||
func isReasoningModel(model string) bool {
|
||||
return strings.HasPrefix(model, "gpt-5")
|
||||
}
|
||||
|
||||
// normalizeToolParameters ensures the tool parameter schema is valid for
|
||||
// OpenAI's Responses API, which requires "properties" on object schemas.
|
||||
//
|
||||
|
||||
@ -225,6 +225,41 @@ func TestChatCompletionsToResponses_WhitespaceOnlyBase64ImageURLSkipped(t *testi
|
||||
assert.Equal(t, "Describe this", parts[0].Text)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_EmptyContentNeverNull(t *testing.T) {
|
||||
// Regression for #2515: the upstream Responses API rejects an input item
|
||||
// whose content field is JSON null. Any chat-completions message that
|
||||
// yields no usable content parts must serialize content as a string.
|
||||
cases := []struct {
|
||||
name string
|
||||
content json.RawMessage
|
||||
}{
|
||||
{"null content", json.RawMessage(`null`)},
|
||||
{"empty array content", json.RawMessage(`[]`)},
|
||||
{"only empty text part", json.RawMessage(`[{"type":"text","text":""}]`)},
|
||||
{"only empty base64 image part", json.RawMessage(`[{"type":"image_url","image_url":{"url":"data:image/png;base64,"}}]`)},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-5.5",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: tc.content},
|
||||
},
|
||||
}
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, string(resp.Input), `"content":null`,
|
||||
"converted input must not contain a null content field")
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 1)
|
||||
assert.Equal(t, `""`, string(items[0].Content),
|
||||
"content must be an empty string, not null")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_SystemArrayContent(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
@ -296,6 +331,48 @@ func TestChatCompletionsToResponses_ServiceTier(t *testing.T) {
|
||||
assert.Equal(t, "flex", resp.ServiceTier)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// temperature / top_p stripping for reasoning models
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestChatCompletionsToResponses_TemperatureStrippedForReasoningModel(t *testing.T) {
|
||||
temp := 0.7
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-5.2",
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
Temperature: &temp,
|
||||
TopP: &temp,
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, resp.Temperature, "reasoning model: temperature must be stripped")
|
||||
assert.Nil(t, resp.TopP, "reasoning model: top_p must be stripped")
|
||||
|
||||
// Must not appear in the serialised request body sent to the upstream.
|
||||
b, err := json.Marshal(resp)
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, string(b), `"temperature"`)
|
||||
assert.NotContains(t, string(b), `"top_p"`)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_TemperaturePreservedForNonReasoningModel(t *testing.T) {
|
||||
temp := 0.7
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{{Role: "user", Content: json.RawMessage(`"Hi"`)}},
|
||||
Temperature: &temp,
|
||||
TopP: &temp,
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp.Temperature, "non-reasoning model: temperature must be preserved")
|
||||
assert.InDelta(t, 0.7, *resp.Temperature, 1e-9)
|
||||
require.NotNil(t, resp.TopP, "non-reasoning model: top_p must be preserved")
|
||||
assert.InDelta(t, 0.7, *resp.TopP, 1e-9)
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantWithTextAndToolCalls(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
@ -379,6 +456,34 @@ func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T)
|
||||
assert.Contains(t, parts[0].Text, "final answer")
|
||||
}
|
||||
|
||||
func TestChatCompletionsToResponses_AssistantReasoningContentPreserved(t *testing.T) {
|
||||
req := &ChatCompletionsRequest{
|
||||
Model: "gpt-4o",
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: json.RawMessage(`"Hi"`)},
|
||||
{
|
||||
Role: "assistant",
|
||||
ReasoningContent: "internal plan",
|
||||
Content: json.RawMessage(`"final answer"`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := ChatCompletionsToResponses(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var items []ResponsesInputItem
|
||||
require.NoError(t, json.Unmarshal(resp.Input, &items))
|
||||
require.Len(t, items, 2)
|
||||
|
||||
var parts []ResponsesContentPart
|
||||
require.NoError(t, json.Unmarshal(items[1].Content, &parts))
|
||||
require.Len(t, parts, 1)
|
||||
assert.Equal(t, "output_text", parts[0].Type)
|
||||
assert.Contains(t, parts[0].Text, "<thinking>internal plan</thinking>")
|
||||
assert.Contains(t, parts[0].Text, "final answer")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ResponsesToChatCompletions tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@ -30,13 +30,18 @@ func ChatCompletionsToResponses(req *ChatCompletionsRequest) (*ResponsesRequest,
|
||||
Model: req.Model,
|
||||
Instructions: req.Instructions,
|
||||
Input: inputJSON,
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
Stream: true, // upstream always streams
|
||||
Include: []string{"reasoning.encrypted_content"},
|
||||
ServiceTier: req.ServiceTier,
|
||||
}
|
||||
|
||||
// Reasoning models (gpt-5.x) do not accept sampling parameters.
|
||||
// See isReasoningModel in anthropic_to_responses.go.
|
||||
if !isReasoningModel(req.Model) {
|
||||
out.Temperature = req.Temperature
|
||||
out.TopP = req.TopP
|
||||
}
|
||||
|
||||
storeFalse := false
|
||||
out.Store = &storeFalse
|
||||
|
||||
@ -150,6 +155,11 @@ func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
// empty/nil and there are tool_calls, only function_call items are emitted.
|
||||
func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
var items []ResponsesInputItem
|
||||
content := ""
|
||||
|
||||
if m.ReasoningContent != "" {
|
||||
content = "<thinking>" + m.ReasoningContent + "</thinking>"
|
||||
}
|
||||
|
||||
// Emit assistant message with output_text if content is non-empty.
|
||||
if len(m.Content) > 0 {
|
||||
@ -158,15 +168,22 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) {
|
||||
return nil, err
|
||||
}
|
||||
if s != "" {
|
||||
parts := []ResponsesContentPart{{Type: "output_text", Text: s}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
|
||||
content += s
|
||||
}
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
parts := []ResponsesContentPart{{Type: "output_text", Text: content}}
|
||||
partsJSON, err := json.Marshal(parts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON})
|
||||
}
|
||||
|
||||
// Emit one function_call item per tool_call.
|
||||
for _, tc := range m.ToolCalls {
|
||||
args := tc.Function.Arguments
|
||||
@ -325,7 +342,14 @@ func marshalChatInputContent(content chatMessageContent) (json.RawMessage, error
|
||||
if content.Text != nil {
|
||||
return json.Marshal(*content.Text)
|
||||
}
|
||||
return json.Marshal(convertChatContentPartsToResponses(content.Parts))
|
||||
parts := convertChatContentPartsToResponses(content.Parts)
|
||||
if len(parts) == 0 {
|
||||
// A nil slice marshals to JSON null, which the upstream Responses API
|
||||
// rejects ("expected an array of objects or string, but got null").
|
||||
// Fall back to an empty string when no usable parts remain.
|
||||
return json.Marshal("")
|
||||
}
|
||||
return json.Marshal(parts)
|
||||
}
|
||||
|
||||
func convertChatContentPartsToResponses(parts []ChatContentPart) []ResponsesContentPart {
|
||||
|
||||
@ -75,6 +75,28 @@ type AnthropicContentBlock struct {
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
}
|
||||
|
||||
func (b AnthropicContentBlock) MarshalJSON() ([]byte, error) {
|
||||
type anthropicContentBlock AnthropicContentBlock
|
||||
base := struct {
|
||||
anthropicContentBlock
|
||||
}{anthropicContentBlock: anthropicContentBlock(b)}
|
||||
|
||||
switch b.Type {
|
||||
case "text":
|
||||
return json.Marshal(struct {
|
||||
Text string `json:"text"`
|
||||
anthropicContentBlock
|
||||
}{Text: b.Text, anthropicContentBlock: anthropicContentBlock(b)})
|
||||
case "thinking":
|
||||
return json.Marshal(struct {
|
||||
Thinking string `json:"thinking"`
|
||||
anthropicContentBlock
|
||||
}{Thinking: b.Thinking, anthropicContentBlock: anthropicContentBlock(b)})
|
||||
default:
|
||||
return json.Marshal(base)
|
||||
}
|
||||
}
|
||||
|
||||
// AnthropicImageSource describes the source data for an image content block.
|
||||
type AnthropicImageSource struct {
|
||||
Type string `json:"type"` // "base64"
|
||||
@ -306,6 +328,37 @@ type ResponsesUsage struct {
|
||||
OutputTokensDetails *ResponsesOutputTokensDetails `json:"output_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
func (u *ResponsesUsage) UnmarshalJSON(data []byte) error {
|
||||
type responsesUsageAlias ResponsesUsage
|
||||
var aux struct {
|
||||
responsesUsageAlias
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
PromptTokensDetails *ResponsesInputTokensDetails `json:"prompt_tokens_details,omitempty"`
|
||||
CompletionTokensDetails *ResponsesOutputTokensDetails `json:"completion_tokens_details,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
*u = ResponsesUsage(aux.responsesUsageAlias)
|
||||
if u.InputTokens == 0 && aux.PromptTokens != 0 {
|
||||
u.InputTokens = aux.PromptTokens
|
||||
}
|
||||
if u.OutputTokens == 0 && aux.CompletionTokens != 0 {
|
||||
u.OutputTokens = aux.CompletionTokens
|
||||
}
|
||||
if u.InputTokensDetails == nil && aux.PromptTokensDetails != nil {
|
||||
u.InputTokensDetails = aux.PromptTokensDetails
|
||||
}
|
||||
if u.OutputTokensDetails == nil && aux.CompletionTokensDetails != nil {
|
||||
u.OutputTokensDetails = aux.CompletionTokensDetails
|
||||
}
|
||||
if u.TotalTokens == 0 && (u.InputTokens != 0 || u.OutputTokens != 0) {
|
||||
u.TotalTokens = u.InputTokens + u.OutputTokens
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResponsesInputTokensDetails breaks down input token usage.
|
||||
type ResponsesInputTokensDetails struct {
|
||||
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||
|
||||
@ -22,6 +22,7 @@ func DefaultModels() []Model {
|
||||
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-flash-image", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3.5-flash", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
|
||||
{Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods},
|
||||
|
||||
@ -15,6 +15,7 @@ var DefaultModels = []Model{
|
||||
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-flash-image", Type: "model", DisplayName: "Gemini 2.5 Flash Image", CreatedAt: ""},
|
||||
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
|
||||
{ID: "gemini-3.5-flash", Type: "model", DisplayName: "Gemini 3.5 Flash", CreatedAt: ""},
|
||||
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
|
||||
{ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""},
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
// pensieve/short-term/maxims/preserve-existing-runtime-behavior-when-replacing-logic-in-stateful-systems)
|
||||
package openai_compat
|
||||
|
||||
// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的支持状态。
|
||||
// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的有效支持状态。
|
||||
//
|
||||
// 仅用于 platform=openai + type=apikey 的账号;其他账号类型不应调用本包判定。
|
||||
type AccountResponsesSupport int
|
||||
@ -35,11 +35,43 @@ const (
|
||||
ResponsesSupportNo
|
||||
)
|
||||
|
||||
// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储探测结果的键名。
|
||||
// ResponsesSupportMode 描述账号级 Responses API 路由覆盖模式。
|
||||
type ResponsesSupportMode string
|
||||
|
||||
const (
|
||||
// ResponsesSupportModeAuto 表示跟随自动探测结果。
|
||||
ResponsesSupportModeAuto ResponsesSupportMode = "auto"
|
||||
|
||||
// ResponsesSupportModeForceResponses 强制使用 /v1/responses。
|
||||
ResponsesSupportModeForceResponses ResponsesSupportMode = "force_responses"
|
||||
|
||||
// ResponsesSupportModeForceChatCompletions 强制使用 /v1/chat/completions。
|
||||
ResponsesSupportModeForceChatCompletions ResponsesSupportMode = "force_chat_completions"
|
||||
)
|
||||
|
||||
// ExtraKeyResponsesMode 是 accounts.extra JSON 中存储手动覆盖模式的键名。
|
||||
// 值类型为 string:auto=跟随探测,force_responses=强制 Responses,
|
||||
// force_chat_completions=强制 Chat Completions。
|
||||
const ExtraKeyResponsesMode = "openai_responses_mode"
|
||||
|
||||
// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储自动探测结果的键名。
|
||||
// 值类型为 bool:true=支持、false=不支持、键缺失=未探测。
|
||||
const ExtraKeyResponsesSupported = "openai_responses_supported"
|
||||
|
||||
// ResolveResponsesSupport 从账号的 extra map 中读取探测标记。
|
||||
// NormalizeResponsesSupportMode 归一化账号级 Responses API 路由覆盖模式。
|
||||
// 缺失或非法值按 auto 处理,以保持存量行为。
|
||||
func NormalizeResponsesSupportMode(mode string) ResponsesSupportMode {
|
||||
switch ResponsesSupportMode(mode) {
|
||||
case ResponsesSupportModeForceResponses:
|
||||
return ResponsesSupportModeForceResponses
|
||||
case ResponsesSupportModeForceChatCompletions:
|
||||
return ResponsesSupportModeForceChatCompletions
|
||||
default:
|
||||
return ResponsesSupportModeAuto
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveResponsesSupport 从账号的 extra map 中读取手动覆盖模式与探测标记。
|
||||
//
|
||||
// 标记缺失或类型不匹配时返回 ResponsesSupportUnknown——调用方应按
|
||||
// "未探测=保留旧行为=走 Responses" 处理(参见 ShouldUseResponsesAPI)。
|
||||
@ -47,6 +79,14 @@ func ResolveResponsesSupport(extra map[string]any) AccountResponsesSupport {
|
||||
if extra == nil {
|
||||
return ResponsesSupportUnknown
|
||||
}
|
||||
if mode, ok := extra[ExtraKeyResponsesMode].(string); ok {
|
||||
switch NormalizeResponsesSupportMode(mode) {
|
||||
case ResponsesSupportModeForceResponses:
|
||||
return ResponsesSupportYes
|
||||
case ResponsesSupportModeForceChatCompletions:
|
||||
return ResponsesSupportNo
|
||||
}
|
||||
}
|
||||
v, ok := extra[ExtraKeyResponsesSupported]
|
||||
if !ok {
|
||||
return ResponsesSupportUnknown
|
||||
|
||||
@ -16,6 +16,12 @@ func TestResolveResponsesSupport(t *testing.T) {
|
||||
{"value wrong type string", map[string]any{ExtraKeyResponsesSupported: "true"}, ResponsesSupportUnknown},
|
||||
{"value wrong type number", map[string]any{ExtraKeyResponsesSupported: 1}, ResponsesSupportUnknown},
|
||||
{"value nil", map[string]any{ExtraKeyResponsesSupported: nil}, ResponsesSupportUnknown},
|
||||
{"force responses", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses)}, ResponsesSupportYes},
|
||||
{"force chat completions", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions)}, ResponsesSupportNo},
|
||||
{"auto follows probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeAuto), ExtraKeyResponsesSupported: false}, ResponsesSupportNo},
|
||||
{"invalid mode follows probe", map[string]any{ExtraKeyResponsesMode: "bogus", ExtraKeyResponsesSupported: true}, ResponsesSupportYes},
|
||||
{"force responses overrides probe false", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses), ExtraKeyResponsesSupported: false}, ResponsesSupportYes},
|
||||
{"force chat completions overrides probe true", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions), ExtraKeyResponsesSupported: true}, ResponsesSupportNo},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@ -42,6 +48,10 @@ func TestShouldUseResponsesAPI(t *testing.T) {
|
||||
// 已探测:标记决定
|
||||
{"explicitly supported", map[string]any{ExtraKeyResponsesSupported: true}, true},
|
||||
{"explicitly unsupported", map[string]any{ExtraKeyResponsesSupported: false}, false},
|
||||
|
||||
// 手动覆盖:覆盖自动探测结果
|
||||
{"force responses overrides unsupported probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses), ExtraKeyResponsesSupported: false}, true},
|
||||
{"force chat completions overrides supported probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions), ExtraKeyResponsesSupported: true}, false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@ -53,3 +63,26 @@ func TestShouldUseResponsesAPI(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesSupportMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode string
|
||||
want ResponsesSupportMode
|
||||
}{
|
||||
{"empty", "", ResponsesSupportModeAuto},
|
||||
{"auto", "auto", ResponsesSupportModeAuto},
|
||||
{"force responses", "force_responses", ResponsesSupportModeForceResponses},
|
||||
{"force chat completions", "force_chat_completions", ResponsesSupportModeForceChatCompletions},
|
||||
{"invalid", "enabled", ResponsesSupportModeAuto},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := NormalizeResponsesSupportMode(tc.mode)
|
||||
if got != tc.want {
|
||||
t.Errorf("NormalizeResponsesSupportMode(%q) = %q, want %q", tc.mode, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -230,6 +230,20 @@ type UserDashboardStats struct {
|
||||
// 性能指标
|
||||
Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数
|
||||
Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数
|
||||
|
||||
// 按"有效平台"维度拆分(与 ops 路径口径一致:group.platform 优先,否则 account.platform)
|
||||
ByPlatform []PlatformDashboardStats `json:"by_platform,omitempty"`
|
||||
}
|
||||
|
||||
// PlatformDashboardStats 单个平台的用量明细。
|
||||
type PlatformDashboardStats struct {
|
||||
Platform string `json:"platform"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
TodayRequests int64 `json:"today_requests"`
|
||||
TodayTokens int64 `json:"today_tokens"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
}
|
||||
|
||||
// UsageLogFilters represents filters for usage log queries
|
||||
@ -265,13 +279,22 @@ type UsageStats struct {
|
||||
EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"`
|
||||
}
|
||||
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
type BatchUserUsageStats struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
// PlatformUsage 表示某用户/某 API key 在单个"有效平台"维度的用量明细。
|
||||
// Platform 取值与 ops 路径口径一致:优先 groups.platform,否则 accounts.platform。
|
||||
type PlatformUsage struct {
|
||||
Platform string `json:"platform"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
}
|
||||
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
type BatchUserUsageStats struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
TodayActualCost float64 `json:"today_actual_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
ByPlatform []PlatformUsage `json:"by_platform,omitempty"`
|
||||
}
|
||||
|
||||
// BatchAPIKeyUsageStats represents usage stats for a single API key
|
||||
type BatchAPIKeyUsageStats struct {
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user