feat(context): add proactive context compression for long conversations
- New context_compressor.go: pure functions operating on raw JSON body (gjson/sjson pattern). approxTokens uses chars/4 heuristic. - compressMessages: removes oldest messages from front, treating consecutive assistant(tool_use)+user(tool_result) pairs as atomic units to prevent orphaned tool_result blocks. - Hooked into Forward() after StripEmptyTextBlocks, gated on account.Credentials[enable_context_compression]. - Config: gateway.context_compression.max_tokens (default 190000). - 8 unit tests covering: approx tokens, no-op when under budget, oldest-message trimming, tool pair preservation, atomic pair removal, body passthrough, body trimming.
This commit is contained in:
parent
95814974de
commit
d535688bfd
@ -485,6 +485,10 @@ type GatewayConfig struct {
|
||||
// RPMSmoothing: RPM 令牌桶平滑配置
|
||||
// 启用后,RPM 配额耗尽时请求等待令牌(最多 MaxWaitMS 毫秒)而非立即返回 429
|
||||
RPMSmoothing RPMSmoothingConfig `mapstructure:"rpm_smoothing"`
|
||||
|
||||
// ContextCompression: 主动上下文压缩配置
|
||||
// 账号启用 enable_context_compression 后,超出 MaxTokens 预算时自动裁剪历史消息
|
||||
ContextCompression ContextCompressionConfig `mapstructure:"context_compression"`
|
||||
}
|
||||
|
||||
type GatewayAntigravityLSWorkerConfig struct {
|
||||
@ -556,6 +560,20 @@ func (c *RPMSmoothingConfig) MaxWait() time.Duration {
|
||||
return time.Duration(c.MaxWaitMS) * time.Millisecond
|
||||
}
|
||||
|
||||
// ContextCompressionConfig 主动上下文压缩配置
|
||||
type ContextCompressionConfig struct {
|
||||
// MaxTokens: 压缩目标 token 数(chars/4 近似),超出时从最旧消息开始裁剪(默认 190000)
|
||||
MaxTokens int `mapstructure:"max_tokens"`
|
||||
}
|
||||
|
||||
// GetMaxTokens returns the configured token budget, defaulting to 190 000.
|
||||
func (c *ContextCompressionConfig) GetMaxTokens() int {
|
||||
if c.MaxTokens <= 0 {
|
||||
return 190_000
|
||||
}
|
||||
return c.MaxTokens
|
||||
}
|
||||
|
||||
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
|
||||
// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
|
||||
type GatewayOpenAIWSConfig struct {
|
||||
|
||||
@ -771,6 +771,17 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsContextCompressionEnabled returns true if the account has opted into proactive
|
||||
// context compression. When enabled, the gateway will trim oldest messages before
|
||||
// dispatch to keep the estimated token count within the configured budget.
|
||||
func (a *Account) IsContextCompressionEnabled() bool {
|
||||
if a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
enabled, _ := a.Credentials["enable_context_compression"].(bool)
|
||||
return enabled
|
||||
}
|
||||
|
||||
func (a *Account) IsBedrock() bool {
|
||||
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock
|
||||
}
|
||||
|
||||
151
backend/internal/service/context_compressor.go
Normal file
151
backend/internal/service/context_compressor.go
Normal file
@ -0,0 +1,151 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// defaultContextCompressionMaxTokens is the default target token budget (chars/4 approximation).
|
||||
// 190K is conservative for a 200K-window model, leaving ~10K headroom for the response.
|
||||
const defaultContextCompressionMaxTokens = 190_000
|
||||
|
||||
// approxTokens estimates the token count for a string using the chars/4 heuristic.
|
||||
func approxTokens(s string) int {
|
||||
return int(math.Ceil(float64(len(s)) / 4.0))
|
||||
}
|
||||
|
||||
// compressMessagesInBody trims the oldest messages from the request body so that the
|
||||
// estimated token count of the messages array fits within maxTokens.
|
||||
// Returns the original body unchanged if no compression is needed or if parsing fails.
|
||||
func compressMessagesInBody(body []byte, maxTokens int) []byte {
|
||||
msgsResult := gjson.GetBytes(body, "messages")
|
||||
if !msgsResult.Exists() || !msgsResult.IsArray() {
|
||||
return body
|
||||
}
|
||||
|
||||
// Unmarshal to a typed slice for processing.
|
||||
var messages []map[string]any
|
||||
if err := json.Unmarshal([]byte(msgsResult.Raw), &messages); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
compressed, changed := compressMessages(messages, maxTokens)
|
||||
if !changed {
|
||||
return body
|
||||
}
|
||||
|
||||
newMsgs, err := json.Marshal(compressed)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
updated, err := sjson.SetRawBytes(body, "messages", newMsgs)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// compressMessages removes the oldest messages from the front of msgs until the
|
||||
// estimated total token count is at or below maxTokens.
|
||||
// tool_use (assistant) and tool_result (user) consecutive pairs are removed atomically
|
||||
// to avoid orphaned tool_result blocks.
|
||||
// Returns (msgs, false) if no compression was needed, or (trimmed, true) otherwise.
|
||||
func compressMessages(msgs []map[string]any, maxTokens int) ([]map[string]any, bool) {
|
||||
if len(msgs) == 0 {
|
||||
return msgs, false
|
||||
}
|
||||
|
||||
// Estimate total tokens.
|
||||
totalTokens := 0
|
||||
for _, m := range msgs {
|
||||
totalTokens += msgTokens(m)
|
||||
}
|
||||
if totalTokens <= maxTokens {
|
||||
return msgs, false
|
||||
}
|
||||
|
||||
// Build atomic removal units: tool_use+tool_result consecutive pairs are one unit.
|
||||
type unit struct {
|
||||
startIdx int
|
||||
endIdx int // exclusive
|
||||
tokens int
|
||||
}
|
||||
units := make([]unit, 0, len(msgs))
|
||||
i := 0
|
||||
for i < len(msgs) {
|
||||
toks := msgTokens(msgs[i])
|
||||
if isAssistantWithToolUse(msgs[i]) && i+1 < len(msgs) && isUserWithToolResult(msgs[i+1]) {
|
||||
toks += msgTokens(msgs[i+1])
|
||||
units = append(units, unit{i, i + 2, toks})
|
||||
i += 2
|
||||
} else {
|
||||
units = append(units, unit{i, i + 1, toks})
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
// Remove units from the front until we are within budget.
|
||||
// Always keep at least the last unit so we never send an empty messages array.
|
||||
removeCount := 0
|
||||
for removeCount < len(units)-1 && totalTokens > maxTokens {
|
||||
totalTokens -= units[removeCount].tokens
|
||||
removeCount++
|
||||
}
|
||||
if removeCount == 0 {
|
||||
return msgs, false
|
||||
}
|
||||
|
||||
cutIdx := units[removeCount].startIdx
|
||||
return msgs[cutIdx:], true
|
||||
}
|
||||
|
||||
// msgTokens estimates token count for a single message using the chars/4 heuristic.
|
||||
func msgTokens(msg map[string]any) int {
|
||||
b, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return approxTokens(string(b))
|
||||
}
|
||||
|
||||
// isAssistantWithToolUse returns true if msg is an assistant message whose content
|
||||
// contains at least one block with "type": "tool_use".
|
||||
func isAssistantWithToolUse(msg map[string]any) bool {
|
||||
role, _ := msg["role"].(string)
|
||||
if role != "assistant" {
|
||||
return false
|
||||
}
|
||||
return contentContainsType(msg["content"], "tool_use")
|
||||
}
|
||||
|
||||
// isUserWithToolResult returns true if msg is a user message whose content
|
||||
// contains at least one block with "type": "tool_result".
|
||||
func isUserWithToolResult(msg map[string]any) bool {
|
||||
role, _ := msg["role"].(string)
|
||||
if role != "user" {
|
||||
return false
|
||||
}
|
||||
return contentContainsType(msg["content"], "tool_result")
|
||||
}
|
||||
|
||||
// contentContainsType returns true if content (a []any of blocks) contains a block
|
||||
// whose "type" field equals blockType.
|
||||
func contentContainsType(content any, blockType string) bool {
|
||||
blocks, ok := content.([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, b := range blocks {
|
||||
block, ok := b.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if t, _ := block["type"].(string); t == blockType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
195
backend/internal/service/context_compressor_test.go
Normal file
195
backend/internal/service/context_compressor_test.go
Normal file
@ -0,0 +1,195 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// helpers
|
||||
|
||||
func makeMsg(role, text string) map[string]any {
|
||||
return map[string]any{
|
||||
"role": role,
|
||||
"content": text,
|
||||
}
|
||||
}
|
||||
|
||||
func makeToolUseMsg(id string) map[string]any {
|
||||
return map[string]any{
|
||||
"role": "assistant",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": id,
|
||||
"name": "search",
|
||||
"input": map[string]any{},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func makeToolResultMsg(toolUseID string) map[string]any {
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": toolUseID,
|
||||
"content": "result text",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func toAnySlice(msgs []map[string]any) []any {
|
||||
out := make([]any, len(msgs))
|
||||
for i, m := range msgs {
|
||||
out[i] = m
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func bodyWithMessages(t *testing.T, msgs []map[string]any) []byte {
|
||||
t.Helper()
|
||||
b, err := json.Marshal(map[string]any{"messages": msgs, "model": "claude-3-5-sonnet-20241022"})
|
||||
require.NoError(t, err)
|
||||
return b
|
||||
}
|
||||
|
||||
// tests
|
||||
|
||||
func TestApproxTokens(t *testing.T) {
|
||||
assert.Equal(t, 1, approxTokens("four")) // 4 chars → 1 token
|
||||
assert.Equal(t, 3, approxTokens("0123456789ab")) // 12 chars → 3 tokens
|
||||
assert.Equal(t, 0, approxTokens(""))
|
||||
}
|
||||
|
||||
func TestCompressMessages_NoCompressionNeeded(t *testing.T) {
|
||||
msgs := []map[string]any{
|
||||
makeMsg("user", "hi"),
|
||||
makeMsg("assistant", "hello"),
|
||||
}
|
||||
result, changed := compressMessages(msgs, 100_000)
|
||||
assert.False(t, changed)
|
||||
assert.Len(t, result, 2)
|
||||
}
|
||||
|
||||
func TestCompressMessages_TrimsOldestMessages(t *testing.T) {
|
||||
// 10 messages, each large enough to be over a tight budget when combined.
|
||||
msgs := make([]map[string]any, 10)
|
||||
for i := range msgs {
|
||||
role := "user"
|
||||
if i%2 == 1 {
|
||||
role = "assistant"
|
||||
}
|
||||
msgs[i] = makeMsg(role, fmt.Sprintf("message number %d with some content to increase token count", i))
|
||||
}
|
||||
|
||||
// Force compression by using a very small token budget.
|
||||
result, changed := compressMessages(msgs, 1)
|
||||
assert.True(t, changed)
|
||||
// Must keep at least one message (the last).
|
||||
assert.GreaterOrEqual(t, len(result), 1)
|
||||
// The remaining messages should be from the tail (newest).
|
||||
lastOrig := msgs[len(msgs)-1]["content"]
|
||||
lastResult := result[len(result)-1]["content"]
|
||||
assert.Equal(t, lastOrig, lastResult)
|
||||
}
|
||||
|
||||
func TestCompressMessages_PreservesToolUsePairs(t *testing.T) {
|
||||
// Messages: user → assistant+tool_use → user+tool_result → assistant
|
||||
msgs := []map[string]any{
|
||||
makeMsg("user", "start"),
|
||||
makeToolUseMsg("tool-1"),
|
||||
makeToolResultMsg("tool-1"),
|
||||
makeMsg("assistant", "done"),
|
||||
}
|
||||
|
||||
// Budget that forces removal of the first non-paired message but keeps the tool pair.
|
||||
// Estimate total tokens and set budget to force removing only "start" but not the pair.
|
||||
total := 0
|
||||
for _, m := range msgs {
|
||||
total += msgTokens(m)
|
||||
}
|
||||
// Budget: remove "start" but keep tool pair + "done".
|
||||
startTokens := msgTokens(msgs[0])
|
||||
budget := total - startTokens
|
||||
|
||||
result, changed := compressMessages(msgs, budget)
|
||||
assert.True(t, changed)
|
||||
|
||||
// tool_use and tool_result should both be present or both absent.
|
||||
hasToolUse := false
|
||||
hasToolResult := false
|
||||
for _, m := range result {
|
||||
if isAssistantWithToolUse(m) {
|
||||
hasToolUse = true
|
||||
}
|
||||
if isUserWithToolResult(m) {
|
||||
hasToolResult = true
|
||||
}
|
||||
}
|
||||
assert.Equal(t, hasToolUse, hasToolResult, "tool_use and tool_result must appear together or not at all")
|
||||
}
|
||||
|
||||
func TestCompressMessages_RemovesToolPairAtomically(t *testing.T) {
|
||||
// Budget forces removal of the tool pair.
|
||||
msgs := []map[string]any{
|
||||
makeMsg("user", "start"),
|
||||
makeToolUseMsg("tool-1"),
|
||||
makeToolResultMsg("tool-1"),
|
||||
makeMsg("assistant", "final answer after tool use"),
|
||||
}
|
||||
|
||||
// Budget: only keep the last "assistant" message.
|
||||
lastTokens := msgTokens(msgs[len(msgs)-1])
|
||||
|
||||
result, changed := compressMessages(msgs, lastTokens)
|
||||
assert.True(t, changed)
|
||||
|
||||
// Neither tool_use nor tool_result should remain.
|
||||
for _, m := range result {
|
||||
assert.False(t, isAssistantWithToolUse(m), "tool_use should have been removed with its pair")
|
||||
assert.False(t, isUserWithToolResult(m), "tool_result should have been removed with its pair")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompressMessagesInBody_NoMessages(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-3-5-sonnet-20241022"}`)
|
||||
result := compressMessagesInBody(body, 1)
|
||||
assert.Equal(t, body, result, "body without messages should be unchanged")
|
||||
}
|
||||
|
||||
func TestCompressMessagesInBody_UnderBudget(t *testing.T) {
|
||||
msgs := []map[string]any{makeMsg("user", "hi")}
|
||||
body := bodyWithMessages(t, msgs)
|
||||
result := compressMessagesInBody(body, 100_000)
|
||||
assert.Equal(t, body, result, "body under budget should be unchanged")
|
||||
}
|
||||
|
||||
func TestCompressMessagesInBody_TrimsToBudget(t *testing.T) {
|
||||
msgs := make([]map[string]any, 20)
|
||||
for i := range msgs {
|
||||
role := "user"
|
||||
if i%2 == 1 {
|
||||
role = "assistant"
|
||||
}
|
||||
msgs[i] = makeMsg(role, fmt.Sprintf("message %d with some padding text to have enough tokens", i))
|
||||
}
|
||||
body := bodyWithMessages(t, msgs)
|
||||
|
||||
// Force significant compression.
|
||||
result := compressMessagesInBody(body, 50)
|
||||
assert.Less(t, len(result), len(body), "compressed body should be smaller")
|
||||
|
||||
// Resulting body should still be valid JSON with a messages array.
|
||||
var parsed map[string]any
|
||||
require.NoError(t, json.Unmarshal(result, &parsed))
|
||||
resultMsgs, ok := parsed["messages"].([]any)
|
||||
require.True(t, ok)
|
||||
assert.Greater(t, len(resultMsgs), 0, "messages array should not be empty")
|
||||
}
|
||||
@ -4182,6 +4182,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
|
||||
body = StripEmptyTextBlocks(body)
|
||||
|
||||
// 主动上下文压缩:裁剪超出 token 预算的历史消息,保留 tool_use/tool_result 对完整性。
|
||||
if account.IsContextCompressionEnabled() {
|
||||
maxTok := s.cfg.Gateway.ContextCompression.GetMaxTokens()
|
||||
body = compressMessagesInBody(body, maxTok)
|
||||
}
|
||||
|
||||
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
|
||||
setOpsUpstreamRequestBody(c, body)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user