215 lines
6.5 KiB
Go
215 lines
6.5 KiB
Go
package windsurf
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
)
|
|
|
|
func TestBuildToolPreambleForProtoCanonicalizesToolsAndChoice(t *testing.T) {
|
|
tools := []OpenAITool{
|
|
{
|
|
Type: "function",
|
|
Function: OpenAIFunction{
|
|
Name: "list_files",
|
|
Description: "List files in the repository",
|
|
Parameters: json.RawMessage(`{"type":"object"}`),
|
|
},
|
|
},
|
|
{
|
|
Type: "function",
|
|
Function: OpenAIFunction{
|
|
Name: "glob",
|
|
Description: "Duplicate alias should be deduped",
|
|
Parameters: json.RawMessage(`{"type":"object"}`),
|
|
},
|
|
},
|
|
}
|
|
|
|
got := BuildToolPreambleForProto(tools, map[string]any{
|
|
"type": "tool",
|
|
"name": "search_files",
|
|
})
|
|
|
|
if strings.Contains(got, "### list_files") {
|
|
t.Fatalf("preamble should not expose alias tool names: %s", got)
|
|
}
|
|
if count := strings.Count(got, `"name":"glob"`); count != 1 {
|
|
t.Fatalf("expected exactly one canonical glob tool, got %d in %s", count, got)
|
|
}
|
|
if !strings.Contains(got, `You must call the function "grep"`) {
|
|
t.Fatalf("forced tool choice should be canonicalized to grep: %s", got)
|
|
}
|
|
if !strings.Contains(got, "relative file paths and search paths") {
|
|
t.Fatalf("preamble should explain workspace-relative path semantics: %s", got)
|
|
}
|
|
}
|
|
|
|
func TestNormalizeMessagesForCascadePreservesStructuredToolResultPayload(t *testing.T) {
|
|
messages := []AnthropicMessage{
|
|
{
|
|
Role: "tool",
|
|
ToolCallID: "call-1",
|
|
Content: json.RawMessage(`[
|
|
{"type":"text","text":"partial listing"},
|
|
{"type":"json","value":{"entries":["a.go","b.go"]}}
|
|
]`),
|
|
},
|
|
}
|
|
|
|
got := NormalizeMessagesForCascade(messages, nil)
|
|
if len(got) != 1 {
|
|
t.Fatalf("NormalizeMessagesForCascade() returned %d messages, want 1", len(got))
|
|
}
|
|
if !strings.Contains(got[0].Content, `"type":"json"`) {
|
|
t.Fatalf("structured tool_result payload should be preserved, got %q", got[0].Content)
|
|
}
|
|
if !strings.Contains(got[0].Content, "Continue the prior user request") {
|
|
t.Fatalf("tool_result should instruct model to continue prior request, got %q", got[0].Content)
|
|
}
|
|
if !strings.Contains(got[0].Content, `tool_call_id="call-1"`) {
|
|
t.Fatalf("tool_result should preserve tool call id, got %q", got[0].Content)
|
|
}
|
|
}
|
|
|
|
func TestNormalizeMessagesForCascadePromotesSlashCommandArgs(t *testing.T) {
|
|
messages := []AnthropicMessage{{
|
|
Role: "user",
|
|
Content: json.RawMessage(`[
|
|
{"type":"text","text":"<command-name>/ccg:plan</command-name>\n<command-message>Long slash command spec that says ask for feature name.</command-message>\n<command-args>分析一下这个项目 我感觉 计费逻辑出问题了</command-args>\n"}
|
|
]`),
|
|
}}
|
|
|
|
got := NormalizeMessagesForCascade(messages, nil)
|
|
if len(got) != 1 {
|
|
t.Fatalf("NormalizeMessagesForCascade() len = %d, want 1", len(got))
|
|
}
|
|
if !strings.Contains(got[0].Content, "Actual user request from the slash command arguments") {
|
|
t.Fatalf("slash command args should be promoted, got %q", got[0].Content)
|
|
}
|
|
if !strings.Contains(got[0].Content, "计费逻辑出问题了") {
|
|
t.Fatalf("actual user request should be preserved, got %q", got[0].Content)
|
|
}
|
|
if strings.Contains(got[0].Content, "Long slash command spec") {
|
|
t.Fatalf("command-message spec should be stripped, got %q", got[0].Content)
|
|
}
|
|
}
|
|
|
|
func TestBuildToolPreambleForProtoWithEnvironmentPrefixesFacts(t *testing.T) {
|
|
tools := []OpenAITool{{
|
|
Type: "function",
|
|
Function: OpenAIFunction{
|
|
Name: "read",
|
|
Parameters: json.RawMessage(`{"type":"object"}`),
|
|
},
|
|
}}
|
|
|
|
got := BuildToolPreambleForProtoWithEnvironment(tools, nil, "<environment_context>\nWorking directory: /Users/user/project\n</environment_context>")
|
|
|
|
if !strings.HasPrefix(got, "## Environment facts") {
|
|
preview := got
|
|
if len(preview) > 80 {
|
|
preview = preview[:80]
|
|
}
|
|
t.Fatalf("environment facts should prefix proto preamble, got %q", preview)
|
|
}
|
|
if !strings.Contains(got, "Prefer these environment facts over any default Cascade workspace assumption") {
|
|
t.Fatalf("environment block should override Cascade workspace prior: %s", got)
|
|
}
|
|
}
|
|
|
|
func TestParseToolCallsFromTextNormalizesAliases(t *testing.T) {
|
|
text := strings.Join([]string{
|
|
`<tool_call>{"name":"list_files","arguments":{"path":"."}}</tool_call>`,
|
|
`{"name":"search_files","arguments":{"pattern":"TODO"}}`,
|
|
`{"tool_code":"apply_patch(\"*** Begin Patch\")"}`,
|
|
}, "\n")
|
|
|
|
got := ParseToolCallsFromText(text)
|
|
if len(got.ToolCalls) != 3 {
|
|
t.Fatalf("ParseToolCallsFromText() returned %d tool calls, want 3", len(got.ToolCalls))
|
|
}
|
|
|
|
wantNames := []string{"glob", "grep", "edit"}
|
|
for i, want := range wantNames {
|
|
if got.ToolCalls[i].Name != want {
|
|
t.Fatalf("tool call %d name = %q, want %q", i, got.ToolCalls[i].Name, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSanitizePathMarksUnmountedWorkspace(t *testing.T) {
|
|
got := SanitizePath("/tmp/windsurf-workspace/pkg/main.go")
|
|
if got != "[unmounted-workspace]/pkg/main.go" {
|
|
t.Fatalf("SanitizePath() = %q, want %q", got, "[unmounted-workspace]/pkg/main.go")
|
|
}
|
|
}
|
|
|
|
func TestWarmupCascadeSkipsTrackedWorkspaceByDefault(t *testing.T) {
|
|
var mu sync.Mutex
|
|
var paths []string
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
mu.Lock()
|
|
paths = append(paths, r.URL.Path)
|
|
mu.Unlock()
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer server.Close()
|
|
|
|
client := NewLocalLSClient(42099, "csrf")
|
|
client.BaseURL = server.URL
|
|
client.HTTP = server.Client()
|
|
|
|
if err := client.WarmupCascade(context.Background(), "token"); err != nil {
|
|
t.Fatalf("WarmupCascade() error = %v", err)
|
|
}
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
for _, path := range paths {
|
|
if path == AddTrackedWorkspaceRPC {
|
|
t.Fatalf("WarmupCascade() unexpectedly called AddTrackedWorkspaceRPC: %v", paths)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWarmupCascadeAddsConfiguredWorkspace(t *testing.T) {
|
|
var mu sync.Mutex
|
|
var paths []string
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
mu.Lock()
|
|
paths = append(paths, r.URL.Path)
|
|
mu.Unlock()
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer server.Close()
|
|
|
|
client := NewLocalLSClient(42099, "csrf")
|
|
client.BaseURL = server.URL
|
|
client.HTTP = server.Client()
|
|
client.TrackedWorkspace = "/repo"
|
|
|
|
if err := client.WarmupCascade(context.Background(), "token"); err != nil {
|
|
t.Fatalf("WarmupCascade() error = %v", err)
|
|
}
|
|
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
found := false
|
|
for _, path := range paths {
|
|
if path == AddTrackedWorkspaceRPC {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Fatalf("WarmupCascade() should call AddTrackedWorkspaceRPC when TrackedWorkspace is configured: %v", paths)
|
|
}
|
|
}
|