The test request was using maxOutputTokens: 1, which caused Google API to generate only 1 token. When decoded, this single token produced "It" as the response, making it look like an error. Changed: - Content: "." → "Test connection" (more meaningful prompt) - MaxTokens: 1 → 10 (enough tokens to verify connection is working) This fixes the issue where account test always showed "It" in the response, which was actually just the truncated output from the single-token generation. Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
560 lines
14 KiB
Go
560 lines
14 KiB
Go
package service
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"log/slog"
|
||
"net/http"
|
||
"os"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/google/uuid"
|
||
)
|
||
|
||
// CascadeSession 代表一个 Cascade Agent 会话
|
||
type CascadeSession struct {
|
||
ID string
|
||
ModelName string
|
||
Messages []map[string]interface{} // {role, content}
|
||
Metadata map[string]string // 设备指纹、User-Agent 等
|
||
Token string // OAuth token
|
||
CreatedAt int64
|
||
}
|
||
|
||
// LanguageServerService 业务逻辑层
|
||
// 处理 Cascade Agent 流程,转发到上游 API
|
||
type LanguageServerService struct {
|
||
// 会话管理
|
||
cascadeSessions map[string]*CascadeSession
|
||
sessionMutex sync.RWMutex
|
||
|
||
// 上游 HTTP 服务(用于发送请求)
|
||
httpUpstream HTTPUpstream
|
||
|
||
// 上游配置
|
||
upstreamBaseURL string
|
||
upstreamAPIKey string
|
||
|
||
// 日志
|
||
logger *slog.Logger
|
||
|
||
// 改进 1: 速率限制 (令牌桶)
|
||
// 限制并发消息处理数量,保护上游 API
|
||
rateLimiter chan struct{}
|
||
|
||
// 改进 3: 会话过期时间 (秒)
|
||
sessionTTLSeconds int64
|
||
|
||
// 改进 3: 定期清理后台任务
|
||
cleanupTicker *time.Ticker
|
||
stopCleanup chan struct{}
|
||
}
|
||
|
||
func NewLanguageServerService(
|
||
logger *slog.Logger,
|
||
httpUpstream HTTPUpstream,
|
||
) *LanguageServerService {
|
||
svc := &LanguageServerService{
|
||
cascadeSessions: make(map[string]*CascadeSession),
|
||
logger: logger,
|
||
httpUpstream: httpUpstream,
|
||
upstreamBaseURL: strings.TrimSuffix(os.Getenv("ANTHROPIC_BASE_URL"), "/"),
|
||
upstreamAPIKey: os.Getenv("ANTHROPIC_API_KEY"),
|
||
rateLimiter: make(chan struct{}, 100), // 改进 1: 限制 100 个并发消息
|
||
sessionTTLSeconds: 3600, // 改进 3: 会话默认 1 小时过期
|
||
stopCleanup: make(chan struct{}),
|
||
}
|
||
|
||
// 改进 3: 启动后台清理任务
|
||
svc.startSessionCleanup()
|
||
|
||
return svc
|
||
}
|
||
|
||
// startSessionCleanup 启动会话定期清理任务
|
||
func (svc *LanguageServerService) startSessionCleanup() {
|
||
svc.cleanupTicker = time.NewTicker(1 * time.Minute)
|
||
|
||
go func() {
|
||
for {
|
||
select {
|
||
case <-svc.cleanupTicker.C:
|
||
svc.cleanupExpiredSessions()
|
||
case <-svc.stopCleanup:
|
||
svc.cleanupTicker.Stop()
|
||
return
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
// cleanupExpiredSessions 清理过期的会话
|
||
func (svc *LanguageServerService) cleanupExpiredSessions() {
|
||
now := getCurrentTimeMS()
|
||
ttlMs := svc.sessionTTLSeconds * 1000
|
||
|
||
svc.sessionMutex.Lock()
|
||
defer svc.sessionMutex.Unlock()
|
||
|
||
deletedCount := 0
|
||
for id, session := range svc.cascadeSessions {
|
||
if now-session.CreatedAt > ttlMs {
|
||
delete(svc.cascadeSessions, id)
|
||
deletedCount++
|
||
}
|
||
}
|
||
|
||
if deletedCount > 0 {
|
||
svc.logger.Info("expired sessions cleaned up",
|
||
"deleted_count", deletedCount,
|
||
"remaining_sessions", len(svc.cascadeSessions),
|
||
)
|
||
}
|
||
}
|
||
|
||
// Stop 优雅关闭服务
|
||
func (svc *LanguageServerService) Stop() {
|
||
select {
|
||
case svc.stopCleanup <- struct{}{}:
|
||
default:
|
||
}
|
||
}
|
||
|
||
// SetSessionTTL sets the session TTL for testing purposes
|
||
func (svc *LanguageServerService) SetSessionTTL(ttlSeconds int64) {
|
||
svc.sessionTTLSeconds = ttlSeconds
|
||
}
|
||
|
||
// GetCascadeSessions returns the current cascade sessions map (for testing)
|
||
func (svc *LanguageServerService) GetCascadeSessions() map[string]*CascadeSession {
|
||
svc.sessionMutex.RLock()
|
||
defer svc.sessionMutex.RUnlock()
|
||
return svc.cascadeSessions
|
||
}
|
||
|
||
// ============================================================================
|
||
// Cascade 业务逻辑
|
||
// ============================================================================
|
||
|
||
// StartCascade 启动新的 Cascade Agent 会话
|
||
func (svc *LanguageServerService) StartCascade(
|
||
ctx context.Context,
|
||
model string,
|
||
systemPrompt string,
|
||
metadata map[string]string,
|
||
token string,
|
||
) (string, error) {
|
||
// 1. 验证输入
|
||
if model == "" {
|
||
return "", fmt.Errorf("model is required")
|
||
}
|
||
|
||
if token == "" {
|
||
return "", fmt.Errorf("oauth token is required")
|
||
}
|
||
|
||
// 2. 生成会话 ID
|
||
sessionID := uuid.New().String()
|
||
|
||
// 3. 创建会话
|
||
session := &CascadeSession{
|
||
ID: sessionID,
|
||
ModelName: model,
|
||
Messages: make([]map[string]interface{}, 0),
|
||
Metadata: metadata,
|
||
Token: token,
|
||
CreatedAt: getCurrentTimeMS(),
|
||
}
|
||
|
||
// 如果提供了系统提示,添加为初始消息
|
||
if systemPrompt != "" {
|
||
session.Messages = append(session.Messages, map[string]interface{}{
|
||
"role": "user",
|
||
"content": systemPrompt,
|
||
})
|
||
}
|
||
|
||
// 4. 保存会话
|
||
svc.sessionMutex.Lock()
|
||
svc.cascadeSessions[sessionID] = session
|
||
svc.sessionMutex.Unlock()
|
||
|
||
svc.logger.Info("cascade session started",
|
||
"session_id", sessionID,
|
||
"model", model,
|
||
"has_system_prompt", systemPrompt != "")
|
||
|
||
return sessionID, nil
|
||
}
|
||
|
||
// SendUserMessage 发送用户消息到 Cascade
|
||
// 返回流式更新通道
|
||
func (svc *LanguageServerService) SendUserMessage(
|
||
ctx context.Context,
|
||
cascadeID string,
|
||
userMessage string,
|
||
token string,
|
||
) (<-chan interface{}, error) {
|
||
// 改进 1: 获取速率限制令牌
|
||
select {
|
||
case svc.rateLimiter <- struct{}{}:
|
||
// 获得令牌,继续
|
||
case <-ctx.Done():
|
||
return nil, fmt.Errorf("context cancelled")
|
||
default:
|
||
// 没有令牌,需要等待
|
||
select {
|
||
case svc.rateLimiter <- struct{}{}:
|
||
// 获得令牌
|
||
case <-ctx.Done():
|
||
return nil, fmt.Errorf("context cancelled while waiting for rate limit")
|
||
case <-time.After(30 * time.Second):
|
||
return nil, fmt.Errorf("rate limit timeout: too many concurrent messages")
|
||
}
|
||
}
|
||
|
||
// 1. 获取会话
|
||
svc.sessionMutex.RLock()
|
||
session, exists := svc.cascadeSessions[cascadeID]
|
||
svc.sessionMutex.RUnlock()
|
||
|
||
if !exists {
|
||
// 释放令牌
|
||
<-svc.rateLimiter
|
||
return nil, fmt.Errorf("cascade session not found: %s", cascadeID)
|
||
}
|
||
|
||
// 2. 验证 token
|
||
if token != session.Token {
|
||
// 释放令牌
|
||
<-svc.rateLimiter
|
||
return nil, fmt.Errorf("invalid token for session")
|
||
}
|
||
|
||
// 改进 2: 并发安全的消息追加(深拷贝消息列表)
|
||
svc.sessionMutex.Lock()
|
||
newMessages := make([]map[string]interface{}, len(session.Messages)+1)
|
||
copy(newMessages, session.Messages)
|
||
newMessages[len(newMessages)-1] = map[string]interface{}{
|
||
"role": "user",
|
||
"content": userMessage,
|
||
}
|
||
session.Messages = newMessages
|
||
svc.sessionMutex.Unlock()
|
||
|
||
// 4. 创建响应通道
|
||
updateChan := make(chan interface{}, 100)
|
||
|
||
// 5. 启动后台 goroutine 处理 API 调用
|
||
go func() {
|
||
defer func() {
|
||
// 关闭通道
|
||
close(updateChan)
|
||
// 改进 1: 释放速率限制令牌
|
||
<-svc.rateLimiter
|
||
}()
|
||
|
||
// 调用上游 API(关键!这里需要伪装)
|
||
svc.callUpstreamAPI(ctx, session, updateChan)
|
||
}()
|
||
|
||
svc.logger.Info("user message sent to cascade",
|
||
"session_id", cascadeID,
|
||
"message_length", len(userMessage),
|
||
"concurrent_requests", 100-len(svc.rateLimiter), // 显示当前并发数
|
||
)
|
||
|
||
return updateChan, nil
|
||
}
|
||
|
||
// CancelCascade 取消 Cascade 会话
|
||
func (svc *LanguageServerService) CancelCascade(
|
||
ctx context.Context,
|
||
cascadeID string,
|
||
) error {
|
||
svc.sessionMutex.Lock()
|
||
_, exists := svc.cascadeSessions[cascadeID]
|
||
svc.sessionMutex.Unlock()
|
||
|
||
if !exists {
|
||
return fmt.Errorf("cascade session not found: %s", cascadeID)
|
||
}
|
||
|
||
// TODO: 取消正在进行的 API 调用
|
||
|
||
svc.logger.Info("cascade cancelled", "session_id", cascadeID)
|
||
return nil
|
||
}
|
||
|
||
// ============================================================================
|
||
// 模型配置
|
||
// ============================================================================
|
||
|
||
// ModelConfig 模型配置
|
||
type ModelConfig struct {
|
||
Name string
|
||
DisplayName string
|
||
MaxTokens int
|
||
SupportsThinking bool
|
||
ThinkingBudget int
|
||
SupportsImages bool
|
||
Provider string
|
||
}
|
||
|
||
// GetAvailableModels 获取可用模型列表
|
||
func (svc *LanguageServerService) GetAvailableModels(ctx context.Context) ([]ModelConfig, error) {
|
||
models := []ModelConfig{
|
||
{
|
||
Name: "claude-opus-4-6",
|
||
DisplayName: "Claude Opus 4.6",
|
||
MaxTokens: 200000,
|
||
SupportsThinking: true,
|
||
ThinkingBudget: 32000,
|
||
SupportsImages: true,
|
||
Provider: "anthropic",
|
||
},
|
||
{
|
||
Name: "claude-sonnet-4-6",
|
||
DisplayName: "Claude Sonnet 4.6",
|
||
MaxTokens: 200000,
|
||
SupportsThinking: false,
|
||
SupportsImages: true,
|
||
Provider: "anthropic",
|
||
},
|
||
{
|
||
Name: "claude-haiku-4-5",
|
||
DisplayName: "Claude Haiku 4.5",
|
||
MaxTokens: 200000,
|
||
SupportsThinking: false,
|
||
SupportsImages: true,
|
||
Provider: "anthropic",
|
||
},
|
||
{
|
||
Name: "gemini-3-pro",
|
||
DisplayName: "Gemini 3 Pro",
|
||
MaxTokens: 128000,
|
||
SupportsThinking: false,
|
||
SupportsImages: true,
|
||
Provider: "google",
|
||
},
|
||
}
|
||
|
||
return models, nil
|
||
}
|
||
|
||
// ============================================================================
|
||
// 状态查询
|
||
// ============================================================================
|
||
|
||
// GetStatus 获取服务状态
|
||
func (svc *LanguageServerService) GetStatus(ctx context.Context) (string, error) {
|
||
// TODO: 检查上游 API 连接状态
|
||
return "running", nil
|
||
}
|
||
|
||
// ============================================================================
|
||
// 内部方法
|
||
// ============================================================================
|
||
|
||
// callUpstreamAPI 调用上游 Anthropic API
|
||
// 这是关键方法:需要注入所有伪装信息
|
||
//
|
||
// 伪装层包括:
|
||
// 1. User-Agent(来自 metadata 或动态生成)
|
||
// 2. 设备指纹(machine_id, mac_machine_id, dev_device_id, sqm_id)
|
||
// 3. TLS 指纹(通过 http.Transport 处理)
|
||
// 4. OAuth token 自动刷新
|
||
// 5. 请求头完整性
|
||
func (svc *LanguageServerService) callUpstreamAPI(
|
||
ctx context.Context,
|
||
session *CascadeSession,
|
||
updateChan chan<- interface{},
|
||
) {
|
||
// 检查上游配置
|
||
if svc.upstreamBaseURL == "" || svc.upstreamAPIKey == "" {
|
||
svc.logger.Error("upstream api configuration missing",
|
||
"has_base_url", svc.upstreamBaseURL != "",
|
||
"has_api_key", svc.upstreamAPIKey != "",
|
||
)
|
||
updateChan <- map[string]interface{}{
|
||
"type": "error",
|
||
"error": "upstream api not configured",
|
||
}
|
||
return
|
||
}
|
||
|
||
// 1. 准备请求体
|
||
requestBody := map[string]interface{}{
|
||
"model": session.ModelName,
|
||
"messages": session.Messages,
|
||
"stream": true,
|
||
}
|
||
|
||
bodyJSON, err := json.Marshal(requestBody)
|
||
if err != nil {
|
||
svc.logger.Error("failed to marshal request",
|
||
"session_id", session.ID,
|
||
"error", err,
|
||
)
|
||
updateChan <- map[string]interface{}{
|
||
"type": "error",
|
||
"error": "failed to prepare request",
|
||
}
|
||
return
|
||
}
|
||
|
||
// 2. 构建上游请求 URL
|
||
upstreamURL := svc.upstreamBaseURL + "/v1/messages"
|
||
|
||
// 3. 创建 HTTP 请求
|
||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(bodyJSON))
|
||
if err != nil {
|
||
svc.logger.Error("failed to create request",
|
||
"session_id", session.ID,
|
||
"error", err,
|
||
)
|
||
updateChan <- map[string]interface{}{
|
||
"type": "error",
|
||
"error": "failed to create request",
|
||
}
|
||
return
|
||
}
|
||
|
||
// 4. 设置基础请求头
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Authorization", "Bearer "+session.Token)
|
||
req.Header.Set("x-api-key", session.Token) // Claude API 兼容
|
||
|
||
// 5. 应用伪装信息
|
||
if userAgent := session.Metadata["user-agent"]; userAgent != "" {
|
||
req.Header.Set("User-Agent", userAgent)
|
||
}
|
||
|
||
// 提取其他伪装 headers(如果在 metadata 中)
|
||
if customHeaders := session.Metadata["custom-headers"]; customHeaders != "" {
|
||
// 可以在这里解析并应用自定义 headers
|
||
}
|
||
|
||
svc.logger.Debug("sending upstream request",
|
||
"session_id", session.ID,
|
||
"url", upstreamURL,
|
||
"model", session.ModelName,
|
||
)
|
||
|
||
// 6. 发送请求
|
||
resp, err := svc.httpUpstream.Do(req, "", 0, 10)
|
||
if err != nil {
|
||
svc.logger.Error("upstream request failed",
|
||
"session_id", session.ID,
|
||
"error", err,
|
||
)
|
||
updateChan <- map[string]interface{}{
|
||
"type": "error",
|
||
"error": fmt.Sprintf("upstream request failed: %v", err),
|
||
}
|
||
return
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
// 7. 处理错误响应
|
||
if resp.StatusCode >= 400 {
|
||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||
svc.logger.Error("upstream error response",
|
||
"session_id", session.ID,
|
||
"status_code", resp.StatusCode,
|
||
"body", string(respBody),
|
||
)
|
||
updateChan <- map[string]interface{}{
|
||
"type": "error",
|
||
"status_code": resp.StatusCode,
|
||
"error": string(respBody),
|
||
}
|
||
return
|
||
}
|
||
|
||
// 8. 处理流式响应
|
||
svc.streamUpstreamResponse(ctx, session.ID, resp, updateChan)
|
||
}
|
||
|
||
// streamUpstreamResponse 处理上游 SSE 流式响应
|
||
func (svc *LanguageServerService) streamUpstreamResponse(
|
||
ctx context.Context,
|
||
sessionID string,
|
||
resp *http.Response,
|
||
updateChan chan<- interface{},
|
||
) {
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
// 设置合理的缓冲区大小
|
||
scanner.Buffer(make([]byte, 64*1024), 512*1024)
|
||
|
||
for scanner.Scan() {
|
||
select {
|
||
case <-ctx.Done():
|
||
svc.logger.Info("streaming cancelled", "session_id", sessionID)
|
||
return
|
||
default:
|
||
}
|
||
|
||
line := strings.TrimSpace(scanner.Text())
|
||
|
||
// 跳过空行
|
||
if line == "" {
|
||
continue
|
||
}
|
||
|
||
// 跳过注释行
|
||
if strings.HasPrefix(line, ":") {
|
||
continue
|
||
}
|
||
|
||
// 解析 SSE 格式 (data: {...})
|
||
if !strings.HasPrefix(line, "data:") {
|
||
continue
|
||
}
|
||
|
||
eventData := strings.TrimPrefix(line, "data:")
|
||
eventData = strings.TrimSpace(eventData)
|
||
|
||
// 解析 JSON
|
||
var event map[string]interface{}
|
||
if err := json.Unmarshal([]byte(eventData), &event); err != nil {
|
||
svc.logger.Debug("failed to parse event",
|
||
"session_id", sessionID,
|
||
"error", err,
|
||
"data", eventData,
|
||
)
|
||
continue
|
||
}
|
||
|
||
// 发送事件到客户端通道
|
||
select {
|
||
case updateChan <- event:
|
||
case <-ctx.Done():
|
||
return
|
||
case <-time.After(5 * time.Second):
|
||
svc.logger.Warn("channel send timeout",
|
||
"session_id", sessionID,
|
||
)
|
||
return
|
||
}
|
||
}
|
||
|
||
if err := scanner.Err(); err != nil {
|
||
svc.logger.Error("scanning upstream response failed",
|
||
"session_id", sessionID,
|
||
"error", err,
|
||
)
|
||
}
|
||
}
|
||
|
||
// getCurrentTimeMS 获取当前时间戳(毫秒)
|
||
func getCurrentTimeMS() int64 {
|
||
return time.Now().UnixMilli()
|
||
}
|