diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index 01c00bb9..6f76ef4f 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -19,7 +19,7 @@ jobs: cache: true - name: Verify Go version run: | - go version | grep -q 'go1.26.1' + go version | grep -q 'go1.26.2' - name: Unit tests working-directory: backend run: make test-unit @@ -38,7 +38,7 @@ jobs: cache: true - name: Verify Go version run: | - go version | grep -q 'go1.26.1' + go version | grep -q 'go1.26.2' - name: golangci-lint uses: golangci/golangci-lint-action@v9 with: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c51b3c07..b729c575 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -115,7 +115,7 @@ jobs: - name: Verify Go version run: | - go version | grep -q 'go1.26.1' + go version | grep -q 'go1.26.2' # Docker setup for GoReleaser - name: Set up QEMU diff --git a/.github/workflows/security-scan.yml b/.github/workflows/security-scan.yml index cc5a90cf..600fd2fa 100644 --- a/.github/workflows/security-scan.yml +++ b/.github/workflows/security-scan.yml @@ -23,7 +23,7 @@ jobs: cache-dependency-path: backend/go.sum - name: Verify Go version run: | - go version | grep -q 'go1.26.1' + go version | grep -q 'go1.26.2' - name: Run govulncheck working-directory: backend run: | diff --git a/README.md b/README.md index 50611a6d..2f73e92a 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,11 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot CTok Thanks to CTok.ai for sponsoring this project! CTok.ai is dedicated to building a one-stop AI programming tool service platform. We offer professional Claude Code packages and technical community services, with support for Google Gemini and OpenAI Codex. Through carefully designed plans and a professional tech community, we provide developers with reliable service guarantees and continuous technical support, making AI-assisted programming a true productivity tool. Click here to register! + + +silkapi +Thanks to SilkAPI for sponsoring this project! SilkAPI is a relay service built on Sub2API, specializing in providing high-speed and stable Codex API relay. + ## Ecosystem diff --git a/README_CN.md b/README_CN.md index 797f106b..a0c3fd4b 100644 --- a/README_CN.md +++ b/README_CN.md @@ -69,6 +69,10 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅的 感谢 CTok.ai 赞助了本项目!CTok.ai 致力于打造一站式 AI 编程工具服务平台。我们提供 Claude Code 专业套餐及技术社群服务,同时支持 Google Gemini 和 OpenAI Codex。通过精心设计的套餐方案和专业的技术社群,为开发者提供稳定的服务保障和持续的技术支持,让 AI 辅助编程真正成为开发者的生产力工具。点击这里注册! + +silkapi +感谢 丝绸API 赞助了本项目! 丝绸API 是基于 Sub2API 搭建的中转服务,专注于提供 Codex 高速稳定API中转。 + ## 生态项目 diff --git a/README_JA.md b/README_JA.md index b7820554..bd69e06b 100644 --- a/README_JA.md +++ b/README_JA.md @@ -68,6 +68,11 @@ Sub2API は、AI 製品のサブスクリプションから API クォータを CTok CTok.ai のご支援に感謝します!CTok.ai はワンストップ AI プログラミングツールサービスプラットフォームの構築に取り組んでいます。Claude Code の専用プランと技術コミュニティサービスを提供し、Google Gemini や OpenAI Codex もサポートしています。丁寧に設計されたプランと専門的な技術コミュニティを通じて、開発者に安定したサービス保証と継続的な技術サポートを提供し、AI アシスト プログラミングを真の生産性向上ツールにします。こちらから登録! + + +silkapi +SilkAPI のご支援に感謝します!SilkAPI は Sub2API をベースに構築された中継サービスで、高速かつ安定した Codex API 中継の提供に特化しています。 + ## エコシステム diff --git a/assets/partners/logos/silkapi.png b/assets/partners/logos/silkapi.png new file mode 100644 index 00000000..97afbda9 Binary files /dev/null and b/assets/partners/logos/silkapi.png differ diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 94e74b8e..b6e5c2ad 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.109 +0.1.110 diff --git a/backend/go.mod b/backend/go.mod index 135cbd3e..c4fc52f1 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -1,12 +1,12 @@ module github.com/Wei-Shaw/sub2api -go 1.26.1 +go 1.26.2 require ( entgo.io/ent v0.14.5 github.com/DATA-DOG/go-sqlmock v1.5.2 - github.com/DouDOU-start/go-sora2api v1.1.0 github.com/alitto/pond/v2 v2.6.2 + github.com/andybalholm/brotli v1.2.0 github.com/aws/aws-sdk-go-v2 v1.41.3 github.com/aws/aws-sdk-go-v2/config v1.32.10 github.com/aws/aws-sdk-go-v2/credentials v1.19.10 @@ -50,7 +50,6 @@ require ( github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/agext/levenshtein v1.2.3 // indirect - github.com/andybalholm/brotli v1.2.0 // indirect github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect @@ -67,14 +66,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect github.com/aws/smithy-go v1.24.2 // indirect - github.com/bdandy/go-errors v1.2.2 // indirect - github.com/bdandy/go-socks4 v1.2.3 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect - github.com/bogdanfinn/fhttp v0.6.8 // indirect - github.com/bogdanfinn/quic-go-utls v1.0.9-utls // indirect - github.com/bogdanfinn/tls-client v1.14.0 // indirect - github.com/bogdanfinn/utls v1.7.7-barnius // indirect - github.com/bogdanfinn/websocket v1.5.5-barnius // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect @@ -151,7 +143,6 @@ require ( github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect - github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 // indirect github.com/testcontainers/testcontainers-go v0.40.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index f5b7968f..996a4b6d 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -10,8 +10,6 @@ github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOEl github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= -github.com/DouDOU-start/go-sora2api v1.1.0 h1:PxWiukK77StiHxEngOFwT1rKUn9oTAJJTl07wQUXwiU= -github.com/DouDOU-start/go-sora2api v1.1.0/go.mod h1:dcwpethoKfAsMWskDD9iGgc/3yox2tkthPLSMVGnhkE= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo= @@ -60,24 +58,10 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb8 github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs= github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= -github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q= -github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM= -github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic= -github.com/bdandy/go-socks4 v1.2.3/go.mod h1:98kiVFgpdogR8aIGLWLvjDVZ8XcKPsSI/ypGrO+bqHI= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0= github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE= -github.com/bogdanfinn/fhttp v0.6.8 h1:LiQyHOY3i0QoxxNB7nq27/nGNNbtPj0fuBPozhR7Ws4= -github.com/bogdanfinn/fhttp v0.6.8/go.mod h1:A+EKDzMx2hb4IUbMx4TlkoHnaJEiLl8r/1Ss1Y+5e5M= -github.com/bogdanfinn/quic-go-utls v1.0.9-utls h1:tV6eDEiRbRCcepALSzxR94JUVD3N3ACIiRLgyc2Ep8s= -github.com/bogdanfinn/quic-go-utls v1.0.9-utls/go.mod h1:aHph9B9H9yPOt5xnhWKSOum27DJAqpiHzwX+gjvaXcg= -github.com/bogdanfinn/tls-client v1.14.0 h1:vyk7Cn4BIvLAGVuMfb0tP22OqogfO1lYamquQNEZU1A= -github.com/bogdanfinn/tls-client v1.14.0/go.mod h1:LsU6mXVn8MOFDwTkyRfI7V1BZM1p0wf2ZfZsICW/1fM= -github.com/bogdanfinn/utls v1.7.7-barnius h1:OuJ497cc7F3yKNVHRsYPQdGggmk5x6+V5ZlrCR7fOLU= -github.com/bogdanfinn/utls v1.7.7-barnius/go.mod h1:aAK1VZQlpKZClF1WEQeq6kyclbkPq4hz6xTbB5xSlmg= -github.com/bogdanfinn/websocket v1.5.5-barnius h1:bY+qnxpai1qe7Jmjx+Sds/cmOSpuuLoR8x61rWltjOI= -github.com/bogdanfinn/websocket v1.5.5-barnius/go.mod h1:gvvEw6pTKHb7yOiFvIfAFTStQWyrm25BMVCTj5wRSsI= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -94,10 +78,6 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= -github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= -github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= -github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U= -github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= @@ -199,8 +179,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= -github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= -github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -236,8 +214,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= -github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -271,8 +247,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= -github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -324,8 +298,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= -github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= @@ -347,8 +319,6 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= -github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5 h1:YqAladjX7xpA6BM04leXMWAEjS0mTZ5kUU9KRBriQJc= -github.com/tam7t/hpkp v0.0.0-20160821193359-2b70b4024ed5/go.mod h1:2JjD2zLQYH5HO74y5+aE3remJQvl6q4Sn6aWA2wD1Ng= github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU= github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY= github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk= @@ -421,15 +391,12 @@ golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2 golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= -golang.org/x/net v0.0.0-20211104170005-ce137452f963/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -439,15 +406,12 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 563a27ce..b503e5c3 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -1,8 +1,6 @@ package admin import ( - "errors" - "fmt" "strconv" "strings" @@ -235,61 +233,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe return result } -// validatePricingBillingMode 校验计费配置 -func validatePricingBillingMode(pricing []service.ChannelModelPricing) error { - for _, p := range pricing { - // 按次/图片模式必须配置默认价格或区间 - if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage { - if p.PerRequestPrice == nil && len(p.Intervals) == 0 { - return errors.New("per-request price or intervals required for per_request/image billing mode") - } - } - // 校验价格不能为负 - if err := validatePriceNotNegative("input_price", p.InputPrice); err != nil { - return err - } - if err := validatePriceNotNegative("output_price", p.OutputPrice); err != nil { - return err - } - if err := validatePriceNotNegative("cache_write_price", p.CacheWritePrice); err != nil { - return err - } - if err := validatePriceNotNegative("cache_read_price", p.CacheReadPrice); err != nil { - return err - } - if err := validatePriceNotNegative("image_output_price", p.ImageOutputPrice); err != nil { - return err - } - if err := validatePriceNotNegative("per_request_price", p.PerRequestPrice); err != nil { - return err - } - // 校验 interval:至少有一个价格字段非空 - for _, iv := range p.Intervals { - if iv.InputPrice == nil && iv.OutputPrice == nil && - iv.CacheWritePrice == nil && iv.CacheReadPrice == nil && - iv.PerRequestPrice == nil { - return fmt.Errorf("interval [%d, %s] has no price fields set for model %v", - iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models) - } - } - } - return nil -} - -func validatePriceNotNegative(field string, val *float64) error { - if val != nil && *val < 0 { - return fmt.Errorf("%s must be >= 0", field) - } - return nil -} - -func formatMaxTokens(max *int) string { - if max == nil { - return "∞" - } - return fmt.Sprintf("%d", *max) -} - // --- Handlers --- // List handles listing channels with pagination @@ -343,10 +286,6 @@ func (h *ChannelHandler) Create(c *gin.Context) { } pricing := pricingRequestToService(req.ModelPricing) - if err := validatePricingBillingMode(pricing); err != nil { - response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error())) - return - } channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ Name: req.Name, @@ -391,10 +330,6 @@ func (h *ChannelHandler) Update(c *gin.Context) { } if req.ModelPricing != nil { pricing := pricingRequestToService(*req.ModelPricing) - if err := validatePricingBillingMode(pricing); err != nil { - response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error())) - return - } input.ModelPricing = &pricing } diff --git a/backend/internal/handler/admin/channel_handler_test.go b/backend/internal/handler/admin/channel_handler_test.go index 6f6ea526..2f4b4440 100644 --- a/backend/internal/handler/admin/channel_handler_test.go +++ b/backend/internal/handler/admin/channel_handler_test.go @@ -400,103 +400,3 @@ func TestPricingRequestToService_NilPriceFields(t *testing.T) { require.Nil(t, r.ImageOutputPrice) require.Nil(t, r.PerRequestPrice) } - -// --------------------------------------------------------------------------- -// 3. validatePricingBillingMode -// --------------------------------------------------------------------------- - -func TestValidatePricingBillingMode(t *testing.T) { - tests := []struct { - name string - pricing []service.ChannelModelPricing - wantErr bool - }{ - { - name: "token mode - valid", - pricing: []service.ChannelModelPricing{ - {BillingMode: service.BillingModeToken}, - }, - wantErr: false, - }, - { - name: "per_request with price - valid", - pricing: []service.ChannelModelPricing{ - { - BillingMode: service.BillingModePerRequest, - PerRequestPrice: float64Ptr(0.5), - }, - }, - wantErr: false, - }, - { - name: "per_request with intervals - valid", - pricing: []service.ChannelModelPricing{ - { - BillingMode: service.BillingModePerRequest, - Intervals: []service.PricingInterval{ - {MinTokens: 0, MaxTokens: intPtr(1000), PerRequestPrice: float64Ptr(0.1)}, - }, - }, - }, - wantErr: false, - }, - { - name: "per_request no price no intervals - invalid", - pricing: []service.ChannelModelPricing{ - {BillingMode: service.BillingModePerRequest}, - }, - wantErr: true, - }, - { - name: "image with price - valid", - pricing: []service.ChannelModelPricing{ - { - BillingMode: service.BillingModeImage, - PerRequestPrice: float64Ptr(0.2), - }, - }, - wantErr: false, - }, - { - name: "image no price no intervals - invalid", - pricing: []service.ChannelModelPricing{ - {BillingMode: service.BillingModeImage}, - }, - wantErr: true, - }, - { - name: "empty list - valid", - pricing: []service.ChannelModelPricing{}, - wantErr: false, - }, - { - name: "mixed modes with invalid image - invalid", - pricing: []service.ChannelModelPricing{ - { - BillingMode: service.BillingModeToken, - InputPrice: float64Ptr(0.01), - }, - { - BillingMode: service.BillingModePerRequest, - PerRequestPrice: float64Ptr(0.5), - }, - { - BillingMode: service.BillingModeImage, - }, - }, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validatePricingBillingMode(tt.pricing) - if tt.wantErr { - require.Error(t, err) - require.Contains(t, err.Error(), "per-request price or intervals required") - } else { - require.NoError(t, err) - } - }) - } -} diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 06916917..4cbe5188 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -128,6 +128,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { BackendModeEnabled: settings.BackendModeEnabled, EnableFingerprintUnification: settings.EnableFingerprintUnification, EnableMetadataPassthrough: settings.EnableMetadataPassthrough, + EnableCCHSigning: settings.EnableCCHSigning, }) } @@ -211,6 +212,7 @@ type UpdateSettingsRequest struct { // Gateway forwarding behavior EnableFingerprintUnification *bool `json:"enable_fingerprint_unification"` EnableMetadataPassthrough *bool `json:"enable_metadata_passthrough"` + EnableCCHSigning *bool `json:"enable_cch_signing"` } // UpdateSettings 更新系统设置 @@ -614,6 +616,12 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } return previousSettings.EnableMetadataPassthrough }(), + EnableCCHSigning: func() bool { + if req.EnableCCHSigning != nil { + return *req.EnableCCHSigning + } + return previousSettings.EnableCCHSigning + }(), } if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { @@ -693,6 +701,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { BackendModeEnabled: updatedSettings.BackendModeEnabled, EnableFingerprintUnification: updatedSettings.EnableFingerprintUnification, EnableMetadataPassthrough: updatedSettings.EnableMetadataPassthrough, + EnableCCHSigning: updatedSettings.EnableCCHSigning, }) } @@ -871,6 +880,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.EnableMetadataPassthrough != after.EnableMetadataPassthrough { changed = append(changed, "enable_metadata_passthrough") } + if before.EnableCCHSigning != after.EnableCCHSigning { + changed = append(changed, "enable_cch_signing") + } return changed } diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index aecbf0c8..73707f79 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -97,6 +97,7 @@ type SystemSettings struct { // Gateway forwarding behavior EnableFingerprintUnification bool `json:"enable_fingerprint_unification"` EnableMetadataPassthrough bool `json:"enable_metadata_passthrough"` + EnableCCHSigning bool `json:"enable_cch_signing"` } type DefaultSubscriptionSetting struct { diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index 98a39c30..51bf43cf 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -758,13 +758,14 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { }) } - if len(funcDecls) == 0 { - if !hasWebSearch { - return nil - } - - // Web Search 工具映射 - return []GeminiToolDeclaration{{ + var declarations []GeminiToolDeclaration + if len(funcDecls) > 0 { + declarations = append(declarations, GeminiToolDeclaration{ + FunctionDeclarations: funcDecls, + }) + } + if hasWebSearch { + declarations = append(declarations, GeminiToolDeclaration{ GoogleSearch: &GeminiGoogleSearch{ EnhancedContent: &GeminiEnhancedContent{ ImageSearch: &GeminiImageSearch{ @@ -772,10 +773,11 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { }, }, }, - }} + }) + } + if len(declarations) == 0 { + return nil } - return []GeminiToolDeclaration{{ - FunctionDeclarations: funcDecls, - }} + return declarations } diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go index aaf8d72a..18d32af7 100644 --- a/backend/internal/pkg/antigravity/request_transformer_test.go +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -300,6 +300,29 @@ func TestBuildTools_CustomTypeTools(t *testing.T) { } } +func TestBuildTools_PreservesWebSearchAlongsideFunctions(t *testing.T) { + tools := []ClaudeTool{ + { + Name: "get_weather", + Description: "Get weather information", + InputSchema: map[string]any{"type": "object"}, + }, + { + Type: "web_search_20250305", + Name: "web_search", + }, + } + + result := buildTools(tools) + require.Len(t, result, 2) + require.Len(t, result[0].FunctionDeclarations, 1) + require.Equal(t, "get_weather", result[0].FunctionDeclarations[0].Name) + require.NotNil(t, result[1].GoogleSearch) + require.NotNil(t, result[1].GoogleSearch.EnhancedContent) + require.NotNil(t, result[1].GoogleSearch.EnhancedContent.ImageSearch) + require.Equal(t, 5, result[1].GoogleSearch.EnhancedContent.ImageSearch.MaxResultCount) +} + func TestBuildGenerationConfig_ThinkingDynamicBudget(t *testing.T) { tests := []struct { name string @@ -437,3 +460,36 @@ func TestTransformClaudeToGeminiWithOptions_PreservesBillingHeaderSystemBlock(t }) } } + +func TestTransformClaudeToGeminiWithOptions_PreservesWebSearchAlongsideFunctions(t *testing.T) { + claudeReq := &ClaudeRequest{ + Model: "claude-3-5-sonnet-latest", + Messages: []ClaudeMessage{ + { + Role: "user", + Content: json.RawMessage(`[{"type":"text","text":"hello"}]`), + }, + }, + Tools: []ClaudeTool{ + { + Name: "get_weather", + Description: "Get weather information", + InputSchema: map[string]any{"type": "object"}, + }, + { + Type: "web_search_20250305", + Name: "web_search", + }, + }, + } + + body, err := TransformClaudeToGeminiWithOptions(claudeReq, "project-1", "gemini-2.5-flash", DefaultTransformOptions()) + require.NoError(t, err) + + var req V1InternalRequest + require.NoError(t, json.Unmarshal(body, &req)) + require.Len(t, req.Request.Tools, 2) + require.Len(t, req.Request.Tools[0].FunctionDeclarations, 1) + require.Equal(t, "get_weather", req.Request.Tools[0].FunctionDeclarations[0].Name) + require.NotNil(t, req.Request.Tools[1].GoogleSearch) +} diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go index 903d5b31..c140449a 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -181,6 +181,50 @@ func TestChatCompletionsToResponses_ImageURL(t *testing.T) { assert.Equal(t, "data:image/png;base64,abc123", parts[1].ImageURL) } +func TestChatCompletionsToResponses_EmptyBase64ImageURLSkipped(t *testing.T) { + content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64,"}}]` + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(content)}, + }, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "input_text", parts[0].Type) + assert.Equal(t, "Describe this", parts[0].Text) +} + +func TestChatCompletionsToResponses_WhitespaceOnlyBase64ImageURLSkipped(t *testing.T) { + content := `[{"type":"text","text":"Describe this"},{"type":"image_url","image_url":{"url":"data:image/png;base64, "}}]` + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(content)}, + }, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[0].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "input_text", parts[0].Type) + assert.Equal(t, "Describe this", parts[0].Text) +} + func TestChatCompletionsToResponses_SystemArrayContent(t *testing.T) { req := &ChatCompletionsRequest{ Model: "gpt-4o", diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go index 6cdd012a..dc157a6d 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go +++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go @@ -339,7 +339,7 @@ func convertChatContentPartsToResponses(parts []ChatContentPart) []ResponsesCont }) } case "image_url": - if p.ImageURL != nil && p.ImageURL.URL != "" { + if p.ImageURL != nil && p.ImageURL.URL != "" && !isEmptyBase64DataURI(p.ImageURL.URL) { responseParts = append(responseParts, ResponsesContentPart{ Type: "input_image", ImageURL: p.ImageURL.URL, @@ -350,6 +350,22 @@ func convertChatContentPartsToResponses(parts []ChatContentPart) []ResponsesCont return responseParts } +func isEmptyBase64DataURI(raw string) bool { + if !strings.HasPrefix(raw, "data:") { + return false + } + rest := strings.TrimPrefix(raw, "data:") + semicolonIdx := strings.Index(rest, ";") + if semicolonIdx < 0 { + return false + } + rest = rest[semicolonIdx+1:] + if !strings.HasPrefix(rest, "base64,") { + return false + } + return strings.TrimSpace(strings.TrimPrefix(rest, "base64,")) == "" +} + func flattenChatContentParts(parts []ChatContentPart) string { var textParts []string for _, p := range parts { diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index d412ea34..24f60f27 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -536,6 +536,7 @@ func TestAPIContracts(t *testing.T) { "max_claude_code_version": "", "allow_ungrouped_key_scheduling": false, "backend_mode_enabled": false, + "enable_cch_signing": false, "enable_fingerprint_unification": true, "enable_metadata_passthrough": false, "custom_menu_items": [], diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index 7b96084d..9667cb98 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -248,40 +248,58 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform } } +// storeErrorCache 存入短 TTL 空缓存,防止 DB 错误后紧密重试。 +// 通过回退 loadedAt 使剩余 TTL = channelErrorTTL。 +func (s *ChannelService) storeErrorCache() { + errorCache := newEmptyChannelCache() + errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) + s.cache.Store(errorCache) +} + // buildCache 从数据库构建渠道缓存。 // 使用独立 context 避免请求取消导致空值被长期缓存。 func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) { - // 断开请求取消链,避免客户端断连导致空值被长期缓存 dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout) defer cancel() - channels, err := s.repo.ListAll(dbCtx) + channels, groupPlatforms, err := s.fetchChannelData(dbCtx) if err != nil { - // error-TTL:失败时存入短 TTL 空缓存,防止紧密重试 - slog.Warn("failed to build channel cache", "error", err) - errorCache := newEmptyChannelCache() - errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL - s.cache.Store(errorCache) - return nil, fmt.Errorf("list all channels: %w", err) + return nil, err + } + + cache := populateChannelCache(channels, groupPlatforms) + s.cache.Store(cache) + return cache, nil +} + +// fetchChannelData 从数据库加载渠道列表和分组平台映射。 +func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[int64]string, error) { + channels, err := s.repo.ListAll(ctx) + if err != nil { + slog.Warn("failed to build channel cache", "error", err) + s.storeErrorCache() + return nil, nil, fmt.Errorf("list all channels: %w", err) } - // 收集所有 groupID,批量查询 platform var allGroupIDs []int64 for i := range channels { allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...) } + groupPlatforms := make(map[int64]string) if len(allGroupIDs) > 0 { - groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs) + groupPlatforms, err = s.repo.GetGroupPlatforms(ctx, allGroupIDs) if err != nil { slog.Warn("failed to load group platforms for channel cache", "error", err) - errorCache := newEmptyChannelCache() - errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) - s.cache.Store(errorCache) - return nil, fmt.Errorf("get group platforms: %w", err) + s.storeErrorCache() + return nil, nil, fmt.Errorf("get group platforms: %w", err) } } + return channels, groupPlatforms, nil +} +// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。 +func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache { cache := newEmptyChannelCache() cache.groupPlatform = groupPlatforms cache.byID = make(map[int64]*Channel, len(channels)) @@ -290,7 +308,6 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) for i := range channels { ch := &channels[i] cache.byID[ch.ID] = ch - for _, gid := range ch.GroupIDs { cache.channelByGroupID[gid] = ch platform := groupPlatforms[gid] @@ -298,11 +315,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) expandMappingToCache(cache, ch, gid, platform) } } - - // 通配符条目保持配置顺序(最先匹配到优先) - - s.cache.Store(cache) - return cache, nil + return cache } // invalidateCache 使缓存失效,让下次读取时自然重建 @@ -466,7 +479,10 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6 // 返回 true 表示模型被限制(不在允许列表中)。 // 如果渠道未启用模型限制或分组无渠道关联,返回 false。 func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { - lk, _ := s.lookupGroupChannel(ctx, groupID) + lk, err := s.lookupGroupChannel(ctx, groupID) + if err != nil { + slog.Warn("failed to load channel cache for model restriction check", "group_id", groupID, "error", err) + } if lk == nil { return false } @@ -537,6 +553,91 @@ func ReplaceModelInBody(body []byte, newModel string) []byte { return newBody } +// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。 +// Create 和 Update 共用此函数,避免重复。 +func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error { + if err := validateNoConflictingModels(pricing); err != nil { + return err + } + if err := validatePricingIntervals(pricing); err != nil { + return err + } + if err := validateNoConflictingMappings(mapping); err != nil { + return err + } + return validatePricingBillingMode(pricing) +} + +// validatePricingBillingMode 校验计费模式配置:按次/图片模式必须配价格或区间,所有价格字段不能为负,区间至少有一个价格字段。 +func validatePricingBillingMode(pricing []ChannelModelPricing) error { + for _, p := range pricing { + if err := checkBillingModeRequirements(p); err != nil { + return err + } + if err := checkPricesNotNegative(p); err != nil { + return err + } + if err := checkIntervalsHavePrices(p); err != nil { + return err + } + } + return nil +} + +func checkBillingModeRequirements(p ChannelModelPricing) error { + if p.BillingMode == BillingModePerRequest || p.BillingMode == BillingModeImage { + if p.PerRequestPrice == nil && len(p.Intervals) == 0 { + return infraerrors.BadRequest( + "BILLING_MODE_MISSING_PRICE", + "per-request price or intervals required for per_request/image billing mode", + ) + } + } + return nil +} + +func checkPricesNotNegative(p ChannelModelPricing) error { + checks := []struct { + field string + val *float64 + }{ + {"input_price", p.InputPrice}, + {"output_price", p.OutputPrice}, + {"cache_write_price", p.CacheWritePrice}, + {"cache_read_price", p.CacheReadPrice}, + {"image_output_price", p.ImageOutputPrice}, + {"per_request_price", p.PerRequestPrice}, + } + for _, c := range checks { + if c.val != nil && *c.val < 0 { + return infraerrors.BadRequest("NEGATIVE_PRICE", fmt.Sprintf("%s must be >= 0", c.field)) + } + } + return nil +} + +func checkIntervalsHavePrices(p ChannelModelPricing) error { + for _, iv := range p.Intervals { + if iv.InputPrice == nil && iv.OutputPrice == nil && + iv.CacheWritePrice == nil && iv.CacheReadPrice == nil && + iv.PerRequestPrice == nil { + return infraerrors.BadRequest( + "INTERVAL_MISSING_PRICE", + fmt.Sprintf("interval [%d, %s] has no price fields set for model %v", + iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models), + ) + } + } + return nil +} + +func formatMaxTokens(max *int) string { + if max == nil { + return "∞" + } + return fmt.Sprintf("%d", *max) +} + // --- CRUD --- // Create 创建渠道 @@ -549,15 +650,8 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) return nil, ErrChannelExists } - // 检查分组冲突 - if len(input.GroupIDs) > 0 { - conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs) - if err != nil { - return nil, fmt.Errorf("check group conflicts: %w", err) - } - if len(conflicting) > 0 { - return nil, ErrGroupAlreadyInChannel - } + if err := s.checkGroupConflicts(ctx, 0, input.GroupIDs); err != nil { + return nil, err } channel := &Channel{ @@ -574,13 +668,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) channel.BillingModelSource = BillingModelSourceChannelMapped } - if err := validateNoConflictingModels(channel.ModelPricing); err != nil { - return nil, err - } - if err := validatePricingIntervals(channel.ModelPricing); err != nil { - return nil, err - } - if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { + if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil { return nil, err } @@ -604,102 +692,112 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan return nil, fmt.Errorf("get channel: %w", err) } - if input.Name != "" && input.Name != channel.Name { - exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id) - if err != nil { - return nil, fmt.Errorf("check channel exists: %w", err) - } - if exists { - return nil, ErrChannelExists - } - channel.Name = input.Name - } - - if input.Description != nil { - channel.Description = *input.Description - } - - if input.Status != "" { - channel.Status = input.Status - } - - if input.RestrictModels != nil { - channel.RestrictModels = *input.RestrictModels - } - - // 检查分组冲突 - if input.GroupIDs != nil { - conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs) - if err != nil { - return nil, fmt.Errorf("check group conflicts: %w", err) - } - if len(conflicting) > 0 { - return nil, ErrGroupAlreadyInChannel - } - channel.GroupIDs = *input.GroupIDs - } - - if input.ModelPricing != nil { - channel.ModelPricing = *input.ModelPricing - } - - if input.ModelMapping != nil { - channel.ModelMapping = input.ModelMapping - } - - if input.BillingModelSource != "" { - channel.BillingModelSource = input.BillingModelSource - } - - if err := validateNoConflictingModels(channel.ModelPricing); err != nil { - return nil, err - } - if err := validatePricingIntervals(channel.ModelPricing); err != nil { - return nil, err - } - if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { + if err := s.applyUpdateInput(ctx, channel, input); err != nil { return nil, err } - // 先获取旧分组,Update 后旧分组关联已删除,无法再查到 - var oldGroupIDs []int64 - if s.authCacheInvalidator != nil { - var err2 error - oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id) - if err2 != nil { - slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2) - } + if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil { + return nil, err } + oldGroupIDs := s.getOldGroupIDs(ctx, id) + if err := s.repo.Update(ctx, channel); err != nil { return nil, fmt.Errorf("update channel: %w", err) } s.invalidateCache() - - // 失效新旧分组的 auth 缓存 - if s.authCacheInvalidator != nil { - seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs)) - for _, gid := range oldGroupIDs { - if _, ok := seen[gid]; !ok { - seen[gid] = struct{}{} - s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) - } - } - for _, gid := range channel.GroupIDs { - if _, ok := seen[gid]; !ok { - seen[gid] = struct{}{} - s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) - } - } - } + s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs) return s.repo.GetByID(ctx, id) } +// applyUpdateInput 将更新请求的字段应用到渠道实体上。 +func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel, input *UpdateChannelInput) error { + if input.Name != "" && input.Name != channel.Name { + exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, channel.ID) + if err != nil { + return fmt.Errorf("check channel exists: %w", err) + } + if exists { + return ErrChannelExists + } + channel.Name = input.Name + } + if input.Description != nil { + channel.Description = *input.Description + } + if input.Status != "" { + channel.Status = input.Status + } + if input.RestrictModels != nil { + channel.RestrictModels = *input.RestrictModels + } + if input.GroupIDs != nil { + if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil { + return err + } + channel.GroupIDs = *input.GroupIDs + } + if input.ModelPricing != nil { + channel.ModelPricing = *input.ModelPricing + } + if input.ModelMapping != nil { + channel.ModelMapping = input.ModelMapping + } + if input.BillingModelSource != "" { + channel.BillingModelSource = input.BillingModelSource + } + return nil +} + +// checkGroupConflicts 检查待关联的分组是否已属于其他渠道。 +// channelID 为当前渠道 ID(Create 时传 0)。 +func (s *ChannelService) checkGroupConflicts(ctx context.Context, channelID int64, groupIDs []int64) error { + if len(groupIDs) == 0 { + return nil + } + conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, channelID, groupIDs) + if err != nil { + return fmt.Errorf("check group conflicts: %w", err) + } + if len(conflicting) > 0 { + return ErrGroupAlreadyInChannel + } + return nil +} + +// getOldGroupIDs 获取渠道更新前的关联分组 ID(用于失效 auth 缓存)。 +func (s *ChannelService) getOldGroupIDs(ctx context.Context, channelID int64) []int64 { + if s.authCacheInvalidator == nil { + return nil + } + oldGroupIDs, err := s.repo.GetGroupIDs(ctx, channelID) + if err != nil { + slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", channelID, "error", err) + } + return oldGroupIDs +} + +// invalidateAuthCacheForGroups 对新旧分组去重后逐个失效 auth 缓存。 +func (s *ChannelService) invalidateAuthCacheForGroups(ctx context.Context, groupIDSets ...[]int64) { + if s.authCacheInvalidator == nil { + return + } + seen := make(map[int64]struct{}) + for _, ids := range groupIDSets { + for _, gid := range ids { + if _, ok := seen[gid]; ok { + continue + } + seen[gid] = struct{}{} + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) + } + } +} + // Delete 删除渠道 func (s *ChannelService) Delete(ctx context.Context, id int64) error { - // 先获取关联分组用于失效缓存 groupIDs, err := s.repo.GetGroupIDs(ctx, id) if err != nil { slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err) @@ -710,12 +808,7 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error { } s.invalidateCache() - - if s.authCacheInvalidator != nil { - for _, gid := range groupIDs { - s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) - } - } + s.invalidateAuthCacheForGroups(ctx, groupIDs) return nil } diff --git a/backend/internal/service/channel_service_test.go b/backend/internal/service/channel_service_test.go index 3a01fd80..e1345618 100644 --- a/backend/internal/service/channel_service_test.go +++ b/backend/internal/service/channel_service_test.go @@ -2199,3 +2199,207 @@ func TestGetChannelModelPricing_NonAntigravityUnaffected(t *testing.T) { require.Equal(t, int64(601), result.ID) require.InDelta(t, 5e-6, *result.InputPrice, 1e-12) } + +// --------------------------------------------------------------------------- +// 10. ToUsageFields +// --------------------------------------------------------------------------- + +func TestToUsageFields_NoMapping(t *testing.T) { + r := ChannelMappingResult{ + MappedModel: "claude-opus-4", + ChannelID: 1, + Mapped: false, + BillingModelSource: BillingModelSourceRequested, + } + fields := r.ToUsageFields("claude-opus-4", "claude-opus-4") + require.Equal(t, int64(1), fields.ChannelID) + require.Equal(t, "claude-opus-4", fields.OriginalModel) + require.Equal(t, "claude-opus-4", fields.ChannelMappedModel) + require.Equal(t, BillingModelSourceRequested, fields.BillingModelSource) + require.Empty(t, fields.ModelMappingChain) +} + +func TestToUsageFields_WithChannelMapping(t *testing.T) { + r := ChannelMappingResult{ + MappedModel: "claude-sonnet-4-20250514", + ChannelID: 2, + Mapped: true, + BillingModelSource: BillingModelSourceChannelMapped, + } + fields := r.ToUsageFields("claude-sonnet-4", "claude-sonnet-4-20250514") + require.Equal(t, int64(2), fields.ChannelID) + require.Equal(t, "claude-sonnet-4", fields.OriginalModel) + require.Equal(t, "claude-sonnet-4-20250514", fields.ChannelMappedModel) + require.Equal(t, "claude-sonnet-4→claude-sonnet-4-20250514", fields.ModelMappingChain) +} + +func TestToUsageFields_WithUpstreamDifference(t *testing.T) { + r := ChannelMappingResult{ + MappedModel: "claude-sonnet-4", + ChannelID: 3, + Mapped: true, + BillingModelSource: BillingModelSourceUpstream, + } + fields := r.ToUsageFields("my-alias", "claude-sonnet-4-20250514") + require.Equal(t, "my-alias", fields.OriginalModel) + require.Equal(t, "claude-sonnet-4", fields.ChannelMappedModel) + require.Equal(t, "my-alias→claude-sonnet-4→claude-sonnet-4-20250514", fields.ModelMappingChain) +} + +// --------------------------------------------------------------------------- +// 11. validatePricingBillingMode (moved from handler tests) +// --------------------------------------------------------------------------- + +func TestValidatePricingBillingMode(t *testing.T) { + tests := []struct { + name string + pricing []ChannelModelPricing + wantErr bool + errMsg string + }{ + { + name: "token mode - valid", + pricing: []ChannelModelPricing{{BillingMode: BillingModeToken}}, + }, + { + name: "per_request with price - valid", + pricing: []ChannelModelPricing{{ + BillingMode: BillingModePerRequest, + PerRequestPrice: testPtrFloat64(0.5), + }}, + }, + { + name: "per_request with intervals - valid", + pricing: []ChannelModelPricing{{ + BillingMode: BillingModePerRequest, + Intervals: []PricingInterval{{MinTokens: 0, MaxTokens: testPtrInt(1000), PerRequestPrice: testPtrFloat64(0.1)}}, + }}, + }, + { + name: "per_request no price no intervals - invalid", + pricing: []ChannelModelPricing{{BillingMode: BillingModePerRequest}}, + wantErr: true, + errMsg: "per-request price or intervals required", + }, + { + name: "image no price no intervals - invalid", + pricing: []ChannelModelPricing{{BillingMode: BillingModeImage}}, + wantErr: true, + errMsg: "per-request price or intervals required", + }, + { + name: "empty list - valid", + pricing: []ChannelModelPricing{}, + }, + { + name: "negative input_price - invalid", + pricing: []ChannelModelPricing{{ + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(-0.01), + }}, + wantErr: true, + errMsg: "input_price must be >= 0", + }, + { + name: "interval with no price fields - invalid", + pricing: []ChannelModelPricing{{ + BillingMode: BillingModePerRequest, + PerRequestPrice: testPtrFloat64(0.5), + Intervals: []PricingInterval{{MinTokens: 0, MaxTokens: testPtrInt(1000)}}, + }}, + wantErr: true, + errMsg: "has no price fields set", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validatePricingBillingMode(tt.pricing) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +// --------------------------------------------------------------------------- +// 12. Antigravity wildcard mapping isolation +// --------------------------------------------------------------------------- + +func TestResolveChannelMapping_AntigravityDoesNotSeeWildcardMappingFromOtherPlatforms(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10, 20}, + ModelMapping: map[string]map[string]string{ + PlatformAnthropic: {"claude-*": "claude-override"}, + PlatformGemini: {"gemini-*": "gemini-override"}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity, 20: PlatformAnthropic}) + svc := newTestChannelService(repo) + + // antigravity 分组不应看到 anthropic/gemini 的通配符映射 + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4") + require.False(t, result.Mapped) + require.Equal(t, "claude-opus-4", result.MappedModel) + + result = svc.ResolveChannelMapping(context.Background(), 10, "gemini-2.5-pro") + require.False(t, result.Mapped) + require.Equal(t, "gemini-2.5-pro", result.MappedModel) + + // anthropic 分组应该能看到 anthropic 的通配符映射 + result = svc.ResolveChannelMapping(context.Background(), 20, "claude-opus-4") + require.True(t, result.Mapped) + require.Equal(t, "claude-override", result.MappedModel) +} + +// --------------------------------------------------------------------------- +// 13. Create/Update with mapping conflict validation +// --------------------------------------------------------------------------- + +func TestCreate_MappingConflict(t *testing.T) { + repo := &mockChannelRepository{} + svc := newTestChannelService(repo) + + _, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "test", + ModelMapping: map[string]map[string]string{ + PlatformAnthropic: { + "claude-*": "target-a", + "claude-opus-*": "target-b", + }, + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "MAPPING_PATTERN_CONFLICT") +} + +func TestUpdate_MappingConflict(t *testing.T) { + existingChannel := &Channel{ + ID: 1, + Name: "existing", + Status: StatusActive, + } + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existingChannel, nil + }, + } + svc := newTestChannelService(repo) + + conflictMapping := map[string]map[string]string{ + PlatformAnthropic: { + "claude-*": "target-a", + "claude-opus-*": "target-b", + }, + } + _, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + ModelMapping: conflictMapping, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "MAPPING_PATTERN_CONFLICT") +} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 52df52d6..92be3e06 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -218,6 +218,8 @@ const ( SettingKeyEnableFingerprintUnification = "enable_fingerprint_unification" // SettingKeyEnableMetadataPassthrough 是否透传客户端原始 metadata.user_id(默认 false) SettingKeyEnableMetadataPassthrough = "enable_metadata_passthrough" + // SettingKeyEnableCCHSigning 是否对 billing header 中的 cch 进行 xxHash64 签名(默认 false) + SettingKeyEnableCCHSigning = "enable_cch_signing" ) // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index e7661aad..5be1f733 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -761,7 +761,9 @@ func TestGatewayService_AnthropicOAuth_ForwardPreservesBillingHeaderSystemBlock( system := gjson.GetBytes(upstream.lastBody, "system") require.True(t, system.Exists()) - require.Equal(t, claudeCodeSystemPrompt, system.String()) + require.True(t, system.IsArray(), "system should be an array") + require.Equal(t, claudeCodeSystemPrompt, system.Array()[0].Get("text").String()) + require.Equal(t, "ephemeral", system.Array()[0].Get("cache_control.type").String()) // 原始 system prompt 应迁移至 messages 中 messages := gjson.GetBytes(upstream.lastBody, "messages") diff --git a/backend/internal/service/gateway_billing_header.go b/backend/internal/service/gateway_billing_header.go new file mode 100644 index 00000000..91fbfd8f --- /dev/null +++ b/backend/internal/service/gateway_billing_header.go @@ -0,0 +1,73 @@ +package service + +import ( + "fmt" + "regexp" + "strings" + + "github.com/cespare/xxhash/v2" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// ccVersionInBillingRe matches the semver part of cc_version (X.Y.Z), preserving +// the trailing message-derived suffix (e.g. ".c02") if present. +var ccVersionInBillingRe = regexp.MustCompile(`cc_version=\d+\.\d+\.\d+`) + +// cchPlaceholderRe matches the cch=00000 placeholder in billing header text, +// scoped to x-anthropic-billing-header to avoid touching user content. +var cchPlaceholderRe = regexp.MustCompile(`(x-anthropic-billing-header:[^"]*?\bcch=)(00000)(;)`) + +const cchSeed uint64 = 0x6E52736AC806831E + +// syncBillingHeaderVersion rewrites cc_version in x-anthropic-billing-header +// system text blocks to match the version extracted from userAgent. +// Only touches system array blocks whose text starts with "x-anthropic-billing-header". +func syncBillingHeaderVersion(body []byte, userAgent string) []byte { + version := ExtractCLIVersion(userAgent) + if version == "" { + return body + } + + systemResult := gjson.GetBytes(body, "system") + if !systemResult.Exists() || !systemResult.IsArray() { + return body + } + + replacement := "cc_version=" + version + idx := 0 + systemResult.ForEach(func(_, item gjson.Result) bool { + text := item.Get("text") + if text.Exists() && text.Type == gjson.String && + strings.HasPrefix(text.String(), "x-anthropic-billing-header") { + newText := ccVersionInBillingRe.ReplaceAllString(text.String(), replacement) + if newText != text.String() { + if updated, err := sjson.SetBytes(body, fmt.Sprintf("system.%d.text", idx), newText); err == nil { + body = updated + } + } + } + idx++ + return true + }) + + return body +} + +// signBillingHeaderCCH computes the xxHash64-based CCH signature for the request +// body and replaces the cch=00000 placeholder with the computed 5-hex-char hash. +// The body must contain the placeholder when this function is called. +func signBillingHeaderCCH(body []byte) []byte { + if !cchPlaceholderRe.Match(body) { + return body + } + cch := fmt.Sprintf("%05x", xxHash64Seeded(body, cchSeed)&0xFFFFF) + return cchPlaceholderRe.ReplaceAll(body, []byte("${1}"+cch+"${3}")) +} + +// xxHash64Seeded computes xxHash64 of data with a custom seed. +func xxHash64Seeded(data []byte, seed uint64) uint64 { + d := xxhash.NewWithSeed(seed) + _, _ = d.Write(data) + return d.Sum64() +} diff --git a/backend/internal/service/gateway_billing_header_test.go b/backend/internal/service/gateway_billing_header_test.go new file mode 100644 index 00000000..ffc4091c --- /dev/null +++ b/backend/internal/service/gateway_billing_header_test.go @@ -0,0 +1,165 @@ +package service + +import ( + "fmt" + "testing" + + "github.com/cespare/xxhash/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestSyncBillingHeaderVersion(t *testing.T) { + tests := []struct { + name string + body string + userAgent string + wantSub string // substring expected in result + unchanged bool // expect body to remain the same + }{ + { + name: "replaces cc_version preserving message-derived suffix", + body: `{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.81.df2; cc_entrypoint=cli; cch=00000;"},{"type":"text","text":"You are Claude Code.","cache_control":{"type":"ephemeral"}}],"messages":[]}`, + userAgent: "claude-cli/2.1.22 (external, cli)", + wantSub: "cc_version=2.1.22.df2", + }, + { + name: "no billing header in system", + body: `{"system":[{"type":"text","text":"You are Claude Code."}],"messages":[]}`, + userAgent: "claude-cli/2.1.22", + unchanged: true, + }, + { + name: "no system field", + body: `{"messages":[]}`, + userAgent: "claude-cli/2.1.22", + unchanged: true, + }, + { + name: "user-agent without version", + body: `{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.81; cc_entrypoint=cli; cch=00000;"}],"messages":[]}`, + userAgent: "Mozilla/5.0", + unchanged: true, + }, + { + name: "empty user-agent", + body: `{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.81; cc_entrypoint=cli; cch=00000;"}],"messages":[]}`, + userAgent: "", + unchanged: true, + }, + { + name: "version already matches", + body: `{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.22; cc_entrypoint=cli; cch=00000;"}],"messages":[]}`, + userAgent: "claude-cli/2.1.22", + unchanged: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := syncBillingHeaderVersion([]byte(tt.body), tt.userAgent) + if tt.unchanged { + assert.Equal(t, tt.body, string(result), "body should remain unchanged") + } else { + assert.Contains(t, string(result), tt.wantSub) + // Ensure old semver is gone + assert.NotContains(t, string(result), "cc_version=2.1.81") + } + }) + } +} + +func TestSignBillingHeaderCCH(t *testing.T) { + t.Run("replaces placeholder with hash", func(t *testing.T) { + body := []byte(`{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.63.a43; cc_entrypoint=cli; cch=00000;"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + result := signBillingHeaderCCH(body) + + // Should not have the placeholder anymore + assert.NotContains(t, string(result), "cch=00000") + + // Should have a 5 hex-char cch value + billingText := gjson.GetBytes(result, "system.0.text").String() + require.Contains(t, billingText, "cch=") + assert.Regexp(t, `cch=[0-9a-f]{5};`, billingText) + }) + + t.Run("no placeholder - body unchanged", func(t *testing.T) { + body := []byte(`{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.63; cc_entrypoint=cli; cch=abcde;"}],"messages":[]}`) + result := signBillingHeaderCCH(body) + assert.Equal(t, string(body), string(result)) + }) + + t.Run("no billing header - body unchanged", func(t *testing.T) { + body := []byte(`{"system":[{"type":"text","text":"You are Claude Code."}],"messages":[]}`) + result := signBillingHeaderCCH(body) + assert.Equal(t, string(body), string(result)) + }) + + t.Run("cch=00000 in user content is not touched", func(t *testing.T) { + body := []byte(`{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.63; cc_entrypoint=cli; cch=00000;"}],"messages":[{"role":"user","content":[{"type":"text","text":"keep literal cch=00000 in this message"}]}]}`) + result := signBillingHeaderCCH(body) + + // Billing header should be signed + billingText := gjson.GetBytes(result, "system.0.text").String() + assert.NotContains(t, billingText, "cch=00000") + + // User message should keep its literal cch=00000 + userText := gjson.GetBytes(result, "messages.0.content.0.text").String() + assert.Contains(t, userText, "cch=00000") + }) + + t.Run("signing is deterministic", func(t *testing.T) { + body := []byte(`{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.63; cc_entrypoint=cli; cch=00000;"}],"messages":[{"role":"user","content":"hi"}]}`) + r1 := signBillingHeaderCCH(body) + body2 := []byte(`{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.63; cc_entrypoint=cli; cch=00000;"}],"messages":[{"role":"user","content":"hi"}]}`) + r2 := signBillingHeaderCCH(body2) + assert.Equal(t, string(r1), string(r2)) + }) + + t.Run("matches reference algorithm", func(t *testing.T) { + // Verify: signBillingHeaderCCH(body) produces cch = xxHash64(body_with_placeholder, seed) & 0xFFFFF + body := []byte(`{"system":[{"type":"text","text":"x-anthropic-billing-header: cc_version=2.1.63.a43; cc_entrypoint=cli; cch=00000;"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + expectedCCH := fmt.Sprintf("%05x", xxHash64Seeded(body, cchSeed)&0xFFFFF) + + result := signBillingHeaderCCH(body) + billingText := gjson.GetBytes(result, "system.0.text").String() + assert.Contains(t, billingText, "cch="+expectedCCH+";") + }) +} + +func TestXXHash64Seeded(t *testing.T) { + t.Run("matches cespare/xxhash for seed 0", func(t *testing.T) { + inputs := []string{"", "a", "hello world", "The quick brown fox jumps over the lazy dog"} + for _, s := range inputs { + data := []byte(s) + expected := xxhash.Sum64(data) + got := xxHash64Seeded(data, 0) + assert.Equal(t, expected, got, "mismatch for input %q", s) + } + }) + + t.Run("large input matches cespare", func(t *testing.T) { + data := make([]byte, 256) + for i := range data { + data[i] = byte(i) + } + expected := xxhash.Sum64(data) + got := xxHash64Seeded(data, 0) + assert.Equal(t, expected, got) + }) + + t.Run("deterministic with custom seed", func(t *testing.T) { + data := []byte("hello world") + h1 := xxHash64Seeded(data, cchSeed) + h2 := xxHash64Seeded(data, cchSeed) + assert.Equal(t, h1, h2) + }) + + t.Run("different seeds produce different results", func(t *testing.T) { + data := []byte("test data for hashing") + h1 := xxHash64Seeded(data, 0) + h2 := xxHash64Seeded(data, cchSeed) + assert.NotEqual(t, h1, h2) + }) +} diff --git a/backend/internal/service/gateway_prompt_test.go b/backend/internal/service/gateway_prompt_test.go index d0f5a8c0..e27e18aa 100644 --- a/backend/internal/service/gateway_prompt_test.go +++ b/backend/internal/service/gateway_prompt_test.go @@ -284,7 +284,7 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) { name string body string system any - wantSystemStr string // system 应为纯字符串 + wantSystemText string // system array 第一个 block 的 text wantMessagesLen int // messages 数组长度 wantFirstMsgRole string // 第一条消息的 role wantFirstMsgText string // 第一条消息的 content[0].text @@ -294,21 +294,21 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) { name: "nil system - no messages injected", body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`, system: nil, - wantSystemStr: claudeCodeSystemPrompt, + wantSystemText: claudeCodeSystemPrompt, wantMessagesLen: 1, // 原始 1 条消息,不注入 }, { name: "empty string system - no messages injected", body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`, system: "", - wantSystemStr: claudeCodeSystemPrompt, + wantSystemText: claudeCodeSystemPrompt, wantMessagesLen: 1, }, { name: "custom string system - migrated to messages", body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`, system: "You are a personal assistant running inside OpenClaw.", - wantSystemStr: claudeCodeSystemPrompt, + wantSystemText: claudeCodeSystemPrompt, wantMessagesLen: 3, // instruction + ack + original wantFirstMsgRole: "user", wantFirstMsgText: "[System Instructions]\nYou are a personal assistant running inside OpenClaw.", @@ -318,7 +318,7 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) { name: "system equals Claude Code prompt - no messages injected", body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`, system: claudeCodeSystemPrompt, - wantSystemStr: claudeCodeSystemPrompt, + wantSystemText: claudeCodeSystemPrompt, wantMessagesLen: 1, }, { @@ -328,7 +328,7 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) { map[string]any{"type": "text", "text": "First instruction"}, map[string]any{"type": "text", "text": "Second instruction"}, }, - wantSystemStr: claudeCodeSystemPrompt, + wantSystemText: claudeCodeSystemPrompt, wantMessagesLen: 3, wantFirstMsgRole: "user", wantFirstMsgText: "[System Instructions]\nFirst instruction\n\nSecond instruction", @@ -338,14 +338,14 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) { name: "empty array system - no messages injected", body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`, system: []any{}, - wantSystemStr: claudeCodeSystemPrompt, + wantSystemText: claudeCodeSystemPrompt, wantMessagesLen: 1, }, { name: "json.RawMessage string system", body: `{"model":"claude-3","system":"Custom prompt","messages":[{"role":"user","content":"hello"}]}`, system: json.RawMessage(`"Custom prompt"`), - wantSystemStr: claudeCodeSystemPrompt, + wantSystemText: claudeCodeSystemPrompt, wantMessagesLen: 3, wantFirstMsgRole: "user", wantFirstMsgText: "[System Instructions]\nCustom prompt", @@ -355,14 +355,14 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) { name: "json.RawMessage nil system", body: `{"model":"claude-3","messages":[{"role":"user","content":"hello"}]}`, system: json.RawMessage(nil), - wantSystemStr: claudeCodeSystemPrompt, + wantSystemText: claudeCodeSystemPrompt, wantMessagesLen: 1, }, { name: "multiple original messages preserved", body: `{"model":"claude-3","messages":[{"role":"user","content":"msg1"},{"role":"assistant","content":"resp1"},{"role":"user","content":"msg2"}]}`, system: "Be helpful", - wantSystemStr: claudeCodeSystemPrompt, + wantSystemText: claudeCodeSystemPrompt, wantMessagesLen: 5, // 2 injected + 3 original wantFirstMsgRole: "user", wantFirstMsgText: "[System Instructions]\nBe helpful", @@ -378,10 +378,17 @@ func TestRewriteSystemForNonClaudeCode(t *testing.T) { err := json.Unmarshal(result, &parsed) require.NoError(t, err) - // system 应为纯字符串 - systemVal, ok := parsed["system"].(string) - require.True(t, ok, "system should be a string, got %T", parsed["system"]) - require.Equal(t, tt.wantSystemStr, systemVal) + // system 应为 array 格式: [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] + systemArr, ok := parsed["system"].([]any) + require.True(t, ok, "system should be an array, got %T", parsed["system"]) + require.Len(t, systemArr, 1, "system array should have exactly 1 block") + systemBlock, ok := systemArr[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "text", systemBlock["type"]) + require.Equal(t, tt.wantSystemText, systemBlock["text"]) + cc, ok := systemBlock["cache_control"].(map[string]any) + require.True(t, ok, "system block should have cache_control") + require.Equal(t, "ephemeral", cc["type"]) // 检查 messages messages, ok := parsed["messages"].([]any) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index ed2be3dc..6d943156 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -3736,8 +3736,17 @@ func rewriteSystemForNonClaudeCode(body []byte, system any) []byte { originalSystemText = strings.Join(parts, "\n\n") } - // 2. 将 system 替换为 Claude Code 标准提示词(纯字符串,通过 Anthropic 检测) - out, ok := setJSONValueBytes(body, "system", claudeCodeSystemPrompt) + // 2. 将 system 替换为 Claude Code 标准提示词(array 格式,与真实 Claude Code 一致) + // 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。 + // 使用 string 格式会被 Anthropic 检测为第三方应用。 + claudeCodeSystemBlock := []map[string]any{ + { + "type": "text", + "text": claudeCodeSystemPrompt, + "cache_control": map[string]string{"type": "ephemeral"}, + }, + } + out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock) if !ok { logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt") return body @@ -3975,17 +3984,22 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if shouldMimicClaudeCode { // 非 Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 + systemRewritten := false if !strings.Contains(strings.ToLower(reqModel), "haiku") && !systemIncludesClaudeCodePrompt(parsed.System) { body = rewriteSystemForNonClaudeCode(body, parsed.System) + systemRewritten = true } - normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} + // system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为); + // 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。 + // 两种情况下 enforceCacheControlLimit 都会兜底处理上限。 + normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten} if s.identityService != nil { fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) if err == nil && fp != nil { // metadata 透传开启时跳过 metadata 注入 - _, mimicMPT := s.settingService.GetGatewayForwardingSettings(ctx) + _, mimicMPT, _ := s.settingService.GetGatewayForwardingSettings(ctx) if !mimicMPT { if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" { normalizeOpts.injectMetadata = true @@ -5571,9 +5585,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // OAuth账号:应用统一指纹和metadata重写(受设置开关控制) var fingerprint *Fingerprint - enableFP, enableMPT := true, false + enableFP, enableMPT, enableCCH := true, false, false if s.settingService != nil { - enableFP, enableMPT = s.settingService.GetGatewayForwardingSettings(ctx) + enableFP, enableMPT, enableCCH = s.settingService.GetGatewayForwardingSettings(ctx) } if account.IsOAuth() && s.identityService != nil { // 1. 获取或创建指纹(包含随机生成的ClientID) @@ -5600,6 +5614,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } + // 同步 billing header cc_version 与实际发送的 User-Agent 版本 + if fingerprint != nil { + body = syncBillingHeaderVersion(body, fingerprint.UserAgent) + } + // CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后) + if enableCCH { + body = signBillingHeaderCCH(body) + } + req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err @@ -5642,7 +5665,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // Build effective drop set: merge static defaults with dynamic beta policy filter rules policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID) effectiveDropSet := mergeDropSets(policyFilterSet) - effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode) // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta) if tokenType == "oauth" { @@ -5653,11 +5675,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex applyClaudeCodeMimicHeaders(req, reqStream) incomingBeta := getHeaderRaw(req.Header, "anthropic-beta") - // Match real Claude CLI traffic (per mitmproxy reports): - // messages requests typically use only oauth + interleaved-thinking. - // Also drop claude-code beta if a downstream client added it. + // Claude Code OAuth credentials are scoped to Claude Code. + // Non-haiku models MUST include claude-code beta for Anthropic to recognize + // this as a legitimate Claude Code request; without it, the request is + // rejected as third-party ("out of extra usage"). + // Haiku models are exempt from third-party detection and don't need it. requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} - setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet)) + if !strings.Contains(strings.ToLower(modelID), "haiku") { + requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking} + } + setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet)) } else { // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta") @@ -8501,9 +8528,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con // OAuth 账号:应用统一指纹和重写 userID(受设置开关控制) // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 - ctEnableFP, ctEnableMPT := true, false + ctEnableFP, ctEnableMPT, ctEnableCCH := true, false, false if s.settingService != nil { - ctEnableFP, ctEnableMPT = s.settingService.GetGatewayForwardingSettings(ctx) + ctEnableFP, ctEnableMPT, ctEnableCCH = s.settingService.GetGatewayForwardingSettings(ctx) } var ctFingerprint *Fingerprint if account.IsOAuth() && s.identityService != nil { @@ -8521,6 +8548,14 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } + // 同步 billing header cc_version 与实际发送的 User-Agent 版本 + if ctFingerprint != nil && ctEnableFP { + body = syncBillingHeaderVersion(body, ctFingerprint.UserAgent) + } + if ctEnableCCH { + body = signBillingHeaderCCH(body) + } + req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index b35ebce5..32bf21c0 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -612,7 +612,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex fullURL += "?alt=sse" } - upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq)) + restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq) + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq)) if err != nil { return nil, "", err } @@ -685,7 +686,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex fullURL += "?alt=sse" } - upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(geminiReq)) + restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq) + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq)) if err != nil { return nil, "", err } @@ -3184,12 +3186,17 @@ func convertClaudeToolsToGeminiTools(tools any) []any { return nil } + hasWebSearch := false funcDecls := make([]any, 0, len(arr)) for _, t := range arr { tm, ok := t.(map[string]any) if !ok { continue } + if isClaudeWebSearchToolMap(tm) { + hasWebSearch = true + continue + } var name, desc string var params any @@ -3233,13 +3240,75 @@ func convertClaudeToolsToGeminiTools(tools any) []any { }) } - if len(funcDecls) == 0 { + out := make([]any, 0, 2) + if len(funcDecls) > 0 { + out = append(out, map[string]any{ + "functionDeclarations": funcDecls, + }) + } + if hasWebSearch { + out = append(out, map[string]any{ + "googleSearch": map[string]any{}, + }) + } + if len(out) == 0 { return nil } - return []any{ - map[string]any{ - "functionDeclarations": funcDecls, - }, + return out +} + +func normalizeGeminiRequestForAIStudio(body []byte) []byte { + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + return body + } + + tools, ok := payload["tools"].([]any) + if !ok || len(tools) == 0 { + return body + } + + modified := false + for _, rawTool := range tools { + tool, ok := rawTool.(map[string]any) + if !ok { + continue + } + googleSearch, ok := tool["googleSearch"] + if !ok { + continue + } + if _, exists := tool["google_search"]; exists { + continue + } + tool["google_search"] = googleSearch + delete(tool, "googleSearch") + modified = true + } + + if !modified { + return body + } + + normalized, err := json.Marshal(payload) + if err != nil { + return body + } + return normalized +} + +func isClaudeWebSearchToolMap(tool map[string]any) bool { + toolType, _ := tool["type"].(string) + if strings.HasPrefix(toolType, "web_search") || toolType == "google_search" { + return true + } + + name, _ := tool["name"].(string) + switch strings.TrimSpace(name) { + case "web_search", "google_search", "web_search_20250305": + return true + default: + return false } } diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go index f659f0e6..c2adf45d 100644 --- a/backend/internal/service/gemini_messages_compat_service_test.go +++ b/backend/internal/service/gemini_messages_compat_service_test.go @@ -164,6 +164,35 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { } } +func TestConvertClaudeToolsToGeminiTools_PreservesWebSearchAlongsideFunctions(t *testing.T) { + tools := []any{ + map[string]any{ + "name": "get_weather", + "description": "Get weather info", + "input_schema": map[string]any{"type": "object"}, + }, + map[string]any{ + "type": "web_search_20250305", + "name": "web_search", + }, + } + + result := convertClaudeToolsToGeminiTools(tools) + require.Len(t, result, 2) + + functionDecl, ok := result[0].(map[string]any) + require.True(t, ok) + funcDecls, ok := functionDecl["functionDeclarations"].([]any) + require.True(t, ok) + require.Len(t, funcDecls, 1) + + searchDecl, ok := result[1].(map[string]any) + require.True(t, ok) + googleSearch, ok := searchDecl["googleSearch"].(map[string]any) + require.True(t, ok) + require.Empty(t, googleSearch) +} + func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs(t *testing.T) { gin.SetMode(gin.TestMode) logSink, restore := captureStructuredLog(t) @@ -232,6 +261,53 @@ func TestGeminiMessagesCompatServiceForward_PreservesRequestedModelAndMappedUpst require.Contains(t, httpStub.lastReq.URL.String(), "/models/claude-sonnet-4-20250514:") } +func TestGeminiMessagesCompatServiceForward_NormalizesWebSearchToolForAIStudio(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + httpStub := &geminiCompatHTTPUpstreamStub{ + response: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"x-request-id": []string{"gemini-req-2"}}, + Body: io.NopCloser(strings.NewReader(`{"candidates":[{"content":{"parts":[{"text":"hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}`)), + }, + } + svc := &GeminiMessagesCompatService{httpUpstream: httpStub, cfg: &config.Config{}} + account := &Account{ + ID: 1, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "test-key", + }, + } + body := []byte(`{"model":"claude-sonnet-4","max_tokens":16,"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"get_weather","description":"Get weather info","input_schema":{"type":"object"}},{"type":"web_search_20250305","name":"web_search"}]}`) + + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, httpStub.lastReq) + + postedBody, err := io.ReadAll(httpStub.lastReq.Body) + require.NoError(t, err) + + var posted map[string]any + require.NoError(t, json.Unmarshal(postedBody, &posted)) + tools, ok := posted["tools"].([]any) + require.True(t, ok) + require.Len(t, tools, 2) + + searchTool, ok := tools[1].(map[string]any) + require.True(t, ok) + _, hasSnake := searchTool["google_search"] + _, hasCamel := searchTool["googleSearch"] + require.True(t, hasSnake) + require.False(t, hasCamel) + _, hasFuncDecl := searchTool["functionDeclarations"] + require.False(t, hasFuncDecl) +} + func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) { claudeReq := map[string]any{ "model": "claude-haiku-4-5-20251001", diff --git a/backend/internal/service/openai_content_session_seed.go b/backend/internal/service/openai_content_session_seed.go new file mode 100644 index 00000000..7c2ba251 --- /dev/null +++ b/backend/internal/service/openai_content_session_seed.go @@ -0,0 +1,107 @@ +package service + +import ( + "encoding/json" + "strings" + + "github.com/tidwall/gjson" +) + +// contentSessionSeedPrefix prevents collisions between content-derived seeds +// and explicit session IDs (e.g. "sess-xxx" or "compat_cc_xxx"). +const contentSessionSeedPrefix = "compat_cs_" + +// deriveOpenAIContentSessionSeed builds a stable session seed from an +// OpenAI-format request body. Only fields constant across conversation turns +// are included: model, tools/functions definitions, system/developer prompts, +// instructions (Responses API), and the first user message. +// Supports both Chat Completions (messages) and Responses API (input). +func deriveOpenAIContentSessionSeed(body []byte) string { + if len(body) == 0 { + return "" + } + + var b strings.Builder + + if model := gjson.GetBytes(body, "model").String(); model != "" { + _, _ = b.WriteString("model=") + _, _ = b.WriteString(model) + } + + if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() && tools.Raw != "[]" { + _, _ = b.WriteString("|tools=") + _, _ = b.WriteString(normalizeCompatSeedJSON(json.RawMessage(tools.Raw))) + } + + if funcs := gjson.GetBytes(body, "functions"); funcs.Exists() && funcs.IsArray() && funcs.Raw != "[]" { + _, _ = b.WriteString("|functions=") + _, _ = b.WriteString(normalizeCompatSeedJSON(json.RawMessage(funcs.Raw))) + } + + if instr := gjson.GetBytes(body, "instructions").String(); instr != "" { + _, _ = b.WriteString("|instructions=") + _, _ = b.WriteString(instr) + } + + firstUserCaptured := false + + msgs := gjson.GetBytes(body, "messages") + if msgs.Exists() && msgs.IsArray() { + msgs.ForEach(func(_, msg gjson.Result) bool { + role := msg.Get("role").String() + switch role { + case "system", "developer": + _, _ = b.WriteString("|system=") + if c := msg.Get("content"); c.Exists() { + _, _ = b.WriteString(normalizeCompatSeedJSON(json.RawMessage(c.Raw))) + } + case "user": + if !firstUserCaptured { + _, _ = b.WriteString("|first_user=") + if c := msg.Get("content"); c.Exists() { + _, _ = b.WriteString(normalizeCompatSeedJSON(json.RawMessage(c.Raw))) + } + firstUserCaptured = true + } + } + return true + }) + } else if inp := gjson.GetBytes(body, "input"); inp.Exists() { + if inp.Type == gjson.String { + _, _ = b.WriteString("|input=") + _, _ = b.WriteString(inp.String()) + } else if inp.IsArray() { + inp.ForEach(func(_, item gjson.Result) bool { + role := item.Get("role").String() + switch role { + case "system", "developer": + _, _ = b.WriteString("|system=") + if c := item.Get("content"); c.Exists() { + _, _ = b.WriteString(normalizeCompatSeedJSON(json.RawMessage(c.Raw))) + } + case "user": + if !firstUserCaptured { + _, _ = b.WriteString("|first_user=") + if c := item.Get("content"); c.Exists() { + _, _ = b.WriteString(normalizeCompatSeedJSON(json.RawMessage(c.Raw))) + } + firstUserCaptured = true + } + } + if !firstUserCaptured && item.Get("type").String() == "input_text" { + _, _ = b.WriteString("|first_user=") + if text := item.Get("text").String(); text != "" { + _, _ = b.WriteString(text) + } + firstUserCaptured = true + } + return true + }) + } + } + + if b.Len() == 0 { + return "" + } + return contentSessionSeedPrefix + b.String() +} diff --git a/backend/internal/service/openai_content_session_seed_test.go b/backend/internal/service/openai_content_session_seed_test.go new file mode 100644 index 00000000..65a0bf18 --- /dev/null +++ b/backend/internal/service/openai_content_session_seed_test.go @@ -0,0 +1,218 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDeriveOpenAIContentSessionSeed_EmptyInputs(t *testing.T) { + require.Empty(t, deriveOpenAIContentSessionSeed(nil)) + require.Empty(t, deriveOpenAIContentSessionSeed([]byte{})) + require.Empty(t, deriveOpenAIContentSessionSeed([]byte(`{}`))) +} + +func TestDeriveOpenAIContentSessionSeed_ModelOnly(t *testing.T) { + seed := deriveOpenAIContentSessionSeed([]byte(`{"model":"gpt-5.4"}`)) + require.Contains(t, seed, contentSessionSeedPrefix) + require.Contains(t, seed, "model=gpt-5.4") +} + +func TestDeriveOpenAIContentSessionSeed_ChatCompletions_StableAcrossTurns(t *testing.T) { + turn1 := []byte(`{ + "model": "gpt-5.4", + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"} + ] + }`) + turn2 := []byte(`{ + "model": "gpt-5.4", + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"} + ] + }`) + s1 := deriveOpenAIContentSessionSeed(turn1) + s2 := deriveOpenAIContentSessionSeed(turn2) + require.Equal(t, s1, s2, "seed should be stable across later turns") + require.NotEmpty(t, s1) +} + +func TestDeriveOpenAIContentSessionSeed_ChatCompletions_DifferentFirstUserDiffers(t *testing.T) { + req1 := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"Question A"}]}`) + req2 := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"Question B"}]}`) + s1 := deriveOpenAIContentSessionSeed(req1) + s2 := deriveOpenAIContentSessionSeed(req2) + require.NotEqual(t, s1, s2) +} + +func TestDeriveOpenAIContentSessionSeed_ChatCompletions_DifferentSystemDiffers(t *testing.T) { + req1 := []byte(`{"model":"gpt-5.4","messages":[{"role":"system","content":"A"},{"role":"user","content":"Hi"}]}`) + req2 := []byte(`{"model":"gpt-5.4","messages":[{"role":"system","content":"B"},{"role":"user","content":"Hi"}]}`) + s1 := deriveOpenAIContentSessionSeed(req1) + s2 := deriveOpenAIContentSessionSeed(req2) + require.NotEqual(t, s1, s2) +} + +func TestDeriveOpenAIContentSessionSeed_ChatCompletions_DifferentModelDiffers(t *testing.T) { + req1 := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"Hi"}]}`) + req2 := []byte(`{"model":"gpt-4o","messages":[{"role":"user","content":"Hi"}]}`) + s1 := deriveOpenAIContentSessionSeed(req1) + s2 := deriveOpenAIContentSessionSeed(req2) + require.NotEqual(t, s1, s2) +} + +func TestDeriveOpenAIContentSessionSeed_ChatCompletions_WithTools(t *testing.T) { + withTools := []byte(`{ + "model": "gpt-5.4", + "tools": [{"type":"function","function":{"name":"get_weather"}}], + "messages": [{"role": "user", "content": "Hello"}] + }`) + withoutTools := []byte(`{ + "model": "gpt-5.4", + "messages": [{"role": "user", "content": "Hello"}] + }`) + s1 := deriveOpenAIContentSessionSeed(withTools) + s2 := deriveOpenAIContentSessionSeed(withoutTools) + require.NotEqual(t, s1, s2, "tools should affect the seed") + require.Contains(t, s1, "|tools=") +} + +func TestDeriveOpenAIContentSessionSeed_ChatCompletions_WithFunctions(t *testing.T) { + body := []byte(`{ + "model": "gpt-5.4", + "functions": [{"name":"get_weather","parameters":{}}], + "messages": [{"role": "user", "content": "Hello"}] + }`) + seed := deriveOpenAIContentSessionSeed(body) + require.Contains(t, seed, "|functions=") +} + +func TestDeriveOpenAIContentSessionSeed_ChatCompletions_DeveloperRole(t *testing.T) { + body := []byte(`{ + "model": "gpt-5.4", + "messages": [ + {"role": "developer", "content": "You are helpful."}, + {"role": "user", "content": "Hello"} + ] + }`) + seed := deriveOpenAIContentSessionSeed(body) + require.Contains(t, seed, "|system=") + require.Contains(t, seed, "|first_user=") +} + +func TestDeriveOpenAIContentSessionSeed_ChatCompletions_StructuredContent(t *testing.T) { + body := []byte(`{ + "model": "gpt-5.4", + "messages": [ + {"role": "user", "content": [{"type":"text","text":"Hello"}]} + ] + }`) + seed := deriveOpenAIContentSessionSeed(body) + require.NotEmpty(t, seed) + require.Contains(t, seed, "|first_user=") +} + +func TestDeriveOpenAIContentSessionSeed_ResponsesAPI_InputString(t *testing.T) { + body := []byte(`{"model":"gpt-5.4","input":"Hello, how are you?"}`) + seed := deriveOpenAIContentSessionSeed(body) + require.Contains(t, seed, "|input=Hello, how are you?") +} + +func TestDeriveOpenAIContentSessionSeed_ResponsesAPI_InputArray(t *testing.T) { + body := []byte(`{ + "model": "gpt-5.4", + "input": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"} + ] + }`) + seed := deriveOpenAIContentSessionSeed(body) + require.Contains(t, seed, "|system=") + require.Contains(t, seed, "|first_user=") +} + +func TestDeriveOpenAIContentSessionSeed_ResponsesAPI_WithInstructions(t *testing.T) { + body := []byte(`{ + "model": "gpt-5.4", + "instructions": "You are a coding assistant.", + "input": "Write a hello world" + }`) + seed := deriveOpenAIContentSessionSeed(body) + require.Contains(t, seed, "|instructions=You are a coding assistant.") + require.Contains(t, seed, "|input=Write a hello world") +} + +func TestDeriveOpenAIContentSessionSeed_Deterministic(t *testing.T) { + body := []byte(`{ + "model": "gpt-5.4", + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"} + ] + }`) + s1 := deriveOpenAIContentSessionSeed(body) + s2 := deriveOpenAIContentSessionSeed(body) + require.Equal(t, s1, s2, "seed must be deterministic") +} + +func TestDeriveOpenAIContentSessionSeed_PrefixPresent(t *testing.T) { + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"Hi"}]}`) + seed := deriveOpenAIContentSessionSeed(body) + require.True(t, len(seed) > len(contentSessionSeedPrefix)) + require.Equal(t, contentSessionSeedPrefix, seed[:len(contentSessionSeedPrefix)]) +} + +func TestDeriveOpenAIContentSessionSeed_EmptyToolsIgnored(t *testing.T) { + body := []byte(`{"model":"gpt-5.4","tools":[],"messages":[{"role":"user","content":"Hi"}]}`) + seed := deriveOpenAIContentSessionSeed(body) + require.NotContains(t, seed, "|tools=") +} + +func TestDeriveOpenAIContentSessionSeed_MessagesPreferredOverInput(t *testing.T) { + body := []byte(`{ + "model": "gpt-5.4", + "messages": [{"role": "user", "content": "from messages"}], + "input": "from input" + }`) + seed := deriveOpenAIContentSessionSeed(body) + require.Contains(t, seed, "|first_user=") + require.NotContains(t, seed, "|input=") +} + +func TestDeriveOpenAIContentSessionSeed_JSONCanonicalisation(t *testing.T) { + compact := []byte(`{"model":"gpt-5.4","tools":[{"type":"function","function":{"name":"get_weather","description":"Get weather"}}],"messages":[{"role":"user","content":"Hi"}]}`) + spaced := []byte(`{ + "model": "gpt-5.4", + "tools": [ + { "type" : "function", "function": { "description": "Get weather", "name": "get_weather" } } + ], + "messages": [ { "role": "user", "content": "Hi" } ] + }`) + s1 := deriveOpenAIContentSessionSeed(compact) + s2 := deriveOpenAIContentSessionSeed(spaced) + require.Equal(t, s1, s2, "different formatting of identical JSON should produce the same seed") +} + +func TestDeriveOpenAIContentSessionSeed_ResponsesAPI_InputTextTypedItem(t *testing.T) { + body := []byte(`{ + "model": "gpt-5.4", + "input": [{"type": "input_text", "text": "Hello world"}] + }`) + seed := deriveOpenAIContentSessionSeed(body) + require.Contains(t, seed, "|first_user=") + require.Contains(t, seed, "Hello world") +} + +func TestDeriveOpenAIContentSessionSeed_ResponsesAPI_TypedMessageItem(t *testing.T) { + body := []byte(`{ + "model": "gpt-5.4", + "input": [{"type": "message", "role": "user", "content": "Hello from typed message"}] + }`) + seed := deriveOpenAIContentSessionSeed(body) + require.Contains(t, seed, "|first_user=") + require.Contains(t, seed, "Hello from typed message") +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 65e70408..2623d773 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1121,6 +1121,7 @@ func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) str // 1. Header: session_id // 2. Header: conversation_id // 3. Body: prompt_cache_key (opencode) +// 4. Body: content-based fallback (model + system + tools + first user message) func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) string { if c == nil { return "" @@ -1133,6 +1134,9 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) if sessionID == "" && len(body) > 0 { sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) } + if sessionID == "" && len(body) > 0 { + sessionID = deriveOpenAIContentSessionSeed(body) + } if sessionID == "" { return "" } @@ -2048,6 +2052,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } } + if sanitizeEmptyBase64InputImagesInOpenAIRequestBodyMap(reqBody) { + bodyModified = true + disablePatch() + } + // Re-serialize body only if modified if bodyModified { serializedByPatch := false @@ -2475,6 +2484,14 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( reqStream = gjson.GetBytes(body, "stream").Bool() } + sanitizedBody, sanitized, err := sanitizeEmptyBase64InputImagesInOpenAIBody(body) + if err != nil { + return nil, err + } + if sanitized { + body = sanitizedBody + } + logger.LegacyPrintf("service.openai_gateway", "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v", account.ID, @@ -3007,6 +3024,14 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( return nil, err } + // Detect SSE responses from upstream and convert to JSON. + // Some upstreams (e.g. other sub2api instances) may return SSE even when + // stream=false was requested. Without this conversion the client would + // receive raw SSE text or a terminal event with empty output. + if isEventStreamResponse(resp.Header) { + return s.handlePassthroughSSEToJSON(resp, c, body) + } + usage := &OpenAIUsage{} usageParsed := false if len(body) > 0 { @@ -3030,6 +3055,56 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( return usage, nil } +// handlePassthroughSSEToJSON converts an SSE response body into a JSON +// response for the passthrough path. It mirrors handleSSEToJSON but skips +// model replacement (passthrough does not remap models). +func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte) (*OpenAIUsage, error) { + bodyText := string(body) + finalResponse, ok := extractCodexFinalResponse(bodyText) + + usage := &OpenAIUsage{} + if ok { + if parsedUsage, parsed := extractOpenAIUsageFromJSONBytes(finalResponse); parsed { + *usage = parsedUsage + } + // When the terminal event has an empty output array, reconstruct + // output from accumulated delta events so the client gets full content. + if len(gjson.GetBytes(finalResponse, "output").Array()) == 0 { + if outputJSON, reconstructed := reconstructResponseOutputFromSSE(bodyText); reconstructed { + if patched, err := sjson.SetRawBytes(finalResponse, "output", outputJSON); err == nil { + finalResponse = patched + } + } + } + body = finalResponse + // Correct tool calls in final response + body = s.correctToolCallsInResponseBody(body) + } else { + terminalType, terminalPayload, terminalOK := extractOpenAISSETerminalEvent(bodyText) + if terminalOK && terminalType == "response.failed" { + msg := extractOpenAISSEErrorMessage(terminalPayload) + if msg == "" { + msg = "Upstream compact response failed" + } + return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg) + } + usage = s.parseSSEUsageFromBody(bodyText) + } + + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + contentType := "application/json; charset=utf-8" + if !ok { + contentType = resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "text/event-stream" + } + } + c.Data(resp.StatusCode, contentType, body) + + return usage, nil +} + func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { if dst == nil || src == nil { return @@ -3858,10 +3933,21 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r return nil, err } + // Detect SSE responses for ALL account types via Content-Type header. + // Some OpenAI-compatible upstreams (including other sub2api instances) + // may return SSE even when stream=false was requested. + if isEventStreamResponse(resp.Header) { + return s.handleSSEToJSON(resp, c, body, originalModel, mappedModel) + } + // For OAuth accounts, also fall back to a body-content heuristic because + // the upstream may omit the Content-Type header while still sending SSE. + // This heuristic is NOT applied to API-key accounts to avoid false + // positives on JSON responses that coincidentally contain "data:" or + // "event:" in their text content. if account.Type == AccountTypeOAuth { bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:")) - if isEventStreamResponse(resp.Header) || bodyLooksLikeSSE { - return s.handleOAuthSSEToJSON(resp, c, body, originalModel, mappedModel) + if bodyLooksLikeSSE { + return s.handleSSEToJSON(resp, c, body, originalModel, mappedModel) } } @@ -3895,7 +3981,7 @@ func isEventStreamResponse(header http.Header) bool { return strings.Contains(contentType, "text/event-stream") } -func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) { +func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) { bodyText := string(body) finalResponse, ok := extractCodexFinalResponse(bodyText) @@ -4954,6 +5040,123 @@ func normalizeOpenAIServiceTier(raw string) *string { } } +func sanitizeEmptyBase64InputImagesInOpenAIBody(body []byte) ([]byte, bool, error) { + if len(body) == 0 || !bytes.Contains(body, []byte(`"image_url"`)) || !bytes.Contains(body, []byte(`base64,`)) { + return body, false, nil + } + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + return body, false, fmt.Errorf("sanitize request body: %w", err) + } + if !sanitizeEmptyBase64InputImagesInOpenAIRequestBodyMap(reqBody) { + return body, false, nil + } + normalized, err := json.Marshal(reqBody) + if err != nil { + return body, false, fmt.Errorf("serialize sanitized request body: %w", err) + } + return normalized, true, nil +} + +func sanitizeEmptyBase64InputImagesInOpenAIRequestBodyMap(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"] + if !ok { + return false + } + normalizedInput, changed := sanitizeEmptyBase64InputImagesInOpenAIInput(input) + if !changed { + return false + } + reqBody["input"] = normalizedInput + return true +} + +func sanitizeEmptyBase64InputImagesInOpenAIInput(input any) (any, bool) { + items, ok := input.([]any) + if !ok { + return input, false + } + + normalizedItems := make([]any, 0, len(items)) + changed := false + for _, item := range items { + itemMap, ok := item.(map[string]any) + if !ok { + normalizedItems = append(normalizedItems, item) + continue + } + if shouldDropEmptyBase64InputImagePart(itemMap) { + changed = true + continue + } + content, ok := itemMap["content"] + if !ok { + normalizedItems = append(normalizedItems, itemMap) + continue + } + parts, ok := content.([]any) + if !ok { + normalizedItems = append(normalizedItems, itemMap) + continue + } + + normalizedParts := make([]any, 0, len(parts)) + itemChanged := false + for _, part := range parts { + if shouldDropEmptyBase64InputImagePart(part) { + changed = true + itemChanged = true + continue + } + normalizedParts = append(normalizedParts, part) + } + if itemChanged { + if len(normalizedParts) == 0 { + continue + } + itemMap["content"] = normalizedParts + } + normalizedItems = append(normalizedItems, itemMap) + } + if !changed { + return input, false + } + return normalizedItems, true +} + +func shouldDropEmptyBase64InputImagePart(part any) bool { + partMap, ok := part.(map[string]any) + if !ok { + return false + } + typeValue, _ := partMap["type"].(string) + if strings.TrimSpace(typeValue) != "input_image" { + return false + } + imageURL, _ := partMap["image_url"].(string) + return isEmptyBase64DataURI(imageURL) +} + +func isEmptyBase64DataURI(raw string) bool { + if !strings.HasPrefix(raw, "data:") { + return false + } + rest := strings.TrimPrefix(raw, "data:") + semicolonIdx := strings.Index(rest, ";") + if semicolonIdx < 0 { + return false + } + rest = rest[semicolonIdx+1:] + if !strings.HasPrefix(rest, "base64,") { + return false + } + return strings.TrimSpace(strings.TrimPrefix(rest, "base64,")) == "" +} + func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) { if c != nil { if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok { diff --git a/backend/internal/service/openai_gateway_service_hotpath_test.go b/backend/internal/service/openai_gateway_service_hotpath_test.go index f73c06c5..234dee00 100644 --- a/backend/internal/service/openai_gateway_service_hotpath_test.go +++ b/backend/internal/service/openai_gateway_service_hotpath_test.go @@ -1,6 +1,7 @@ package service import ( + "encoding/json" "net/http/httptest" "testing" @@ -139,3 +140,61 @@ func TestGetOpenAIRequestBodyMap_WriteBackContextCache(t *testing.T) { require.True(t, ok) require.Equal(t, got, cachedMap) } + +func TestSanitizeEmptyBase64InputImagesInOpenAIRequestBodyMap(t *testing.T) { + var reqBody map[string]any + require.NoError(t, json.Unmarshal([]byte(`{ + "model":"gpt-5.4", + "input":[ + {"role":"user","content":[ + {"type":"input_text","text":"Describe this"}, + {"type":"input_image","image_url":"data:image/png;base64, "}, + {"type":"input_image","image_url":"data:image/png;base64,abc123"} + ]}, + {"role":"user","content":[ + {"type":"input_image","image_url":"data:image/png;base64,"} + ]}, + {"type":"input_image","image_url":"data:image/png;base64,"}, + {"type":"input_image","image_url":"data:image/png;base64,top-level-valid"} + ] + }`), &reqBody)) + + require.True(t, sanitizeEmptyBase64InputImagesInOpenAIRequestBodyMap(reqBody)) + + normalized, err := json.Marshal(reqBody) + require.NoError(t, err) + require.JSONEq(t, `{ + "model":"gpt-5.4", + "input":[ + {"role":"user","content":[ + {"type":"input_text","text":"Describe this"}, + {"type":"input_image","image_url":"data:image/png;base64,abc123"} + ]}, + {"type":"input_image","image_url":"data:image/png;base64,top-level-valid"} + ] + }`, string(normalized)) +} + +func TestSanitizeEmptyBase64InputImagesInOpenAIBody(t *testing.T) { + body, changed, err := sanitizeEmptyBase64InputImagesInOpenAIBody([]byte(`{ + "model":"gpt-5.4", + "stream":true, + "input":[ + {"role":"user","content":[ + {"type":"input_text","text":"Describe this"}, + {"type":"input_image","image_url":"data:image/png;base64,"} + ]} + ] + }`)) + require.NoError(t, err) + require.True(t, changed) + require.JSONEq(t, `{ + "model":"gpt-5.4", + "stream":true, + "input":[ + {"role":"user","content":[ + {"type":"input_text","text":"Describe this"} + ]} + ] + }`, string(body)) +} diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 9e2f33f2..cf2d875f 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -237,6 +237,60 @@ func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) { require.Equal(t, "", empty) } +func TestOpenAIGatewayService_GenerateSessionHash_ContentFallback(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/chat/completions", nil) + + svc := &OpenAIGatewayService{} + + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"system","content":"You are helpful."},{"role":"user","content":"Hello"}]}`) + + hash := svc.GenerateSessionHash(c, body) + require.NotEmpty(t, hash, "content-based fallback should produce a hash") + + hash2 := svc.GenerateSessionHash(c, body) + require.Equal(t, hash, hash2, "same content should produce same hash") + + bodyExtended := []byte(`{"model":"gpt-5.4","messages":[{"role":"system","content":"You are helpful."},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi!"},{"role":"user","content":"How are you?"}]}`) + hashExtended := svc.GenerateSessionHash(c, bodyExtended) + require.Equal(t, hash, hashExtended, "hash should be stable across later turns") + + bodyDifferent := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"Different question"}]}`) + hashDifferent := svc.GenerateSessionHash(c, bodyDifferent) + require.NotEqual(t, hash, hashDifferent, "different content should produce different hash") +} + +func TestOpenAIGatewayService_GenerateSessionHash_ExplicitSignalWinsOverContent(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/chat/completions", nil) + + svc := &OpenAIGatewayService{} + body := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"Hello"}]}`) + + contentHash := svc.GenerateSessionHash(c, body) + require.NotEmpty(t, contentHash) + + c.Request.Header.Set("session_id", "explicit-session") + explicitHash := svc.GenerateSessionHash(c, body) + require.NotEmpty(t, explicitHash) + require.NotEqual(t, contentHash, explicitHash, "explicit session_id should override content fallback") +} + +func TestOpenAIGatewayService_GenerateSessionHash_EmptyBodyStillEmpty(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/chat/completions", nil) + + svc := &OpenAIGatewayService{} + require.Empty(t, svc.GenerateSessionHash(c, []byte(`{}`))) + require.Empty(t, svc.GenerateSessionHash(c, nil)) +} + func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { if c.waitCounts != nil { if count, ok := c.waitCounts[accountID]; ok { @@ -1797,7 +1851,7 @@ func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) { require.Contains(t, string(finalResp), `"input_tokens":11`) } -func TestHandleOAuthSSEToJSON_CompletedEventReturnsJSON(t *testing.T) { +func TestHandleSSEToJSON_CompletedEventReturnsJSON(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) @@ -1814,7 +1868,7 @@ func TestHandleOAuthSSEToJSON_CompletedEventReturnsJSON(t *testing.T) { `data: [DONE]`, }, "\n")) - usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") + usage, err := svc.handleSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") require.NoError(t, err) require.NotNil(t, usage) require.Equal(t, 7, usage.InputTokens) @@ -1826,7 +1880,7 @@ func TestHandleOAuthSSEToJSON_CompletedEventReturnsJSON(t *testing.T) { require.NotContains(t, rec.Body.String(), "data:") } -func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) { +func TestHandleSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) @@ -1842,7 +1896,7 @@ func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) { `data: [DONE]`, }, "\n")) - usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") + usage, err := svc.handleSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") require.NoError(t, err) require.NotNil(t, usage) require.Equal(t, 0, usage.InputTokens) @@ -1850,7 +1904,7 @@ func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) { require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`) } -func TestHandleOAuthSSEToJSON_ResponseFailedReturnsProtocolError(t *testing.T) { +func TestHandleSSEToJSON_ResponseFailedReturnsProtocolError(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) @@ -1866,7 +1920,7 @@ func TestHandleOAuthSSEToJSON_ResponseFailedReturnsProtocolError(t *testing.T) { `data: [DONE]`, }, "\n")) - usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") + usage, err := svc.handleSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") require.Nil(t, usage) require.Error(t, err) require.Equal(t, http.StatusBadGateway, rec.Code) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index b7145121..7d0ef5bd 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -81,6 +81,7 @@ const backendModeDBTimeout = 5 * time.Second type cachedGatewayForwardingSettings struct { fingerprintUnification bool metadataPassthrough bool + cchSigning bool expiresAt int64 // unix nano } @@ -514,6 +515,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet // Gateway forwarding behavior updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification) updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough) + updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning) err = s.settingRepo.SetMultiple(ctx, updates) if err == nil { @@ -533,6 +535,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ fingerprintUnification: settings.EnableFingerprintUnification, metadataPassthrough: settings.EnableMetadataPassthrough, + cchSigning: settings.EnableCCHSigning, expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), }) if s.onUpdate != nil { @@ -639,20 +642,20 @@ func (s *SettingService) IsBackendModeEnabled(ctx context.Context) bool { // GetGatewayForwardingSettings returns cached gateway forwarding settings. // Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path. -// Returns (fingerprintUnification, metadataPassthrough). -func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough bool) { +// Returns (fingerprintUnification, metadataPassthrough, cchSigning). +func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fingerprintUnification, metadataPassthrough, cchSigning bool) { if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil { if time.Now().UnixNano() < cached.expiresAt { - return cached.fingerprintUnification, cached.metadataPassthrough + return cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning } } type gwfResult struct { - fp, mp bool + fp, mp, cch bool } val, _, _ := gatewayForwardingSF.Do("gateway_forwarding", func() (any, error) { if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil { if time.Now().UnixNano() < cached.expiresAt { - return gwfResult{cached.fingerprintUnification, cached.metadataPassthrough}, nil + return gwfResult{cached.fingerprintUnification, cached.metadataPassthrough, cached.cchSigning}, nil } } dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), gatewayForwardingDBTimeout) @@ -660,32 +663,36 @@ func (s *SettingService) GetGatewayForwardingSettings(ctx context.Context) (fing values, err := s.settingRepo.GetMultiple(dbCtx, []string{ SettingKeyEnableFingerprintUnification, SettingKeyEnableMetadataPassthrough, + SettingKeyEnableCCHSigning, }) if err != nil { slog.Warn("failed to get gateway forwarding settings", "error", err) gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ fingerprintUnification: true, metadataPassthrough: false, + cchSigning: false, expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(), }) - return gwfResult{true, false}, nil + return gwfResult{true, false, false}, nil } fp := true if v, ok := values[SettingKeyEnableFingerprintUnification]; ok && v != "" { fp = v == "true" } mp := values[SettingKeyEnableMetadataPassthrough] == "true" + cch := values[SettingKeyEnableCCHSigning] == "true" gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ fingerprintUnification: fp, metadataPassthrough: mp, + cchSigning: cch, expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), }) - return gwfResult{fp, mp}, nil + return gwfResult{fp, mp, cch}, nil }) if r, ok := val.(gwfResult); ok { - return r.fp, r.mp + return r.fp, r.mp, r.cch } - return true, false // fail-open defaults + return true, false, false // fail-open defaults } // IsEmailVerifyEnabled 检查是否开启邮件验证 @@ -983,13 +990,14 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin // 分组隔离 result.AllowUngroupedKeyScheduling = settings[SettingKeyAllowUngroupedKeyScheduling] == "true" - // Gateway forwarding behavior (defaults: fingerprint=true, metadata_passthrough=false) + // Gateway forwarding behavior (defaults: fingerprint=true, metadata_passthrough=false, cch_signing=false) if v, ok := settings[SettingKeyEnableFingerprintUnification]; ok && v != "" { result.EnableFingerprintUnification = v == "true" } else { result.EnableFingerprintUnification = true // default: enabled (current behavior) } result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true" + result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true" return result } diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 473d7297..fedb3f2f 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -78,6 +78,7 @@ type SystemSettings struct { // Gateway forwarding behavior EnableFingerprintUnification bool // 是否统一 OAuth 账号的指纹头(默认 true) EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false) + EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false) } type DefaultSubscriptionSetting struct { diff --git a/deploy/Dockerfile b/deploy/Dockerfile index 7caa5ca6..b0b6036c 100644 --- a/deploy/Dockerfile +++ b/deploy/Dockerfile @@ -7,7 +7,7 @@ # ============================================================================= ARG NODE_IMAGE=node:24-alpine -ARG GOLANG_IMAGE=golang:1.26.1-alpine +ARG GOLANG_IMAGE=golang:1.26.2-alpine ARG ALPINE_IMAGE=alpine:3.20 ARG GOPROXY=https://goproxy.cn,direct ARG GOSUMDB=sum.golang.google.cn diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 013f2dfb..b7ee6be5 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -89,6 +89,7 @@ export interface SystemSettings { // Gateway forwarding behavior enable_fingerprint_unification: boolean enable_metadata_passthrough: boolean + enable_cch_signing: boolean } export interface UpdateSettingsRequest { @@ -146,6 +147,7 @@ export interface UpdateSettingsRequest { allow_ungrouped_key_scheduling?: boolean enable_fingerprint_unification?: boolean enable_metadata_passthrough?: boolean + enable_cch_signing?: boolean } /** diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index fc9297fd..d3b16d4a 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -4268,6 +4268,8 @@ export default { fingerprintUnificationHint: 'Unify X-Stainless-* headers across users sharing the same OAuth account. Disabling passes through each client\'s original headers.', metadataPassthrough: 'Metadata Passthrough', metadataPassthroughHint: 'Pass through client\'s original metadata.user_id without rewriting. May improve upstream cache hit rates.', + cchSigning: 'CCH Signing', + cchSigningHint: 'Sign the billing header in forwarded requests with CCH hash. When disabled, the placeholder is preserved.', }, site: { title: 'Site Settings', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 57bfefdc..fcaaf5ab 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -4431,6 +4431,8 @@ export default { fingerprintUnificationHint: '统一共享同一 OAuth 账号的用户的 X-Stainless-* 请求头。关闭后透传客户端原始请求头。', metadataPassthrough: 'Metadata 透传', metadataPassthroughHint: '透传客户端原始 metadata.user_id,不进行重写。可能提高上游缓存命中率。', + cchSigning: 'CCH 签名', + cchSigningHint: '对转发请求的 billing header 进行 CCH 哈希签名。关闭时保留原始占位符。', }, site: { title: '站点设置', diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 9ae40aeb..f43140ab 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -1376,6 +1376,19 @@ + + +
+
+ +

+ {{ t('admin.settings.gatewayForwarding.cchSigningHint') }} +

+
+ +
@@ -2248,7 +2261,8 @@ const form = reactive({ allow_ungrouped_key_scheduling: false, // Gateway forwarding behavior enable_fingerprint_unification: true, - enable_metadata_passthrough: false + enable_metadata_passthrough: false, + enable_cch_signing: false }) const defaultSubscriptionGroupOptions = computed(() => @@ -2556,7 +2570,8 @@ async function saveSettings() { max_claude_code_version: form.max_claude_code_version, allow_ungrouped_key_scheduling: form.allow_ungrouped_key_scheduling, enable_fingerprint_unification: form.enable_fingerprint_unification, - enable_metadata_passthrough: form.enable_metadata_passthrough + enable_metadata_passthrough: form.enable_metadata_passthrough, + enable_cch_signing: form.enable_cch_signing } const updated = await adminAPI.settings.updateSettings(payload) Object.assign(form, updated)