Merge pull request #2820 from Pluviobyte/fix/antigravity-passthrough-message-start-input-tokens

fix(antigravity): capture message_start input_tokens in streaming passthrough
This commit is contained in:
Wesley Liddick 2026-05-27 20:59:51 +08:00 committed by GitHub
commit 7579058e91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 2 deletions

View File

@ -4463,6 +4463,14 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp
}
// extractSSEUsage 从 SSE data 行中提取 Claude usage用于流式透传场景
//
// Anthropic streaming 的 usage 字段分布在两类事件中:
// - message_start嵌套在 event.message.usageinput_tokens、cache_creation_input_tokens、
// cache_read_input_tokens 等输入侧字段)
// - message_delta位于顶层 event.usage流结束时的最终 output_tokens
//
// 仅读取顶层 event.usage 会漏掉 message_start 的输入侧字段,导致流式透传请求落库的
// usage_logs 记录 input_tokens=0。
func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUsage) {
if !strings.HasPrefix(line, "data: ") {
return
@ -4472,8 +4480,15 @@ func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUs
if json.Unmarshal([]byte(dataStr), &event) != nil {
return
}
u, ok := event["usage"].(map[string]any)
if !ok {
var u map[string]any
if eventType, _ := event["type"].(string); eventType == "message_start" {
if msg, ok := event["message"].(map[string]any); ok {
u, _ = msg["usage"].(map[string]any)
}
} else {
u, _ = event["usage"].(map[string]any)
}
if u == nil {
return
}
if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 {

View File

@ -1301,6 +1301,19 @@ func TestExtractSSEUsage(t *testing.T) {
line: `data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`,
expected: ClaudeUsage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5, CacheCreationInputTokens: 3},
},
{
// Anthropic message_start 把 usage 嵌套在 message.usage 下,
// 必须从这里提取输入侧字段(含 cache_read/cache_creation_input_tokens
name: "message_start nested usage with input/cache tokens",
line: `data: {"type":"message_start","message":{"id":"msg_01","usage":{"input_tokens":35576,"cache_creation_input_tokens":0,"cache_read_input_tokens":12000,"output_tokens":1}}}`,
expected: ClaudeUsage{InputTokens: 35576, OutputTokens: 1, CacheReadInputTokens: 12000},
},
{
// message_start.message.usage.cache_creation 内的 5m/1h 明细也要解析。
name: "message_start nested usage with cache_creation breakdown",
line: `data: {"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation":{"ephemeral_5m_input_tokens":30,"ephemeral_1h_input_tokens":70}}}}`,
expected: ClaudeUsage{InputTokens: 100, CacheCreation5mTokens: 30, CacheCreation1hTokens: 70},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -1311,6 +1324,29 @@ func TestExtractSSEUsage(t *testing.T) {
}
}
// TestExtractSSEUsage_StreamingSequence 复现 issue #2332完整的 Anthropic streaming
// 序列message_start → message_delta必须把两类事件中的 usage 字段都汇入同一份累计值,
// 否则透传账号产出的 usage_logs 会出现 input_tokens=0、仅有 output_tokens 的"残缺"记录。
func TestExtractSSEUsage_StreamingSequence(t *testing.T) {
svc := &AntigravityGatewayService{}
usage := &ClaudeUsage{}
// 1) message_start携带完整输入侧 usageinput_tokens + cache_read
svc.extractSSEUsage(
`data: {"type":"message_start","message":{"id":"msg_01","type":"message","role":"assistant","content":[],"model":"claude-opus-4-6","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":35576,"cache_creation_input_tokens":0,"cache_read_input_tokens":12000,"output_tokens":1}}}`,
usage,
)
// 2) message_delta流结束时只带 output_tokens无 input_tokens 字段)
svc.extractSSEUsage(
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":816}}`,
usage,
)
require.Equal(t, 35576, usage.InputTokens, "message_start 的 input_tokens 必须被记录,否则记账会缺失输入侧 token (#2332)")
require.Equal(t, 12000, usage.CacheReadInputTokens, "message_start 的 cache_read_input_tokens 必须被记录")
require.Equal(t, 816, usage.OutputTokens, "message_delta 的最终 output_tokens 必须被记录")
}
// TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测
func TestAntigravityClientWriter(t *testing.T) {
t.Run("normal write succeeds", func(t *testing.T) {