- New context_compressor.go: pure functions operating on raw JSON body (gjson/sjson pattern). approxTokens uses chars/4 heuristic. - compressMessages: removes oldest messages from front, treating consecutive assistant(tool_use)+user(tool_result) pairs as atomic units to prevent orphaned tool_result blocks. - Hooked into Forward() after StripEmptyTextBlocks, gated on account.Credentials[enable_context_compression]. - Config: gateway.context_compression.max_tokens (default 190000). - 8 unit tests covering: approx tokens, no-op when under budget, oldest-message trimming, tool pair preservation, atomic pair removal, body passthrough, body trimming.
196 lines
5.4 KiB
Go
196 lines
5.4 KiB
Go
package service
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// helpers
|
|
|
|
func makeMsg(role, text string) map[string]any {
|
|
return map[string]any{
|
|
"role": role,
|
|
"content": text,
|
|
}
|
|
}
|
|
|
|
func makeToolUseMsg(id string) map[string]any {
|
|
return map[string]any{
|
|
"role": "assistant",
|
|
"content": []any{
|
|
map[string]any{
|
|
"type": "tool_use",
|
|
"id": id,
|
|
"name": "search",
|
|
"input": map[string]any{},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func makeToolResultMsg(toolUseID string) map[string]any {
|
|
return map[string]any{
|
|
"role": "user",
|
|
"content": []any{
|
|
map[string]any{
|
|
"type": "tool_result",
|
|
"tool_use_id": toolUseID,
|
|
"content": "result text",
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func toAnySlice(msgs []map[string]any) []any {
|
|
out := make([]any, len(msgs))
|
|
for i, m := range msgs {
|
|
out[i] = m
|
|
}
|
|
return out
|
|
}
|
|
|
|
func bodyWithMessages(t *testing.T, msgs []map[string]any) []byte {
|
|
t.Helper()
|
|
b, err := json.Marshal(map[string]any{"messages": msgs, "model": "claude-3-5-sonnet-20241022"})
|
|
require.NoError(t, err)
|
|
return b
|
|
}
|
|
|
|
// tests
|
|
|
|
func TestApproxTokens(t *testing.T) {
|
|
assert.Equal(t, 1, approxTokens("four")) // 4 chars → 1 token
|
|
assert.Equal(t, 3, approxTokens("0123456789ab")) // 12 chars → 3 tokens
|
|
assert.Equal(t, 0, approxTokens(""))
|
|
}
|
|
|
|
func TestCompressMessages_NoCompressionNeeded(t *testing.T) {
|
|
msgs := []map[string]any{
|
|
makeMsg("user", "hi"),
|
|
makeMsg("assistant", "hello"),
|
|
}
|
|
result, changed := compressMessages(msgs, 100_000)
|
|
assert.False(t, changed)
|
|
assert.Len(t, result, 2)
|
|
}
|
|
|
|
func TestCompressMessages_TrimsOldestMessages(t *testing.T) {
|
|
// 10 messages, each large enough to be over a tight budget when combined.
|
|
msgs := make([]map[string]any, 10)
|
|
for i := range msgs {
|
|
role := "user"
|
|
if i%2 == 1 {
|
|
role = "assistant"
|
|
}
|
|
msgs[i] = makeMsg(role, fmt.Sprintf("message number %d with some content to increase token count", i))
|
|
}
|
|
|
|
// Force compression by using a very small token budget.
|
|
result, changed := compressMessages(msgs, 1)
|
|
assert.True(t, changed)
|
|
// Must keep at least one message (the last).
|
|
assert.GreaterOrEqual(t, len(result), 1)
|
|
// The remaining messages should be from the tail (newest).
|
|
lastOrig := msgs[len(msgs)-1]["content"]
|
|
lastResult := result[len(result)-1]["content"]
|
|
assert.Equal(t, lastOrig, lastResult)
|
|
}
|
|
|
|
func TestCompressMessages_PreservesToolUsePairs(t *testing.T) {
|
|
// Messages: user → assistant+tool_use → user+tool_result → assistant
|
|
msgs := []map[string]any{
|
|
makeMsg("user", "start"),
|
|
makeToolUseMsg("tool-1"),
|
|
makeToolResultMsg("tool-1"),
|
|
makeMsg("assistant", "done"),
|
|
}
|
|
|
|
// Budget that forces removal of the first non-paired message but keeps the tool pair.
|
|
// Estimate total tokens and set budget to force removing only "start" but not the pair.
|
|
total := 0
|
|
for _, m := range msgs {
|
|
total += msgTokens(m)
|
|
}
|
|
// Budget: remove "start" but keep tool pair + "done".
|
|
startTokens := msgTokens(msgs[0])
|
|
budget := total - startTokens
|
|
|
|
result, changed := compressMessages(msgs, budget)
|
|
assert.True(t, changed)
|
|
|
|
// tool_use and tool_result should both be present or both absent.
|
|
hasToolUse := false
|
|
hasToolResult := false
|
|
for _, m := range result {
|
|
if isAssistantWithToolUse(m) {
|
|
hasToolUse = true
|
|
}
|
|
if isUserWithToolResult(m) {
|
|
hasToolResult = true
|
|
}
|
|
}
|
|
assert.Equal(t, hasToolUse, hasToolResult, "tool_use and tool_result must appear together or not at all")
|
|
}
|
|
|
|
func TestCompressMessages_RemovesToolPairAtomically(t *testing.T) {
|
|
// Budget forces removal of the tool pair.
|
|
msgs := []map[string]any{
|
|
makeMsg("user", "start"),
|
|
makeToolUseMsg("tool-1"),
|
|
makeToolResultMsg("tool-1"),
|
|
makeMsg("assistant", "final answer after tool use"),
|
|
}
|
|
|
|
// Budget: only keep the last "assistant" message.
|
|
lastTokens := msgTokens(msgs[len(msgs)-1])
|
|
|
|
result, changed := compressMessages(msgs, lastTokens)
|
|
assert.True(t, changed)
|
|
|
|
// Neither tool_use nor tool_result should remain.
|
|
for _, m := range result {
|
|
assert.False(t, isAssistantWithToolUse(m), "tool_use should have been removed with its pair")
|
|
assert.False(t, isUserWithToolResult(m), "tool_result should have been removed with its pair")
|
|
}
|
|
}
|
|
|
|
func TestCompressMessagesInBody_NoMessages(t *testing.T) {
|
|
body := []byte(`{"model":"claude-3-5-sonnet-20241022"}`)
|
|
result := compressMessagesInBody(body, 1)
|
|
assert.Equal(t, body, result, "body without messages should be unchanged")
|
|
}
|
|
|
|
func TestCompressMessagesInBody_UnderBudget(t *testing.T) {
|
|
msgs := []map[string]any{makeMsg("user", "hi")}
|
|
body := bodyWithMessages(t, msgs)
|
|
result := compressMessagesInBody(body, 100_000)
|
|
assert.Equal(t, body, result, "body under budget should be unchanged")
|
|
}
|
|
|
|
func TestCompressMessagesInBody_TrimsToBudget(t *testing.T) {
|
|
msgs := make([]map[string]any, 20)
|
|
for i := range msgs {
|
|
role := "user"
|
|
if i%2 == 1 {
|
|
role = "assistant"
|
|
}
|
|
msgs[i] = makeMsg(role, fmt.Sprintf("message %d with some padding text to have enough tokens", i))
|
|
}
|
|
body := bodyWithMessages(t, msgs)
|
|
|
|
// Force significant compression.
|
|
result := compressMessagesInBody(body, 50)
|
|
assert.Less(t, len(result), len(body), "compressed body should be smaller")
|
|
|
|
// Resulting body should still be valid JSON with a messages array.
|
|
var parsed map[string]any
|
|
require.NoError(t, json.Unmarshal(result, &parsed))
|
|
resultMsgs, ok := parsed["messages"].([]any)
|
|
require.True(t, ok)
|
|
assert.Greater(t, len(resultMsgs), 0, "messages array should not be empty")
|
|
}
|