sub2api/backend/internal/service/context_compressor.go
win d535688bfd feat(context): add proactive context compression for long conversations
- New context_compressor.go: pure functions operating on raw JSON body
  (gjson/sjson pattern). approxTokens uses chars/4 heuristic.
- compressMessages: removes oldest messages from front, treating
  consecutive assistant(tool_use)+user(tool_result) pairs as atomic units
  to prevent orphaned tool_result blocks.
- Hooked into Forward() after StripEmptyTextBlocks, gated on
  account.Credentials[enable_context_compression].
- Config: gateway.context_compression.max_tokens (default 190000).
- 8 unit tests covering: approx tokens, no-op when under budget,
  oldest-message trimming, tool pair preservation, atomic pair removal,
  body passthrough, body trimming.
2026-04-29 01:33:05 +08:00

152 lines
4.2 KiB
Go

package service
import (
"encoding/json"
"math"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// defaultContextCompressionMaxTokens is the default target token budget (chars/4 approximation).
// 190K is conservative for a 200K-window model, leaving ~10K headroom for the response.
const defaultContextCompressionMaxTokens = 190_000
// approxTokens estimates the token count for a string using the chars/4 heuristic.
func approxTokens(s string) int {
return int(math.Ceil(float64(len(s)) / 4.0))
}
// compressMessagesInBody trims the oldest messages from the request body so that the
// estimated token count of the messages array fits within maxTokens.
// Returns the original body unchanged if no compression is needed or if parsing fails.
func compressMessagesInBody(body []byte, maxTokens int) []byte {
msgsResult := gjson.GetBytes(body, "messages")
if !msgsResult.Exists() || !msgsResult.IsArray() {
return body
}
// Unmarshal to a typed slice for processing.
var messages []map[string]any
if err := json.Unmarshal([]byte(msgsResult.Raw), &messages); err != nil {
return body
}
compressed, changed := compressMessages(messages, maxTokens)
if !changed {
return body
}
newMsgs, err := json.Marshal(compressed)
if err != nil {
return body
}
updated, err := sjson.SetRawBytes(body, "messages", newMsgs)
if err != nil {
return body
}
return updated
}
// compressMessages removes the oldest messages from the front of msgs until the
// estimated total token count is at or below maxTokens.
// tool_use (assistant) and tool_result (user) consecutive pairs are removed atomically
// to avoid orphaned tool_result blocks.
// Returns (msgs, false) if no compression was needed, or (trimmed, true) otherwise.
func compressMessages(msgs []map[string]any, maxTokens int) ([]map[string]any, bool) {
if len(msgs) == 0 {
return msgs, false
}
// Estimate total tokens.
totalTokens := 0
for _, m := range msgs {
totalTokens += msgTokens(m)
}
if totalTokens <= maxTokens {
return msgs, false
}
// Build atomic removal units: tool_use+tool_result consecutive pairs are one unit.
type unit struct {
startIdx int
endIdx int // exclusive
tokens int
}
units := make([]unit, 0, len(msgs))
i := 0
for i < len(msgs) {
toks := msgTokens(msgs[i])
if isAssistantWithToolUse(msgs[i]) && i+1 < len(msgs) && isUserWithToolResult(msgs[i+1]) {
toks += msgTokens(msgs[i+1])
units = append(units, unit{i, i + 2, toks})
i += 2
} else {
units = append(units, unit{i, i + 1, toks})
i++
}
}
// Remove units from the front until we are within budget.
// Always keep at least the last unit so we never send an empty messages array.
removeCount := 0
for removeCount < len(units)-1 && totalTokens > maxTokens {
totalTokens -= units[removeCount].tokens
removeCount++
}
if removeCount == 0 {
return msgs, false
}
cutIdx := units[removeCount].startIdx
return msgs[cutIdx:], true
}
// msgTokens estimates token count for a single message using the chars/4 heuristic.
func msgTokens(msg map[string]any) int {
b, err := json.Marshal(msg)
if err != nil {
return 0
}
return approxTokens(string(b))
}
// isAssistantWithToolUse returns true if msg is an assistant message whose content
// contains at least one block with "type": "tool_use".
func isAssistantWithToolUse(msg map[string]any) bool {
role, _ := msg["role"].(string)
if role != "assistant" {
return false
}
return contentContainsType(msg["content"], "tool_use")
}
// isUserWithToolResult returns true if msg is a user message whose content
// contains at least one block with "type": "tool_result".
func isUserWithToolResult(msg map[string]any) bool {
role, _ := msg["role"].(string)
if role != "user" {
return false
}
return contentContainsType(msg["content"], "tool_result")
}
// contentContainsType returns true if content (a []any of blocks) contains a block
// whose "type" field equals blockType.
func contentContainsType(content any, blockType string) bool {
blocks, ok := content.([]any)
if !ok {
return false
}
for _, b := range blocks {
block, ok := b.(map[string]any)
if !ok {
continue
}
if t, _ := block["type"].(string); t == blockType {
return true
}
}
return false
}