sub2api/backend/internal/service/context_compressor_test.go
win d535688bfd feat(context): add proactive context compression for long conversations
- New context_compressor.go: pure functions operating on raw JSON body
  (gjson/sjson pattern). approxTokens uses chars/4 heuristic.
- compressMessages: removes oldest messages from front, treating
  consecutive assistant(tool_use)+user(tool_result) pairs as atomic units
  to prevent orphaned tool_result blocks.
- Hooked into Forward() after StripEmptyTextBlocks, gated on
  account.Credentials[enable_context_compression].
- Config: gateway.context_compression.max_tokens (default 190000).
- 8 unit tests covering: approx tokens, no-op when under budget,
  oldest-message trimming, tool pair preservation, atomic pair removal,
  body passthrough, body trimming.
2026-04-29 01:33:05 +08:00

196 lines
5.4 KiB
Go

package service
import (
"encoding/json"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// helpers
func makeMsg(role, text string) map[string]any {
return map[string]any{
"role": role,
"content": text,
}
}
func makeToolUseMsg(id string) map[string]any {
return map[string]any{
"role": "assistant",
"content": []any{
map[string]any{
"type": "tool_use",
"id": id,
"name": "search",
"input": map[string]any{},
},
},
}
}
func makeToolResultMsg(toolUseID string) map[string]any {
return map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "tool_result",
"tool_use_id": toolUseID,
"content": "result text",
},
},
}
}
func toAnySlice(msgs []map[string]any) []any {
out := make([]any, len(msgs))
for i, m := range msgs {
out[i] = m
}
return out
}
func bodyWithMessages(t *testing.T, msgs []map[string]any) []byte {
t.Helper()
b, err := json.Marshal(map[string]any{"messages": msgs, "model": "claude-3-5-sonnet-20241022"})
require.NoError(t, err)
return b
}
// tests
func TestApproxTokens(t *testing.T) {
assert.Equal(t, 1, approxTokens("four")) // 4 chars → 1 token
assert.Equal(t, 3, approxTokens("0123456789ab")) // 12 chars → 3 tokens
assert.Equal(t, 0, approxTokens(""))
}
func TestCompressMessages_NoCompressionNeeded(t *testing.T) {
msgs := []map[string]any{
makeMsg("user", "hi"),
makeMsg("assistant", "hello"),
}
result, changed := compressMessages(msgs, 100_000)
assert.False(t, changed)
assert.Len(t, result, 2)
}
func TestCompressMessages_TrimsOldestMessages(t *testing.T) {
// 10 messages, each large enough to be over a tight budget when combined.
msgs := make([]map[string]any, 10)
for i := range msgs {
role := "user"
if i%2 == 1 {
role = "assistant"
}
msgs[i] = makeMsg(role, fmt.Sprintf("message number %d with some content to increase token count", i))
}
// Force compression by using a very small token budget.
result, changed := compressMessages(msgs, 1)
assert.True(t, changed)
// Must keep at least one message (the last).
assert.GreaterOrEqual(t, len(result), 1)
// The remaining messages should be from the tail (newest).
lastOrig := msgs[len(msgs)-1]["content"]
lastResult := result[len(result)-1]["content"]
assert.Equal(t, lastOrig, lastResult)
}
func TestCompressMessages_PreservesToolUsePairs(t *testing.T) {
// Messages: user → assistant+tool_use → user+tool_result → assistant
msgs := []map[string]any{
makeMsg("user", "start"),
makeToolUseMsg("tool-1"),
makeToolResultMsg("tool-1"),
makeMsg("assistant", "done"),
}
// Budget that forces removal of the first non-paired message but keeps the tool pair.
// Estimate total tokens and set budget to force removing only "start" but not the pair.
total := 0
for _, m := range msgs {
total += msgTokens(m)
}
// Budget: remove "start" but keep tool pair + "done".
startTokens := msgTokens(msgs[0])
budget := total - startTokens
result, changed := compressMessages(msgs, budget)
assert.True(t, changed)
// tool_use and tool_result should both be present or both absent.
hasToolUse := false
hasToolResult := false
for _, m := range result {
if isAssistantWithToolUse(m) {
hasToolUse = true
}
if isUserWithToolResult(m) {
hasToolResult = true
}
}
assert.Equal(t, hasToolUse, hasToolResult, "tool_use and tool_result must appear together or not at all")
}
func TestCompressMessages_RemovesToolPairAtomically(t *testing.T) {
// Budget forces removal of the tool pair.
msgs := []map[string]any{
makeMsg("user", "start"),
makeToolUseMsg("tool-1"),
makeToolResultMsg("tool-1"),
makeMsg("assistant", "final answer after tool use"),
}
// Budget: only keep the last "assistant" message.
lastTokens := msgTokens(msgs[len(msgs)-1])
result, changed := compressMessages(msgs, lastTokens)
assert.True(t, changed)
// Neither tool_use nor tool_result should remain.
for _, m := range result {
assert.False(t, isAssistantWithToolUse(m), "tool_use should have been removed with its pair")
assert.False(t, isUserWithToolResult(m), "tool_result should have been removed with its pair")
}
}
func TestCompressMessagesInBody_NoMessages(t *testing.T) {
body := []byte(`{"model":"claude-3-5-sonnet-20241022"}`)
result := compressMessagesInBody(body, 1)
assert.Equal(t, body, result, "body without messages should be unchanged")
}
func TestCompressMessagesInBody_UnderBudget(t *testing.T) {
msgs := []map[string]any{makeMsg("user", "hi")}
body := bodyWithMessages(t, msgs)
result := compressMessagesInBody(body, 100_000)
assert.Equal(t, body, result, "body under budget should be unchanged")
}
func TestCompressMessagesInBody_TrimsToBudget(t *testing.T) {
msgs := make([]map[string]any, 20)
for i := range msgs {
role := "user"
if i%2 == 1 {
role = "assistant"
}
msgs[i] = makeMsg(role, fmt.Sprintf("message %d with some padding text to have enough tokens", i))
}
body := bodyWithMessages(t, msgs)
// Force significant compression.
result := compressMessagesInBody(body, 50)
assert.Less(t, len(result), len(body), "compressed body should be smaller")
// Resulting body should still be valid JSON with a messages array.
var parsed map[string]any
require.NoError(t, json.Unmarshal(result, &parsed))
resultMsgs, ok := parsed["messages"].([]any)
require.True(t, ok)
assert.Greater(t, len(resultMsgs), 0, "messages array should not be empty")
}