chore: merge upstream v0.1.110, keep Claude customizations

Merge strategy: keep local Claude customizations (ours), accept all other upstream changes.

Claude constants.go retains:
- DefaultCLIVersion = 2.1.88
- Enhanced beta headers (9 new betas)
- ModelSupports1M() function
- GetOAuthBetaHeader() function
- GetAPIKeyBetaHeader() function
- ApplyFingerprintOverrides() function

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
This commit is contained in:
win 2026-04-09 00:46:41 +08:00
commit 5595297203
41 changed files with 1742 additions and 409 deletions

View File

@ -19,7 +19,7 @@ jobs:
cache: true
- name: Verify Go version
run: |
go version | grep -q 'go1.26.1'
go version | grep -q 'go1.26.2'
- name: Unit tests
working-directory: backend
run: make test-unit
@ -38,7 +38,7 @@ jobs:
cache: true
- name: Verify Go version
run: |
go version | grep -q 'go1.26.1'
go version | grep -q 'go1.26.2'
- name: golangci-lint
uses: golangci/golangci-lint-action@v9
with:

View File

@ -115,7 +115,7 @@ jobs:
- name: Verify Go version
run: |
go version | grep -q 'go1.26.1'
go version | grep -q 'go1.26.2'
# Docker setup for GoReleaser
- name: Set up QEMU

View File

@ -23,7 +23,7 @@ jobs:
cache-dependency-path: backend/go.sum
- name: Verify Go version
run: |
go version | grep -q 'go1.26.1'
go version | grep -q 'go1.26.2'
- name: Run govulncheck
working-directory: backend
run: |

View File

@ -69,6 +69,11 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
<td width="180"><a href="https://ctok.ai"><img src="assets/partners/logos/ctok.png" alt="CTok" width="150"></a></td>
<td>Thanks to CTok.ai for sponsoring this project! CTok.ai is dedicated to building a one-stop AI programming tool service platform. We offer professional Claude Code packages and technical community services, with support for Google Gemini and OpenAI Codex. Through carefully designed plans and a professional tech community, we provide developers with reliable service guarantees and continuous technical support, making AI-assisted programming a true productivity tool. Click <a href="https://ctok.ai">here</a> to register!</td>
</tr>
<tr>
<td width="180"><a href="https://code.silkapi.com/"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
<td>Thanks to SilkAPI for sponsoring this project! <a href="https://code.silkapi.com/">SilkAPI</a> is a relay service built on Sub2API, specializing in providing high-speed and stable Codex API relay.</td>
</tr>
</table>
## Ecosystem

View File

@ -69,6 +69,10 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的
<td>感谢 CTok.ai 赞助了本项目CTok.ai 致力于打造一站式 AI 编程工具服务平台。我们提供 Claude Code 专业套餐及技术社群服务,同时支持 Google Gemini 和 OpenAI Codex。通过精心设计的套餐方案和专业的技术社群为开发者提供稳定的服务保障和持续的技术支持让 AI 辅助编程真正成为开发者的生产力工具。点击<a href="https://ctok.ai">这里</a>注册!</td>
</tr>
<tr>
<td width="180"><a href="https://code.silkapi.com/"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
<td>感谢 丝绸API 赞助了本项目! <a href="https://code.silkapi.com/">丝绸API</a> 是基于 Sub2API 搭建的中转服务,专注于提供 Codex 高速稳定API中转。</td>
</tr>
</table>
## 生态项目

View File

@ -68,6 +68,11 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを
<td width="180"><a href="https://ctok.ai"><img src="assets/partners/logos/ctok.png" alt="CTok" width="150"></a></td>
<td>CTok.ai のご支援に感謝しますCTok.ai はワンストップ AI プログラミングツールサービスプラットフォームの構築に取り組んでいます。Claude Code の専用プランと技術コミュニティサービスを提供し、Google Gemini や OpenAI Codex もサポートしています。丁寧に設計されたプランと専門的な技術コミュニティを通じて、開発者に安定したサービス保証と継続的な技術サポートを提供し、AI アシスト プログラミングを真の生産性向上ツールにします。<a href="https://ctok.ai">こちら</a>から登録!</td>
</tr>
<tr>
<td width="180"><a href="https://code.silkapi.com/"><img src="assets/partners/logos/silkapi.png" alt="silkapi" width="150"></a></td>
<td>SilkAPI のご支援に感謝します!<a href="https://code.silkapi.com/">SilkAPI</a> は Sub2API をベースに構築された中継サービスで、高速かつ安定した Codex API 中継の提供に特化しています。</td>
</tr>
</table>
## エコシステム

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 KiB

View File

@ -1 +1 @@
0.1.109
0.1.110

View File

@ -1,12 +1,12 @@
module github.com/Wei-Shaw/sub2api
go 1.26.1
go 1.26.2
require (
entgo.io/ent v0.14.5
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/DouDOU-start/go-sora2api v1.1.0
github.com/alitto/pond/v2 v2.6.2
github.com/andybalholm/brotli v1.2.0
github.com/aws/aws-sdk-go-v2 v1.41.3
github.com/aws/aws-sdk-go-v2/config v1.32.10
github.com/aws/aws-sdk-go-v2/credentials v1.19.10
@ -50,7 +50,6 @@ require (
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/agext/levenshtein v1.2.3 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect
@ -67,14 +66,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect
github.com/aws/smithy-go v1.24.2 // indirect
github.com/bdandy/go-errors v1.2.2 // indirect
github.com/bdandy/go-socks4 v1.2.3 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
github.com/bogdanfinn/fhttp v0.6.8 // indirect
github.com/bogdanfinn/quic-go-utls v1.0.9-utls // indirect
github.com/bogdanfinn/tls-client v1.14.0 // indirect
github.com/bogdanfinn/utls v1.7.7-barnius // indirect
github.com/bogdanfinn/websocket v1.5.5-barnius // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
@ -151,7 +143,6 @@ require (
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 // indirect
github.com/testcontainers/testcontainers-go v0.40.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect

View File

@ -10,8 +10,6 @@ github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOEl
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/DouDOU-start/go-sora2api v1.1.0 h1:PxWiukK77StiHxEngOFwT1rKUn9oTAJJTl07wQUXwiU=
github.com/DouDOU-start/go-sora2api v1.1.0/go.mod h1:dcwpethoKfAsMWskDD9iGgc/3yox2tkthPLSMVGnhkE=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo=
@ -60,24 +58,10 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb8
github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q=
github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM=
github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic=
github.com/bdandy/go-socks4 v1.2.3/go.mod h1:98kiVFgpdogR8aIGLWLvjDVZ8XcKPsSI/ypGrO+bqHI=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
github.com/bogdanfinn/fhttp v0.6.8 h1:LiQyHOY3i0QoxxNB7nq27/nGNNbtPj0fuBPozhR7Ws4=
github.com/bogdanfinn/fhttp v0.6.8/go.mod h1:A+EKDzMx2hb4IUbMx4TlkoHnaJEiLl8r/1Ss1Y+5e5M=
github.com/bogdanfinn/quic-go-utls v1.0.9-utls h1:tV6eDEiRbRCcepALSzxR94JUVD3N3ACIiRLgyc2Ep8s=
github.com/bogdanfinn/quic-go-utls v1.0.9-utls/go.mod h1:aHph9B9H9yPOt5xnhWKSOum27DJAqpiHzwX+gjvaXcg=
github.com/bogdanfinn/tls-client v1.14.0 h1:vyk7Cn4BIvLAGVuMfb0tP22OqogfO1lYamquQNEZU1A=
github.com/bogdanfinn/tls-client v1.14.0/go.mod h1:LsU6mXVn8MOFDwTkyRfI7V1BZM1p0wf2ZfZsICW/1fM=
github.com/bogdanfinn/utls v1.7.7-barnius h1:OuJ497cc7F3yKNVHRsYPQdGggmk5x6+V5ZlrCR7fOLU=
github.com/bogdanfinn/utls v1.7.7-barnius/go.mod h1:aAK1VZQlpKZClF1WEQeq6kyclbkPq4hz6xTbB5xSlmg=
github.com/bogdanfinn/websocket v1.5.5-barnius h1:bY+qnxpai1qe7Jmjx+Sds/cmOSpuuLoR8x61rWltjOI=
github.com/bogdanfinn/websocket v1.5.5-barnius/go.mod h1:gvvEw6pTKHb7yOiFvIfAFTStQWyrm25BMVCTj5wRSsI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
@ -94,10 +78,6 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
@ -199,8 +179,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4=
github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y=
github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI=
github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@ -236,8 +214,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
@ -271,8 +247,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@ -324,8 +298,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ=
@ -347,8 +319,6 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc=
github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5/go.mod h1:2JjD2zLQYH5HO74y5+aE3remJQvl6q4Sn6aWA2wD1Ng=
github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU=
github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY=
github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk=
@ -421,15 +391,12 @@ golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.0.0-20211104170005-ce137452f963/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@ -439,15 +406,12 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -1,8 +1,6 @@
package admin
import (
"errors"
"fmt"
"strconv"
"strings"
@ -235,61 +233,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
return result
}
// validatePricingBillingMode 校验计费配置
func validatePricingBillingMode(pricing []service.ChannelModelPricing) error {
for _, p := range pricing {
// 按次/图片模式必须配置默认价格或区间
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage {
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
return errors.New("per-request price or intervals required for per_request/image billing mode")
}
}
// 校验价格不能为负
if err := validatePriceNotNegative("input_price", p.InputPrice); err != nil {
return err
}
if err := validatePriceNotNegative("output_price", p.OutputPrice); err != nil {
return err
}
if err := validatePriceNotNegative("cache_write_price", p.CacheWritePrice); err != nil {
return err
}
if err := validatePriceNotNegative("cache_read_price", p.CacheReadPrice); err != nil {
return err
}
if err := validatePriceNotNegative("image_output_price", p.ImageOutputPrice); err != nil {
return err
}
if err := validatePriceNotNegative("per_request_price", p.PerRequestPrice); err != nil {
return err
}
// 校验 interval至少有一个价格字段非空
for _, iv := range p.Intervals {
if iv.InputPrice == nil && iv.OutputPrice == nil &&
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
iv.PerRequestPrice == nil {
return fmt.Errorf("interval [%d, %s] has no price fields set for model %v",
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models)
}
}
}
return nil
}
func validatePriceNotNegative(field string, val *float64) error {
if val != nil && *val < 0 {
return fmt.Errorf("%s must be >= 0", field)
}
return nil
}
func formatMaxTokens(max *int) string {
if max == nil {
return "∞"
}
return fmt.Sprintf("%d", *max)
}
// --- Handlers ---
// List handles listing channels with pagination
@ -343,10 +286,6 @@ func (h *ChannelHandler) Create(c *gin.Context) {
}
pricing := pricingRequestToService(req.ModelPricing)
if err := validatePricingBillingMode(pricing); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
Name: req.Name,
@ -391,10 +330,6 @@ func (h *ChannelHandler) Update(c *gin.Context) {
}
if req.ModelPricing != nil {
pricing := pricingRequestToService(*req.ModelPricing)
if err := validatePricingBillingMode(pricing); err != nil {
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
return
}
input.ModelPricing = &pricing
}

View File

@ -400,103 +400,3 @@ func TestPricingRequestToService_NilPriceFields(t *testing.T) {
require.Nil(t, r.ImageOutputPrice)
require.Nil(t, r.PerRequestPrice)
}
// ---------------------------------------------------------------------------
// 3. validatePricingBillingMode
// ---------------------------------------------------------------------------
func TestValidatePricingBillingMode(t *testing.T) {
tests := []struct {
name string
pricing []service.ChannelModelPricing
wantErr bool
}{
{
name: "token mode - valid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModeToken},
},
wantErr: false,
},
{
name: "per_request with price - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(0.5),
},
},
wantErr: false,
},
{
name: "per_request with intervals - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModePerRequest,
Intervals: []service.PricingInterval{
{MinTokens: 0, MaxTokens: intPtr(1000), PerRequestPrice: float64Ptr(0.1)},
},
},
},
wantErr: false,
},
{
name: "per_request no price no intervals - invalid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModePerRequest},
},
wantErr: true,
},
{
name: "image with price - valid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModeImage,
PerRequestPrice: float64Ptr(0.2),
},
},
wantErr: false,
},
{
name: "image no price no intervals - invalid",
pricing: []service.ChannelModelPricing{
{BillingMode: service.BillingModeImage},
},
wantErr: true,
},
{
name: "empty list - valid",
pricing: []service.ChannelModelPricing{},
wantErr: false,
},
{
name: "mixed modes with invalid image - invalid",
pricing: []service.ChannelModelPricing{
{
BillingMode: service.BillingModeToken,
InputPrice: float64Ptr(0.01),
},
{
BillingMode: service.BillingModePerRequest,
PerRequestPrice: float64Ptr(0.5),
},
{
BillingMode: service.BillingModeImage,
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validatePricingBillingMode(tt.pricing)
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), "per-request price or intervals required")
} else {
require.NoError(t, err)
}
})
}
}

View File

@ -128,6 +128,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
BackendModeEnabled: settings.BackendModeEnabled,
EnableFingerprintUnification: settings.EnableFingerprintUnification,
EnableMetadataPassthrough: settings.EnableMetadataPassthrough,
EnableCCHSigning: settings.EnableCCHSigning,
})
}
@ -211,6 +212,7 @@ type UpdateSettingsRequest struct {
// Gateway forwarding behavior
EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"`
EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"`
EnableCCHSigning *bool `json:"enable_cch_signing"`
}
// UpdateSettings 更新系统设置
@ -614,6 +616,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
return previousSettings.EnableMetadataPassthrough
}(),
EnableCCHSigning: func() bool {
if req.EnableCCHSigning != nil {
return *req.EnableCCHSigning
}
return previousSettings.EnableCCHSigning
}(),
}
if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil {
@ -693,6 +701,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
BackendModeEnabled: updatedSettings.BackendModeEnabled,
EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification,
EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough,
EnableCCHSigning: updatedSettings.EnableCCHSigning,
})
}
@ -871,6 +880,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.EnableMetadataPassthrough != after.EnableMetadataPassthrough {
changed = append(changed, "enable_metadata_passthrough")
}
if before.EnableCCHSigning != after.EnableCCHSigning {
changed = append(changed, "enable_cch_signing")
}
return changed
}

View File

@ -97,6 +97,7 @@ type SystemSettings struct {
// Gateway forwarding behavior
EnableFingerprintUnification bool `json:"enable_fingerprint_unification"`
EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"`
EnableCCHSigning bool `json:"enable_cch_signing"`
}
type DefaultSubscriptionSetting struct {

View File

@ -758,13 +758,14 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
})
}
if len(funcDecls) == 0 {
if !hasWebSearch {
return nil
}
// Web Search 工具映射
return []GeminiToolDeclaration{{
var declarations []GeminiToolDeclaration
if len(funcDecls) > 0 {
declarations = append(declarations, GeminiToolDeclaration{
FunctionDeclarations: funcDecls,
})
}
if hasWebSearch {
declarations = append(declarations, GeminiToolDeclaration{
GoogleSearch: &GeminiGoogleSearch{
EnhancedContent: &GeminiEnhancedContent{
ImageSearch: &GeminiImageSearch{
@ -772,10 +773,11 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration {
},
},
},
}}
})
}
if len(declarations) == 0 {
return nil
}
return []GeminiToolDeclaration{{
FunctionDeclarations: funcDecls,
}}
return declarations
}

View File

@ -300,6 +300,29 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
}
}
func TestBuildTools_PreservesWebSearchAlongsideFunctions(t *testing.T) {
tools := []ClaudeTool{
{
Name: "get_weather",
Description: "Get weather information",
InputSchema: map[string]any{"type": "object"},
},
{
Type: "web_search_20250305",
Name: "web_search",
},
}
result := buildTools(tools)
require.Len(t, result, 2)
require.Len(t, result[0].FunctionDeclarations, 1)
require.Equal(t, "get_weather", result[0].FunctionDeclarations[0].Name)
require.NotNil(t, result[1].GoogleSearch)
require.NotNil(t, result[1].GoogleSearch.EnhancedContent)
require.NotNil(t, result[1].GoogleSearch.EnhancedContent.ImageSearch)
require.Equal(t, 5, result[1].GoogleSearch.EnhancedContent.ImageSearch.MaxResultCount)
}
func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) {
tests := []struct {
name string
@ -437,3 +460,36 @@ func TestTransformClaudeToGeminiWithOptions_PreservesBillingHeaderSystemBlock(t
})
}
}
func TestTransformClaudeToGeminiWithOptions_PreservesWebSearchAlongsideFunctions(t *testing.T) {
claudeReq := &ClaudeRequest{
Model: "claude-3-5-sonnet-latest",
Messages: []ClaudeMessage{
{
Role: "user",
Content: json.RawMessage(`[{"type":"text","text":"hello"}]`),
},
},
Tools: []ClaudeTool{
{
Name: "get_weather",
Description: "Get weather information",
InputSchema: map[string]any{"type": "object"},
},
{
Type: "web_search_20250305",
Name: "web_search",
},
},
}
body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "gemini-2.5-flash", DefaultTransformOptions())
require.NoError(t, err)
var req V1InternalRequest
require.NoError(t, json.Unmarshal(body, &req))
require.Len(t, req.Request.Tools, 2)
require.Len(t, req.Request.Tools[0].FunctionDeclarations, 1)
require.Equal(t, "get_weather", req.Request.Tools[0].FunctionDeclarations[0].Name)
require.NotNil(t, req.Request.Tools[1].GoogleSearch)
}

View File

@ -181,6 +181,50 @@ func TestChatCompletionsToResponses_ImageURL(t *testing.T) {
assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL)
}
func TestChatCompletionsToResponses_EmptyBase64ImageURLSkipped(t *testing.T) {
content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64,"}}]`
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(content)},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 1)
var parts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
require.Len(t, parts, 1)
assert.Equal(t, "input_text", parts[0].Type)
assert.Equal(t, "Describe this", parts[0].Text)
}
func TestChatCompletionsToResponses_WhitespaceOnlyBase64ImageURLSkipped(t *testing.T) {
content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64, "}}]`
req := &ChatCompletionsRequest{
Model: "gpt-4o",
Messages: []ChatMessage{
{Role: "user", Content: json.RawMessage(content)},
},
}
resp, err := ChatCompletionsToResponses(req)
require.NoError(t, err)
var items []ResponsesInputItem
require.NoError(t, json.Unmarshal(resp.Input, &items))
require.Len(t, items, 1)
var parts []ResponsesContentPart
require.NoError(t, json.Unmarshal(items[0].Content, &parts))
require.Len(t, parts, 1)
assert.Equal(t, "input_text", parts[0].Type)
assert.Equal(t, "Describe this", parts[0].Text)
}
func TestChatCompletionsToResponses_SystemArrayContent(t *testing.T) {
req := &ChatCompletionsRequest{
Model: "gpt-4o",

View File

@ -339,7 +339,7 @@ func convertChatContentPartsToResponses(parts []ChatContentPart) []ResponsesCont
})
}
case "image_url":
if p.ImageURL != nil && p.ImageURL.URL != "" {
if p.ImageURL != nil && p.ImageURL.URL != "" && !isEmptyBase64DataURI(p.ImageURL.URL) {
responseParts = append(responseParts, ResponsesContentPart{
Type: "input_image",
ImageURL: p.ImageURL.URL,
@ -350,6 +350,22 @@ func convertChatContentPartsToResponses(parts []ChatContentPart) []ResponsesCont
return responseParts
}
func isEmptyBase64DataURI(raw string) bool {
if !strings.HasPrefix(raw, "data:") {
return false
}
rest := strings.TrimPrefix(raw, "data:")
semicolonIdx := strings.Index(rest, ";")
if semicolonIdx < 0 {
return false
}
rest = rest[semicolonIdx+1:]
if !strings.HasPrefix(rest, "base64,") {
return false
}
return strings.TrimSpace(strings.TrimPrefix(rest, "base64,")) == ""
}
func flattenChatContentParts(parts []ChatContentPart) string {
var textParts []string
for _, p := range parts {

View File

@ -536,6 +536,7 @@ func TestAPIContracts(t *testing.T) {
"max_claude_code_version": "",
"allow_ungrouped_key_scheduling": false,
"backend_mode_enabled": false,
"enable_cch_signing": false,
"enable_fingerprint_unification": true,
"enable_metadata_passthrough": false,
"custom_menu_items": [],

View File

@ -248,40 +248,58 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform
}
}
// storeErrorCache 存入短 TTL 空缓存,防止 DB 错误后紧密重试。
// 通过回退 loadedAt 使剩余 TTL = channelErrorTTL。
func (s *ChannelService) storeErrorCache() {
errorCache := newEmptyChannelCache()
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
s.cache.Store(errorCache)
}
// buildCache 从数据库构建渠道缓存。
// 使用独立 context 避免请求取消导致空值被长期缓存。
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
// 断开请求取消链,避免客户端断连导致空值被长期缓存
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
defer cancel()
channels, err := s.repo.ListAll(dbCtx)
channels, groupPlatforms, err := s.fetchChannelData(dbCtx)
if err != nil {
// error-TTL失败时存入短 TTL 空缓存,防止紧密重试
slog.Warn("failed to build channel cache", "error", err)
errorCache := newEmptyChannelCache()
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL
s.cache.Store(errorCache)
return nil, fmt.Errorf("list all channels: %w", err)
return nil, err
}
cache := populateChannelCache(channels, groupPlatforms)
s.cache.Store(cache)
return cache, nil
}
// fetchChannelData 从数据库加载渠道列表和分组平台映射。
func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[int64]string, error) {
channels, err := s.repo.ListAll(ctx)
if err != nil {
slog.Warn("failed to build channel cache", "error", err)
s.storeErrorCache()
return nil, nil, fmt.Errorf("list all channels: %w", err)
}
// 收集所有 groupID批量查询 platform
var allGroupIDs []int64
for i := range channels {
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
}
groupPlatforms := make(map[int64]string)
if len(allGroupIDs) > 0 {
groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs)
groupPlatforms, err = s.repo.GetGroupPlatforms(ctx, allGroupIDs)
if err != nil {
slog.Warn("failed to load group platforms for channel cache", "error", err)
errorCache := newEmptyChannelCache()
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
s.cache.Store(errorCache)
return nil, fmt.Errorf("get group platforms: %w", err)
s.storeErrorCache()
return nil, nil, fmt.Errorf("get group platforms: %w", err)
}
}
return channels, groupPlatforms, nil
}
// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。
func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache {
cache := newEmptyChannelCache()
cache.groupPlatform = groupPlatforms
cache.byID = make(map[int64]*Channel, len(channels))
@ -290,7 +308,6 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
for i := range channels {
ch := &channels[i]
cache.byID[ch.ID] = ch
for _, gid := range ch.GroupIDs {
cache.channelByGroupID[gid] = ch
platform := groupPlatforms[gid]
@ -298,11 +315,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
expandMappingToCache(cache, ch, gid, platform)
}
}
// 通配符条目保持配置顺序(最先匹配到优先)
s.cache.Store(cache)
return cache, nil
return cache
}
// invalidateCache 使缓存失效,让下次读取时自然重建
@ -466,7 +479,10 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
// 返回 true 表示模型被限制(不在允许列表中)。
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
lk, _ := s.lookupGroupChannel(ctx, groupID)
lk, err := s.lookupGroupChannel(ctx, groupID)
if err != nil {
slog.Warn("failed to load channel cache for model restriction check", "group_id", groupID, "error", err)
}
if lk == nil {
return false
}
@ -537,6 +553,91 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
return newBody
}
// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
// Create 和 Update 共用此函数,避免重复。
func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error {
if err := validateNoConflictingModels(pricing); err != nil {
return err
}
if err := validatePricingIntervals(pricing); err != nil {
return err
}
if err := validateNoConflictingMappings(mapping); err != nil {
return err
}
return validatePricingBillingMode(pricing)
}
// validatePricingBillingMode 校验计费模式配置:按次/图片模式必须配价格或区间,所有价格字段不能为负,区间至少有一个价格字段。
func validatePricingBillingMode(pricing []ChannelModelPricing) error {
for _, p := range pricing {
if err := checkBillingModeRequirements(p); err != nil {
return err
}
if err := checkPricesNotNegative(p); err != nil {
return err
}
if err := checkIntervalsHavePrices(p); err != nil {
return err
}
}
return nil
}
func checkBillingModeRequirements(p ChannelModelPricing) error {
if p.BillingMode == BillingModePerRequest || p.BillingMode == BillingModeImage {
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
return infraerrors.BadRequest(
"BILLING_MODE_MISSING_PRICE",
"per-request price or intervals required for per_request/image billing mode",
)
}
}
return nil
}
func checkPricesNotNegative(p ChannelModelPricing) error {
checks := []struct {
field string
val *float64
}{
{"input_price", p.InputPrice},
{"output_price", p.OutputPrice},
{"cache_write_price", p.CacheWritePrice},
{"cache_read_price", p.CacheReadPrice},
{"image_output_price", p.ImageOutputPrice},
{"per_request_price", p.PerRequestPrice},
}
for _, c := range checks {
if c.val != nil && *c.val < 0 {
return infraerrors.BadRequest("NEGATIVE_PRICE", fmt.Sprintf("%s must be >= 0", c.field))
}
}
return nil
}
func checkIntervalsHavePrices(p ChannelModelPricing) error {
for _, iv := range p.Intervals {
if iv.InputPrice == nil && iv.OutputPrice == nil &&
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
iv.PerRequestPrice == nil {
return infraerrors.BadRequest(
"INTERVAL_MISSING_PRICE",
fmt.Sprintf("interval [%d, %s] has no price fields set for model %v",
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models),
)
}
}
return nil
}
func formatMaxTokens(max *int) string {
if max == nil {
return "∞"
}
return fmt.Sprintf("%d", *max)
}
// --- CRUD ---
// Create 创建渠道
@ -549,15 +650,8 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
return nil, ErrChannelExists
}
// 检查分组冲突
if len(input.GroupIDs) > 0 {
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs)
if err != nil {
return nil, fmt.Errorf("check group conflicts: %w", err)
}
if len(conflicting) > 0 {
return nil, ErrGroupAlreadyInChannel
}
if err := s.checkGroupConflicts(ctx, 0, input.GroupIDs); err != nil {
return nil, err
}
channel := &Channel{
@ -574,13 +668,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
channel.BillingModelSource = BillingModelSourceChannelMapped
}
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err
}
@ -604,102 +692,112 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
return nil, fmt.Errorf("get channel: %w", err)
}
if input.Name != "" && input.Name != channel.Name {
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id)
if err != nil {
return nil, fmt.Errorf("check channel exists: %w", err)
}
if exists {
return nil, ErrChannelExists
}
channel.Name = input.Name
}
if input.Description != nil {
channel.Description = *input.Description
}
if input.Status != "" {
channel.Status = input.Status
}
if input.RestrictModels != nil {
channel.RestrictModels = *input.RestrictModels
}
// 检查分组冲突
if input.GroupIDs != nil {
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
if err != nil {
return nil, fmt.Errorf("check group conflicts: %w", err)
}
if len(conflicting) > 0 {
return nil, ErrGroupAlreadyInChannel
}
channel.GroupIDs = *input.GroupIDs
}
if input.ModelPricing != nil {
channel.ModelPricing = *input.ModelPricing
}
if input.ModelMapping != nil {
channel.ModelMapping = input.ModelMapping
}
if input.BillingModelSource != "" {
channel.BillingModelSource = input.BillingModelSource
}
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
return nil, err
}
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
return nil, err
}
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
if err := s.applyUpdateInput(ctx, channel, input); err != nil {
return nil, err
}
// 先获取旧分组Update 后旧分组关联已删除,无法再查到
var oldGroupIDs []int64
if s.authCacheInvalidator != nil {
var err2 error
oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
if err2 != nil {
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
}
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
return nil, err
}
oldGroupIDs := s.getOldGroupIDs(ctx, id)
if err := s.repo.Update(ctx, channel); err != nil {
return nil, fmt.Errorf("update channel: %w", err)
}
s.invalidateCache()
// 失效新旧分组的 auth 缓存
if s.authCacheInvalidator != nil {
seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
for _, gid := range oldGroupIDs {
if _, ok := seen[gid]; !ok {
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
for _, gid := range channel.GroupIDs {
if _, ok := seen[gid]; !ok {
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
}
s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs)
return s.repo.GetByID(ctx, id)
}
// applyUpdateInput 将更新请求的字段应用到渠道实体上。
func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel, input *UpdateChannelInput) error {
if input.Name != "" && input.Name != channel.Name {
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, channel.ID)
if err != nil {
return fmt.Errorf("check channel exists: %w", err)
}
if exists {
return ErrChannelExists
}
channel.Name = input.Name
}
if input.Description != nil {
channel.Description = *input.Description
}
if input.Status != "" {
channel.Status = input.Status
}
if input.RestrictModels != nil {
channel.RestrictModels = *input.RestrictModels
}
if input.GroupIDs != nil {
if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil {
return err
}
channel.GroupIDs = *input.GroupIDs
}
if input.ModelPricing != nil {
channel.ModelPricing = *input.ModelPricing
}
if input.ModelMapping != nil {
channel.ModelMapping = input.ModelMapping
}
if input.BillingModelSource != "" {
channel.BillingModelSource = input.BillingModelSource
}
return nil
}
// checkGroupConflicts 检查待关联的分组是否已属于其他渠道。
// channelID 为当前渠道 IDCreate 时传 0
func (s *ChannelService) checkGroupConflicts(ctx context.Context, channelID int64, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
}
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, channelID, groupIDs)
if err != nil {
return fmt.Errorf("check group conflicts: %w", err)
}
if len(conflicting) > 0 {
return ErrGroupAlreadyInChannel
}
return nil
}
// getOldGroupIDs 获取渠道更新前的关联分组 ID用于失效 auth 缓存)。
func (s *ChannelService) getOldGroupIDs(ctx context.Context, channelID int64) []int64 {
if s.authCacheInvalidator == nil {
return nil
}
oldGroupIDs, err := s.repo.GetGroupIDs(ctx, channelID)
if err != nil {
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", channelID, "error", err)
}
return oldGroupIDs
}
// invalidateAuthCacheForGroups 对新旧分组去重后逐个失效 auth 缓存。
func (s *ChannelService) invalidateAuthCacheForGroups(ctx context.Context, groupIDSets ...[]int64) {
if s.authCacheInvalidator == nil {
return
}
seen := make(map[int64]struct{})
for _, ids := range groupIDSets {
for _, gid := range ids {
if _, ok := seen[gid]; ok {
continue
}
seen[gid] = struct{}{}
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
}
// Delete 删除渠道
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
// 先获取关联分组用于失效缓存
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
if err != nil {
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
@ -710,12 +808,7 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error {
}
s.invalidateCache()
if s.authCacheInvalidator != nil {
for _, gid := range groupIDs {
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
}
}
s.invalidateAuthCacheForGroups(ctx, groupIDs)
return nil
}

View File

@ -2199,3 +2199,207 @@ func TestGetChannelModelPricing_NonAntigravityUnaffected(t *testing.T) {
require.Equal(t, int64(601), result.ID)
require.InDelta(t, 5e-6, *result.InputPrice, 1e-12)
}
// ---------------------------------------------------------------------------
// 10. ToUsageFields
// ---------------------------------------------------------------------------
func TestToUsageFields_NoMapping(t *testing.T) {
r := ChannelMappingResult{
MappedModel: "claude-opus-4",
ChannelID: 1,
Mapped: false,
BillingModelSource: BillingModelSourceRequested,
}
fields := r.ToUsageFields("claude-opus-4", "claude-opus-4")
require.Equal(t, int64(1), fields.ChannelID)
require.Equal(t, "claude-opus-4", fields.OriginalModel)
require.Equal(t, "claude-opus-4", fields.ChannelMappedModel)
require.Equal(t, BillingModelSourceRequested, fields.BillingModelSource)
require.Empty(t, fields.ModelMappingChain)
}
func TestToUsageFields_WithChannelMapping(t *testing.T) {
r := ChannelMappingResult{
MappedModel: "claude-sonnet-4-20250514",
ChannelID: 2,
Mapped: true,
BillingModelSource: BillingModelSourceChannelMapped,
}
fields := r.ToUsageFields("claude-sonnet-4", "claude-sonnet-4-20250514")
require.Equal(t, int64(2), fields.ChannelID)
require.Equal(t, "claude-sonnet-4", fields.OriginalModel)
require.Equal(t, "claude-sonnet-4-20250514", fields.ChannelMappedModel)
require.Equal(t, "claude-sonnet-4→claude-sonnet-4-20250514", fields.ModelMappingChain)
}
func TestToUsageFields_WithUpstreamDifference(t *testing.T) {
r := ChannelMappingResult{
MappedModel: "claude-sonnet-4",
ChannelID: 3,
Mapped: true,
BillingModelSource: BillingModelSourceUpstream,
}
fields := r.ToUsageFields("my-alias", "claude-sonnet-4-20250514")
require.Equal(t, "my-alias", fields.OriginalModel)
require.Equal(t, "claude-sonnet-4", fields.ChannelMappedModel)
require.Equal(t, "my-alias→claude-sonnet-4→claude-sonnet-4-20250514", fields.ModelMappingChain)
}
// ---------------------------------------------------------------------------
// 11. validatePricingBillingMode (moved from handler tests)
// ---------------------------------------------------------------------------
func TestValidatePricingBillingMode(t *testing.T) {
tests := []struct {
name string
pricing []ChannelModelPricing
wantErr bool
errMsg string
}{
{
name: "token mode - valid",
pricing: []ChannelModelPricing{{BillingMode: BillingModeToken}},
},
{
name: "per_request with price - valid",
pricing: []ChannelModelPricing{{
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0.5),
}},
},
{
name: "per_request with intervals - valid",
pricing: []ChannelModelPricing{{
BillingMode: BillingModePerRequest,
Intervals: []PricingInterval{{MinTokens: 0, MaxTokens: testPtrInt(1000), PerRequestPrice: testPtrFloat64(0.1)}},
}},
},
{
name: "per_request no price no intervals - invalid",
pricing: []ChannelModelPricing{{BillingMode: BillingModePerRequest}},
wantErr: true,
errMsg: "per-request price or intervals required",
},
{
name: "image no price no intervals - invalid",
pricing: []ChannelModelPricing{{BillingMode: BillingModeImage}},
wantErr: true,
errMsg: "per-request price or intervals required",
},
{
name: "empty list - valid",
pricing: []ChannelModelPricing{},
},
{
name: "negative input_price - invalid",
pricing: []ChannelModelPricing{{
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(-0.01),
}},
wantErr: true,
errMsg: "input_price must be >= 0",
},
{
name: "interval with no price fields - invalid",
pricing: []ChannelModelPricing{{
BillingMode: BillingModePerRequest,
PerRequestPrice: testPtrFloat64(0.5),
Intervals: []PricingInterval{{MinTokens: 0, MaxTokens: testPtrInt(1000)}},
}},
wantErr: true,
errMsg: "has no price fields set",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validatePricingBillingMode(tt.pricing)
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errMsg)
} else {
require.NoError(t, err)
}
})
}
}
// ---------------------------------------------------------------------------
// 12. Antigravity wildcard mapping isolation
// ---------------------------------------------------------------------------
func TestResolveChannelMapping_AntigravityDoesNotSeeWildcardMappingFromOtherPlatforms(t *testing.T) {
ch := Channel{
ID: 1,
Status: StatusActive,
GroupIDs: []int64{10, 20},
ModelMapping: map[string]map[string]string{
PlatformAnthropic: {"claude-*": "claude-override"},
PlatformGemini: {"gemini-*": "gemini-override"},
},
}
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity, 20: PlatformAnthropic})
svc := newTestChannelService(repo)
// antigravity 分组不应看到 anthropic/gemini 的通配符映射
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
require.False(t, result.Mapped)
require.Equal(t, "claude-opus-4", result.MappedModel)
result = svc.ResolveChannelMapping(context.Background(), 10, "gemini-2.5-pro")
require.False(t, result.Mapped)
require.Equal(t, "gemini-2.5-pro", result.MappedModel)
// anthropic 分组应该能看到 anthropic 的通配符映射
result = svc.ResolveChannelMapping(context.Background(), 20, "claude-opus-4")
require.True(t, result.Mapped)
require.Equal(t, "claude-override", result.MappedModel)
}
// ---------------------------------------------------------------------------
// 13. Create/Update with mapping conflict validation
// ---------------------------------------------------------------------------
func TestCreate_MappingConflict(t *testing.T) {
repo := &mockChannelRepository{}
svc := newTestChannelService(repo)
_, err := svc.Create(context.Background(), &CreateChannelInput{
Name: "test",
ModelMapping: map[string]map[string]string{
PlatformAnthropic: {
"claude-*": "target-a",
"claude-opus-*": "target-b",
},
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "MAPPING_PATTERN_CONFLICT")
}
func TestUpdate_MappingConflict(t *testing.T) {
existingChannel := &Channel{
ID: 1,
Name: "existing",
Status: StatusActive,
}
repo := &mockChannelRepository{
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
return existingChannel, nil
},
}
svc := newTestChannelService(repo)
conflictMapping := map[string]map[string]string{
PlatformAnthropic: {
"claude-*": "target-a",
"claude-opus-*": "target-b",
},
}
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
ModelMapping: conflictMapping,
})
require.Error(t, err)
require.Contains(t, err.Error(), "MAPPING_PATTERN_CONFLICT")
}

View File

@ -218,6 +218,8 @@ const (
SettingKeyEnableFingerprintUnification = "enable_fingerprint_unification"
// SettingKeyEnableMetadataPassthrough 是否透传客户端原始 metadata.user_id默认 false
SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough"
// SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false
SettingKeyEnableCCHSigning = "enable_cch_signing"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).

View File

@ -761,7 +761,9 @@ func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock(
system := gjson.GetBytes(upstream.lastBody, "system")
require.True(t, system.Exists())
require.Equal(t, claudeCodeSystemPrompt, system.String())
require.True(t, system.IsArray(), "system should be an array")
require.Equal(t, claudeCodeSystemPrompt, system.Array()[0].Get("text").String())
require.Equal(t, "ephemeral", system.Array()[0].Get("cache_control.type").String())
// 原始 system prompt 应迁移至 messages 中
messages := gjson.GetBytes(upstream.lastBody, "messages")

View File

@ -0,0 +1,73 @@
package service
import (
"fmt"
"regexp"
"strings"
"github.com/cespare/xxhash/v2"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ccVersionInBillingRe matches the semver part of cc_version (X.Y.Z), preserving
// the trailing message-derived suffix (e.g. ".c02") if present.
var ccVersionInBillingRe = regexp.MustCompile(`cc_version=\d+\.\d+\.\d+`)
// cchPlaceholderRe matches the cch=00000 placeholder in billing header text,
// scoped to x-anthropic-billing-header to avoid touching user content.
var cchPlaceholderRe = regexp.MustCompile(`(x-anthropic-billing-header:[^"]*?\bcch=)(00000)(;)`)
const cchSeed uint64 = 0x6E52736AC806831E
// syncBillingHeaderVersion rewrites cc_version in x-anthropic-billing-header
// system text blocks to match the version extracted from userAgent.
// Only touches system array blocks whose text starts with "x-anthropic-billing-header".
func syncBillingHeaderVersion(body []byte, userAgent string) []byte {
version := ExtractCLIVersion(userAgent)
if version == "" {
return body
}
systemResult := gjson.GetBytes(body, "system")
if !systemResult.Exists() || !systemResult.IsArray() {
return body
}
replacement := "cc_version=" + version
idx := 0
systemResult.ForEach(func(_, item gjson.Result) bool {
text := item.Get("text")
if text.Exists() && text.Type == gjson.String &&
strings.HasPrefix(text.String(), "x-anthropic-billing-header") {
newText := ccVersionInBillingRe.ReplaceAllString(text.String(), replacement)
if newText != text.String() {
if updated, err := sjson.SetBytes(body, fmt.Sprintf("system.%d.text", idx), newText); err == nil {
body = updated
}
}
}
idx++
return true
})
return body
}
// signBillingHeaderCCH computes the xxHash64-based CCH signature for the request
// body and replaces the cch=00000 placeholder with the computed 5-hex-char hash.
// The body must contain the placeholder when this function is called.
func signBillingHeaderCCH(body []byte) []byte {
if !cchPlaceholderRe.Match(body) {
return body
}
cch := fmt.Sprintf("%05x", xxHash64Seeded(body, cchSeed)&0xFFFFF)
return cchPlaceholderRe.ReplaceAll(body, []byte("${1}"+cch+"${3}"))
}
// xxHash64Seeded computes xxHash64 of data with a custom seed.
func xxHash64Seeded(data []byte, seed uint64) uint64 {
d := xxhash.NewWithSeed(seed)
_, _ = d.Write(data)
return d.Sum64()
}

View File

@ -0,0 +1,165 @@
package service
import (
"fmt"
"testing"
"github.com/cespare/xxhash/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestSyncBillingHeaderVersion(t *testing.T) {
tests := []struct {
name string
body string
userAgent string
wantSub string // substring expected in result
unchanged bool // expect body to remain the same
}{
{
name: "replaces cc_version preserving message-derived suffix",
body: `{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.81.df2; cc_entrypoint=cli; cch=00000;"},{"type":"text","text":"You are Claude Code.","cache_control":{"type":"ephemeral"}}],"messages":[]}`,
userAgent: "claude-cli/2.1.22 (external, cli)",
wantSub: "cc_version=2.1.22.df2",
},
{
name: "no billing header in system",
body: `{"system":[{"type":"text","text":"You are Claude Code."}],"messages":[]}`,
userAgent: "claude-cli/2.1.22",
unchanged: true,
},
{
name: "no system field",
body: `{"messages":[]}`,
userAgent: "claude-cli/2.1.22",
unchanged: true,
},
{
name: "user-agent without version",
body: `{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.81; cc_entrypoint=cli; cch=00000;"}],"messages":[]}`,
userAgent: "Mozilla/5.0",
unchanged: true,
},
{
name: "empty user-agent",
body: `{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.81; cc_entrypoint=cli; cch=00000;"}],"messages":[]}`,
userAgent: "",
unchanged: true,
},
{
name: "version already matches",
body: `{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.22; cc_entrypoint=cli; cch=00000;"}],"messages":[]}`,
userAgent: "claude-cli/2.1.22",
unchanged: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := syncBillingHeaderVersion([]byte(tt.body), tt.userAgent)
if tt.unchanged {
assert.Equal(t, tt.body, string(result), "body should remain unchanged")
} else {
assert.Contains(t, string(result), tt.wantSub)
// Ensure old semver is gone
assert.NotContains(t, string(result), "cc_version=2.1.81")
}
})
}
}
func TestSignBillingHeaderCCH(t *testing.T) {
t.Run("replaces placeholder with hash", func(t *testing.T) {
body := []byte(`{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.63.a43; cc_entrypoint=cli; cch=00000;"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
result := signBillingHeaderCCH(body)
// Should not have the placeholder anymore
assert.NotContains(t, string(result), "cch=00000")
// Should have a 5 hex-char cch value
billingText := gjson.GetBytes(result, "system.0.text").String()
require.Contains(t, billingText, "cch=")
assert.Regexp(t, `cch=[0-9a-f]{5};`, billingText)
})
t.Run("no placeholder - body unchanged", func(t *testing.T) {
body := []byte(`{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.63; cc_entrypoint=cli; cch=abcde;"}],"messages":[]}`)
result := signBillingHeaderCCH(body)
assert.Equal(t, string(body), string(result))
})
t.Run("no billing header - body unchanged", func(t *testing.T) {
body := []byte(`{"system":[{"type":"text","text":"You are Claude Code."}],"messages":[]}`)
result := signBillingHeaderCCH(body)
assert.Equal(t, string(body), string(result))
})
t.Run("cch=00000 in user content is not touched", func(t *testing.T) {
body := []byte(`{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.63; cc_entrypoint=cli; cch=00000;"}],"messages":[{"role":"user","content":[{"type":"text","text":"keep literal cch=00000 in this message"}]}]}`)
result := signBillingHeaderCCH(body)
// Billing header should be signed
billingText := gjson.GetBytes(result, "system.0.text").String()
assert.NotContains(t, billingText, "cch=00000")
// User message should keep its literal cch=00000
userText := gjson.GetBytes(result, "messages.0.content.0.text").String()
assert.Contains(t, userText, "cch=00000")
})
t.Run("signing is deterministic", func(t *testing.T) {
body := []byte(`{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.63; cc_entrypoint=cli; cch=00000;"}],"messages":[{"role":"user","content":"hi"}]}`)
r1 := signBillingHeaderCCH(body)
body2 := []byte(`{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.63; cc_entrypoint=cli; cch=00000;"}],"messages":[{"role":"user","content":"hi"}]}`)
r2 := signBillingHeaderCCH(body2)
assert.Equal(t, string(r1), string(r2))
})
t.Run("matches reference algorithm", func(t *testing.T) {
// Verify: signBillingHeaderCCH(body) produces cch = xxHash64(body_with_placeholder, seed) & 0xFFFFF
body := []byte(`{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.63.a43; cc_entrypoint=cli; cch=00000;"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
expectedCCH := fmt.Sprintf("%05x", xxHash64Seeded(body, cchSeed)&0xFFFFF)
result := signBillingHeaderCCH(body)
billingText := gjson.GetBytes(result, "system.0.text").String()
assert.Contains(t, billingText, "cch="+expectedCCH+";")
})
}
func TestXXHash64Seeded(t *testing.T) {
t.Run("matches cespare/xxhash for seed 0", func(t *testing.T) {
inputs := []string{"", "a", "hello world", "The quick brown fox jumps over the lazy dog"}
for _, s := range inputs {
data := []byte(s)
expected := xxhash.Sum64(data)
got := xxHash64Seeded(data, 0)
assert.Equal(t, expected, got, "mismatch for input %q", s)
}
})
t.Run("large input matches cespare", func(t *testing.T) {
data := make([]byte, 256)
for i := range data {
data[i] = byte(i)
}
expected := xxhash.Sum64(data)
got := xxHash64Seeded(data, 0)
assert.Equal(t, expected, got)
})
t.Run("deterministic with custom seed", func(t *testing.T) {
data := []byte("hello world")
h1 := xxHash64Seeded(data, cchSeed)
h2 := xxHash64Seeded(data, cchSeed)
assert.Equal(t, h1, h2)
})
t.Run("different seeds produce different results", func(t *testing.T) {
data := []byte("test data for hashing")
h1 := xxHash64Seeded(data, 0)
h2 := xxHash64Seeded(data, cchSeed)
assert.NotEqual(t, h1, h2)
})
}

View File

@ -284,7 +284,7 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) {
name string
body string
system any
wantSystemStr string // system 应为纯字符串
wantSystemText string // system array 第一个 block 的 text
wantMessagesLen int // messages 数组长度
wantFirstMsgRole string // 第一条消息的 role
wantFirstMsgText string // 第一条消息的 content[0].text
@ -294,21 +294,21 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) {
name: "nil system - no messages injected",
body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`,
system: nil,
wantSystemStr: claudeCodeSystemPrompt,
wantSystemText: claudeCodeSystemPrompt,
wantMessagesLen: 1, // 原始 1 条消息,不注入
},
{
name: "empty string system - no messages injected",
body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`,
system: "",
wantSystemStr: claudeCodeSystemPrompt,
wantSystemText: claudeCodeSystemPrompt,
wantMessagesLen: 1,
},
{
name: "custom string system - migrated to messages",
body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`,
system: "You are a personal assistant running inside OpenClaw.",
wantSystemStr: claudeCodeSystemPrompt,
wantSystemText: claudeCodeSystemPrompt,
wantMessagesLen: 3, // instruction + ack + original
wantFirstMsgRole: "user",
wantFirstMsgText: "[System Instructions]\nYou are a personal assistant running inside OpenClaw.",
@ -318,7 +318,7 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) {
name: "system equals Claude Code prompt - no messages injected",
body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`,
system: claudeCodeSystemPrompt,
wantSystemStr: claudeCodeSystemPrompt,
wantSystemText: claudeCodeSystemPrompt,
wantMessagesLen: 1,
},
{
@ -328,7 +328,7 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) {
map[string]any{"type": "text", "text": "First instruction"},
map[string]any{"type": "text", "text": "Second instruction"},
},
wantSystemStr: claudeCodeSystemPrompt,
wantSystemText: claudeCodeSystemPrompt,
wantMessagesLen: 3,
wantFirstMsgRole: "user",
wantFirstMsgText: "[System Instructions]\nFirst instruction\n\nSecond instruction",
@ -338,14 +338,14 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) {
name: "empty array system - no messages injected",
body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`,
system: []any{},
wantSystemStr: claudeCodeSystemPrompt,
wantSystemText: claudeCodeSystemPrompt,
wantMessagesLen: 1,
},
{
name: "json.RawMessage string system",
body: `{"model":"claude-3","system":"Custom prompt","messages":[{"role":"user","content":"hello"}]}`,
system: json.RawMessage(`"Custom prompt"`),
wantSystemStr: claudeCodeSystemPrompt,
wantSystemText: claudeCodeSystemPrompt,
wantMessagesLen: 3,
wantFirstMsgRole: "user",
wantFirstMsgText: "[System Instructions]\nCustom prompt",
@ -355,14 +355,14 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) {
name: "json.RawMessage nil system",
body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`,
system: json.RawMessage(nil),
wantSystemStr: claudeCodeSystemPrompt,
wantSystemText: claudeCodeSystemPrompt,
wantMessagesLen: 1,
},
{
name: "multiple original messages preserved",
body: `{"model":"claude-3","messages":[{"role":"user","content":"msg1"},{"role":"assistant","content":"resp1"},{"role":"user","content":"msg2"}]}`,
system: "Be helpful",
wantSystemStr: claudeCodeSystemPrompt,
wantSystemText: claudeCodeSystemPrompt,
wantMessagesLen: 5, // 2 injected + 3 original
wantFirstMsgRole: "user",
wantFirstMsgText: "[System Instructions]\nBe helpful",
@ -378,10 +378,17 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) {
err := json.Unmarshal(result, &parsed)
require.NoError(t, err)
// system 应为纯字符串
systemVal, ok := parsed["system"].(string)
require.True(t, ok, "system should be a string, got %T", parsed["system"])
require.Equal(t, tt.wantSystemStr, systemVal)
// system 应为 array 格式: [{type: "text", text: "...", cache_control: {type: "ephemeral"}}]
systemArr, ok := parsed["system"].([]any)
require.True(t, ok, "system should be an array, got %T", parsed["system"])
require.Len(t, systemArr, 1, "system array should have exactly 1 block")
systemBlock, ok := systemArr[0].(map[string]any)
require.True(t, ok)
require.Equal(t, "text", systemBlock["type"])
require.Equal(t, tt.wantSystemText, systemBlock["text"])
cc, ok := systemBlock["cache_control"].(map[string]any)
require.True(t, ok, "system block should have cache_control")
require.Equal(t, "ephemeral", cc["type"])
// 检查 messages
messages, ok := parsed["messages"].([]any)

View File

@ -3736,8 +3736,17 @@ func rewriteSystemForNonClaudeCode(body []byte, system any) []byte {
originalSystemText = strings.Join(parts, "\n\n")
}
// 2. 将 system 替换为 Claude Code 标准提示词(纯字符串,通过 Anthropic 检测)
out, ok := setJSONValueBytes(body, "system", claudeCodeSystemPrompt)
// 2. 将 system 替换为 Claude Code 标准提示词array 格式,与真实 Claude Code 一致)
// 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。
// 使用 string 格式会被 Anthropic 检测为第三方应用。
claudeCodeSystemBlock := []map[string]any{
{
"type": "text",
"text": claudeCodeSystemPrompt,
"cache_control": map[string]string{"type": "ephemeral"},
},
}
out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock)
if !ok {
logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt")
return body
@ -3975,17 +3984,22 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if shouldMimicClaudeCode {
// 非 Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages
// 条件1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
systemRewritten := false
if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
!systemIncludesClaudeCodePrompt(parsed.System) {
body = rewriteSystemForNonClaudeCode(body, parsed.System)
systemRewritten = true
}
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
// system 被重写时保留 CC prompt 的 cache_control: ephemeral匹配真实 Claude Code 行为);
// 未重写时haiku / 已含 CC 前缀)剥离客户端 cache_control与原有行为一致。
// 两种情况下 enforceCacheControlLimit 都会兜底处理上限。
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
if s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil && fp != nil {
// metadata 透传开启时跳过 metadata 注入
_, mimicMPT := s.settingService.GetGatewayForwardingSettings(ctx)
_, mimicMPT, _ := s.settingService.GetGatewayForwardingSettings(ctx)
if !mimicMPT {
if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" {
normalizeOpts.injectMetadata = true
@ -5571,9 +5585,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// OAuth账号应用统一指纹和metadata重写受设置开关控制
var fingerprint *Fingerprint
enableFP, enableMPT := true, false
enableFP, enableMPT, enableCCH := true, false, false
if s.settingService != nil {
enableFP, enableMPT = s.settingService.GetGatewayForwardingSettings(ctx)
enableFP, enableMPT, enableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
}
if account.IsOAuth() && s.identityService != nil {
// 1. 获取或创建指纹包含随机生成的ClientID
@ -5600,6 +5614,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
// 同步 billing header cc_version 与实际发送的 User-Agent 版本
if fingerprint != nil {
body = syncBillingHeaderVersion(body, fingerprint.UserAgent)
}
// CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后)
if enableCCH {
body = signBillingHeaderCCH(body)
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
@ -5642,7 +5665,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// Build effective drop set: merge static defaults with dynamic beta policy filter rules
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID)
effectiveDropSet := mergeDropSets(policyFilterSet)
effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode)
// 处理 anthropic-beta headerOAuth 账号需要包含 oauth beta
if tokenType == "oauth" {
@ -5653,11 +5675,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
applyClaudeCodeMimicHeaders(req, reqStream)
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
// Match real Claude CLI traffic (per mitmproxy reports):
// messages requests typically use only oauth + interleaved-thinking.
// Also drop claude-code beta if a downstream client added it.
// Claude Code OAuth credentials are scoped to Claude Code.
// Non-haiku models MUST include claude-code beta for Anthropic to recognize
// this as a legitimate Claude Code request; without it, the request is
// rejected as third-party ("out of extra usage").
// Haiku models are exempt from third-party detection and don't need it.
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet))
if !strings.Contains(strings.ToLower(modelID), "haiku") {
requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking}
}
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet))
} else {
// Claude Code 客户端:尽量透传原始 header仅补齐 oauth beta
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")
@ -8501,9 +8528,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:应用统一指纹和重写 userID受设置开关控制
// 如果启用了会话ID伪装会在重写后替换 session 部分为固定值
ctEnableFP, ctEnableMPT := true, false
ctEnableFP, ctEnableMPT, ctEnableCCH := true, false, false
if s.settingService != nil {
ctEnableFP, ctEnableMPT = s.settingService.GetGatewayForwardingSettings(ctx)
ctEnableFP, ctEnableMPT, ctEnableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
}
var ctFingerprint *Fingerprint
if account.IsOAuth() && s.identityService != nil {
@ -8521,6 +8548,14 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
// 同步 billing header cc_version 与实际发送的 User-Agent 版本
if ctFingerprint != nil && ctEnableFP {
body = syncBillingHeaderVersion(body, ctFingerprint.UserAgent)
}
if ctEnableCCH {
body = signBillingHeaderCCH(body)
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err

View File

@ -612,7 +612,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq))
restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
if err != nil {
return nil, "", err
}
@ -685,7 +686,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
fullURL += "?alt=sse"
}
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq))
restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq)
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq))
if err != nil {
return nil, "", err
}
@ -3184,12 +3186,17 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
return nil
}
hasWebSearch := false
funcDecls := make([]any, 0, len(arr))
for _, t := range arr {
tm, ok := t.(map[string]any)
if !ok {
continue
}
if isClaudeWebSearchToolMap(tm) {
hasWebSearch = true
continue
}
var name, desc string
var params any
@ -3233,13 +3240,75 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
})
}
if len(funcDecls) == 0 {
out := make([]any, 0, 2)
if len(funcDecls) > 0 {
out = append(out, map[string]any{
"functionDeclarations": funcDecls,
})
}
if hasWebSearch {
out = append(out, map[string]any{
"googleSearch": map[string]any{},
})
}
if len(out) == 0 {
return nil
}
return []any{
map[string]any{
"functionDeclarations": funcDecls,
},
return out
}
func normalizeGeminiRequestForAIStudio(body []byte) []byte {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return body
}
tools, ok := payload["tools"].([]any)
if !ok || len(tools) == 0 {
return body
}
modified := false
for _, rawTool := range tools {
tool, ok := rawTool.(map[string]any)
if !ok {
continue
}
googleSearch, ok := tool["googleSearch"]
if !ok {
continue
}
if _, exists := tool["google_search"]; exists {
continue
}
tool["google_search"] = googleSearch
delete(tool, "googleSearch")
modified = true
}
if !modified {
return body
}
normalized, err := json.Marshal(payload)
if err != nil {
return body
}
return normalized
}
func isClaudeWebSearchToolMap(tool map[string]any) bool {
toolType, _ := tool["type"].(string)
if strings.HasPrefix(toolType, "web_search") || toolType == "google_search" {
return true
}
name, _ := tool["name"].(string)
switch strings.TrimSpace(name) {
case "web_search", "google_search", "web_search_20250305":
return true
default:
return false
}
}

View File

@ -164,6 +164,35 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
}
}
func TestConvertClaudeToolsToGeminiTools_PreservesWebSearchAlongsideFunctions(t *testing.T) {
tools := []any{
map[string]any{
"name": "get_weather",
"description": "Get weather info",
"input_schema": map[string]any{"type": "object"},
},
map[string]any{
"type": "web_search_20250305",
"name": "web_search",
},
}
result := convertClaudeToolsToGeminiTools(tools)
require.Len(t, result, 2)
functionDecl, ok := result[0].(map[string]any)
require.True(t, ok)
funcDecls, ok := functionDecl["functionDeclarations"].([]any)
require.True(t, ok)
require.Len(t, funcDecls, 1)
searchDecl, ok := result[1].(map[string]any)
require.True(t, ok)
googleSearch, ok := searchDecl["googleSearch"].(map[string]any)
require.True(t, ok)
require.Empty(t, googleSearch)
}
func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t)
@ -232,6 +261,53 @@ func TestGeminiMessagesCompatServiceForward_PreservesRequestedModelAndMappedUpst
require.Contains(t, httpStub.lastReq.URL.String(), "/models/claude-sonnet-4-20250514:")
}
func TestGeminiMessagesCompatServiceForward_NormalizesWebSearchToolForAIStudio(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
httpStub := &geminiCompatHTTPUpstreamStub{
response: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"x-request-id": []string{"gemini-req-2"}},
Body: io.NopCloser(strings.NewReader(`{"candidates":[{"content":{"parts":[{"text":"hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}`)),
},
}
svc := &GeminiMessagesCompatService{httpUpstream: httpStub, cfg: &config.Config{}}
account := &Account{
ID: 1,
Type: AccountTypeAPIKey,
Credentials: map[string]any{
"api_key": "test-key",
},
}
body := []byte(`{"model":"claude-sonnet-4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"get_weather","description":"Get weather info","input_schema":{"type":"object"}},{"type":"web_search_20250305","name":"web_search"}]}`)
result, err := svc.Forward(context.Background(), c, account, body)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, httpStub.lastReq)
postedBody, err := io.ReadAll(httpStub.lastReq.Body)
require.NoError(t, err)
var posted map[string]any
require.NoError(t, json.Unmarshal(postedBody, &posted))
tools, ok := posted["tools"].([]any)
require.True(t, ok)
require.Len(t, tools, 2)
searchTool, ok := tools[1].(map[string]any)
require.True(t, ok)
_, hasSnake := searchTool["google_search"]
_, hasCamel := searchTool["googleSearch"]
require.True(t, hasSnake)
require.False(t, hasCamel)
_, hasFuncDecl := searchTool["functionDeclarations"]
require.False(t, hasFuncDecl)
}
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
claudeReq := map[string]any{
"model": "claude-haiku-4-5-20251001",

View File

@ -0,0 +1,107 @@
package service
import (
"encoding/json"
"strings"
"github.com/tidwall/gjson"
)
// contentSessionSeedPrefix prevents collisions between content-derived seeds
// and explicit session IDs (e.g. "sess-xxx" or "compat_cc_xxx").
const contentSessionSeedPrefix = "compat_cs_"
// deriveOpenAIContentSessionSeed builds a stable session seed from an
// OpenAI-format request body. Only fields constant across conversation turns
// are included: model, tools/functions definitions, system/developer prompts,
// instructions (Responses API), and the first user message.
// Supports both Chat Completions (messages) and Responses API (input).
func deriveOpenAIContentSessionSeed(body []byte) string {
if len(body) == 0 {
return ""
}
var b strings.Builder
if model := gjson.GetBytes(body, "model").String(); model != "" {
_, _ = b.WriteString("model=")
_, _ = b.WriteString(model)
}
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() && tools.Raw != "[]" {
_, _ = b.WriteString("|tools=")
_, _ = b.WriteString(normalizeCompatSeedJSON(json.RawMessage(tools.Raw)))
}
if funcs := gjson.GetBytes(body, "functions"); funcs.Exists() && funcs.IsArray() && funcs.Raw != "[]" {
_, _ = b.WriteString("|functions=")
_, _ = b.WriteString(normalizeCompatSeedJSON(json.RawMessage(funcs.Raw)))
}
if instr := gjson.GetBytes(body, "instructions").String(); instr != "" {
_, _ = b.WriteString("|instructions=")
_, _ = b.WriteString(instr)
}
firstUserCaptured := false
msgs := gjson.GetBytes(body, "messages")
if msgs.Exists() && msgs.IsArray() {
msgs.ForEach(func(_, msg gjson.Result) bool {
role := msg.Get("role").String()
switch role {
case "system", "developer":
_, _ = b.WriteString("|system=")
if c := msg.Get("content"); c.Exists() {
_, _ = b.WriteString(normalizeCompatSeedJSON(json.RawMessage(c.Raw)))
}
case "user":
if !firstUserCaptured {
_, _ = b.WriteString("|first_user=")
if c := msg.Get("content"); c.Exists() {
_, _ = b.WriteString(normalizeCompatSeedJSON(json.RawMessage(c.Raw)))
}
firstUserCaptured = true
}
}
return true
})
} else if inp := gjson.GetBytes(body, "input"); inp.Exists() {
if inp.Type == gjson.String {
_, _ = b.WriteString("|input=")
_, _ = b.WriteString(inp.String())
} else if inp.IsArray() {
inp.ForEach(func(_, item gjson.Result) bool {
role := item.Get("role").String()
switch role {
case "system", "developer":
_, _ = b.WriteString("|system=")
if c := item.Get("content"); c.Exists() {
_, _ = b.WriteString(normalizeCompatSeedJSON(json.RawMessage(c.Raw)))
}
case "user":
if !firstUserCaptured {
_, _ = b.WriteString("|first_user=")
if c := item.Get("content"); c.Exists() {
_, _ = b.WriteString(normalizeCompatSeedJSON(json.RawMessage(c.Raw)))
}
firstUserCaptured = true
}
}
if !firstUserCaptured && item.Get("type").String() == "input_text" {
_, _ = b.WriteString("|first_user=")
if text := item.Get("text").String(); text != "" {
_, _ = b.WriteString(text)
}
firstUserCaptured = true
}
return true
})
}
}
if b.Len() == 0 {
return ""
}
return contentSessionSeedPrefix + b.String()
}

View File

@ -0,0 +1,218 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestDeriveOpenAIContentSessionSeed_EmptyInputs(t *testing.T) {
require.Empty(t, deriveOpenAIContentSessionSeed(nil))
require.Empty(t, deriveOpenAIContentSessionSeed([]byte{}))
require.Empty(t, deriveOpenAIContentSessionSeed([]byte(`{}`)))
}
func TestDeriveOpenAIContentSessionSeed_ModelOnly(t *testing.T) {
seed := deriveOpenAIContentSessionSeed([]byte(`{"model":"gpt-5.4"}`))
require.Contains(t, seed, contentSessionSeedPrefix)
require.Contains(t, seed, "model=gpt-5.4")
}
func TestDeriveOpenAIContentSessionSeed_ChatCompletions_StableAcrossTurns(t *testing.T) {
turn1 := []byte(`{
"model": "gpt-5.4",
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello"}
]
}`)
turn2 := []byte(`{
"model": "gpt-5.4",
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
}`)
s1 := deriveOpenAIContentSessionSeed(turn1)
s2 := deriveOpenAIContentSessionSeed(turn2)
require.Equal(t, s1, s2, "seed should be stable across later turns")
require.NotEmpty(t, s1)
}
func TestDeriveOpenAIContentSessionSeed_ChatCompletions_DifferentFirstUserDiffers(t *testing.T) {
req1 := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"Question A"}]}`)
req2 := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"Question B"}]}`)
s1 := deriveOpenAIContentSessionSeed(req1)
s2 := deriveOpenAIContentSessionSeed(req2)
require.NotEqual(t, s1, s2)
}
func TestDeriveOpenAIContentSessionSeed_ChatCompletions_DifferentSystemDiffers(t *testing.T) {
req1 := []byte(`{"model":"gpt-5.4","messages":[{"role":"system","content":"A"},{"role":"user","content":"Hi"}]}`)
req2 := []byte(`{"model":"gpt-5.4","messages":[{"role":"system","content":"B"},{"role":"user","content":"Hi"}]}`)
s1 := deriveOpenAIContentSessionSeed(req1)
s2 := deriveOpenAIContentSessionSeed(req2)
require.NotEqual(t, s1, s2)
}
func TestDeriveOpenAIContentSessionSeed_ChatCompletions_DifferentModelDiffers(t *testing.T) {
req1 := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"Hi"}]}`)
req2 := []byte(`{"model":"gpt-4o","messages":[{"role":"user","content":"Hi"}]}`)
s1 := deriveOpenAIContentSessionSeed(req1)
s2 := deriveOpenAIContentSessionSeed(req2)
require.NotEqual(t, s1, s2)
}
func TestDeriveOpenAIContentSessionSeed_ChatCompletions_WithTools(t *testing.T) {
withTools := []byte(`{
"model": "gpt-5.4",
"tools": [{"type":"function","function":{"name":"get_weather"}}],
"messages": [{"role": "user", "content": "Hello"}]
}`)
withoutTools := []byte(`{
"model": "gpt-5.4",
"messages": [{"role": "user", "content": "Hello"}]
}`)
s1 := deriveOpenAIContentSessionSeed(withTools)
s2 := deriveOpenAIContentSessionSeed(withoutTools)
require.NotEqual(t, s1, s2, "tools should affect the seed")
require.Contains(t, s1, "|tools=")
}
func TestDeriveOpenAIContentSessionSeed_ChatCompletions_WithFunctions(t *testing.T) {
body := []byte(`{
"model": "gpt-5.4",
"functions": [{"name":"get_weather","parameters":{}}],
"messages": [{"role": "user", "content": "Hello"}]
}`)
seed := deriveOpenAIContentSessionSeed(body)
require.Contains(t, seed, "|functions=")
}
func TestDeriveOpenAIContentSessionSeed_ChatCompletions_DeveloperRole(t *testing.T) {
body := []byte(`{
"model": "gpt-5.4",
"messages": [
{"role": "developer", "content": "You are helpful."},
{"role": "user", "content": "Hello"}
]
}`)
seed := deriveOpenAIContentSessionSeed(body)
require.Contains(t, seed, "|system=")
require.Contains(t, seed, "|first_user=")
}
func TestDeriveOpenAIContentSessionSeed_ChatCompletions_StructuredContent(t *testing.T) {
body := []byte(`{
"model": "gpt-5.4",
"messages": [
{"role": "user", "content": [{"type":"text","text":"Hello"}]}
]
}`)
seed := deriveOpenAIContentSessionSeed(body)
require.NotEmpty(t, seed)
require.Contains(t, seed, "|first_user=")
}
func TestDeriveOpenAIContentSessionSeed_ResponsesAPI_InputString(t *testing.T) {
body := []byte(`{"model":"gpt-5.4","input":"Hello, how are you?"}`)
seed := deriveOpenAIContentSessionSeed(body)
require.Contains(t, seed, "|input=Hello, how are you?")
}
func TestDeriveOpenAIContentSessionSeed_ResponsesAPI_InputArray(t *testing.T) {
body := []byte(`{
"model": "gpt-5.4",
"input": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello"}
]
}`)
seed := deriveOpenAIContentSessionSeed(body)
require.Contains(t, seed, "|system=")
require.Contains(t, seed, "|first_user=")
}
func TestDeriveOpenAIContentSessionSeed_ResponsesAPI_WithInstructions(t *testing.T) {
body := []byte(`{
"model": "gpt-5.4",
"instructions": "You are a coding assistant.",
"input": "Write a hello world"
}`)
seed := deriveOpenAIContentSessionSeed(body)
require.Contains(t, seed, "|instructions=You are a coding assistant.")
require.Contains(t, seed, "|input=Write a hello world")
}
func TestDeriveOpenAIContentSessionSeed_Deterministic(t *testing.T) {
body := []byte(`{
"model": "gpt-5.4",
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello"}
]
}`)
s1 := deriveOpenAIContentSessionSeed(body)
s2 := deriveOpenAIContentSessionSeed(body)
require.Equal(t, s1, s2, "seed must be deterministic")
}
func TestDeriveOpenAIContentSessionSeed_PrefixPresent(t *testing.T) {
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"Hi"}]}`)
seed := deriveOpenAIContentSessionSeed(body)
require.True(t, len(seed) > len(contentSessionSeedPrefix))
require.Equal(t, contentSessionSeedPrefix, seed[:len(contentSessionSeedPrefix)])
}
func TestDeriveOpenAIContentSessionSeed_EmptyToolsIgnored(t *testing.T) {
body := []byte(`{"model":"gpt-5.4","tools":[],"messages":[{"role":"user","content":"Hi"}]}`)
seed := deriveOpenAIContentSessionSeed(body)
require.NotContains(t, seed, "|tools=")
}
func TestDeriveOpenAIContentSessionSeed_MessagesPreferredOverInput(t *testing.T) {
body := []byte(`{
"model": "gpt-5.4",
"messages": [{"role": "user", "content": "from messages"}],
"input": "from input"
}`)
seed := deriveOpenAIContentSessionSeed(body)
require.Contains(t, seed, "|first_user=")
require.NotContains(t, seed, "|input=")
}
func TestDeriveOpenAIContentSessionSeed_JSONCanonicalisation(t *testing.T) {
compact := []byte(`{"model":"gpt-5.4","tools":[{"type":"function","function":{"name":"get_weather","description":"Get weather"}}],"messages":[{"role":"user","content":"Hi"}]}`)
spaced := []byte(`{
"model": "gpt-5.4",
"tools": [
{ "type" : "function", "function": { "description": "Get weather", "name": "get_weather" } }
],
"messages": [ { "role": "user", "content": "Hi" } ]
}`)
s1 := deriveOpenAIContentSessionSeed(compact)
s2 := deriveOpenAIContentSessionSeed(spaced)
require.Equal(t, s1, s2, "different formatting of identical JSON should produce the same seed")
}
func TestDeriveOpenAIContentSessionSeed_ResponsesAPI_InputTextTypedItem(t *testing.T) {
body := []byte(`{
"model": "gpt-5.4",
"input": [{"type": "input_text", "text": "Hello world"}]
}`)
seed := deriveOpenAIContentSessionSeed(body)
require.Contains(t, seed, "|first_user=")
require.Contains(t, seed, "Hello world")
}
func TestDeriveOpenAIContentSessionSeed_ResponsesAPI_TypedMessageItem(t *testing.T) {
body := []byte(`{
"model": "gpt-5.4",
"input": [{"type": "message", "role": "user", "content": "Hello from typed message"}]
}`)
seed := deriveOpenAIContentSessionSeed(body)
require.Contains(t, seed, "|first_user=")
require.Contains(t, seed, "Hello from typed message")
}

View File

@ -1121,6 +1121,7 @@ func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) str
// 1. Header: session_id
// 2. Header: conversation_id
// 3. Body: prompt_cache_key (opencode)
// 4. Body: content-based fallback (model + system + tools + first user message)
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) string {
if c == nil {
return ""
@ -1133,6 +1134,9 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte)
if sessionID == "" && len(body) > 0 {
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
}
if sessionID == "" && len(body) > 0 {
sessionID = deriveOpenAIContentSessionSeed(body)
}
if sessionID == "" {
return ""
}
@ -2048,6 +2052,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
}
if sanitizeEmptyBase64InputImagesInOpenAIRequestBodyMap(reqBody) {
bodyModified = true
disablePatch()
}
// Re-serialize body only if modified
if bodyModified {
serializedByPatch := false
@ -2475,6 +2484,14 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
reqStream = gjson.GetBytes(body, "stream").Bool()
}
sanitizedBody, sanitized, err := sanitizeEmptyBase64InputImagesInOpenAIBody(body)
if err != nil {
return nil, err
}
if sanitized {
body = sanitizedBody
}
logger.LegacyPrintf("service.openai_gateway",
"[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
account.ID,
@ -3007,6 +3024,14 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
return nil, err
}
// Detect SSE responses from upstream and convert to JSON.
// Some upstreams (e.g. other sub2api instances) may return SSE even when
// stream=false was requested. Without this conversion the client would
// receive raw SSE text or a terminal event with empty output.
if isEventStreamResponse(resp.Header) {
return s.handlePassthroughSSEToJSON(resp, c, body)
}
usage := &OpenAIUsage{}
usageParsed := false
if len(body) > 0 {
@ -3030,6 +3055,56 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
return usage, nil
}
// handlePassthroughSSEToJSON converts an SSE response body into a JSON
// response for the passthrough path. It mirrors handleSSEToJSON but skips
// model replacement (passthrough does not remap models).
func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte) (*OpenAIUsage, error) {
bodyText := string(body)
finalResponse, ok := extractCodexFinalResponse(bodyText)
usage := &OpenAIUsage{}
if ok {
if parsedUsage, parsed := extractOpenAIUsageFromJSONBytes(finalResponse); parsed {
*usage = parsedUsage
}
// When the terminal event has an empty output array, reconstruct
// output from accumulated delta events so the client gets full content.
if len(gjson.GetBytes(finalResponse, "output").Array()) == 0 {
if outputJSON, reconstructed := reconstructResponseOutputFromSSE(bodyText); reconstructed {
if patched, err := sjson.SetRawBytes(finalResponse, "output", outputJSON); err == nil {
finalResponse = patched
}
}
}
body = finalResponse
// Correct tool calls in final response
body = s.correctToolCallsInResponseBody(body)
} else {
terminalType, terminalPayload, terminalOK := extractOpenAISSETerminalEvent(bodyText)
if terminalOK && terminalType == "response.failed" {
msg := extractOpenAISSEErrorMessage(terminalPayload)
if msg == "" {
msg = "Upstream compact response failed"
}
return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg)
}
usage = s.parseSSEUsageFromBody(bodyText)
}
writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter)
contentType := "application/json; charset=utf-8"
if !ok {
contentType = resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "text/event-stream"
}
}
c.Data(resp.StatusCode, contentType, body)
return usage, nil
}
func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) {
if dst == nil || src == nil {
return
@ -3858,10 +3933,21 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r
return nil, err
}
// Detect SSE responses for ALL account types via Content-Type header.
// Some OpenAI-compatible upstreams (including other sub2api instances)
// may return SSE even when stream=false was requested.
if isEventStreamResponse(resp.Header) {
return s.handleSSEToJSON(resp, c, body, originalModel, mappedModel)
}
// For OAuth accounts, also fall back to a body-content heuristic because
// the upstream may omit the Content-Type header while still sending SSE.
// This heuristic is NOT applied to API-key accounts to avoid false
// positives on JSON responses that coincidentally contain "data:" or
// "event:" in their text content.
if account.Type == AccountTypeOAuth {
bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:"))
if isEventStreamResponse(resp.Header) || bodyLooksLikeSSE {
return s.handleOAuthSSEToJSON(resp, c, body, originalModel, mappedModel)
if bodyLooksLikeSSE {
return s.handleSSEToJSON(resp, c, body, originalModel, mappedModel)
}
}
@ -3895,7 +3981,7 @@ func isEventStreamResponse(header http.Header) bool {
return strings.Contains(contentType, "text/event-stream")
}
func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) {
func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) {
bodyText := string(body)
finalResponse, ok := extractCodexFinalResponse(bodyText)
@ -4954,6 +5040,123 @@ func normalizeOpenAIServiceTier(raw string) *string {
}
}
func sanitizeEmptyBase64InputImagesInOpenAIBody(body []byte) ([]byte, bool, error) {
if len(body) == 0 || !bytes.Contains(body, []byte(`"image_url"`)) || !bytes.Contains(body, []byte(`base64,`)) {
return body, false, nil
}
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
return body, false, fmt.Errorf("sanitize request body: %w", err)
}
if !sanitizeEmptyBase64InputImagesInOpenAIRequestBodyMap(reqBody) {
return body, false, nil
}
normalized, err := json.Marshal(reqBody)
if err != nil {
return body, false, fmt.Errorf("serialize sanitized request body: %w", err)
}
return normalized, true, nil
}
func sanitizeEmptyBase64InputImagesInOpenAIRequestBodyMap(reqBody map[string]any) bool {
if reqBody == nil {
return false
}
input, ok := reqBody["input"]
if !ok {
return false
}
normalizedInput, changed := sanitizeEmptyBase64InputImagesInOpenAIInput(input)
if !changed {
return false
}
reqBody["input"] = normalizedInput
return true
}
func sanitizeEmptyBase64InputImagesInOpenAIInput(input any) (any, bool) {
items, ok := input.([]any)
if !ok {
return input, false
}
normalizedItems := make([]any, 0, len(items))
changed := false
for _, item := range items {
itemMap, ok := item.(map[string]any)
if !ok {
normalizedItems = append(normalizedItems, item)
continue
}
if shouldDropEmptyBase64InputImagePart(itemMap) {
changed = true
continue
}
content, ok := itemMap["content"]
if !ok {
normalizedItems = append(normalizedItems, itemMap)
continue
}
parts, ok := content.([]any)
if !ok {
normalizedItems = append(normalizedItems, itemMap)
continue
}
normalizedParts := make([]any, 0, len(parts))
itemChanged := false
for _, part := range parts {
if shouldDropEmptyBase64InputImagePart(part) {
changed = true
itemChanged = true
continue
}
normalizedParts = append(normalizedParts, part)
}
if itemChanged {
if len(normalizedParts) == 0 {
continue
}
itemMap["content"] = normalizedParts
}
normalizedItems = append(normalizedItems, itemMap)
}
if !changed {
return input, false
}
return normalizedItems, true
}
func shouldDropEmptyBase64InputImagePart(part any) bool {
partMap, ok := part.(map[string]any)
if !ok {
return false
}
typeValue, _ := partMap["type"].(string)
if strings.TrimSpace(typeValue) != "input_image" {
return false
}
imageURL, _ := partMap["image_url"].(string)
return isEmptyBase64DataURI(imageURL)
}
func isEmptyBase64DataURI(raw string) bool {
if !strings.HasPrefix(raw, "data:") {
return false
}
rest := strings.TrimPrefix(raw, "data:")
semicolonIdx := strings.Index(rest, ";")
if semicolonIdx < 0 {
return false
}
rest = rest[semicolonIdx+1:]
if !strings.HasPrefix(rest, "base64,") {
return false
}
return strings.TrimSpace(strings.TrimPrefix(rest, "base64,")) == ""
}
func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) {
if c != nil {
if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok {

View File

@ -1,6 +1,7 @@
package service
import (
"encoding/json"
"net/http/httptest"
"testing"
@ -139,3 +140,61 @@ func TestGetOpenAIRequestBodyMap_WriteBackContextCache(t *testing.T) {
require.True(t, ok)
require.Equal(t, got, cachedMap)
}
func TestSanitizeEmptyBase64InputImagesInOpenAIRequestBodyMap(t *testing.T) {
var reqBody map[string]any
require.NoError(t, json.Unmarshal([]byte(`{
"model":"gpt-5.4",
"input":[
{"role":"user","content":[
{"type":"input_text","text":"Describe this"},
{"type":"input_image","image_url":"data:image/png;base64, "},
{"type":"input_image","image_url":"data:image/png;base64,abc123"}
]},
{"role":"user","content":[
{"type":"input_image","image_url":"data:image/png;base64,"}
]},
{"type":"input_image","image_url":"data:image/png;base64,"},
{"type":"input_image","image_url":"data:image/png;base64,top-level-valid"}
]
}`), &reqBody))
require.True(t, sanitizeEmptyBase64InputImagesInOpenAIRequestBodyMap(reqBody))
normalized, err := json.Marshal(reqBody)
require.NoError(t, err)
require.JSONEq(t, `{
"model":"gpt-5.4",
"input":[
{"role":"user","content":[
{"type":"input_text","text":"Describe this"},
{"type":"input_image","image_url":"data:image/png;base64,abc123"}
]},
{"type":"input_image","image_url":"data:image/png;base64,top-level-valid"}
]
}`, string(normalized))
}
func TestSanitizeEmptyBase64InputImagesInOpenAIBody(t *testing.T) {
body, changed, err := sanitizeEmptyBase64InputImagesInOpenAIBody([]byte(`{
"model":"gpt-5.4",
"stream":true,
"input":[
{"role":"user","content":[
{"type":"input_text","text":"Describe this"},
{"type":"input_image","image_url":"data:image/png;base64,"}
]}
]
}`))
require.NoError(t, err)
require.True(t, changed)
require.JSONEq(t, `{
"model":"gpt-5.4",
"stream":true,
"input":[
{"role":"user","content":[
{"type":"input_text","text":"Describe this"}
]}
]
}`, string(body))
}

View File

@ -237,6 +237,60 @@ func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) {
require.Equal(t, "", empty)
}
func TestOpenAIGatewayService_GenerateSessionHash_ContentFallback(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/chat/completions", nil)
svc := &OpenAIGatewayService{}
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"system","content":"You are helpful."},{"role":"user","content":"Hello"}]}`)
hash := svc.GenerateSessionHash(c, body)
require.NotEmpty(t, hash, "content-based fallback should produce a hash")
hash2 := svc.GenerateSessionHash(c, body)
require.Equal(t, hash, hash2, "same content should produce same hash")
bodyExtended := []byte(`{"model":"gpt-5.4","messages":[{"role":"system","content":"You are helpful."},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi!"},{"role":"user","content":"How are you?"}]}`)
hashExtended := svc.GenerateSessionHash(c, bodyExtended)
require.Equal(t, hash, hashExtended, "hash should be stable across later turns")
bodyDifferent := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"Different question"}]}`)
hashDifferent := svc.GenerateSessionHash(c, bodyDifferent)
require.NotEqual(t, hash, hashDifferent, "different content should produce different hash")
}
func TestOpenAIGatewayService_GenerateSessionHash_ExplicitSignalWinsOverContent(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/chat/completions", nil)
svc := &OpenAIGatewayService{}
body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"Hello"}]}`)
contentHash := svc.GenerateSessionHash(c, body)
require.NotEmpty(t, contentHash)
c.Request.Header.Set("session_id", "explicit-session")
explicitHash := svc.GenerateSessionHash(c, body)
require.NotEmpty(t, explicitHash)
require.NotEqual(t, contentHash, explicitHash, "explicit session_id should override content fallback")
}
func TestOpenAIGatewayService_GenerateSessionHash_EmptyBodyStillEmpty(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/chat/completions", nil)
svc := &OpenAIGatewayService{}
require.Empty(t, svc.GenerateSessionHash(c, []byte(`{}`)))
require.Empty(t, svc.GenerateSessionHash(c, nil))
}
func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if c.waitCounts != nil {
if count, ok := c.waitCounts[accountID]; ok {
@ -1797,7 +1851,7 @@ func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) {
require.Contains(t, string(finalResp), `"input_tokens":11`)
}
func TestHandleOAuthSSEToJSON_CompletedEventReturnsJSON(t *testing.T) {
func TestHandleSSEToJSON_CompletedEventReturnsJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
@ -1814,7 +1868,7 @@ func TestHandleOAuthSSEToJSON_CompletedEventReturnsJSON(t *testing.T) {
`data: [DONE]`,
}, "\n"))
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
usage, err := svc.handleSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, 7, usage.InputTokens)
@ -1826,7 +1880,7 @@ func TestHandleOAuthSSEToJSON_CompletedEventReturnsJSON(t *testing.T) {
require.NotContains(t, rec.Body.String(), "data:")
}
func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) {
func TestHandleSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
@ -1842,7 +1896,7 @@ func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) {
`data: [DONE]`,
}, "\n"))
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
usage, err := svc.handleSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, 0, usage.InputTokens)
@ -1850,7 +1904,7 @@ func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) {
require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`)
}
func TestHandleOAuthSSEToJSON_ResponseFailedReturnsProtocolError(t *testing.T) {
func TestHandleSSEToJSON_ResponseFailedReturnsProtocolError(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
@ -1866,7 +1920,7 @@ func TestHandleOAuthSSEToJSON_ResponseFailedReturnsProtocolError(t *testing.T) {
`data: [DONE]`,
}, "\n"))
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
usage, err := svc.handleSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
require.Nil(t, usage)
require.Error(t, err)
require.Equal(t, http.StatusBadGateway, rec.Code)

View File

@ -81,6 +81,7 @@ const backendModeDBTimeout = 5 * time.Second
type cachedGatewayForwardingSettings struct {
fingerprintUnification bool
metadataPassthrough bool
cchSigning bool
expiresAt int64 // unix nano
}
@ -514,6 +515,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
// Gateway forwarding behavior
updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification)
updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough)
updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning)
err = s.settingRepo.SetMultiple(ctx, updates)
if err == nil {
@ -533,6 +535,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
fingerprintUnification: settings.EnableFingerprintUnification,
metadataPassthrough: settings.EnableMetadataPassthrough,
cchSigning: settings.EnableCCHSigning,
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
})
if s.onUpdate != nil {
@ -639,20 +642,20 @@ func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool {
// GetGatewayForwardingSettings returns cached gateway forwarding settings.
// Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path.
// Returns (fingerprintUnification, metadataPassthrough).
func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough bool) {
// Returns (fingerprintUnification, metadataPassthrough, cchSigning).
func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough, cchSigning bool) {
if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return cached.fingerprintUnification, cached.metadataPassthrough
return cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning
}
}
type gwfResult struct {
fp, mp bool
fp, mp, cch bool
}
val, _, _ := gatewayForwardingSF.Do("gateway_forwarding", func() (any, error) {
if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil {
if time.Now().UnixNano() < cached.expiresAt {
return gwfResult{cached.fingerprintUnification, cached.metadataPassthrough}, nil
return gwfResult{cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning}, nil
}
}
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), gatewayForwardingDBTimeout)
@ -660,32 +663,36 @@ func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fing
values, err := s.settingRepo.GetMultiple(dbCtx, []string{
SettingKeyEnableFingerprintUnification,
SettingKeyEnableMetadataPassthrough,
SettingKeyEnableCCHSigning,
})
if err != nil {
slog.Warn("failed to get gateway forwarding settings", "error", err)
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
fingerprintUnification: true,
metadataPassthrough: false,
cchSigning: false,
expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(),
})
return gwfResult{true, false}, nil
return gwfResult{true, false, false}, nil
}
fp := true
if v, ok := values[SettingKeyEnableFingerprintUnification]; ok && v != "" {
fp = v == "true"
}
mp := values[SettingKeyEnableMetadataPassthrough] == "true"
cch := values[SettingKeyEnableCCHSigning] == "true"
gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{
fingerprintUnification: fp,
metadataPassthrough: mp,
cchSigning: cch,
expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(),
})
return gwfResult{fp, mp}, nil
return gwfResult{fp, mp, cch}, nil
})
if r, ok := val.(gwfResult); ok {
return r.fp, r.mp
return r.fp, r.mp, r.cch
}
return true, false // fail-open defaults
return true, false, false // fail-open defaults
}
// IsEmailVerifyEnabled 检查是否开启邮件验证
@ -983,13 +990,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
// 分组隔离
result.AllowUngroupedKeyScheduling = settings[SettingKeyAllowUngroupedKeyScheduling] == "true"
// Gateway forwarding behavior (defaults: fingerprint=true, metadata_passthrough=false)
// Gateway forwarding behavior (defaults: fingerprint=true, metadata_passthrough=false, cch_signing=false)
if v, ok := settings[SettingKeyEnableFingerprintUnification]; ok && v != "" {
result.EnableFingerprintUnification = v == "true"
} else {
result.EnableFingerprintUnification = true // default: enabled (current behavior)
}
result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true"
result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true"
return result
}

View File

@ -78,6 +78,7 @@ type SystemSettings struct {
// Gateway forwarding behavior
EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata默认 false
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false
}
type DefaultSubscriptionSetting struct {

View File

@ -7,7 +7,7 @@
# =============================================================================
ARG NODE_IMAGE=node:24-alpine
ARG GOLANG_IMAGE=golang:1.26.1-alpine
ARG GOLANG_IMAGE=golang:1.26.2-alpine
ARG ALPINE_IMAGE=alpine:3.20
ARG GOPROXY=https://goproxy.cn,direct
ARG GOSUMDB=sum.golang.google.cn

View File

@ -89,6 +89,7 @@ export interface SystemSettings {
// Gateway forwarding behavior
enable_fingerprint_unification: boolean
enable_metadata_passthrough: boolean
enable_cch_signing: boolean
}
export interface UpdateSettingsRequest {
@ -146,6 +147,7 @@ export interface UpdateSettingsRequest {
allow_ungrouped_key_scheduling?: boolean
enable_fingerprint_unification?: boolean
enable_metadata_passthrough?: boolean
enable_cch_signing?: boolean
}
/**

View File

@ -4268,6 +4268,8 @@ export default {
fingerprintUnificationHint: 'Unify X-Stainless-* headers across users sharing the same OAuth account. Disabling passes through each client\'s original headers.',
metadataPassthrough: 'Metadata Passthrough',
metadataPassthroughHint: 'Pass through client\'s original metadata.user_id without rewriting. May improve upstream cache hit rates.',
cchSigning: 'CCH Signing',
cchSigningHint: 'Sign the billing header in forwarded requests with CCH hash. When disabled, the placeholder is preserved.',
},
site: {
title: 'Site Settings',

View File

@ -4431,6 +4431,8 @@ export default {
fingerprintUnificationHint: '统一共享同一 OAuth 账号的用户的 X-Stainless-* 请求头。关闭后透传客户端原始请求头。',
metadataPassthrough: 'Metadata 透传',
metadataPassthroughHint: '透传客户端原始 metadata.user_id不进行重写。可能提高上游缓存命中率。',
cchSigning: 'CCH 签名',
cchSigningHint: '对转发请求的 billing header 进行 CCH 哈希签名。关闭时保留原始占位符。',
},
site: {
title: '站点设置',

View File

@ -1376,6 +1376,19 @@
</div>
<Toggle v-model="form.enable_metadata_passthrough" />
</div>
<!-- CCH Signing -->
<div class="flex items-center justify-between">
<div>
<label class="text-sm font-medium text-gray-700 dark:text-gray-300">
{{ t('admin.settings.gatewayForwarding.cchSigning') }}
</label>
<p class="mt-0.5 text-xs text-gray-500 dark:text-gray-400">
{{ t('admin.settings.gatewayForwarding.cchSigningHint') }}
</p>
</div>
<Toggle v-model="form.enable_cch_signing" />
</div>
</div>
</div>
</div><!-- /Tab: Gateway Claude Code, Scheduling -->
@ -2248,7 +2261,8 @@ const form = reactive<SettingsForm>({
allow_ungrouped_key_scheduling: false,
// Gateway forwarding behavior
enable_fingerprint_unification: true,
enable_metadata_passthrough: false
enable_metadata_passthrough: false,
enable_cch_signing: false
})
const defaultSubscriptionGroupOptions = computed<DefaultSubscriptionGroupOption[]>(() =>
@ -2556,7 +2570,8 @@ async function saveSettings() {
max_claude_code_version: form.max_claude_code_version,
allow_ungrouped_key_scheduling: form.allow_ungrouped_key_scheduling,
enable_fingerprint_unification: form.enable_fingerprint_unification,
enable_metadata_passthrough: form.enable_metadata_passthrough
enable_metadata_passthrough: form.enable_metadata_passthrough,
enable_cch_signing: form.enable_cch_signing
}
const updated = await adminAPI.settings.updateSettings(payload)
Object.assign(form, updated)