204 lines
6.2 KiB
Go
204 lines
6.2 KiB
Go
package windsurf
|
||
|
||
import (
|
||
"context"
|
||
"encoding/binary"
|
||
"io"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"sync"
|
||
"testing"
|
||
)
|
||
|
||
// TestSessionIDIsolatedPerAccount locks in the fix that prevents Windsurf
|
||
// 上下文记忆乱套:单例 LocalLSClient 必须为每个上游 account 维护独立的
|
||
// SessionID,否则不同用户的请求会被本地 LS 视作同一 cascade 会话,导致历史串扰。
|
||
func TestSessionIDIsolatedPerAccount(t *testing.T) {
|
||
c := NewLocalLSClient(0, "csrf")
|
||
|
||
idA1 := c.sessionIDForAccount(1001, "user-a-token-v1")
|
||
idB1 := c.sessionIDForAccount(1002, "user-b-token")
|
||
|
||
if idA1 == "" || idB1 == "" {
|
||
t.Fatalf("expected non-empty SessionIDs, got %q / %q", idA1, idB1)
|
||
}
|
||
if idA1 == idB1 {
|
||
t.Fatalf("SessionIDs leaked across tokens: %q", idA1)
|
||
}
|
||
|
||
// Stable across token refreshes for the same account.
|
||
if idA2 := c.sessionIDForAccount(1001, "user-a-token-v2"); idA2 != idA1 {
|
||
t.Fatalf("SessionID for account A changed across token refresh: %q -> %q", idA1, idA2)
|
||
}
|
||
if idB2 := c.sessionIDForAccount(1002, "user-b-token"); idB2 != idB1 {
|
||
t.Fatalf("SessionID for account B changed across calls: %q -> %q", idB1, idB2)
|
||
}
|
||
}
|
||
|
||
func TestSessionFallbackUsesTokenHash(t *testing.T) {
|
||
c := NewLocalLSClient(0, "csrf")
|
||
|
||
rawToken := "user-a-very-secret-token"
|
||
_ = c.sessionIDFor(rawToken)
|
||
|
||
if _, ok := c.sessionSlots[rawToken]; ok {
|
||
t.Fatalf("sessionSlots must not store raw token keys")
|
||
}
|
||
if _, ok := c.sessionSlots["token:"+apiKeyHash(rawToken)]; !ok {
|
||
t.Fatalf("sessionSlots should store fallback token hash key")
|
||
}
|
||
}
|
||
|
||
// TestWarmedFlagIsolatedPerAccount verifies that marking one account as warmed
|
||
// does NOT cause a different account's warmup path to be skipped — the bug
|
||
// before the fix was that the global Warmed=true set by user A let user B's
|
||
// request bypass InitializeCascadePanelState entirely while reusing user A's
|
||
// SessionID.
|
||
func TestWarmedFlagIsolatedPerAccount(t *testing.T) {
|
||
c := NewLocalLSClient(0, "csrf")
|
||
|
||
c.markSessionWarmedForAccount(1001, "user-a-token", true)
|
||
|
||
if !c.isSessionWarmedForAccount(1001, "user-a-token-refreshed") {
|
||
t.Fatalf("account A should be warmed after markSessionWarmedForAccount")
|
||
}
|
||
if c.isSessionWarmedForAccount(1002, "user-b-token") {
|
||
t.Fatalf("account B must NOT be considered warmed; warm state leaked across accounts")
|
||
}
|
||
}
|
||
|
||
// TestResetSessionPerAccount verifies resetSession only rotates the target
|
||
// account's SessionID and does not disturb other accounts.
|
||
func TestResetSessionPerAccount(t *testing.T) {
|
||
c := NewLocalLSClient(0, "csrf")
|
||
|
||
idA1 := c.sessionIDForAccount(1001, "user-a-token")
|
||
idB1 := c.sessionIDForAccount(1002, "user-b-token")
|
||
|
||
c.markSessionWarmedForAccount(1001, "user-a-token", true)
|
||
c.markSessionWarmedForAccount(1002, "user-b-token", true)
|
||
|
||
c.resetSessionForAccount(1001, "user-a-token")
|
||
|
||
idA2 := c.sessionIDForAccount(1001, "user-a-token-refreshed")
|
||
idB2 := c.sessionIDForAccount(1002, "user-b-token")
|
||
|
||
if idA2 == idA1 {
|
||
t.Fatalf("resetSession(tokenA) did not rotate SessionID: %q", idA1)
|
||
}
|
||
if c.isSessionWarmedForAccount(1001, "user-a-token") {
|
||
t.Fatalf("resetSession(tokenA) did not clear warmed flag")
|
||
}
|
||
if idB2 != idB1 {
|
||
t.Fatalf("resetSession(tokenA) clobbered tokenB SessionID: %q -> %q", idB1, idB2)
|
||
}
|
||
if !c.isSessionWarmedForAccount(1002, "user-b-token") {
|
||
t.Fatalf("resetSession(tokenA) clobbered tokenB warmed flag")
|
||
}
|
||
}
|
||
|
||
func TestStartCascadeMetadataUsesAccountScopedSessionID(t *testing.T) {
|
||
var mu sync.Mutex
|
||
sessionIDs := map[string]string{}
|
||
|
||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
body, _ := io.ReadAll(r.Body)
|
||
payload := stripGRPCFrame(body)
|
||
meta := testBytesField(t, payload, 1)
|
||
token := testStringField(t, meta, 3)
|
||
sessionID := testStringField(t, meta, 10)
|
||
|
||
mu.Lock()
|
||
sessionIDs[token] = sessionID
|
||
mu.Unlock()
|
||
|
||
w.Header().Set("Content-Type", "application/grpc")
|
||
w.Header().Set("grpc-status", "0")
|
||
// StartCascade response: field 1 cascade_id.
|
||
resp := encodeStringField(1, "cascade-"+token)
|
||
frame := make([]byte, 5+len(resp))
|
||
binary.BigEndian.PutUint32(frame[1:5], uint32(len(resp)))
|
||
copy(frame[5:], resp)
|
||
_, _ = w.Write(frame)
|
||
}))
|
||
defer server.Close()
|
||
|
||
c := NewLocalLSClient(0, "csrf")
|
||
c.BaseURL = server.URL
|
||
c.HTTP = server.Client()
|
||
|
||
if _, err := c.StartCascadeForAccount(context.Background(), 1001, "token-a-v1"); err != nil {
|
||
t.Fatalf("StartCascade(token-a) error = %v", err)
|
||
}
|
||
if _, err := c.StartCascadeForAccount(context.Background(), 1001, "token-a-v2"); err != nil {
|
||
t.Fatalf("StartCascade(token-a refreshed) error = %v", err)
|
||
}
|
||
if _, err := c.StartCascadeForAccount(context.Background(), 1002, "token-b"); err != nil {
|
||
t.Fatalf("StartCascade(token-b) error = %v", err)
|
||
}
|
||
|
||
mu.Lock()
|
||
defer mu.Unlock()
|
||
idA1 := sessionIDs["token-a-v1"]
|
||
idA2 := sessionIDs["token-a-v2"]
|
||
idB := sessionIDs["token-b"]
|
||
if idA1 == "" || idA2 == "" || idB == "" {
|
||
t.Fatalf("expected captured session IDs, got %q / %q / %q", idA1, idA2, idB)
|
||
}
|
||
if idA1 != idA2 {
|
||
t.Fatalf("StartCascade metadata changed session ID across token refresh: %q -> %q", idA1, idA2)
|
||
}
|
||
if idA1 == idB {
|
||
t.Fatalf("StartCascade metadata reused session ID across accounts: %q", idA1)
|
||
}
|
||
}
|
||
|
||
func testBytesField(t *testing.T, data []byte, wantField uint64) []byte {
|
||
t.Helper()
|
||
pos := 0
|
||
for pos < len(data) {
|
||
tag, next, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
t.Fatalf("failed to read tag at pos %d", pos)
|
||
}
|
||
pos = next
|
||
fieldNum := tag >> 3
|
||
wireType := tag & 7
|
||
switch wireType {
|
||
case 2:
|
||
length, next, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
t.Fatalf("failed to read length at pos %d", pos)
|
||
}
|
||
pos = next
|
||
end := pos + int(length)
|
||
if end > len(data) {
|
||
t.Fatalf("field %d out of bounds: end=%d len=%d", fieldNum, end, len(data))
|
||
}
|
||
if fieldNum == wantField {
|
||
return data[pos:end]
|
||
}
|
||
pos = end
|
||
case 0:
|
||
_, next, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
t.Fatalf("failed to skip varint at pos %d", pos)
|
||
}
|
||
pos = next
|
||
case 1:
|
||
pos += 8
|
||
case 5:
|
||
pos += 4
|
||
default:
|
||
t.Fatalf("unexpected wire type %d at pos %d", wireType, pos)
|
||
}
|
||
}
|
||
t.Fatalf("field %d not found", wantField)
|
||
return nil
|
||
}
|
||
|
||
func testStringField(t *testing.T, data []byte, wantField uint64) string {
|
||
t.Helper()
|
||
return string(testBytesField(t, data, wantField))
|
||
}
|