Merge pull request #2849 from Pluviobyte/fix/count-tokens-payload-filter
fix(gateway): filter count_tokens generation fields
This commit is contained in:
commit
52292741cb
@ -476,6 +476,66 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_ModelMappingPreservesOtherFie
|
||||
require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int(), "max_tokens 不应被修改")
|
||||
}
|
||||
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokensFiltersGenerationFields(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil)
|
||||
|
||||
body := []byte(`{"model":"claude-sonnet-4-20250514","system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"tool","input_schema":{"type":"object"}}],"temperature":0.7,"top_p":0.9,"top_k":40,"stream":true,"stop_sequences":["END"],"max_tokens":1024,"thinking":{"type":"enabled","budget_tokens":5000}}`)
|
||||
parsed := &ParsedRequest{
|
||||
Body: body,
|
||||
Model: "claude-sonnet-4-20250514",
|
||||
}
|
||||
|
||||
upstreamRespBody := `{"input_tokens":42}`
|
||||
upstream := &anthropicHTTPUpstreamRecorder{
|
||||
resp: &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(upstreamRespBody)),
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
cfg: &config.Config{Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}},
|
||||
httpUpstream: upstream,
|
||||
rateLimitService: &RateLimitService{},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 302,
|
||||
Name: "count-token-filter-test",
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "upstream-key",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
},
|
||||
Extra: map[string]any{"anthropic_passthrough": true},
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
err := svc.ForwardCountTokens(context.Background(), c, account, parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
sentBody := upstream.lastBody
|
||||
require.False(t, gjson.GetBytes(sentBody, "temperature").Exists())
|
||||
require.False(t, gjson.GetBytes(sentBody, "top_p").Exists())
|
||||
require.False(t, gjson.GetBytes(sentBody, "top_k").Exists())
|
||||
require.False(t, gjson.GetBytes(sentBody, "stream").Exists())
|
||||
require.False(t, gjson.GetBytes(sentBody, "stop_sequences").Exists())
|
||||
require.Equal(t, "claude-sonnet-4-20250514", gjson.GetBytes(sentBody, "model").String())
|
||||
require.Equal(t, "sys", gjson.GetBytes(sentBody, "system.0.text").String())
|
||||
require.Equal(t, "hello", gjson.GetBytes(sentBody, "messages.0.content").String())
|
||||
require.Equal(t, "tool", gjson.GetBytes(sentBody, "tools.0.name").String())
|
||||
require.Equal(t, int64(1024), gjson.GetBytes(sentBody, "max_tokens").Int())
|
||||
require.Equal(t, "enabled", gjson.GetBytes(sentBody, "thinking.type").String())
|
||||
}
|
||||
|
||||
// TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping
|
||||
// 确保空模型名不会触发映射逻辑
|
||||
func TestGatewayService_AnthropicAPIKeyPassthrough_EmptyModelSkipsMapping(t *testing.T) {
|
||||
|
||||
@ -9448,6 +9448,7 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough(
|
||||
}
|
||||
targetURL = validatedURL + "/v1/messages/count_tokens?beta=true"
|
||||
}
|
||||
body = sanitizeCountTokensRequestBody(body)
|
||||
|
||||
// 同 buildUpstreamRequestAnthropicAPIKeyPassthrough:能力维度 sanitize。
|
||||
clientBeta := ""
|
||||
@ -9564,6 +9565,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
if ctEnableCCH {
|
||||
body = signBillingHeaderCCH(body)
|
||||
}
|
||||
body = sanitizeCountTokensRequestBody(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
@ -9634,6 +9636,25 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func sanitizeCountTokensRequestBody(body []byte) []byte {
|
||||
out := body
|
||||
for _, path := range []string{
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"stream",
|
||||
"stop_sequences",
|
||||
"stop",
|
||||
} {
|
||||
if gjson.GetBytes(out, path).Exists() {
|
||||
if next, ok := deleteJSONPathBytes(out, path); ok {
|
||||
out = next
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// countTokensError 返回 count_tokens 错误响应
|
||||
func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user