From 27600b1d2c9579f83abbb1ee469bfc45b908f840 Mon Sep 17 00:00:00 2001 From: Pluviobyte Date: Thu, 28 May 2026 05:40:50 +0000 Subject: [PATCH] fix(gateway): filter count_tokens generation fields Anthropic count_tokens rejects generation-only fields such as temperature, top_p, top_k, stream, and stop sequences. Passing the original messages payload through unchanged can turn otherwise valid requests into upstream 400 errors. Sanitize only the count_tokens upstream payload after the gateway's existing request normalization, preserving fields that existing compatibility paths rely on while removing parameters the count_tokens endpoint does not accept. Fixes #2764 Co-authored-by: Cursor --- ...teway_anthropic_apikey_passthrough_test.go | 60 +++++++++++++++++++ backend/internal/service/gateway_service.go | 21 +++++++ 2 files changed, 81 insertions(+) diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index 5cb03f30..9062c517 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -476,6 +476,66 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFie require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int(), "max_tokens 不应被修改") } +func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokensFiltersGenerationFields(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil) + + body := []byte(`{"model":"claude-sonnet-4-20250514","system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"tool","input_schema":{"type":"object"}}],"temperature":0.7,"top_p":0.9,"top_k":40,"stream":true,"stop_sequences":["END"],"max_tokens":1024,"thinking":{"type":"enabled","budget_tokens":5000}}`) + parsed := &ParsedRequest{ + Body: body, + Model: "claude-sonnet-4-20250514", + } + + upstreamRespBody := `{"input_tokens":42}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(upstreamRespBody)), + }, + } + + svc := &GatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}}, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + account := &Account{ + ID: 302, + Name: "count-token-filter-test", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "upstream-key", + "base_url": "https://api.anthropic.com", + }, + Extra: map[string]any{"anthropic_passthrough": true}, + Status: StatusActive, + Schedulable: true, + } + + err := svc.ForwardCountTokens(context.Background(), c, account, parsed) + require.NoError(t, err) + + sentBody := upstream.lastBody + require.False(t, gjson.GetBytes(sentBody, "temperature").Exists()) + require.False(t, gjson.GetBytes(sentBody, "top_p").Exists()) + require.False(t, gjson.GetBytes(sentBody, "top_k").Exists()) + require.False(t, gjson.GetBytes(sentBody, "stream").Exists()) + require.False(t, gjson.GetBytes(sentBody, "stop_sequences").Exists()) + require.Equal(t, "claude-sonnet-4-20250514", gjson.GetBytes(sentBody, "model").String()) + require.Equal(t, "sys", gjson.GetBytes(sentBody, "system.0.text").String()) + require.Equal(t, "hello", gjson.GetBytes(sentBody, "messages.0.content").String()) + require.Equal(t, "tool", gjson.GetBytes(sentBody, "tools.0.name").String()) + require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int()) + require.Equal(t, "enabled", gjson.GetBytes(sentBody, "thinking.type").String()) +} + // TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping // 确保空模型名不会触发映射逻辑 func TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping(t *testing.T) { diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 4a8175a4..a787e3eb 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -9311,6 +9311,7 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough( } targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" } + body = sanitizeCountTokensRequestBody(body) req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) if err != nil { @@ -9405,6 +9406,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if ctEnableCCH { body = signBillingHeaderCCH(body) } + body = sanitizeCountTokensRequestBody(body) req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { @@ -9501,6 +9503,25 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con return req, nil } +func sanitizeCountTokensRequestBody(body []byte) []byte { + out := body + for _, path := range []string{ + "temperature", + "top_p", + "top_k", + "stream", + "stop_sequences", + "stop", + } { + if gjson.GetBytes(out, path).Exists() { + if next, ok := deleteJSONPathBytes(out, path); ok { + out = next + } + } + } + return out +} + // countTokensError 返回 count_tokens 错误响应 func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{