sub2api/backend/internal/service/language_server_service.go
win 12ae97b755 fix: Increase maxOutputTokens in Antigravity test request from 1 to 10
The test request was using maxOutputTokens: 1, which caused Google API to
generate only 1 token. When decoded, this single token produced "It" as the
response, making it look like an error.

Changed:
- Content: "." → "Test connection" (more meaningful prompt)
- MaxTokens: 1 → 10 (enough tokens to verify connection is working)

This fixes the issue where account test always showed "It" in the response,
which was actually just the truncated output from the single-token generation.

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
2026-04-11 18:49:53 +08:00

560 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
// CascadeSession 代表一个 Cascade Agent 会话
type CascadeSession struct {
ID string
ModelName string
Messages []map[string]interface{} // {role, content}
Metadata map[string]string // 设备指纹、User-Agent 等
Token string // OAuth token
CreatedAt int64
}
// LanguageServerService 业务逻辑层
// 处理 Cascade Agent 流程,转发到上游 API
type LanguageServerService struct {
// 会话管理
cascadeSessions map[string]*CascadeSession
sessionMutex sync.RWMutex
// 上游 HTTP 服务(用于发送请求)
httpUpstream HTTPUpstream
// 上游配置
upstreamBaseURL string
upstreamAPIKey string
// 日志
logger *slog.Logger
// 改进 1: 速率限制 (令牌桶)
// 限制并发消息处理数量,保护上游 API
rateLimiter chan struct{}
// 改进 3: 会话过期时间 (秒)
sessionTTLSeconds int64
// 改进 3: 定期清理后台任务
cleanupTicker *time.Ticker
stopCleanup chan struct{}
}
func NewLanguageServerService(
logger *slog.Logger,
httpUpstream HTTPUpstream,
) *LanguageServerService {
svc := &LanguageServerService{
cascadeSessions: make(map[string]*CascadeSession),
logger: logger,
httpUpstream: httpUpstream,
upstreamBaseURL: strings.TrimSuffix(os.Getenv("ANTHROPIC_BASE_URL"), "/"),
upstreamAPIKey: os.Getenv("ANTHROPIC_API_KEY"),
rateLimiter: make(chan struct{}, 100), // 改进 1: 限制 100 个并发消息
sessionTTLSeconds: 3600, // 改进 3: 会话默认 1 小时过期
stopCleanup: make(chan struct{}),
}
// 改进 3: 启动后台清理任务
svc.startSessionCleanup()
return svc
}
// startSessionCleanup 启动会话定期清理任务
func (svc *LanguageServerService) startSessionCleanup() {
svc.cleanupTicker = time.NewTicker(1 * time.Minute)
go func() {
for {
select {
case <-svc.cleanupTicker.C:
svc.cleanupExpiredSessions()
case <-svc.stopCleanup:
svc.cleanupTicker.Stop()
return
}
}
}()
}
// cleanupExpiredSessions 清理过期的会话
func (svc *LanguageServerService) cleanupExpiredSessions() {
now := getCurrentTimeMS()
ttlMs := svc.sessionTTLSeconds * 1000
svc.sessionMutex.Lock()
defer svc.sessionMutex.Unlock()
deletedCount := 0
for id, session := range svc.cascadeSessions {
if now-session.CreatedAt > ttlMs {
delete(svc.cascadeSessions, id)
deletedCount++
}
}
if deletedCount > 0 {
svc.logger.Info("expired sessions cleaned up",
"deleted_count", deletedCount,
"remaining_sessions", len(svc.cascadeSessions),
)
}
}
// Stop 优雅关闭服务
func (svc *LanguageServerService) Stop() {
select {
case svc.stopCleanup <- struct{}{}:
default:
}
}
// SetSessionTTL sets the session TTL for testing purposes
func (svc *LanguageServerService) SetSessionTTL(ttlSeconds int64) {
svc.sessionTTLSeconds = ttlSeconds
}
// GetCascadeSessions returns the current cascade sessions map (for testing)
func (svc *LanguageServerService) GetCascadeSessions() map[string]*CascadeSession {
svc.sessionMutex.RLock()
defer svc.sessionMutex.RUnlock()
return svc.cascadeSessions
}
// ============================================================================
// Cascade 业务逻辑
// ============================================================================
// StartCascade 启动新的 Cascade Agent 会话
func (svc *LanguageServerService) StartCascade(
ctx context.Context,
model string,
systemPrompt string,
metadata map[string]string,
token string,
) (string, error) {
// 1. 验证输入
if model == "" {
return "", fmt.Errorf("model is required")
}
if token == "" {
return "", fmt.Errorf("oauth token is required")
}
// 2. 生成会话 ID
sessionID := uuid.New().String()
// 3. 创建会话
session := &CascadeSession{
ID: sessionID,
ModelName: model,
Messages: make([]map[string]interface{}, 0),
Metadata: metadata,
Token: token,
CreatedAt: getCurrentTimeMS(),
}
// 如果提供了系统提示,添加为初始消息
if systemPrompt != "" {
session.Messages = append(session.Messages, map[string]interface{}{
"role": "user",
"content": systemPrompt,
})
}
// 4. 保存会话
svc.sessionMutex.Lock()
svc.cascadeSessions[sessionID] = session
svc.sessionMutex.Unlock()
svc.logger.Info("cascade session started",
"session_id", sessionID,
"model", model,
"has_system_prompt", systemPrompt != "")
return sessionID, nil
}
// SendUserMessage 发送用户消息到 Cascade
// 返回流式更新通道
func (svc *LanguageServerService) SendUserMessage(
ctx context.Context,
cascadeID string,
userMessage string,
token string,
) (<-chan interface{}, error) {
// 改进 1: 获取速率限制令牌
select {
case svc.rateLimiter <- struct{}{}:
// 获得令牌,继续
case <-ctx.Done():
return nil, fmt.Errorf("context cancelled")
default:
// 没有令牌,需要等待
select {
case svc.rateLimiter <- struct{}{}:
// 获得令牌
case <-ctx.Done():
return nil, fmt.Errorf("context cancelled while waiting for rate limit")
case <-time.After(30 * time.Second):
return nil, fmt.Errorf("rate limit timeout: too many concurrent messages")
}
}
// 1. 获取会话
svc.sessionMutex.RLock()
session, exists := svc.cascadeSessions[cascadeID]
svc.sessionMutex.RUnlock()
if !exists {
// 释放令牌
<-svc.rateLimiter
return nil, fmt.Errorf("cascade session not found: %s", cascadeID)
}
// 2. 验证 token
if token != session.Token {
// 释放令牌
<-svc.rateLimiter
return nil, fmt.Errorf("invalid token for session")
}
// 改进 2: 并发安全的消息追加(深拷贝消息列表)
svc.sessionMutex.Lock()
newMessages := make([]map[string]interface{}, len(session.Messages)+1)
copy(newMessages, session.Messages)
newMessages[len(newMessages)-1] = map[string]interface{}{
"role": "user",
"content": userMessage,
}
session.Messages = newMessages
svc.sessionMutex.Unlock()
// 4. 创建响应通道
updateChan := make(chan interface{}, 100)
// 5. 启动后台 goroutine 处理 API 调用
go func() {
defer func() {
// 关闭通道
close(updateChan)
// 改进 1: 释放速率限制令牌
<-svc.rateLimiter
}()
// 调用上游 API关键这里需要伪装
svc.callUpstreamAPI(ctx, session, updateChan)
}()
svc.logger.Info("user message sent to cascade",
"session_id", cascadeID,
"message_length", len(userMessage),
"concurrent_requests", 100-len(svc.rateLimiter), // 显示当前并发数
)
return updateChan, nil
}
// CancelCascade 取消 Cascade 会话
func (svc *LanguageServerService) CancelCascade(
ctx context.Context,
cascadeID string,
) error {
svc.sessionMutex.Lock()
_, exists := svc.cascadeSessions[cascadeID]
svc.sessionMutex.Unlock()
if !exists {
return fmt.Errorf("cascade session not found: %s", cascadeID)
}
// TODO: 取消正在进行的 API 调用
svc.logger.Info("cascade cancelled", "session_id", cascadeID)
return nil
}
// ============================================================================
// 模型配置
// ============================================================================
// ModelConfig 模型配置
type ModelConfig struct {
Name string
DisplayName string
MaxTokens int
SupportsThinking bool
ThinkingBudget int
SupportsImages bool
Provider string
}
// GetAvailableModels 获取可用模型列表
func (svc *LanguageServerService) GetAvailableModels(ctx context.Context) ([]ModelConfig, error) {
models := []ModelConfig{
{
Name: "claude-opus-4-6",
DisplayName: "Claude Opus 4.6",
MaxTokens: 200000,
SupportsThinking: true,
ThinkingBudget: 32000,
SupportsImages: true,
Provider: "anthropic",
},
{
Name: "claude-sonnet-4-6",
DisplayName: "Claude Sonnet 4.6",
MaxTokens: 200000,
SupportsThinking: false,
SupportsImages: true,
Provider: "anthropic",
},
{
Name: "claude-haiku-4-5",
DisplayName: "Claude Haiku 4.5",
MaxTokens: 200000,
SupportsThinking: false,
SupportsImages: true,
Provider: "anthropic",
},
{
Name: "gemini-3-pro",
DisplayName: "Gemini 3 Pro",
MaxTokens: 128000,
SupportsThinking: false,
SupportsImages: true,
Provider: "google",
},
}
return models, nil
}
// ============================================================================
// 状态查询
// ============================================================================
// GetStatus 获取服务状态
func (svc *LanguageServerService) GetStatus(ctx context.Context) (string, error) {
// TODO: 检查上游 API 连接状态
return "running", nil
}
// ============================================================================
// 内部方法
// ============================================================================
// callUpstreamAPI 调用上游 Anthropic API
// 这是关键方法:需要注入所有伪装信息
//
// 伪装层包括:
// 1. User-Agent来自 metadata 或动态生成)
// 2. 设备指纹machine_id, mac_machine_id, dev_device_id, sqm_id
// 3. TLS 指纹(通过 http.Transport 处理)
// 4. OAuth token 自动刷新
// 5. 请求头完整性
func (svc *LanguageServerService) callUpstreamAPI(
ctx context.Context,
session *CascadeSession,
updateChan chan<- interface{},
) {
// 检查上游配置
if svc.upstreamBaseURL == "" || svc.upstreamAPIKey == "" {
svc.logger.Error("upstream api configuration missing",
"has_base_url", svc.upstreamBaseURL != "",
"has_api_key", svc.upstreamAPIKey != "",
)
updateChan <- map[string]interface{}{
"type": "error",
"error": "upstream api not configured",
}
return
}
// 1. 准备请求体
requestBody := map[string]interface{}{
"model": session.ModelName,
"messages": session.Messages,
"stream": true,
}
bodyJSON, err := json.Marshal(requestBody)
if err != nil {
svc.logger.Error("failed to marshal request",
"session_id", session.ID,
"error", err,
)
updateChan <- map[string]interface{}{
"type": "error",
"error": "failed to prepare request",
}
return
}
// 2. 构建上游请求 URL
upstreamURL := svc.upstreamBaseURL + "/v1/messages"
// 3. 创建 HTTP 请求
req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(bodyJSON))
if err != nil {
svc.logger.Error("failed to create request",
"session_id", session.ID,
"error", err,
)
updateChan <- map[string]interface{}{
"type": "error",
"error": "failed to create request",
}
return
}
// 4. 设置基础请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+session.Token)
req.Header.Set("x-api-key", session.Token) // Claude API 兼容
// 5. 应用伪装信息
if userAgent := session.Metadata["user-agent"]; userAgent != "" {
req.Header.Set("User-Agent", userAgent)
}
// 提取其他伪装 headers如果在 metadata 中)
if customHeaders := session.Metadata["custom-headers"]; customHeaders != "" {
// 可以在这里解析并应用自定义 headers
}
svc.logger.Debug("sending upstream request",
"session_id", session.ID,
"url", upstreamURL,
"model", session.ModelName,
)
// 6. 发送请求
resp, err := svc.httpUpstream.Do(req, "", 0, 10)
if err != nil {
svc.logger.Error("upstream request failed",
"session_id", session.ID,
"error", err,
)
updateChan <- map[string]interface{}{
"type": "error",
"error": fmt.Sprintf("upstream request failed: %v", err),
}
return
}
defer func() { _ = resp.Body.Close() }()
// 7. 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
svc.logger.Error("upstream error response",
"session_id", session.ID,
"status_code", resp.StatusCode,
"body", string(respBody),
)
updateChan <- map[string]interface{}{
"type": "error",
"status_code": resp.StatusCode,
"error": string(respBody),
}
return
}
// 8. 处理流式响应
svc.streamUpstreamResponse(ctx, session.ID, resp, updateChan)
}
// streamUpstreamResponse 处理上游 SSE 流式响应
func (svc *LanguageServerService) streamUpstreamResponse(
ctx context.Context,
sessionID string,
resp *http.Response,
updateChan chan<- interface{},
) {
scanner := bufio.NewScanner(resp.Body)
// 设置合理的缓冲区大小
scanner.Buffer(make([]byte, 64*1024), 512*1024)
for scanner.Scan() {
select {
case <-ctx.Done():
svc.logger.Info("streaming cancelled", "session_id", sessionID)
return
default:
}
line := strings.TrimSpace(scanner.Text())
// 跳过空行
if line == "" {
continue
}
// 跳过注释行
if strings.HasPrefix(line, ":") {
continue
}
// 解析 SSE 格式 (data: {...})
if !strings.HasPrefix(line, "data:") {
continue
}
eventData := strings.TrimPrefix(line, "data:")
eventData = strings.TrimSpace(eventData)
// 解析 JSON
var event map[string]interface{}
if err := json.Unmarshal([]byte(eventData), &event); err != nil {
svc.logger.Debug("failed to parse event",
"session_id", sessionID,
"error", err,
"data", eventData,
)
continue
}
// 发送事件到客户端通道
select {
case updateChan <- event:
case <-ctx.Done():
return
case <-time.After(5 * time.Second):
svc.logger.Warn("channel send timeout",
"session_id", sessionID,
)
return
}
}
if err := scanner.Err(); err != nil {
svc.logger.Error("scanning upstream response failed",
"session_id", sessionID,
"error", err,
)
}
}
// getCurrentTimeMS 获取当前时间戳(毫秒)
func getCurrentTimeMS() int64 {
return time.Now().UnixMilli()
}