fix(gateway): filter count_tokens generation fields

Anthropic count_tokens rejects generation-only fields such as temperature, top_p, top_k, stream, and stop sequences. Passing the original messages payload through unchanged can turn otherwise valid requests into upstream 400 errors.

Sanitize only the count_tokens upstream payload after the gateway's existing request normalization, preserving fields that existing compatibility paths rely on while removing parameters the count_tokens endpoint does not accept.

Fixes #2764

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Pluviobyte 2026-05-28 05:40:50 +00:00
parent 89d96f4b25
commit 27600b1d2c
No known key found for this signature in database
2 changed files with 81 additions and 0 deletions

View File

@ -476,6 +476,66 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFie
require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int(), "max_tokens 不应被修改")
}
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokensFiltersGenerationFields(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
body := []byte(`{"model":"claude-sonnet-4-20250514","system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"tool","input_schema":{"type":"object"}}],"temperature":0.7,"top_p":0.9,"top_k":40,"stream":true,"stop_sequences":["END"],"max_tokens":1024,"thinking":{"type":"enabled","budget_tokens":5000}}`)
parsed := &ParsedRequest{
Body: body,
Model: "claude-sonnet-4-20250514",
}
upstreamRespBody := `{"input_tokens":42}`
upstream := &anthropicHTTPUpstreamRecorder{
resp: &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
},
}
svc := &GatewayService{
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
httpUpstream: upstream,
rateLimitService: &RateLimitService{},
}
account := &Account{
ID: 302,
Name: "count-token-filter-test",
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Concurrency: 1,
Credentials: map[string]any{
"api_key": "upstream-key",
"base_url": "https://api.anthropic.com",
},
Extra: map[string]any{"anthropic_passthrough": true},
Status: StatusActive,
Schedulable: true,
}
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
require.NoError(t, err)
sentBody := upstream.lastBody
require.False(t, gjson.GetBytes(sentBody, "temperature").Exists())
require.False(t, gjson.GetBytes(sentBody, "top_p").Exists())
require.False(t, gjson.GetBytes(sentBody, "top_k").Exists())
require.False(t, gjson.GetBytes(sentBody, "stream").Exists())
require.False(t, gjson.GetBytes(sentBody, "stop_sequences").Exists())
require.Equal(t, "claude-sonnet-4-20250514", gjson.GetBytes(sentBody, "model").String())
require.Equal(t, "sys", gjson.GetBytes(sentBody, "system.0.text").String())
require.Equal(t, "hello", gjson.GetBytes(sentBody, "messages.0.content").String())
require.Equal(t, "tool", gjson.GetBytes(sentBody, "tools.0.name").String())
require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int())
require.Equal(t, "enabled", gjson.GetBytes(sentBody, "thinking.type").String())
}
// TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping
// 确保空模型名不会触发映射逻辑
func TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping(t *testing.T) {

View File

@ -9311,6 +9311,7 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough(
}
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
}
body = sanitizeCountTokensRequestBody(body)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body))
if err != nil {
@ -9405,6 +9406,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if ctEnableCCH {
body = signBillingHeaderCCH(body)
}
body = sanitizeCountTokensRequestBody(body)
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
@ -9501,6 +9503,25 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
return req, nil
}
func sanitizeCountTokensRequestBody(body []byte) []byte {
out := body
for _, path := range []string{
"temperature",
"top_p",
"top_k",
"stream",
"stop_sequences",
"stop",
} {
if gjson.GetBytes(out, path).Exists() {
if next, ok := deleteJSONPathBytes(out, path); ok {
out = next
}
}
}
return out
}
// countTokensError 返回 count_tokens 错误响应
func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{