package service import ( "bufio" "context" "encoding/json" "fmt" "io" "log/slog" "strings" "sync" "time" "github.com/google/uuid" ) // CascadeSession 代表一个 Cascade Agent 会话 type CascadeSession struct { ID string ModelName string Messages []map[string]interface{} // {role, content} Metadata map[string]string // 设备指纹、User-Agent 等 Token string // OAuth token CreatedAt int64 } // LanguageServerService 业务逻辑层 // 处理 Cascade Agent 流程,通过 AntigravityGatewayService 转发到上游 API type LanguageServerService struct { // 会话管理 cascadeSessions map[string]*CascadeSession sessionMutex sync.RWMutex // 上游 HTTP 服务(用于发送请求) httpUpstream HTTPUpstream // Antigravity 网关(账号池调度 + TLS 指纹 + token 刷新) antigravitySvc *AntigravityGatewayService accountRepo AccountRepository // 日志 logger *slog.Logger // 改进 1: 速率限制 (令牌桶) // 限制并发消息处理数量,保护上游 API rateLimiter chan struct{} // 改进 3: 会话过期时间 (秒) sessionTTLSeconds int64 // 改进 3: 定期清理后台任务 cleanupTicker *time.Ticker stopCleanup chan struct{} } func NewLanguageServerService( logger *slog.Logger, httpUpstream HTTPUpstream, antigravitySvc *AntigravityGatewayService, accountRepo AccountRepository, ) *LanguageServerService { svc := &LanguageServerService{ cascadeSessions: make(map[string]*CascadeSession), logger: logger, httpUpstream: httpUpstream, antigravitySvc: antigravitySvc, accountRepo: accountRepo, rateLimiter: make(chan struct{}, 100), // 改进 1: 限制 100 个并发消息 sessionTTLSeconds: 3600, // 改进 3: 会话默认 1 小时过期 stopCleanup: make(chan struct{}), } // 改进 3: 启动后台清理任务 svc.startSessionCleanup() return svc } // startSessionCleanup 启动会话定期清理任务 func (svc *LanguageServerService) startSessionCleanup() { svc.cleanupTicker = time.NewTicker(1 * time.Minute) go func() { for { select { case <-svc.cleanupTicker.C: svc.cleanupExpiredSessions() case <-svc.stopCleanup: svc.cleanupTicker.Stop() return } } }() } // cleanupExpiredSessions 清理过期的会话 func (svc *LanguageServerService) cleanupExpiredSessions() { now := getCurrentTimeMS() ttlMs := svc.sessionTTLSeconds * 1000 svc.sessionMutex.Lock() defer svc.sessionMutex.Unlock() deletedCount := 0 for id, session := range svc.cascadeSessions { if now-session.CreatedAt > ttlMs { delete(svc.cascadeSessions, id) deletedCount++ } } if deletedCount > 0 { svc.logger.Info("expired sessions cleaned up", "deleted_count", deletedCount, "remaining_sessions", len(svc.cascadeSessions), ) } } // Stop 优雅关闭服务 func (svc *LanguageServerService) Stop() { select { case svc.stopCleanup <- struct{}{}: default: } } // SetSessionTTL sets the session TTL for testing purposes func (svc *LanguageServerService) SetSessionTTL(ttlSeconds int64) { svc.sessionTTLSeconds = ttlSeconds } // GetCascadeSessions returns the current cascade sessions map (for testing) func (svc *LanguageServerService) GetCascadeSessions() map[string]*CascadeSession { svc.sessionMutex.RLock() defer svc.sessionMutex.RUnlock() return svc.cascadeSessions } // ============================================================================ // Cascade 业务逻辑 // ============================================================================ // StartCascade 启动新的 Cascade Agent 会话 func (svc *LanguageServerService) StartCascade( ctx context.Context, model string, systemPrompt string, metadata map[string]string, token string, ) (string, error) { // 1. 验证输入 if model == "" { return "", fmt.Errorf("model is required") } if token == "" { return "", fmt.Errorf("oauth token is required") } // 2. 生成会话 ID sessionID := uuid.New().String() // 3. 创建会话 session := &CascadeSession{ ID: sessionID, ModelName: model, Messages: make([]map[string]interface{}, 0), Metadata: metadata, Token: token, CreatedAt: getCurrentTimeMS(), } // 如果提供了系统提示,添加为初始消息 if systemPrompt != "" { session.Messages = append(session.Messages, map[string]interface{}{ "role": "user", "content": systemPrompt, }) } // 4. 保存会话 svc.sessionMutex.Lock() svc.cascadeSessions[sessionID] = session svc.sessionMutex.Unlock() svc.logger.Info("cascade session started", "session_id", sessionID, "model", model, "has_system_prompt", systemPrompt != "") return sessionID, nil } // SendUserMessage 发送用户消息到 Cascade // 返回流式更新通道 func (svc *LanguageServerService) SendUserMessage( ctx context.Context, cascadeID string, userMessage string, token string, ) (<-chan interface{}, error) { // 改进 1: 获取速率限制令牌 select { case svc.rateLimiter <- struct{}{}: // 获得令牌,继续 case <-ctx.Done(): return nil, fmt.Errorf("context cancelled") default: // 没有令牌,需要等待 select { case svc.rateLimiter <- struct{}{}: // 获得令牌 case <-ctx.Done(): return nil, fmt.Errorf("context cancelled while waiting for rate limit") case <-time.After(30 * time.Second): return nil, fmt.Errorf("rate limit timeout: too many concurrent messages") } } // 1. 获取会话 svc.sessionMutex.RLock() session, exists := svc.cascadeSessions[cascadeID] svc.sessionMutex.RUnlock() if !exists { // 释放令牌 <-svc.rateLimiter return nil, fmt.Errorf("cascade session not found: %s", cascadeID) } // 2. 验证 token if token != session.Token { // 释放令牌 <-svc.rateLimiter return nil, fmt.Errorf("invalid token for session") } // 改进 2: 并发安全的消息追加(深拷贝消息列表) svc.sessionMutex.Lock() newMessages := make([]map[string]interface{}, len(session.Messages)+1) copy(newMessages, session.Messages) newMessages[len(newMessages)-1] = map[string]interface{}{ "role": "user", "content": userMessage, } session.Messages = newMessages svc.sessionMutex.Unlock() // 4. 创建响应通道 updateChan := make(chan interface{}, 100) // 5. 启动后台 goroutine 处理 API 调用 go func() { defer func() { // 关闭通道 close(updateChan) // 改进 1: 释放速率限制令牌 <-svc.rateLimiter }() // 调用上游 API(关键!这里需要伪装) svc.callUpstreamAPI(ctx, session, updateChan) }() svc.logger.Info("user message sent to cascade", "session_id", cascadeID, "message_length", len(userMessage), "concurrent_requests", 100-len(svc.rateLimiter), // 显示当前并发数 ) return updateChan, nil } // CancelCascade 取消 Cascade 会话 func (svc *LanguageServerService) CancelCascade( ctx context.Context, cascadeID string, ) error { svc.sessionMutex.Lock() _, exists := svc.cascadeSessions[cascadeID] svc.sessionMutex.Unlock() if !exists { return fmt.Errorf("cascade session not found: %s", cascadeID) } // TODO: 取消正在进行的 API 调用 svc.logger.Info("cascade cancelled", "session_id", cascadeID) return nil } // ============================================================================ // 模型配置 // ============================================================================ // ModelConfig 模型配置 type ModelConfig struct { Name string DisplayName string MaxTokens int SupportsThinking bool ThinkingBudget int SupportsImages bool Provider string } // GetAvailableModels 获取可用模型列表 func (svc *LanguageServerService) GetAvailableModels(ctx context.Context) ([]ModelConfig, error) { models := []ModelConfig{ { Name: "claude-opus-4-7", DisplayName: "Claude Opus 4.7", MaxTokens: 200000, SupportsThinking: true, ThinkingBudget: 32000, SupportsImages: true, Provider: "anthropic", }, { Name: "claude-sonnet-4-7", DisplayName: "Claude Sonnet 4.7", MaxTokens: 200000, SupportsThinking: true, ThinkingBudget: 16000, SupportsImages: true, Provider: "anthropic", }, { Name: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", MaxTokens: 200000, SupportsThinking: true, ThinkingBudget: 32000, SupportsImages: true, Provider: "anthropic", }, { Name: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", MaxTokens: 200000, SupportsThinking: false, SupportsImages: true, Provider: "anthropic", }, { Name: "claude-haiku-4-5", DisplayName: "Claude Haiku 4.5", MaxTokens: 200000, SupportsThinking: false, SupportsImages: true, Provider: "anthropic", }, { Name: "gemini-3-pro", DisplayName: "Gemini 3 Pro", MaxTokens: 128000, SupportsThinking: false, SupportsImages: true, Provider: "google", }, } return models, nil } // ============================================================================ // 状态查询 // ============================================================================ // GetStatus 获取服务状态 func (svc *LanguageServerService) GetStatus(ctx context.Context) (string, error) { // TODO: 检查上游 API 连接状态 return "running", nil } // ============================================================================ // 内部方法 // ============================================================================ // callUpstreamAPI 通过 AntigravityGatewayService 调用上游 API。 // 复用账号池调度、模型映射、TLS 指纹伪装、token 刷新和重试逻辑。 func (svc *LanguageServerService) callUpstreamAPI( ctx context.Context, session *CascadeSession, updateChan chan<- interface{}, ) { if svc.antigravitySvc == nil || svc.accountRepo == nil { updateChan <- map[string]interface{}{ "type": "error", "error": "antigravity gateway not configured", } return } // 1. 选取第一个可用的 Antigravity 账号 accounts, err := svc.accountRepo.ListByPlatform(ctx, PlatformAntigravity) if err != nil || len(accounts) == 0 { svc.logger.Error("no antigravity accounts available", "session_id", session.ID, "error", err) updateChan <- map[string]interface{}{ "type": "error", "error": "no antigravity accounts available", } return } account := &accounts[0] // 2. 准备 Claude 格式请求体 requestBody := map[string]interface{}{ "model": session.ModelName, "messages": session.Messages, "stream": true, "max_tokens": 8192, } bodyJSON, err := json.Marshal(requestBody) if err != nil { svc.logger.Error("failed to marshal request", "session_id", session.ID, "error", err) updateChan <- map[string]interface{}{ "type": "error", "error": "failed to prepare request", } return } svc.logger.Debug("forwarding via antigravity", "session_id", session.ID, "model", session.ModelName, "account_id", account.ID) // 3. 通过 AntigravityGatewayService 转发(完整 TLS 指纹 + token 刷新 + 重试) respBody, statusCode, err := svc.antigravitySvc.ForwardRaw(ctx, account, bodyJSON) if err != nil { svc.logger.Error("upstream request failed", "session_id", session.ID, "error", err) updateChan <- map[string]interface{}{ "type": "error", "error": fmt.Sprintf("upstream request failed: %v", err), } return } defer func() { _ = respBody.Close() }() // 4. 处理错误响应 if statusCode >= 400 { body, _ := io.ReadAll(io.LimitReader(respBody, 2<<20)) svc.logger.Error("upstream error response", "session_id", session.ID, "status_code", statusCode, "body", string(body)) updateChan <- map[string]interface{}{ "type": "error", "status_code": statusCode, "error": string(body), } return } // 5. 流式转发响应 svc.streamUpstreamResponse(ctx, session.ID, respBody, updateChan) } // streamUpstreamResponse 处理上游 SSE 流式响应 func (svc *LanguageServerService) streamUpstreamResponse( ctx context.Context, sessionID string, body io.ReadCloser, updateChan chan<- interface{}, ) { scanner := bufio.NewScanner(body) // 设置合理的缓冲区大小 scanner.Buffer(make([]byte, 64*1024), 512*1024) for scanner.Scan() { select { case <-ctx.Done(): svc.logger.Info("streaming cancelled", "session_id", sessionID) return default: } line := strings.TrimSpace(scanner.Text()) // 跳过空行 if line == "" { continue } // 跳过注释行 if strings.HasPrefix(line, ":") { continue } // 解析 SSE 格式 (data: {...}) if !strings.HasPrefix(line, "data:") { continue } eventData := strings.TrimPrefix(line, "data:") eventData = strings.TrimSpace(eventData) // 解析 JSON var event map[string]interface{} if err := json.Unmarshal([]byte(eventData), &event); err != nil { svc.logger.Debug("failed to parse event", "session_id", sessionID, "error", err, "data", eventData, ) continue } // 发送事件到客户端通道 select { case updateChan <- event: case <-ctx.Done(): return case <-time.After(5 * time.Second): svc.logger.Warn("channel send timeout", "session_id", sessionID, ) return } } if err := scanner.Err(); err != nil { svc.logger.Error("scanning upstream response failed", "session_id", sessionID, "error", err, ) } } // getCurrentTimeMS 获取当前时间戳(毫秒) func getCurrentTimeMS() int64 { return time.Now().UnixMilli() }