235 lines
5.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package windsurf
import (
"context"
"fmt"
"strings"
"time"
)
const (
RawGetChatMessageRPC = "/exa.language_server_pb.LanguageServerService/RawGetChatMessage"
SourceUser = 1
SourceSystem = 2
SourceAssistant = 3
SourceTool = 4
)
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
// Images 当前消息携带的图片(通常只有 user role 才非空)。
// 仅用于发送/replay。CascadeImage.Base64Data 不应出现在持久化/日志/指纹中。
Images []CascadeImage `json:"images,omitempty"`
// ImageDigests 仅摘要视图(含 sha256 / mime / byte_len / caption不含 base64
// 供 conversation pool 指纹与日志使用。
ImageDigests []ImageDigest `json:"image_digests,omitempty"`
}
type LegacyChatDelta struct {
Text string
InProgress bool
IsError bool
}
func encodeTimestamp() []byte {
now := time.Now()
secs := uint64(now.Unix())
nanos := uint64(now.Nanosecond())
out := encodeVarintField(1, secs)
if nanos > 0 {
out = append(out, encodeVarintField(2, nanos)...)
}
return out
}
func buildChatMessage(content string, source int, conversationID string) []byte {
var parts []byte
parts = append(parts, encodeStringField(1, generateUUID())...)
parts = append(parts, encodeVarintField(2, uint64(source))...)
parts = append(parts, encodeBytesField(3, encodeTimestamp())...)
parts = append(parts, encodeStringField(4, conversationID)...)
if source == SourceAssistant {
actionGeneric := encodeStringField(1, content)
action := encodeBytesField(1, actionGeneric)
parts = append(parts, encodeBytesField(6, action)...)
} else {
intentGeneric := encodeStringField(1, content)
intent := encodeBytesField(1, intentGeneric)
parts = append(parts, encodeBytesField(5, intent)...)
}
return parts
}
func BuildRawGetChatMessageRequest(apiKey string, messages []ChatMessage, modelEnum int, modelName string) []byte {
var parts []byte
conversationID := generateUUID()
parts = append(parts, encodeBytesField(1, buildMetadata(apiKey, generateUUID()))...)
var systemPrompt string
for _, msg := range messages {
if msg.Role == "system" {
if systemPrompt != "" {
systemPrompt += "\n"
}
systemPrompt += msg.Content
continue
}
var source int
var text string
switch msg.Role {
case "user":
source = SourceUser
text = msg.Content
case "assistant":
source = SourceAssistant
text = msg.Content
case "tool":
source = SourceUser
text = "[tool result]: " + msg.Content
default:
source = SourceUser
text = msg.Content
}
parts = append(parts, encodeBytesField(2, buildChatMessage(text, source, conversationID))...)
}
if systemPrompt != "" {
parts = append(parts, encodeStringField(3, systemPrompt)...)
}
parts = append(parts, encodeVarintField(4, uint64(modelEnum))...)
if modelName != "" {
parts = append(parts, encodeStringField(5, modelName)...)
}
return parts
}
func ParseRawChatResponse(data []byte) LegacyChatDelta {
pos := 0
var deltaMsg []byte
for pos < len(data) {
tag, np, ok := ReadVarint(data, pos)
if !ok {
break
}
pos = np
fieldNum := tag >> 3
wireType := tag & 7
switch wireType {
case 2:
length, np2, ok := ReadVarint(data, pos)
if !ok {
return LegacyChatDelta{}
}
pos = np2
if pos+int(length) > len(data) {
return LegacyChatDelta{}
}
field := data[pos : pos+int(length)]
pos += int(length)
if fieldNum == 1 {
deltaMsg = field
}
case 0:
_, np2, ok := ReadVarint(data, pos)
if !ok {
return LegacyChatDelta{}
}
pos = np2
case 1:
pos += 8
case 5:
pos += 4
default:
return LegacyChatDelta{}
}
}
if deltaMsg == nil {
return LegacyChatDelta{}
}
var result LegacyChatDelta
pos = 0
for pos < len(deltaMsg) {
tag, np, ok := ReadVarint(deltaMsg, pos)
if !ok {
break
}
pos = np
fieldNum := tag >> 3
wireType := tag & 7
switch wireType {
case 2:
length, np2, ok := ReadVarint(deltaMsg, pos)
if !ok {
return result
}
pos = np2
if pos+int(length) > len(deltaMsg) {
return result
}
field := deltaMsg[pos : pos+int(length)]
pos += int(length)
if fieldNum == 5 {
result.Text = string(field)
}
case 0:
val, np2, ok := ReadVarint(deltaMsg, pos)
if !ok {
return result
}
pos = np2
if fieldNum == 6 {
result.InProgress = val != 0
} else if fieldNum == 7 {
result.IsError = val != 0
}
case 1:
pos += 8
case 5:
pos += 4
default:
pos = len(deltaMsg)
}
}
return result
}
func (l *LocalLSClient) StreamLegacyChat(ctx context.Context, token string, messages []ChatMessage, modelEnum int, modelName string) (string, error) {
reqBody := BuildRawGetChatMessageRequest(token, messages, modelEnum, modelName)
respData, err := l.grpcUnaryRaw(ctx, RawGetChatMessageRPC, reqBody)
if err != nil {
if strings.Contains(err.Error(), "panel state not found") || strings.Contains(err.Error(), "not_found") {
_ = l.ForceWarmupCascade(ctx, token)
respData, err = l.grpcUnaryRaw(ctx, RawGetChatMessageRPC, reqBody)
if err != nil {
return "", fmt.Errorf("legacy chat retry: %w", err)
}
} else {
return "", fmt.Errorf("legacy chat: %w", err)
}
}
delta := ParseRawChatResponse(respData)
if delta.IsError {
return "", fmt.Errorf("legacy chat error: %s", delta.Text)
}
return SanitizePath(delta.Text), nil
}