235 lines
5.3 KiB
Go
235 lines
5.3 KiB
Go
package windsurf
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
const (
|
||
RawGetChatMessageRPC = "/exa.language_server_pb.LanguageServerService/RawGetChatMessage"
|
||
|
||
SourceUser = 1
|
||
SourceSystem = 2
|
||
SourceAssistant = 3
|
||
SourceTool = 4
|
||
)
|
||
|
||
type ChatMessage struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
// Images 当前消息携带的图片(通常只有 user role 才非空)。
|
||
// 仅用于发送/replay。CascadeImage.Base64Data 不应出现在持久化/日志/指纹中。
|
||
Images []CascadeImage `json:"images,omitempty"`
|
||
// ImageDigests 仅摘要视图(含 sha256 / mime / byte_len / caption,不含 base64)。
|
||
// 供 conversation pool 指纹与日志使用。
|
||
ImageDigests []ImageDigest `json:"image_digests,omitempty"`
|
||
}
|
||
|
||
type LegacyChatDelta struct {
|
||
Text string
|
||
InProgress bool
|
||
IsError bool
|
||
}
|
||
|
||
func encodeTimestamp() []byte {
|
||
now := time.Now()
|
||
secs := uint64(now.Unix())
|
||
nanos := uint64(now.Nanosecond())
|
||
out := encodeVarintField(1, secs)
|
||
if nanos > 0 {
|
||
out = append(out, encodeVarintField(2, nanos)...)
|
||
}
|
||
return out
|
||
}
|
||
|
||
func buildChatMessage(content string, source int, conversationID string) []byte {
|
||
var parts []byte
|
||
parts = append(parts, encodeStringField(1, generateUUID())...)
|
||
parts = append(parts, encodeVarintField(2, uint64(source))...)
|
||
parts = append(parts, encodeBytesField(3, encodeTimestamp())...)
|
||
parts = append(parts, encodeStringField(4, conversationID)...)
|
||
|
||
if source == SourceAssistant {
|
||
actionGeneric := encodeStringField(1, content)
|
||
action := encodeBytesField(1, actionGeneric)
|
||
parts = append(parts, encodeBytesField(6, action)...)
|
||
} else {
|
||
intentGeneric := encodeStringField(1, content)
|
||
intent := encodeBytesField(1, intentGeneric)
|
||
parts = append(parts, encodeBytesField(5, intent)...)
|
||
}
|
||
|
||
return parts
|
||
}
|
||
|
||
func BuildRawGetChatMessageRequest(apiKey string, messages []ChatMessage, modelEnum int, modelName string) []byte {
|
||
var parts []byte
|
||
conversationID := generateUUID()
|
||
|
||
parts = append(parts, encodeBytesField(1, buildMetadata(apiKey, generateUUID()))...)
|
||
|
||
var systemPrompt string
|
||
for _, msg := range messages {
|
||
if msg.Role == "system" {
|
||
if systemPrompt != "" {
|
||
systemPrompt += "\n"
|
||
}
|
||
systemPrompt += msg.Content
|
||
continue
|
||
}
|
||
|
||
var source int
|
||
var text string
|
||
|
||
switch msg.Role {
|
||
case "user":
|
||
source = SourceUser
|
||
text = msg.Content
|
||
case "assistant":
|
||
source = SourceAssistant
|
||
text = msg.Content
|
||
case "tool":
|
||
source = SourceUser
|
||
text = "[tool result]: " + msg.Content
|
||
default:
|
||
source = SourceUser
|
||
text = msg.Content
|
||
}
|
||
|
||
parts = append(parts, encodeBytesField(2, buildChatMessage(text, source, conversationID))...)
|
||
}
|
||
|
||
if systemPrompt != "" {
|
||
parts = append(parts, encodeStringField(3, systemPrompt)...)
|
||
}
|
||
|
||
parts = append(parts, encodeVarintField(4, uint64(modelEnum))...)
|
||
|
||
if modelName != "" {
|
||
parts = append(parts, encodeStringField(5, modelName)...)
|
||
}
|
||
|
||
return parts
|
||
}
|
||
|
||
func ParseRawChatResponse(data []byte) LegacyChatDelta {
|
||
pos := 0
|
||
var deltaMsg []byte
|
||
for pos < len(data) {
|
||
tag, np, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np
|
||
fieldNum := tag >> 3
|
||
wireType := tag & 7
|
||
|
||
switch wireType {
|
||
case 2:
|
||
length, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
return LegacyChatDelta{}
|
||
}
|
||
pos = np2
|
||
if pos+int(length) > len(data) {
|
||
return LegacyChatDelta{}
|
||
}
|
||
field := data[pos : pos+int(length)]
|
||
pos += int(length)
|
||
if fieldNum == 1 {
|
||
deltaMsg = field
|
||
}
|
||
case 0:
|
||
_, np2, ok := ReadVarint(data, pos)
|
||
if !ok {
|
||
return LegacyChatDelta{}
|
||
}
|
||
pos = np2
|
||
case 1:
|
||
pos += 8
|
||
case 5:
|
||
pos += 4
|
||
default:
|
||
return LegacyChatDelta{}
|
||
}
|
||
}
|
||
|
||
if deltaMsg == nil {
|
||
return LegacyChatDelta{}
|
||
}
|
||
|
||
var result LegacyChatDelta
|
||
pos = 0
|
||
for pos < len(deltaMsg) {
|
||
tag, np, ok := ReadVarint(deltaMsg, pos)
|
||
if !ok {
|
||
break
|
||
}
|
||
pos = np
|
||
fieldNum := tag >> 3
|
||
wireType := tag & 7
|
||
|
||
switch wireType {
|
||
case 2:
|
||
length, np2, ok := ReadVarint(deltaMsg, pos)
|
||
if !ok {
|
||
return result
|
||
}
|
||
pos = np2
|
||
if pos+int(length) > len(deltaMsg) {
|
||
return result
|
||
}
|
||
field := deltaMsg[pos : pos+int(length)]
|
||
pos += int(length)
|
||
if fieldNum == 5 {
|
||
result.Text = string(field)
|
||
}
|
||
case 0:
|
||
val, np2, ok := ReadVarint(deltaMsg, pos)
|
||
if !ok {
|
||
return result
|
||
}
|
||
pos = np2
|
||
if fieldNum == 6 {
|
||
result.InProgress = val != 0
|
||
} else if fieldNum == 7 {
|
||
result.IsError = val != 0
|
||
}
|
||
case 1:
|
||
pos += 8
|
||
case 5:
|
||
pos += 4
|
||
default:
|
||
pos = len(deltaMsg)
|
||
}
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
func (l *LocalLSClient) StreamLegacyChat(ctx context.Context, token string, messages []ChatMessage, modelEnum int, modelName string) (string, error) {
|
||
reqBody := BuildRawGetChatMessageRequest(token, messages, modelEnum, modelName)
|
||
|
||
respData, err := l.grpcUnaryRaw(ctx, RawGetChatMessageRPC, reqBody)
|
||
if err != nil {
|
||
if strings.Contains(err.Error(), "panel state not found") || strings.Contains(err.Error(), "not_found") {
|
||
_ = l.ForceWarmupCascade(ctx, token)
|
||
respData, err = l.grpcUnaryRaw(ctx, RawGetChatMessageRPC, reqBody)
|
||
if err != nil {
|
||
return "", fmt.Errorf("legacy chat retry: %w", err)
|
||
}
|
||
} else {
|
||
return "", fmt.Errorf("legacy chat: %w", err)
|
||
}
|
||
}
|
||
|
||
delta := ParseRawChatResponse(respData)
|
||
if delta.IsError {
|
||
return "", fmt.Errorf("legacy chat error: %s", delta.Text)
|
||
}
|
||
|
||
return SanitizePath(delta.Text), nil
|
||
}
|