sub2api/backend/internal/pkg/windsurf/local_ls_session_isolation_test.go

204 lines
6.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package windsurf
import (
"context"
"encoding/binary"
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"
)
// TestSessionIDIsolatedPerAccount locks in the fix that prevents Windsurf
// 上下文记忆乱套:单例 LocalLSClient 必须为每个上游 account 维护独立的
// SessionID否则不同用户的请求会被本地 LS 视作同一 cascade 会话,导致历史串扰。
func TestSessionIDIsolatedPerAccount(t *testing.T) {
c := NewLocalLSClient(0, "csrf")
idA1 := c.sessionIDForAccount(1001, "user-a-token-v1")
idB1 := c.sessionIDForAccount(1002, "user-b-token")
if idA1 == "" || idB1 == "" {
t.Fatalf("expected non-empty SessionIDs, got %q / %q", idA1, idB1)
}
if idA1 == idB1 {
t.Fatalf("SessionIDs leaked across tokens: %q", idA1)
}
// Stable across token refreshes for the same account.
if idA2 := c.sessionIDForAccount(1001, "user-a-token-v2"); idA2 != idA1 {
t.Fatalf("SessionID for account A changed across token refresh: %q -> %q", idA1, idA2)
}
if idB2 := c.sessionIDForAccount(1002, "user-b-token"); idB2 != idB1 {
t.Fatalf("SessionID for account B changed across calls: %q -> %q", idB1, idB2)
}
}
func TestSessionFallbackUsesTokenHash(t *testing.T) {
c := NewLocalLSClient(0, "csrf")
rawToken := "user-a-very-secret-token"
_ = c.sessionIDFor(rawToken)
if _, ok := c.sessionSlots[rawToken]; ok {
t.Fatalf("sessionSlots must not store raw token keys")
}
if _, ok := c.sessionSlots["token:"+apiKeyHash(rawToken)]; !ok {
t.Fatalf("sessionSlots should store fallback token hash key")
}
}
// TestWarmedFlagIsolatedPerAccount verifies that marking one account as warmed
// does NOT cause a different account's warmup path to be skipped — the bug
// before the fix was that the global Warmed=true set by user A let user B's
// request bypass InitializeCascadePanelState entirely while reusing user A's
// SessionID.
func TestWarmedFlagIsolatedPerAccount(t *testing.T) {
c := NewLocalLSClient(0, "csrf")
c.markSessionWarmedForAccount(1001, "user-a-token", true)
if !c.isSessionWarmedForAccount(1001, "user-a-token-refreshed") {
t.Fatalf("account A should be warmed after markSessionWarmedForAccount")
}
if c.isSessionWarmedForAccount(1002, "user-b-token") {
t.Fatalf("account B must NOT be considered warmed; warm state leaked across accounts")
}
}
// TestResetSessionPerAccount verifies resetSession only rotates the target
// account's SessionID and does not disturb other accounts.
func TestResetSessionPerAccount(t *testing.T) {
c := NewLocalLSClient(0, "csrf")
idA1 := c.sessionIDForAccount(1001, "user-a-token")
idB1 := c.sessionIDForAccount(1002, "user-b-token")
c.markSessionWarmedForAccount(1001, "user-a-token", true)
c.markSessionWarmedForAccount(1002, "user-b-token", true)
c.resetSessionForAccount(1001, "user-a-token")
idA2 := c.sessionIDForAccount(1001, "user-a-token-refreshed")
idB2 := c.sessionIDForAccount(1002, "user-b-token")
if idA2 == idA1 {
t.Fatalf("resetSession(tokenA) did not rotate SessionID: %q", idA1)
}
if c.isSessionWarmedForAccount(1001, "user-a-token") {
t.Fatalf("resetSession(tokenA) did not clear warmed flag")
}
if idB2 != idB1 {
t.Fatalf("resetSession(tokenA) clobbered tokenB SessionID: %q -> %q", idB1, idB2)
}
if !c.isSessionWarmedForAccount(1002, "user-b-token") {
t.Fatalf("resetSession(tokenA) clobbered tokenB warmed flag")
}
}
func TestStartCascadeMetadataUsesAccountScopedSessionID(t *testing.T) {
var mu sync.Mutex
sessionIDs := map[string]string{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
payload := stripGRPCFrame(body)
meta := testBytesField(t, payload, 1)
token := testStringField(t, meta, 3)
sessionID := testStringField(t, meta, 10)
mu.Lock()
sessionIDs[token] = sessionID
mu.Unlock()
w.Header().Set("Content-Type", "application/grpc")
w.Header().Set("grpc-status", "0")
// StartCascade response: field 1 cascade_id.
resp := encodeStringField(1, "cascade-"+token)
frame := make([]byte, 5+len(resp))
binary.BigEndian.PutUint32(frame[1:5], uint32(len(resp)))
copy(frame[5:], resp)
_, _ = w.Write(frame)
}))
defer server.Close()
c := NewLocalLSClient(0, "csrf")
c.BaseURL = server.URL
c.HTTP = server.Client()
if _, err := c.StartCascadeForAccount(context.Background(), 1001, "token-a-v1"); err != nil {
t.Fatalf("StartCascade(token-a) error = %v", err)
}
if _, err := c.StartCascadeForAccount(context.Background(), 1001, "token-a-v2"); err != nil {
t.Fatalf("StartCascade(token-a refreshed) error = %v", err)
}
if _, err := c.StartCascadeForAccount(context.Background(), 1002, "token-b"); err != nil {
t.Fatalf("StartCascade(token-b) error = %v", err)
}
mu.Lock()
defer mu.Unlock()
idA1 := sessionIDs["token-a-v1"]
idA2 := sessionIDs["token-a-v2"]
idB := sessionIDs["token-b"]
if idA1 == "" || idA2 == "" || idB == "" {
t.Fatalf("expected captured session IDs, got %q / %q / %q", idA1, idA2, idB)
}
if idA1 != idA2 {
t.Fatalf("StartCascade metadata changed session ID across token refresh: %q -> %q", idA1, idA2)
}
if idA1 == idB {
t.Fatalf("StartCascade metadata reused session ID across accounts: %q", idA1)
}
}
func testBytesField(t *testing.T, data []byte, wantField uint64) []byte {
t.Helper()
pos := 0
for pos < len(data) {
tag, next, ok := ReadVarint(data, pos)
if !ok {
t.Fatalf("failed to read tag at pos %d", pos)
}
pos = next
fieldNum := tag >> 3
wireType := tag & 7
switch wireType {
case 2:
length, next, ok := ReadVarint(data, pos)
if !ok {
t.Fatalf("failed to read length at pos %d", pos)
}
pos = next
end := pos + int(length)
if end > len(data) {
t.Fatalf("field %d out of bounds: end=%d len=%d", fieldNum, end, len(data))
}
if fieldNum == wantField {
return data[pos:end]
}
pos = end
case 0:
_, next, ok := ReadVarint(data, pos)
if !ok {
t.Fatalf("failed to skip varint at pos %d", pos)
}
pos = next
case 1:
pos += 8
case 5:
pos += 4
default:
t.Fatalf("unexpected wire type %d at pos %d", wireType, pos)
}
}
t.Fatalf("field %d not found", wantField)
return nil
}
func testStringField(t *testing.T, data []byte, wantField uint64) string {
t.Helper()
return string(testBytesField(t, data, wantField))
}