feat(context): add proactive context compression for long conversations
- New context_compressor.go: pure functions operating on raw JSON body (gjson/sjson pattern). approxTokens uses chars/4 heuristic. - compressMessages: removes oldest messages from front, treating consecutive assistant(tool_use)+user(tool_result) pairs as atomic units to prevent orphaned tool_result blocks. - Hooked into Forward() after StripEmptyTextBlocks, gated on account.Credentials[enable_context_compression]. - Config: gateway.context_compression.max_tokens (default 190000). - 8 unit tests covering: approx tokens, no-op when under budget, oldest-message trimming, tool pair preservation, atomic pair removal, body passthrough, body trimming.
This commit is contained in:
parent
95814974de
commit
d535688bfd
@ -485,6 +485,10 @@ type GatewayConfig struct {
|
|||||||
// RPMSmoothing: RPM 令牌桶平滑配置
|
// RPMSmoothing: RPM 令牌桶平滑配置
|
||||||
// 启用后,RPM 配额耗尽时请求等待令牌(最多 MaxWaitMS 毫秒)而非立即返回 429
|
// 启用后,RPM 配额耗尽时请求等待令牌(最多 MaxWaitMS 毫秒)而非立即返回 429
|
||||||
RPMSmoothing RPMSmoothingConfig `mapstructure:"rpm_smoothing"`
|
RPMSmoothing RPMSmoothingConfig `mapstructure:"rpm_smoothing"`
|
||||||
|
|
||||||
|
// ContextCompression: 主动上下文压缩配置
|
||||||
|
// 账号启用 enable_context_compression 后,超出 MaxTokens 预算时自动裁剪历史消息
|
||||||
|
ContextCompression ContextCompressionConfig `mapstructure:"context_compression"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GatewayAntigravityLSWorkerConfig struct {
|
type GatewayAntigravityLSWorkerConfig struct {
|
||||||
@ -556,6 +560,20 @@ func (c *RPMSmoothingConfig) MaxWait() time.Duration {
|
|||||||
return time.Duration(c.MaxWaitMS) * time.Millisecond
|
return time.Duration(c.MaxWaitMS) * time.Millisecond
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ContextCompressionConfig 主动上下文压缩配置
|
||||||
|
type ContextCompressionConfig struct {
|
||||||
|
// MaxTokens: 压缩目标 token 数(chars/4 近似),超出时从最旧消息开始裁剪(默认 190000)
|
||||||
|
MaxTokens int `mapstructure:"max_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMaxTokens returns the configured token budget, defaulting to 190 000.
|
||||||
|
func (c *ContextCompressionConfig) GetMaxTokens() int {
|
||||||
|
if c.MaxTokens <= 0 {
|
||||||
|
return 190_000
|
||||||
|
}
|
||||||
|
return c.MaxTokens
|
||||||
|
}
|
||||||
|
|
||||||
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
|
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
|
||||||
// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
|
// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
|
||||||
type GatewayOpenAIWSConfig struct {
|
type GatewayOpenAIWSConfig struct {
|
||||||
|
|||||||
@ -771,6 +771,17 @@ func (a *Account) IsInterceptWarmupEnabled() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsContextCompressionEnabled returns true if the account has opted into proactive
|
||||||
|
// context compression. When enabled, the gateway will trim oldest messages before
|
||||||
|
// dispatch to keep the estimated token count within the configured budget.
|
||||||
|
func (a *Account) IsContextCompressionEnabled() bool {
|
||||||
|
if a.Credentials == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
enabled, _ := a.Credentials["enable_context_compression"].(bool)
|
||||||
|
return enabled
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Account) IsBedrock() bool {
|
func (a *Account) IsBedrock() bool {
|
||||||
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock
|
return a.Platform == PlatformAnthropic && a.Type == AccountTypeBedrock
|
||||||
}
|
}
|
||||||
|
|||||||
151
backend/internal/service/context_compressor.go
Normal file
151
backend/internal/service/context_compressor.go
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultContextCompressionMaxTokens is the default target token budget (chars/4 approximation).
|
||||||
|
// 190K is conservative for a 200K-window model, leaving ~10K headroom for the response.
|
||||||
|
const defaultContextCompressionMaxTokens = 190_000
|
||||||
|
|
||||||
|
// approxTokens estimates the token count for a string using the chars/4 heuristic.
|
||||||
|
func approxTokens(s string) int {
|
||||||
|
return int(math.Ceil(float64(len(s)) / 4.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
// compressMessagesInBody trims the oldest messages from the request body so that the
|
||||||
|
// estimated token count of the messages array fits within maxTokens.
|
||||||
|
// Returns the original body unchanged if no compression is needed or if parsing fails.
|
||||||
|
func compressMessagesInBody(body []byte, maxTokens int) []byte {
|
||||||
|
msgsResult := gjson.GetBytes(body, "messages")
|
||||||
|
if !msgsResult.Exists() || !msgsResult.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal to a typed slice for processing.
|
||||||
|
var messages []map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(msgsResult.Raw), &messages); err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
compressed, changed := compressMessages(messages, maxTokens)
|
||||||
|
if !changed {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
newMsgs, err := json.Marshal(compressed)
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
updated, err := sjson.SetRawBytes(body, "messages", newMsgs)
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
|
||||||
|
// compressMessages removes the oldest messages from the front of msgs until the
|
||||||
|
// estimated total token count is at or below maxTokens.
|
||||||
|
// tool_use (assistant) and tool_result (user) consecutive pairs are removed atomically
|
||||||
|
// to avoid orphaned tool_result blocks.
|
||||||
|
// Returns (msgs, false) if no compression was needed, or (trimmed, true) otherwise.
|
||||||
|
func compressMessages(msgs []map[string]any, maxTokens int) ([]map[string]any, bool) {
|
||||||
|
if len(msgs) == 0 {
|
||||||
|
return msgs, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Estimate total tokens.
|
||||||
|
totalTokens := 0
|
||||||
|
for _, m := range msgs {
|
||||||
|
totalTokens += msgTokens(m)
|
||||||
|
}
|
||||||
|
if totalTokens <= maxTokens {
|
||||||
|
return msgs, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build atomic removal units: tool_use+tool_result consecutive pairs are one unit.
|
||||||
|
type unit struct {
|
||||||
|
startIdx int
|
||||||
|
endIdx int // exclusive
|
||||||
|
tokens int
|
||||||
|
}
|
||||||
|
units := make([]unit, 0, len(msgs))
|
||||||
|
i := 0
|
||||||
|
for i < len(msgs) {
|
||||||
|
toks := msgTokens(msgs[i])
|
||||||
|
if isAssistantWithToolUse(msgs[i]) && i+1 < len(msgs) && isUserWithToolResult(msgs[i+1]) {
|
||||||
|
toks += msgTokens(msgs[i+1])
|
||||||
|
units = append(units, unit{i, i + 2, toks})
|
||||||
|
i += 2
|
||||||
|
} else {
|
||||||
|
units = append(units, unit{i, i + 1, toks})
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove units from the front until we are within budget.
|
||||||
|
// Always keep at least the last unit so we never send an empty messages array.
|
||||||
|
removeCount := 0
|
||||||
|
for removeCount < len(units)-1 && totalTokens > maxTokens {
|
||||||
|
totalTokens -= units[removeCount].tokens
|
||||||
|
removeCount++
|
||||||
|
}
|
||||||
|
if removeCount == 0 {
|
||||||
|
return msgs, false
|
||||||
|
}
|
||||||
|
|
||||||
|
cutIdx := units[removeCount].startIdx
|
||||||
|
return msgs[cutIdx:], true
|
||||||
|
}
|
||||||
|
|
||||||
|
// msgTokens estimates token count for a single message using the chars/4 heuristic.
|
||||||
|
func msgTokens(msg map[string]any) int {
|
||||||
|
b, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return approxTokens(string(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAssistantWithToolUse returns true if msg is an assistant message whose content
|
||||||
|
// contains at least one block with "type": "tool_use".
|
||||||
|
func isAssistantWithToolUse(msg map[string]any) bool {
|
||||||
|
role, _ := msg["role"].(string)
|
||||||
|
if role != "assistant" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return contentContainsType(msg["content"], "tool_use")
|
||||||
|
}
|
||||||
|
|
||||||
|
// isUserWithToolResult returns true if msg is a user message whose content
|
||||||
|
// contains at least one block with "type": "tool_result".
|
||||||
|
func isUserWithToolResult(msg map[string]any) bool {
|
||||||
|
role, _ := msg["role"].(string)
|
||||||
|
if role != "user" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return contentContainsType(msg["content"], "tool_result")
|
||||||
|
}
|
||||||
|
|
||||||
|
// contentContainsType returns true if content (a []any of blocks) contains a block
|
||||||
|
// whose "type" field equals blockType.
|
||||||
|
func contentContainsType(content any, blockType string) bool {
|
||||||
|
blocks, ok := content.([]any)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, b := range blocks {
|
||||||
|
block, ok := b.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if t, _ := block["type"].(string); t == blockType {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
195
backend/internal/service/context_compressor_test.go
Normal file
195
backend/internal/service/context_compressor_test.go
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// helpers
|
||||||
|
|
||||||
|
func makeMsg(role, text string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"role": role,
|
||||||
|
"content": text,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeToolUseMsg(id string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": id,
|
||||||
|
"name": "search",
|
||||||
|
"input": map[string]any{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeToolResultMsg(toolUseID string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []any{
|
||||||
|
map[string]any{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": toolUseID,
|
||||||
|
"content": "result text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toAnySlice(msgs []map[string]any) []any {
|
||||||
|
out := make([]any, len(msgs))
|
||||||
|
for i, m := range msgs {
|
||||||
|
out[i] = m
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func bodyWithMessages(t *testing.T, msgs []map[string]any) []byte {
|
||||||
|
t.Helper()
|
||||||
|
b, err := json.Marshal(map[string]any{"messages": msgs, "model": "claude-3-5-sonnet-20241022"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// tests
|
||||||
|
|
||||||
|
func TestApproxTokens(t *testing.T) {
|
||||||
|
assert.Equal(t, 1, approxTokens("four")) // 4 chars → 1 token
|
||||||
|
assert.Equal(t, 3, approxTokens("0123456789ab")) // 12 chars → 3 tokens
|
||||||
|
assert.Equal(t, 0, approxTokens(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompressMessages_NoCompressionNeeded(t *testing.T) {
|
||||||
|
msgs := []map[string]any{
|
||||||
|
makeMsg("user", "hi"),
|
||||||
|
makeMsg("assistant", "hello"),
|
||||||
|
}
|
||||||
|
result, changed := compressMessages(msgs, 100_000)
|
||||||
|
assert.False(t, changed)
|
||||||
|
assert.Len(t, result, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompressMessages_TrimsOldestMessages(t *testing.T) {
|
||||||
|
// 10 messages, each large enough to be over a tight budget when combined.
|
||||||
|
msgs := make([]map[string]any, 10)
|
||||||
|
for i := range msgs {
|
||||||
|
role := "user"
|
||||||
|
if i%2 == 1 {
|
||||||
|
role = "assistant"
|
||||||
|
}
|
||||||
|
msgs[i] = makeMsg(role, fmt.Sprintf("message number %d with some content to increase token count", i))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Force compression by using a very small token budget.
|
||||||
|
result, changed := compressMessages(msgs, 1)
|
||||||
|
assert.True(t, changed)
|
||||||
|
// Must keep at least one message (the last).
|
||||||
|
assert.GreaterOrEqual(t, len(result), 1)
|
||||||
|
// The remaining messages should be from the tail (newest).
|
||||||
|
lastOrig := msgs[len(msgs)-1]["content"]
|
||||||
|
lastResult := result[len(result)-1]["content"]
|
||||||
|
assert.Equal(t, lastOrig, lastResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompressMessages_PreservesToolUsePairs(t *testing.T) {
|
||||||
|
// Messages: user → assistant+tool_use → user+tool_result → assistant
|
||||||
|
msgs := []map[string]any{
|
||||||
|
makeMsg("user", "start"),
|
||||||
|
makeToolUseMsg("tool-1"),
|
||||||
|
makeToolResultMsg("tool-1"),
|
||||||
|
makeMsg("assistant", "done"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Budget that forces removal of the first non-paired message but keeps the tool pair.
|
||||||
|
// Estimate total tokens and set budget to force removing only "start" but not the pair.
|
||||||
|
total := 0
|
||||||
|
for _, m := range msgs {
|
||||||
|
total += msgTokens(m)
|
||||||
|
}
|
||||||
|
// Budget: remove "start" but keep tool pair + "done".
|
||||||
|
startTokens := msgTokens(msgs[0])
|
||||||
|
budget := total - startTokens
|
||||||
|
|
||||||
|
result, changed := compressMessages(msgs, budget)
|
||||||
|
assert.True(t, changed)
|
||||||
|
|
||||||
|
// tool_use and tool_result should both be present or both absent.
|
||||||
|
hasToolUse := false
|
||||||
|
hasToolResult := false
|
||||||
|
for _, m := range result {
|
||||||
|
if isAssistantWithToolUse(m) {
|
||||||
|
hasToolUse = true
|
||||||
|
}
|
||||||
|
if isUserWithToolResult(m) {
|
||||||
|
hasToolResult = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, hasToolUse, hasToolResult, "tool_use and tool_result must appear together or not at all")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompressMessages_RemovesToolPairAtomically(t *testing.T) {
|
||||||
|
// Budget forces removal of the tool pair.
|
||||||
|
msgs := []map[string]any{
|
||||||
|
makeMsg("user", "start"),
|
||||||
|
makeToolUseMsg("tool-1"),
|
||||||
|
makeToolResultMsg("tool-1"),
|
||||||
|
makeMsg("assistant", "final answer after tool use"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Budget: only keep the last "assistant" message.
|
||||||
|
lastTokens := msgTokens(msgs[len(msgs)-1])
|
||||||
|
|
||||||
|
result, changed := compressMessages(msgs, lastTokens)
|
||||||
|
assert.True(t, changed)
|
||||||
|
|
||||||
|
// Neither tool_use nor tool_result should remain.
|
||||||
|
for _, m := range result {
|
||||||
|
assert.False(t, isAssistantWithToolUse(m), "tool_use should have been removed with its pair")
|
||||||
|
assert.False(t, isUserWithToolResult(m), "tool_result should have been removed with its pair")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompressMessagesInBody_NoMessages(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"claude-3-5-sonnet-20241022"}`)
|
||||||
|
result := compressMessagesInBody(body, 1)
|
||||||
|
assert.Equal(t, body, result, "body without messages should be unchanged")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompressMessagesInBody_UnderBudget(t *testing.T) {
|
||||||
|
msgs := []map[string]any{makeMsg("user", "hi")}
|
||||||
|
body := bodyWithMessages(t, msgs)
|
||||||
|
result := compressMessagesInBody(body, 100_000)
|
||||||
|
assert.Equal(t, body, result, "body under budget should be unchanged")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompressMessagesInBody_TrimsToBudget(t *testing.T) {
|
||||||
|
msgs := make([]map[string]any, 20)
|
||||||
|
for i := range msgs {
|
||||||
|
role := "user"
|
||||||
|
if i%2 == 1 {
|
||||||
|
role = "assistant"
|
||||||
|
}
|
||||||
|
msgs[i] = makeMsg(role, fmt.Sprintf("message %d with some padding text to have enough tokens", i))
|
||||||
|
}
|
||||||
|
body := bodyWithMessages(t, msgs)
|
||||||
|
|
||||||
|
// Force significant compression.
|
||||||
|
result := compressMessagesInBody(body, 50)
|
||||||
|
assert.Less(t, len(result), len(body), "compressed body should be smaller")
|
||||||
|
|
||||||
|
// Resulting body should still be valid JSON with a messages array.
|
||||||
|
var parsed map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(result, &parsed))
|
||||||
|
resultMsgs, ok := parsed["messages"].([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Greater(t, len(resultMsgs), 0, "messages array should not be empty")
|
||||||
|
}
|
||||||
@ -4182,6 +4182,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
|
// Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400.
|
||||||
body = StripEmptyTextBlocks(body)
|
body = StripEmptyTextBlocks(body)
|
||||||
|
|
||||||
|
// 主动上下文压缩:裁剪超出 token 预算的历史消息,保留 tool_use/tool_result 对完整性。
|
||||||
|
if account.IsContextCompressionEnabled() {
|
||||||
|
maxTok := s.cfg.Gateway.ContextCompression.GetMaxTokens()
|
||||||
|
body = compressMessagesInBody(body, maxTok)
|
||||||
|
}
|
||||||
|
|
||||||
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
|
// 重试间复用同一请求体,避免每次 string(body) 产生额外分配。
|
||||||
setOpsUpstreamRequestBody(c, body)
|
setOpsUpstreamRequestBody(c, body)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user