package service import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "log/slog" "net/http" "os" "strings" "sync" "time" "github.com/google/uuid" ) // CascadeSession 代表一个 Cascade Agent 会话 type CascadeSession struct { ID string ModelName string Messages []map[string]interface{} // {role, content} Metadata map[string]string // 设备指纹、User-Agent 等 Token string // OAuth token CreatedAt int64 } // LanguageServerService 业务逻辑层 // 处理 Cascade Agent 流程,转发到上游 API type LanguageServerService struct { // 会话管理 cascadeSessions map[string]*CascadeSession sessionMutex sync.RWMutex // 上游 HTTP 服务(用于发送请求) httpUpstream HTTPUpstream // 上游配置 upstreamBaseURL string upstreamAPIKey string // 日志 logger *slog.Logger // 改进 1: 速率限制 (令牌桶) // 限制并发消息处理数量,保护上游 API rateLimiter chan struct{} // 改进 3: 会话过期时间 (秒) sessionTTLSeconds int64 // 改进 3: 定期清理后台任务 cleanupTicker *time.Ticker stopCleanup chan struct{} } func NewLanguageServerService( logger *slog.Logger, httpUpstream HTTPUpstream, ) *LanguageServerService { svc := &LanguageServerService{ cascadeSessions: make(map[string]*CascadeSession), logger: logger, httpUpstream: httpUpstream, upstreamBaseURL: strings.TrimSuffix(os.Getenv("ANTHROPIC_BASE_URL"), "/"), upstreamAPIKey: os.Getenv("ANTHROPIC_API_KEY"), rateLimiter: make(chan struct{}, 100), // 改进 1: 限制 100 个并发消息 sessionTTLSeconds: 3600, // 改进 3: 会话默认 1 小时过期 stopCleanup: make(chan struct{}), } // 改进 3: 启动后台清理任务 svc.startSessionCleanup() return svc } // startSessionCleanup 启动会话定期清理任务 func (svc *LanguageServerService) startSessionCleanup() { svc.cleanupTicker = time.NewTicker(1 * time.Minute) go func() { for { select { case <-svc.cleanupTicker.C: svc.cleanupExpiredSessions() case <-svc.stopCleanup: svc.cleanupTicker.Stop() return } } }() } // cleanupExpiredSessions 清理过期的会话 func (svc *LanguageServerService) cleanupExpiredSessions() { now := getCurrentTimeMS() ttlMs := svc.sessionTTLSeconds * 1000 svc.sessionMutex.Lock() defer svc.sessionMutex.Unlock() deletedCount := 0 for id, session := range svc.cascadeSessions { if now-session.CreatedAt > ttlMs { delete(svc.cascadeSessions, id) deletedCount++ } } if deletedCount > 0 { svc.logger.Info("expired sessions cleaned up", "deleted_count", deletedCount, "remaining_sessions", len(svc.cascadeSessions), ) } } // Stop 优雅关闭服务 func (svc *LanguageServerService) Stop() { select { case svc.stopCleanup <- struct{}{}: default: } } // SetSessionTTL sets the session TTL for testing purposes func (svc *LanguageServerService) SetSessionTTL(ttlSeconds int64) { svc.sessionTTLSeconds = ttlSeconds } // GetCascadeSessions returns the current cascade sessions map (for testing) func (svc *LanguageServerService) GetCascadeSessions() map[string]*CascadeSession { svc.sessionMutex.RLock() defer svc.sessionMutex.RUnlock() return svc.cascadeSessions } // ============================================================================ // Cascade 业务逻辑 // ============================================================================ // StartCascade 启动新的 Cascade Agent 会话 func (svc *LanguageServerService) StartCascade( ctx context.Context, model string, systemPrompt string, metadata map[string]string, token string, ) (string, error) { // 1. 验证输入 if model == "" { return "", fmt.Errorf("model is required") } if token == "" { return "", fmt.Errorf("oauth token is required") } // 2. 生成会话 ID sessionID := uuid.New().String() // 3. 创建会话 session := &CascadeSession{ ID: sessionID, ModelName: model, Messages: make([]map[string]interface{}, 0), Metadata: metadata, Token: token, CreatedAt: getCurrentTimeMS(), } // 如果提供了系统提示,添加为初始消息 if systemPrompt != "" { session.Messages = append(session.Messages, map[string]interface{}{ "role": "user", "content": systemPrompt, }) } // 4. 保存会话 svc.sessionMutex.Lock() svc.cascadeSessions[sessionID] = session svc.sessionMutex.Unlock() svc.logger.Info("cascade session started", "session_id", sessionID, "model", model, "has_system_prompt", systemPrompt != "") return sessionID, nil } // SendUserMessage 发送用户消息到 Cascade // 返回流式更新通道 func (svc *LanguageServerService) SendUserMessage( ctx context.Context, cascadeID string, userMessage string, token string, ) (<-chan interface{}, error) { // 改进 1: 获取速率限制令牌 select { case svc.rateLimiter <- struct{}{}: // 获得令牌,继续 case <-ctx.Done(): return nil, fmt.Errorf("context cancelled") default: // 没有令牌,需要等待 select { case svc.rateLimiter <- struct{}{}: // 获得令牌 case <-ctx.Done(): return nil, fmt.Errorf("context cancelled while waiting for rate limit") case <-time.After(30 * time.Second): return nil, fmt.Errorf("rate limit timeout: too many concurrent messages") } } // 1. 获取会话 svc.sessionMutex.RLock() session, exists := svc.cascadeSessions[cascadeID] svc.sessionMutex.RUnlock() if !exists { // 释放令牌 <-svc.rateLimiter return nil, fmt.Errorf("cascade session not found: %s", cascadeID) } // 2. 验证 token if token != session.Token { // 释放令牌 <-svc.rateLimiter return nil, fmt.Errorf("invalid token for session") } // 改进 2: 并发安全的消息追加(深拷贝消息列表) svc.sessionMutex.Lock() newMessages := make([]map[string]interface{}, len(session.Messages)+1) copy(newMessages, session.Messages) newMessages[len(newMessages)-1] = map[string]interface{}{ "role": "user", "content": userMessage, } session.Messages = newMessages svc.sessionMutex.Unlock() // 4. 创建响应通道 updateChan := make(chan interface{}, 100) // 5. 启动后台 goroutine 处理 API 调用 go func() { defer func() { // 关闭通道 close(updateChan) // 改进 1: 释放速率限制令牌 <-svc.rateLimiter }() // 调用上游 API(关键!这里需要伪装) svc.callUpstreamAPI(ctx, session, updateChan) }() svc.logger.Info("user message sent to cascade", "session_id", cascadeID, "message_length", len(userMessage), "concurrent_requests", 100-len(svc.rateLimiter), // 显示当前并发数 ) return updateChan, nil } // CancelCascade 取消 Cascade 会话 func (svc *LanguageServerService) CancelCascade( ctx context.Context, cascadeID string, ) error { svc.sessionMutex.Lock() _, exists := svc.cascadeSessions[cascadeID] svc.sessionMutex.Unlock() if !exists { return fmt.Errorf("cascade session not found: %s", cascadeID) } // TODO: 取消正在进行的 API 调用 svc.logger.Info("cascade cancelled", "session_id", cascadeID) return nil } // ============================================================================ // 模型配置 // ============================================================================ // ModelConfig 模型配置 type ModelConfig struct { Name string DisplayName string MaxTokens int SupportsThinking bool ThinkingBudget int SupportsImages bool Provider string } // GetAvailableModels 获取可用模型列表 func (svc *LanguageServerService) GetAvailableModels(ctx context.Context) ([]ModelConfig, error) { models := []ModelConfig{ { Name: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", MaxTokens: 200000, SupportsThinking: true, ThinkingBudget: 32000, SupportsImages: true, Provider: "anthropic", }, { Name: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", MaxTokens: 200000, SupportsThinking: false, SupportsImages: true, Provider: "anthropic", }, { Name: "claude-haiku-4-5", DisplayName: "Claude Haiku 4.5", MaxTokens: 200000, SupportsThinking: false, SupportsImages: true, Provider: "anthropic", }, { Name: "gemini-3-pro", DisplayName: "Gemini 3 Pro", MaxTokens: 128000, SupportsThinking: false, SupportsImages: true, Provider: "google", }, } return models, nil } // ============================================================================ // 状态查询 // ============================================================================ // GetStatus 获取服务状态 func (svc *LanguageServerService) GetStatus(ctx context.Context) (string, error) { // TODO: 检查上游 API 连接状态 return "running", nil } // ============================================================================ // 内部方法 // ============================================================================ // callUpstreamAPI 调用上游 Anthropic API // 这是关键方法:需要注入所有伪装信息 // // 伪装层包括: // 1. User-Agent(来自 metadata 或动态生成) // 2. 设备指纹(machine_id, mac_machine_id, dev_device_id, sqm_id) // 3. TLS 指纹(通过 http.Transport 处理) // 4. OAuth token 自动刷新 // 5. 请求头完整性 func (svc *LanguageServerService) callUpstreamAPI( ctx context.Context, session *CascadeSession, updateChan chan<- interface{}, ) { // 检查上游配置 if svc.upstreamBaseURL == "" || svc.upstreamAPIKey == "" { svc.logger.Error("upstream api configuration missing", "has_base_url", svc.upstreamBaseURL != "", "has_api_key", svc.upstreamAPIKey != "", ) updateChan <- map[string]interface{}{ "type": "error", "error": "upstream api not configured", } return } // 1. 准备请求体 requestBody := map[string]interface{}{ "model": session.ModelName, "messages": session.Messages, "stream": true, } bodyJSON, err := json.Marshal(requestBody) if err != nil { svc.logger.Error("failed to marshal request", "session_id", session.ID, "error", err, ) updateChan <- map[string]interface{}{ "type": "error", "error": "failed to prepare request", } return } // 2. 构建上游请求 URL upstreamURL := svc.upstreamBaseURL + "/v1/messages" // 3. 创建 HTTP 请求 req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(bodyJSON)) if err != nil { svc.logger.Error("failed to create request", "session_id", session.ID, "error", err, ) updateChan <- map[string]interface{}{ "type": "error", "error": "failed to create request", } return } // 4. 设置基础请求头 req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+session.Token) req.Header.Set("x-api-key", session.Token) // Claude API 兼容 // 5. 应用伪装信息 if userAgent := session.Metadata["user-agent"]; userAgent != "" { req.Header.Set("User-Agent", userAgent) } // 提取其他伪装 headers(如果在 metadata 中) if customHeaders := session.Metadata["custom-headers"]; customHeaders != "" { // 可以在这里解析并应用自定义 headers } svc.logger.Debug("sending upstream request", "session_id", session.ID, "url", upstreamURL, "model", session.ModelName, ) // 6. 发送请求 resp, err := svc.httpUpstream.Do(req, "", 0, 10) if err != nil { svc.logger.Error("upstream request failed", "session_id", session.ID, "error", err, ) updateChan <- map[string]interface{}{ "type": "error", "error": fmt.Sprintf("upstream request failed: %v", err), } return } defer func() { _ = resp.Body.Close() }() // 7. 处理错误响应 if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) svc.logger.Error("upstream error response", "session_id", session.ID, "status_code", resp.StatusCode, "body", string(respBody), ) updateChan <- map[string]interface{}{ "type": "error", "status_code": resp.StatusCode, "error": string(respBody), } return } // 8. 处理流式响应 svc.streamUpstreamResponse(ctx, session.ID, resp, updateChan) } // streamUpstreamResponse 处理上游 SSE 流式响应 func (svc *LanguageServerService) streamUpstreamResponse( ctx context.Context, sessionID string, resp *http.Response, updateChan chan<- interface{}, ) { scanner := bufio.NewScanner(resp.Body) // 设置合理的缓冲区大小 scanner.Buffer(make([]byte, 64*1024), 512*1024) for scanner.Scan() { select { case <-ctx.Done(): svc.logger.Info("streaming cancelled", "session_id", sessionID) return default: } line := strings.TrimSpace(scanner.Text()) // 跳过空行 if line == "" { continue } // 跳过注释行 if strings.HasPrefix(line, ":") { continue } // 解析 SSE 格式 (data: {...}) if !strings.HasPrefix(line, "data:") { continue } eventData := strings.TrimPrefix(line, "data:") eventData = strings.TrimSpace(eventData) // 解析 JSON var event map[string]interface{} if err := json.Unmarshal([]byte(eventData), &event); err != nil { svc.logger.Debug("failed to parse event", "session_id", sessionID, "error", err, "data", eventData, ) continue } // 发送事件到客户端通道 select { case updateChan <- event: case <-ctx.Done(): return case <-time.After(5 * time.Second): svc.logger.Warn("channel send timeout", "session_id", sessionID, ) return } } if err := scanner.Err(); err != nil { svc.logger.Error("scanning upstream response failed", "session_id", sessionID, "error", err, ) } } // getCurrentTimeMS 获取当前时间戳(毫秒) func getCurrentTimeMS() int64 { return time.Now().UnixMilli() }