sub2api/backend/internal/pkg/windsurf/conversation_pool.go
win 002066e700 chore(wip): 保存订制改动以便合并上游
- windsurf: client/pool/local_ls/tool_emulation/tool_names/models 调整
- handler: admin account_data / failover_loop / gateway_handler
- repository: scheduler_cache 及测试
- service: windsurf_chat_service / windsurf_gateway_service
- deploy: compose 合并为单文件(含 windsurf-ls profile),Dockerfile.ls
- cmd: 新增 dump_ls_models / dump_preamble / test_windsurf_tools 辅助工具
2026-04-24 11:14:36 +08:00

188 lines
4.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package windsurf
import (
"crypto/sha256"
"encoding/json"
"fmt"
"sync"
"time"
)
const (
poolTTL = 30 * time.Minute
poolMax = 500
)
type ConversationEntry struct {
CascadeID string
SessionID string
LSPort int
APIKey string
CreatedAt time.Time
LastAccess time.Time
}
type ConversationPool struct {
mu sync.Mutex
pool map[string]*ConversationEntry
stats poolStats
}
type poolStats struct {
Hits int `json:"hits"`
Misses int `json:"misses"`
Stores int `json:"stores"`
Evictions int `json:"evictions"`
Expired int `json:"expired"`
}
func NewConversationPool() *ConversationPool {
cp := &ConversationPool{
pool: make(map[string]*ConversationEntry),
}
go cp.pruneLoop()
return cp
}
func (cp *ConversationPool) Checkout(fingerprint string) *ConversationEntry {
if fingerprint == "" {
cp.mu.Lock()
cp.stats.Misses++
cp.mu.Unlock()
return nil
}
cp.mu.Lock()
defer cp.mu.Unlock()
entry, ok := cp.pool[fingerprint]
if !ok {
cp.stats.Misses++
return nil
}
delete(cp.pool, fingerprint)
if time.Since(entry.LastAccess) > poolTTL {
cp.stats.Expired++
cp.stats.Misses++
return nil
}
cp.stats.Hits++
return entry
}
func (cp *ConversationPool) Checkin(fingerprint string, entry *ConversationEntry) {
if fingerprint == "" || entry == nil {
return
}
now := time.Now()
cp.mu.Lock()
defer cp.mu.Unlock()
if entry.CreatedAt.IsZero() {
entry.CreatedAt = now
}
entry.LastAccess = now
cp.pool[fingerprint] = entry
cp.stats.Stores++
cp.pruneLocked(now)
}
func (cp *ConversationPool) InvalidateFor(apiKey string, lsPort int) int {
cp.mu.Lock()
defer cp.mu.Unlock()
dropped := 0
for fp, e := range cp.pool {
if (apiKey != "" && e.APIKey == apiKey) || (lsPort > 0 && e.LSPort == lsPort) {
delete(cp.pool, fp)
dropped++
}
}
return dropped
}
func (cp *ConversationPool) pruneLocked(now time.Time) {
for fp, e := range cp.pool {
if now.Sub(e.LastAccess) > poolTTL {
delete(cp.pool, fp)
cp.stats.Expired++
}
}
if len(cp.pool) <= poolMax {
return
}
// LRU eviction: find oldest entries
type fpTime struct {
fp string
t time.Time
}
entries := make([]fpTime, 0, len(cp.pool))
for fp, e := range cp.pool {
entries = append(entries, fpTime{fp, e.LastAccess})
}
// Simple sort by time
for i := 0; i < len(entries)-1; i++ {
for j := i + 1; j < len(entries); j++ {
if entries[j].t.Before(entries[i].t) {
entries[i], entries[j] = entries[j], entries[i]
}
}
}
toDrop := len(entries) - poolMax
for i := 0; i < toDrop; i++ {
delete(cp.pool, entries[i].fp)
cp.stats.Evictions++
}
}
func (cp *ConversationPool) pruneLoop() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
cp.mu.Lock()
cp.pruneLocked(time.Now())
cp.mu.Unlock()
}
}
// FingerprintBefore computes the fingerprint for resuming a conversation.
// Hash only user/tool turns (excluding the last one) for lookup.
// apiKey 参与 hashcascade_id 绑定具体上游账号/LS不同账号即使消息一致也不能复用
// 否则 failover 切号后命中旧 cascade 会触发 "panel state not found"。
func FingerprintBefore(messages []ChatMessage, modelKey, apiKey string) string {
turns := stableTurns(messages)
if len(turns) < 2 {
return ""
}
return hashFingerprint(modelKey, apiKey, turns[:len(turns)-1])
}
// FingerprintAfter computes the fingerprint after a successful turn.
func FingerprintAfter(messages []ChatMessage, modelKey, apiKey string) string {
turns := stableTurns(messages)
if len(turns) == 0 {
return ""
}
return hashFingerprint(modelKey, apiKey, turns)
}
func stableTurns(messages []ChatMessage) []ChatMessage {
var turns []ChatMessage
for _, m := range messages {
if m.Role == "user" || m.Role == "tool" {
turns = append(turns, m)
}
}
return turns
}
func hashFingerprint(modelKey, apiKey string, turns []ChatMessage) string {
type canonical struct {
Role string `json:"role"`
Content string `json:"content"`
}
cans := make([]canonical, len(turns))
for i, t := range turns {
cans[i] = canonical{Role: t.Role, Content: t.Content}
}
data, _ := json.Marshal(cans)
h := sha256.Sum256([]byte(fmt.Sprintf("%s\x00\x00%s\x00\x00%s", modelKey, apiKey, data)))
return fmt.Sprintf("%x", h)
}