diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 3d5e151f..4d116313 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -485,6 +485,10 @@ type GatewayConfig struct { // RPMSmoothing: RPM 令牌桶平滑配置 // 启用后,RPM 配额耗尽时请求等待令牌(最多 MaxWaitMS 毫秒)而非立即返回 429 RPMSmoothing RPMSmoothingConfig `mapstructure:"rpm_smoothing"` + + // ContextCompression: 主动上下文压缩配置 + // 账号启用 enable_context_compression 后,超出 MaxTokens 预算时自动裁剪历史消息 + ContextCompression ContextCompressionConfig `mapstructure:"context_compression"` } type GatewayAntigravityLSWorkerConfig struct { @@ -556,6 +560,20 @@ func (c *RPMSmoothingConfig) MaxWait() time.Duration { return time.Duration(c.MaxWaitMS) * time.Millisecond } +// ContextCompressionConfig 主动上下文压缩配置 +type ContextCompressionConfig struct { + // MaxTokens: 压缩目标 token 数(chars/4 近似),超出时从最旧消息开始裁剪(默认 190000) + MaxTokens int `mapstructure:"max_tokens"` +} + +// GetMaxTokens returns the configured token budget, defaulting to 190 000. +func (c *ContextCompressionConfig) GetMaxTokens() int { + if c.MaxTokens <= 0 { + return 190_000 + } + return c.MaxTokens +} + // GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。 // 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。 type GatewayOpenAIWSConfig struct { diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index a1449ffd..feb1da37 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -771,6 +771,17 @@ func (a *Account) IsInterceptWarmupEnabled() bool { return false } +// IsContextCompressionEnabled returns true if the account has opted into proactive +// context compression. When enabled, the gateway will trim oldest messages before +// dispatch to keep the estimated token count within the configured budget. +func (a *Account) IsContextCompressionEnabled() bool { + if a.Credentials == nil { + return false + } + enabled, _ := a.Credentials["enable_context_compression"].(bool) + return enabled +} + func (a *Account) IsBedrock() bool { return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock } diff --git a/backend/internal/service/context_compressor.go b/backend/internal/service/context_compressor.go new file mode 100644 index 00000000..a3500cda --- /dev/null +++ b/backend/internal/service/context_compressor.go @@ -0,0 +1,151 @@ +package service + +import ( + "encoding/json" + "math" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// defaultContextCompressionMaxTokens is the default target token budget (chars/4 approximation). +// 190K is conservative for a 200K-window model, leaving ~10K headroom for the response. +const defaultContextCompressionMaxTokens = 190_000 + +// approxTokens estimates the token count for a string using the chars/4 heuristic. +func approxTokens(s string) int { + return int(math.Ceil(float64(len(s)) / 4.0)) +} + +// compressMessagesInBody trims the oldest messages from the request body so that the +// estimated token count of the messages array fits within maxTokens. +// Returns the original body unchanged if no compression is needed or if parsing fails. +func compressMessagesInBody(body []byte, maxTokens int) []byte { + msgsResult := gjson.GetBytes(body, "messages") + if !msgsResult.Exists() || !msgsResult.IsArray() { + return body + } + + // Unmarshal to a typed slice for processing. + var messages []map[string]any + if err := json.Unmarshal([]byte(msgsResult.Raw), &messages); err != nil { + return body + } + + compressed, changed := compressMessages(messages, maxTokens) + if !changed { + return body + } + + newMsgs, err := json.Marshal(compressed) + if err != nil { + return body + } + updated, err := sjson.SetRawBytes(body, "messages", newMsgs) + if err != nil { + return body + } + return updated +} + +// compressMessages removes the oldest messages from the front of msgs until the +// estimated total token count is at or below maxTokens. +// tool_use (assistant) and tool_result (user) consecutive pairs are removed atomically +// to avoid orphaned tool_result blocks. +// Returns (msgs, false) if no compression was needed, or (trimmed, true) otherwise. +func compressMessages(msgs []map[string]any, maxTokens int) ([]map[string]any, bool) { + if len(msgs) == 0 { + return msgs, false + } + + // Estimate total tokens. + totalTokens := 0 + for _, m := range msgs { + totalTokens += msgTokens(m) + } + if totalTokens <= maxTokens { + return msgs, false + } + + // Build atomic removal units: tool_use+tool_result consecutive pairs are one unit. + type unit struct { + startIdx int + endIdx int // exclusive + tokens int + } + units := make([]unit, 0, len(msgs)) + i := 0 + for i < len(msgs) { + toks := msgTokens(msgs[i]) + if isAssistantWithToolUse(msgs[i]) && i+1 < len(msgs) && isUserWithToolResult(msgs[i+1]) { + toks += msgTokens(msgs[i+1]) + units = append(units, unit{i, i + 2, toks}) + i += 2 + } else { + units = append(units, unit{i, i + 1, toks}) + i++ + } + } + + // Remove units from the front until we are within budget. + // Always keep at least the last unit so we never send an empty messages array. + removeCount := 0 + for removeCount < len(units)-1 && totalTokens > maxTokens { + totalTokens -= units[removeCount].tokens + removeCount++ + } + if removeCount == 0 { + return msgs, false + } + + cutIdx := units[removeCount].startIdx + return msgs[cutIdx:], true +} + +// msgTokens estimates token count for a single message using the chars/4 heuristic. +func msgTokens(msg map[string]any) int { + b, err := json.Marshal(msg) + if err != nil { + return 0 + } + return approxTokens(string(b)) +} + +// isAssistantWithToolUse returns true if msg is an assistant message whose content +// contains at least one block with "type": "tool_use". +func isAssistantWithToolUse(msg map[string]any) bool { + role, _ := msg["role"].(string) + if role != "assistant" { + return false + } + return contentContainsType(msg["content"], "tool_use") +} + +// isUserWithToolResult returns true if msg is a user message whose content +// contains at least one block with "type": "tool_result". +func isUserWithToolResult(msg map[string]any) bool { + role, _ := msg["role"].(string) + if role != "user" { + return false + } + return contentContainsType(msg["content"], "tool_result") +} + +// contentContainsType returns true if content (a []any of blocks) contains a block +// whose "type" field equals blockType. +func contentContainsType(content any, blockType string) bool { + blocks, ok := content.([]any) + if !ok { + return false + } + for _, b := range blocks { + block, ok := b.(map[string]any) + if !ok { + continue + } + if t, _ := block["type"].(string); t == blockType { + return true + } + } + return false +} diff --git a/backend/internal/service/context_compressor_test.go b/backend/internal/service/context_compressor_test.go new file mode 100644 index 00000000..bad3983d --- /dev/null +++ b/backend/internal/service/context_compressor_test.go @@ -0,0 +1,195 @@ +package service + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// helpers + +func makeMsg(role, text string) map[string]any { + return map[string]any{ + "role": role, + "content": text, + } +} + +func makeToolUseMsg(id string) map[string]any { + return map[string]any{ + "role": "assistant", + "content": []any{ + map[string]any{ + "type": "tool_use", + "id": id, + "name": "search", + "input": map[string]any{}, + }, + }, + } +} + +func makeToolResultMsg(toolUseID string) map[string]any { + return map[string]any{ + "role": "user", + "content": []any{ + map[string]any{ + "type": "tool_result", + "tool_use_id": toolUseID, + "content": "result text", + }, + }, + } +} + +func toAnySlice(msgs []map[string]any) []any { + out := make([]any, len(msgs)) + for i, m := range msgs { + out[i] = m + } + return out +} + +func bodyWithMessages(t *testing.T, msgs []map[string]any) []byte { + t.Helper() + b, err := json.Marshal(map[string]any{"messages": msgs, "model": "claude-3-5-sonnet-20241022"}) + require.NoError(t, err) + return b +} + +// tests + +func TestApproxTokens(t *testing.T) { + assert.Equal(t, 1, approxTokens("four")) // 4 chars → 1 token + assert.Equal(t, 3, approxTokens("0123456789ab")) // 12 chars → 3 tokens + assert.Equal(t, 0, approxTokens("")) +} + +func TestCompressMessages_NoCompressionNeeded(t *testing.T) { + msgs := []map[string]any{ + makeMsg("user", "hi"), + makeMsg("assistant", "hello"), + } + result, changed := compressMessages(msgs, 100_000) + assert.False(t, changed) + assert.Len(t, result, 2) +} + +func TestCompressMessages_TrimsOldestMessages(t *testing.T) { + // 10 messages, each large enough to be over a tight budget when combined. + msgs := make([]map[string]any, 10) + for i := range msgs { + role := "user" + if i%2 == 1 { + role = "assistant" + } + msgs[i] = makeMsg(role, fmt.Sprintf("message number %d with some content to increase token count", i)) + } + + // Force compression by using a very small token budget. + result, changed := compressMessages(msgs, 1) + assert.True(t, changed) + // Must keep at least one message (the last). + assert.GreaterOrEqual(t, len(result), 1) + // The remaining messages should be from the tail (newest). + lastOrig := msgs[len(msgs)-1]["content"] + lastResult := result[len(result)-1]["content"] + assert.Equal(t, lastOrig, lastResult) +} + +func TestCompressMessages_PreservesToolUsePairs(t *testing.T) { + // Messages: user → assistant+tool_use → user+tool_result → assistant + msgs := []map[string]any{ + makeMsg("user", "start"), + makeToolUseMsg("tool-1"), + makeToolResultMsg("tool-1"), + makeMsg("assistant", "done"), + } + + // Budget that forces removal of the first non-paired message but keeps the tool pair. + // Estimate total tokens and set budget to force removing only "start" but not the pair. + total := 0 + for _, m := range msgs { + total += msgTokens(m) + } + // Budget: remove "start" but keep tool pair + "done". + startTokens := msgTokens(msgs[0]) + budget := total - startTokens + + result, changed := compressMessages(msgs, budget) + assert.True(t, changed) + + // tool_use and tool_result should both be present or both absent. + hasToolUse := false + hasToolResult := false + for _, m := range result { + if isAssistantWithToolUse(m) { + hasToolUse = true + } + if isUserWithToolResult(m) { + hasToolResult = true + } + } + assert.Equal(t, hasToolUse, hasToolResult, "tool_use and tool_result must appear together or not at all") +} + +func TestCompressMessages_RemovesToolPairAtomically(t *testing.T) { + // Budget forces removal of the tool pair. + msgs := []map[string]any{ + makeMsg("user", "start"), + makeToolUseMsg("tool-1"), + makeToolResultMsg("tool-1"), + makeMsg("assistant", "final answer after tool use"), + } + + // Budget: only keep the last "assistant" message. + lastTokens := msgTokens(msgs[len(msgs)-1]) + + result, changed := compressMessages(msgs, lastTokens) + assert.True(t, changed) + + // Neither tool_use nor tool_result should remain. + for _, m := range result { + assert.False(t, isAssistantWithToolUse(m), "tool_use should have been removed with its pair") + assert.False(t, isUserWithToolResult(m), "tool_result should have been removed with its pair") + } +} + +func TestCompressMessagesInBody_NoMessages(t *testing.T) { + body := []byte(`{"model":"claude-3-5-sonnet-20241022"}`) + result := compressMessagesInBody(body, 1) + assert.Equal(t, body, result, "body without messages should be unchanged") +} + +func TestCompressMessagesInBody_UnderBudget(t *testing.T) { + msgs := []map[string]any{makeMsg("user", "hi")} + body := bodyWithMessages(t, msgs) + result := compressMessagesInBody(body, 100_000) + assert.Equal(t, body, result, "body under budget should be unchanged") +} + +func TestCompressMessagesInBody_TrimsToBudget(t *testing.T) { + msgs := make([]map[string]any, 20) + for i := range msgs { + role := "user" + if i%2 == 1 { + role = "assistant" + } + msgs[i] = makeMsg(role, fmt.Sprintf("message %d with some padding text to have enough tokens", i)) + } + body := bodyWithMessages(t, msgs) + + // Force significant compression. + result := compressMessagesInBody(body, 50) + assert.Less(t, len(result), len(body), "compressed body should be smaller") + + // Resulting body should still be valid JSON with a messages array. + var parsed map[string]any + require.NoError(t, json.Unmarshal(result, &parsed)) + resultMsgs, ok := parsed["messages"].([]any) + require.True(t, ok) + assert.Greater(t, len(resultMsgs), 0, "messages array should not be empty") +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index dfe3fe34..23a7ccbc 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -4182,6 +4182,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400. body = StripEmptyTextBlocks(body) + // 主动上下文压缩:裁剪超出 token 预算的历史消息,保留 tool_use/tool_result 对完整性。 + if account.IsContextCompressionEnabled() { + maxTok := s.cfg.Gateway.ContextCompression.GetMaxTokens() + body = compressMessagesInBody(body, maxTok) + } + // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 setOpsUpstreamRequestBody(c, body)