feat(context): add proactive context compression for long conversations

- New context_compressor.go: pure functions operating on raw JSON body
  (gjson/sjson pattern). approxTokens uses chars/4 heuristic.
- compressMessages: removes oldest messages from front, treating
  consecutive assistant(tool_use)+user(tool_result) pairs as atomic units
  to prevent orphaned tool_result blocks.
- Hooked into Forward() after StripEmptyTextBlocks, gated on
  account.Credentials[enable_context_compression].
- Config: gateway.context_compression.max_tokens (default 190000).
- 8 unit tests covering: approx tokens, no-op when under budget,
  oldest-message trimming, tool pair preservation, atomic pair removal,
  body passthrough, body trimming.
This commit is contained in:
win 2026-04-29 01:33:05 +08:00
parent 95814974de
commit d535688bfd
5 changed files with 381 additions and 0 deletions

View File

@ -485,6 +485,10 @@ type GatewayConfig struct {
// RPMSmoothing: RPM 令牌桶平滑配置
// 启用后RPM 配额耗尽时请求等待令牌(最多 MaxWaitMS 毫秒)而非立即返回 429
RPMSmoothing RPMSmoothingConfig `mapstructure:"rpm_smoothing"`
// ContextCompression: 主动上下文压缩配置
// 账号启用 enable_context_compression 后,超出 MaxTokens 预算时自动裁剪历史消息
ContextCompression ContextCompressionConfig `mapstructure:"context_compression"`
}
type GatewayAntigravityLSWorkerConfig struct {
@ -556,6 +560,20 @@ func (c *RPMSmoothingConfig) MaxWait() time.Duration {
return time.Duration(c.MaxWaitMS) * time.Millisecond
}
// ContextCompressionConfig 主动上下文压缩配置
type ContextCompressionConfig struct {
// MaxTokens: 压缩目标 token 数chars/4 近似),超出时从最旧消息开始裁剪(默认 190000
MaxTokens int `mapstructure:"max_tokens"`
}
// GetMaxTokens returns the configured token budget, defaulting to 190 000.
func (c *ContextCompressionConfig) GetMaxTokens() int {
if c.MaxTokens <= 0 {
return 190_000
}
return c.MaxTokens
}
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
type GatewayOpenAIWSConfig struct {

View File

@ -771,6 +771,17 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
return false
}
// IsContextCompressionEnabled returns true if the account has opted into proactive
// context compression. When enabled, the gateway will trim oldest messages before
// dispatch to keep the estimated token count within the configured budget.
func (a *Account) IsContextCompressionEnabled() bool {
if a.Credentials == nil {
return false
}
enabled, _ := a.Credentials["enable_context_compression"].(bool)
return enabled
}
func (a *Account) IsBedrock() bool {
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock
}

View File

@ -0,0 +1,151 @@
package service
import (
"encoding/json"
"math"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// defaultContextCompressionMaxTokens is the default target token budget (chars/4 approximation).
// 190K is conservative for a 200K-window model, leaving ~10K headroom for the response.
const defaultContextCompressionMaxTokens = 190_000
// approxTokens estimates the token count for a string using the chars/4 heuristic.
func approxTokens(s string) int {
return int(math.Ceil(float64(len(s)) / 4.0))
}
// compressMessagesInBody trims the oldest messages from the request body so that the
// estimated token count of the messages array fits within maxTokens.
// Returns the original body unchanged if no compression is needed or if parsing fails.
func compressMessagesInBody(body []byte, maxTokens int) []byte {
msgsResult := gjson.GetBytes(body, "messages")
if !msgsResult.Exists() || !msgsResult.IsArray() {
return body
}
// Unmarshal to a typed slice for processing.
var messages []map[string]any
if err := json.Unmarshal([]byte(msgsResult.Raw), &messages); err != nil {
return body
}
compressed, changed := compressMessages(messages, maxTokens)
if !changed {
return body
}
newMsgs, err := json.Marshal(compressed)
if err != nil {
return body
}
updated, err := sjson.SetRawBytes(body, "messages", newMsgs)
if err != nil {
return body
}
return updated
}
// compressMessages removes the oldest messages from the front of msgs until the
// estimated total token count is at or below maxTokens.
// tool_use (assistant) and tool_result (user) consecutive pairs are removed atomically
// to avoid orphaned tool_result blocks.
// Returns (msgs, false) if no compression was needed, or (trimmed, true) otherwise.
func compressMessages(msgs []map[string]any, maxTokens int) ([]map[string]any, bool) {
if len(msgs) == 0 {
return msgs, false
}
// Estimate total tokens.
totalTokens := 0
for _, m := range msgs {
totalTokens += msgTokens(m)
}
if totalTokens <= maxTokens {
return msgs, false
}
// Build atomic removal units: tool_use+tool_result consecutive pairs are one unit.
type unit struct {
startIdx int
endIdx int // exclusive
tokens int
}
units := make([]unit, 0, len(msgs))
i := 0
for i < len(msgs) {
toks := msgTokens(msgs[i])
if isAssistantWithToolUse(msgs[i]) && i+1 < len(msgs) && isUserWithToolResult(msgs[i+1]) {
toks += msgTokens(msgs[i+1])
units = append(units, unit{i, i + 2, toks})
i += 2
} else {
units = append(units, unit{i, i + 1, toks})
i++
}
}
// Remove units from the front until we are within budget.
// Always keep at least the last unit so we never send an empty messages array.
removeCount := 0
for removeCount < len(units)-1 && totalTokens > maxTokens {
totalTokens -= units[removeCount].tokens
removeCount++
}
if removeCount == 0 {
return msgs, false
}
cutIdx := units[removeCount].startIdx
return msgs[cutIdx:], true
}
// msgTokens estimates token count for a single message using the chars/4 heuristic.
func msgTokens(msg map[string]any) int {
b, err := json.Marshal(msg)
if err != nil {
return 0
}
return approxTokens(string(b))
}
// isAssistantWithToolUse returns true if msg is an assistant message whose content
// contains at least one block with "type": "tool_use".
func isAssistantWithToolUse(msg map[string]any) bool {
role, _ := msg["role"].(string)
if role != "assistant" {
return false
}
return contentContainsType(msg["content"], "tool_use")
}
// isUserWithToolResult returns true if msg is a user message whose content
// contains at least one block with "type": "tool_result".
func isUserWithToolResult(msg map[string]any) bool {
role, _ := msg["role"].(string)
if role != "user" {
return false
}
return contentContainsType(msg["content"], "tool_result")
}
// contentContainsType returns true if content (a []any of blocks) contains a block
// whose "type" field equals blockType.
func contentContainsType(content any, blockType string) bool {
blocks, ok := content.([]any)
if !ok {
return false
}
for _, b := range blocks {
block, ok := b.(map[string]any)
if !ok {
continue
}
if t, _ := block["type"].(string); t == blockType {
return true
}
}
return false
}

View File

@ -0,0 +1,195 @@
package service
import (
"encoding/json"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// helpers
func makeMsg(role, text string) map[string]any {
return map[string]any{
"role": role,
"content": text,
}
}
func makeToolUseMsg(id string) map[string]any {
return map[string]any{
"role": "assistant",
"content": []any{
map[string]any{
"type": "tool_use",
"id": id,
"name": "search",
"input": map[string]any{},
},
},
}
}
func makeToolResultMsg(toolUseID string) map[string]any {
return map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "tool_result",
"tool_use_id": toolUseID,
"content": "result text",
},
},
}
}
func toAnySlice(msgs []map[string]any) []any {
out := make([]any, len(msgs))
for i, m := range msgs {
out[i] = m
}
return out
}
func bodyWithMessages(t *testing.T, msgs []map[string]any) []byte {
t.Helper()
b, err := json.Marshal(map[string]any{"messages": msgs, "model": "claude-3-5-sonnet-20241022"})
require.NoError(t, err)
return b
}
// tests
func TestApproxTokens(t *testing.T) {
assert.Equal(t, 1, approxTokens("four")) // 4 chars → 1 token
assert.Equal(t, 3, approxTokens("0123456789ab")) // 12 chars → 3 tokens
assert.Equal(t, 0, approxTokens(""))
}
func TestCompressMessages_NoCompressionNeeded(t *testing.T) {
msgs := []map[string]any{
makeMsg("user", "hi"),
makeMsg("assistant", "hello"),
}
result, changed := compressMessages(msgs, 100_000)
assert.False(t, changed)
assert.Len(t, result, 2)
}
func TestCompressMessages_TrimsOldestMessages(t *testing.T) {
// 10 messages, each large enough to be over a tight budget when combined.
msgs := make([]map[string]any, 10)
for i := range msgs {
role := "user"
if i%2 == 1 {
role = "assistant"
}
msgs[i] = makeMsg(role, fmt.Sprintf("message number %d with some content to increase token count", i))
}
// Force compression by using a very small token budget.
result, changed := compressMessages(msgs, 1)
assert.True(t, changed)
// Must keep at least one message (the last).
assert.GreaterOrEqual(t, len(result), 1)
// The remaining messages should be from the tail (newest).
lastOrig := msgs[len(msgs)-1]["content"]
lastResult := result[len(result)-1]["content"]
assert.Equal(t, lastOrig, lastResult)
}
func TestCompressMessages_PreservesToolUsePairs(t *testing.T) {
// Messages: user → assistant+tool_use → user+tool_result → assistant
msgs := []map[string]any{
makeMsg("user", "start"),
makeToolUseMsg("tool-1"),
makeToolResultMsg("tool-1"),
makeMsg("assistant", "done"),
}
// Budget that forces removal of the first non-paired message but keeps the tool pair.
// Estimate total tokens and set budget to force removing only "start" but not the pair.
total := 0
for _, m := range msgs {
total += msgTokens(m)
}
// Budget: remove "start" but keep tool pair + "done".
startTokens := msgTokens(msgs[0])
budget := total - startTokens
result, changed := compressMessages(msgs, budget)
assert.True(t, changed)
// tool_use and tool_result should both be present or both absent.
hasToolUse := false
hasToolResult := false
for _, m := range result {
if isAssistantWithToolUse(m) {
hasToolUse = true
}
if isUserWithToolResult(m) {
hasToolResult = true
}
}
assert.Equal(t, hasToolUse, hasToolResult, "tool_use and tool_result must appear together or not at all")
}
func TestCompressMessages_RemovesToolPairAtomically(t *testing.T) {
// Budget forces removal of the tool pair.
msgs := []map[string]any{
makeMsg("user", "start"),
makeToolUseMsg("tool-1"),
makeToolResultMsg("tool-1"),
makeMsg("assistant", "final answer after tool use"),
}
// Budget: only keep the last "assistant" message.
lastTokens := msgTokens(msgs[len(msgs)-1])
result, changed := compressMessages(msgs, lastTokens)
assert.True(t, changed)
// Neither tool_use nor tool_result should remain.
for _, m := range result {
assert.False(t, isAssistantWithToolUse(m), "tool_use should have been removed with its pair")
assert.False(t, isUserWithToolResult(m), "tool_result should have been removed with its pair")
}
}
func TestCompressMessagesInBody_NoMessages(t *testing.T) {
body := []byte(`{"model":"claude-3-5-sonnet-20241022"}`)
result := compressMessagesInBody(body, 1)
assert.Equal(t, body, result, "body without messages should be unchanged")
}
func TestCompressMessagesInBody_UnderBudget(t *testing.T) {
msgs := []map[string]any{makeMsg("user", "hi")}
body := bodyWithMessages(t, msgs)
result := compressMessagesInBody(body, 100_000)
assert.Equal(t, body, result, "body under budget should be unchanged")
}
func TestCompressMessagesInBody_TrimsToBudget(t *testing.T) {
msgs := make([]map[string]any, 20)
for i := range msgs {
role := "user"
if i%2 == 1 {
role = "assistant"
}
msgs[i] = makeMsg(role, fmt.Sprintf("message %d with some padding text to have enough tokens", i))
}
body := bodyWithMessages(t, msgs)
// Force significant compression.
result := compressMessagesInBody(body, 50)
assert.Less(t, len(result), len(body), "compressed body should be smaller")
// Resulting body should still be valid JSON with a messages array.
var parsed map[string]any
require.NoError(t, json.Unmarshal(result, &parsed))
resultMsgs, ok := parsed["messages"].([]any)
require.True(t, ok)
assert.Greater(t, len(resultMsgs), 0, "messages array should not be empty")
}

View File

@ -4182,6 +4182,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
body = StripEmptyTextBlocks(body)
// 主动上下文压缩:裁剪超出 token 预算的历史消息,保留 tool_use/tool_result 对完整性。
if account.IsContextCompressionEnabled() {
maxTok := s.cfg.Gateway.ContextCompression.GetMaxTokens()
body = compressMessagesInBody(body, maxTok)
}
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
setOpsUpstreamRequestBody(c, body)