- New context_compressor.go: pure functions operating on raw JSON body (gjson/sjson pattern). approxTokens uses chars/4 heuristic. - compressMessages: removes oldest messages from front, treating consecutive assistant(tool_use)+user(tool_result) pairs as atomic units to prevent orphaned tool_result blocks. - Hooked into Forward() after StripEmptyTextBlocks, gated on account.Credentials[enable_context_compression]. - Config: gateway.context_compression.max_tokens (default 190000). - 8 unit tests covering: approx tokens, no-op when under budget, oldest-message trimming, tool pair preservation, atomic pair removal, body passthrough, body trimming.
152 lines
4.2 KiB
Go
152 lines
4.2 KiB
Go
package service
|
|
|
|
import (
|
|
"encoding/json"
|
|
"math"
|
|
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
// defaultContextCompressionMaxTokens is the default target token budget (chars/4 approximation).
|
|
// 190K is conservative for a 200K-window model, leaving ~10K headroom for the response.
|
|
const defaultContextCompressionMaxTokens = 190_000
|
|
|
|
// approxTokens estimates the token count for a string using the chars/4 heuristic.
|
|
func approxTokens(s string) int {
|
|
return int(math.Ceil(float64(len(s)) / 4.0))
|
|
}
|
|
|
|
// compressMessagesInBody trims the oldest messages from the request body so that the
|
|
// estimated token count of the messages array fits within maxTokens.
|
|
// Returns the original body unchanged if no compression is needed or if parsing fails.
|
|
func compressMessagesInBody(body []byte, maxTokens int) []byte {
|
|
msgsResult := gjson.GetBytes(body, "messages")
|
|
if !msgsResult.Exists() || !msgsResult.IsArray() {
|
|
return body
|
|
}
|
|
|
|
// Unmarshal to a typed slice for processing.
|
|
var messages []map[string]any
|
|
if err := json.Unmarshal([]byte(msgsResult.Raw), &messages); err != nil {
|
|
return body
|
|
}
|
|
|
|
compressed, changed := compressMessages(messages, maxTokens)
|
|
if !changed {
|
|
return body
|
|
}
|
|
|
|
newMsgs, err := json.Marshal(compressed)
|
|
if err != nil {
|
|
return body
|
|
}
|
|
updated, err := sjson.SetRawBytes(body, "messages", newMsgs)
|
|
if err != nil {
|
|
return body
|
|
}
|
|
return updated
|
|
}
|
|
|
|
// compressMessages removes the oldest messages from the front of msgs until the
|
|
// estimated total token count is at or below maxTokens.
|
|
// tool_use (assistant) and tool_result (user) consecutive pairs are removed atomically
|
|
// to avoid orphaned tool_result blocks.
|
|
// Returns (msgs, false) if no compression was needed, or (trimmed, true) otherwise.
|
|
func compressMessages(msgs []map[string]any, maxTokens int) ([]map[string]any, bool) {
|
|
if len(msgs) == 0 {
|
|
return msgs, false
|
|
}
|
|
|
|
// Estimate total tokens.
|
|
totalTokens := 0
|
|
for _, m := range msgs {
|
|
totalTokens += msgTokens(m)
|
|
}
|
|
if totalTokens <= maxTokens {
|
|
return msgs, false
|
|
}
|
|
|
|
// Build atomic removal units: tool_use+tool_result consecutive pairs are one unit.
|
|
type unit struct {
|
|
startIdx int
|
|
endIdx int // exclusive
|
|
tokens int
|
|
}
|
|
units := make([]unit, 0, len(msgs))
|
|
i := 0
|
|
for i < len(msgs) {
|
|
toks := msgTokens(msgs[i])
|
|
if isAssistantWithToolUse(msgs[i]) && i+1 < len(msgs) && isUserWithToolResult(msgs[i+1]) {
|
|
toks += msgTokens(msgs[i+1])
|
|
units = append(units, unit{i, i + 2, toks})
|
|
i += 2
|
|
} else {
|
|
units = append(units, unit{i, i + 1, toks})
|
|
i++
|
|
}
|
|
}
|
|
|
|
// Remove units from the front until we are within budget.
|
|
// Always keep at least the last unit so we never send an empty messages array.
|
|
removeCount := 0
|
|
for removeCount < len(units)-1 && totalTokens > maxTokens {
|
|
totalTokens -= units[removeCount].tokens
|
|
removeCount++
|
|
}
|
|
if removeCount == 0 {
|
|
return msgs, false
|
|
}
|
|
|
|
cutIdx := units[removeCount].startIdx
|
|
return msgs[cutIdx:], true
|
|
}
|
|
|
|
// msgTokens estimates token count for a single message using the chars/4 heuristic.
|
|
func msgTokens(msg map[string]any) int {
|
|
b, err := json.Marshal(msg)
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
return approxTokens(string(b))
|
|
}
|
|
|
|
// isAssistantWithToolUse returns true if msg is an assistant message whose content
|
|
// contains at least one block with "type": "tool_use".
|
|
func isAssistantWithToolUse(msg map[string]any) bool {
|
|
role, _ := msg["role"].(string)
|
|
if role != "assistant" {
|
|
return false
|
|
}
|
|
return contentContainsType(msg["content"], "tool_use")
|
|
}
|
|
|
|
// isUserWithToolResult returns true if msg is a user message whose content
|
|
// contains at least one block with "type": "tool_result".
|
|
func isUserWithToolResult(msg map[string]any) bool {
|
|
role, _ := msg["role"].(string)
|
|
if role != "user" {
|
|
return false
|
|
}
|
|
return contentContainsType(msg["content"], "tool_result")
|
|
}
|
|
|
|
// contentContainsType returns true if content (a []any of blocks) contains a block
|
|
// whose "type" field equals blockType.
|
|
func contentContainsType(content any, blockType string) bool {
|
|
blocks, ok := content.([]any)
|
|
if !ok {
|
|
return false
|
|
}
|
|
for _, b := range blocks {
|
|
block, ok := b.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if t, _ := block["type"].(string); t == blockType {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|