diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 6c9673e4..719dc3b7 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -205,7 +205,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService) settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) requestEventBus := service.NewRequestEventBus() - opsHandler := admin.NewOpsHandler(opsService, requestEventBus) + opsLogBroadcaster := service.ProvideOpsLogBroadcaster() + opsHandler := admin.NewOpsHandler(opsService, requestEventBus, opsLogBroadcaster) updateCache := repository.NewUpdateCache(redisClient) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) serviceBuildInfo := provideServiceBuildInfo(buildInfo) @@ -240,7 +241,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { windsurfAuthService := service.ProvideWindsurfAuthService(configConfig, accountRepository, proxyRepository, adminService) windsurfRefreshService := service.ProvideWindsurfRefreshService(configConfig, accountRepository, proxyRepository) windsurfProbeService := service.ProvideWindsurfProbeService(configConfig, accountRepository, proxyRepository) - windsurfHandler := handler.ProvideWindsurfHandler(windsurfAuthService, windsurfLSService, windsurfProbeService) + windsurfTierAccessService := service.ProvideWindsurfTierAccessService(configConfig, accountRepository) + windsurfHandler := handler.ProvideWindsurfHandler(windsurfAuthService, windsurfLSService, windsurfProbeService, windsurfTierAccessService) affiliateHandler := admin.NewAffiliateHandler(affiliateService, adminService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler, windsurfHandler, affiliateHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) @@ -260,7 +262,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) healthService := service.NewHealthService(db, redisClient) - engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, healthService, redisClient) + engine := server.ProvideRouter(configConfig, handlers, jwtAuthMiddleware, adminAuthMiddleware, apiKeyAuthMiddleware, apiKeyService, subscriptionService, opsService, settingService, healthService, redisClient, opsLogBroadcaster) httpServer := server.ProvideHTTPServer(configConfig, engine) opsMetricsCollector := service.ProvideOpsMetricsCollector(opsRepository, settingRepository, accountRepository, concurrencyService, db, redisClient, configConfig) opsAggregationService := service.ProvideOpsAggregationService(opsRepository, settingRepository, db, redisClient, configConfig) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 30d6db3f..fbbab69a 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -27,8 +27,15 @@ const ( ) // DefaultCSPPolicy is the default Content-Security-Policy with nonce support -// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware -const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com https://*.stripe.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com https://*.stripe.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" +// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware. +// +// Firebase Auth popup flow (used by Windsurf Google login) requires: +// - script-src https://apis.google.com (loads gapi for the OAuth iframe) +// - frame-src https://*.firebaseapp.com https://accounts.google.com https://apis.google.com +// - media-src 'self' data: (Firebase plays a tiny silent base64 WAV +// to keep the popup channel alive across +// browser autoplay restrictions) +const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com https://*.stripe.com https://apis.google.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; media-src 'self' data:; frame-src https://challenges.cloudflare.com https://*.stripe.com https://*.firebaseapp.com https://accounts.google.com https://apis.google.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" // UMQ(用户消息队列)模式常量 const ( diff --git a/backend/internal/handler/admin/ops_handler.go b/backend/internal/handler/admin/ops_handler.go index d9c49250..0eaac506 100644 --- a/backend/internal/handler/admin/ops_handler.go +++ b/backend/internal/handler/admin/ops_handler.go @@ -18,6 +18,7 @@ import ( type OpsHandler struct { opsService *service.OpsService requestEventBus *service.RequestEventBus + logBroadcaster *service.OpsLogBroadcaster } // GetErrorLogByID returns ops error log detail. @@ -71,8 +72,8 @@ func parseOpsViewParam(c *gin.Context) string { } } -func NewOpsHandler(opsService *service.OpsService, requestEventBus *service.RequestEventBus) *OpsHandler { - return &OpsHandler{opsService: opsService, requestEventBus: requestEventBus} +func NewOpsHandler(opsService *service.OpsService, requestEventBus *service.RequestEventBus, logBroadcaster *service.OpsLogBroadcaster) *OpsHandler { + return &OpsHandler{opsService: opsService, requestEventBus: requestEventBus, logBroadcaster: logBroadcaster} } // GetErrorLogs lists ops error logs. diff --git a/backend/internal/handler/admin/ops_log_stream_handler.go b/backend/internal/handler/admin/ops_log_stream_handler.go new file mode 100644 index 00000000..fd41f3ef --- /dev/null +++ b/backend/internal/handler/admin/ops_log_stream_handler.go @@ -0,0 +1,210 @@ +package admin + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +const ( + opsLogStreamHeartbeat = 25 * time.Second + opsLogStreamRecentMax = 200 + opsLogStreamSubBufEntries = 1024 + opsLogStreamModelMaxLen = 256 +) + +// LogStream serves a Server-Sent Events feed of every gateway request. +// +// GET /api/v1/admin/ops/logs/stream?min_status=400&model=glm-4.7&account_id=42&min_latency_ms=2000 +// +// Filter query params (all optional, AND-combined): +// +// min_status — int only emit when entry.status >= this value +// model — exact match on model key +// account_id — int64 +// group_id — int64 +// min_latency_ms — int64 only emit when entry.latency_ms >= this value +// +// The handler keeps the connection open until the client disconnects, the +// monitoring is disabled, or the broadcaster is torn down. A heartbeat +// comment line is sent every 25s so reverse proxies don't time out idle +// streams. +func (h *OpsHandler) LogStream(c *gin.Context) { + if h.logBroadcaster == nil { + response.Error(c, http.StatusServiceUnavailable, "log broadcaster not configured") + return + } + // nil opsService is allowed for lightweight deployments / tests where + // the OpsService dependency is intentionally absent. The admin auth + // middleware on the route group still enforces JWT + admin role, so + // the stream is never reachable anonymously. + if h.opsService != nil { + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + } + + filter, err := parseOpsLogFilter(c) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + response.Error(c, http.StatusInternalServerError, "streaming unsupported") + return + } + + ch, unsubscribe := h.logBroadcaster.Subscribe(filter, opsLogStreamSubBufEntries) + defer unsubscribe() + + // Prime client with recent buffered history (so a fresh dashboard tab + // renders something immediately instead of staying blank). + for _, e := range h.logBroadcaster.Snapshot(filter, opsLogStreamRecentMax) { + if err := writeOpsLogSSE(c.Writer, &e); err != nil { + return + } + } + flusher.Flush() + + heartbeat := time.NewTicker(opsLogStreamHeartbeat) + defer heartbeat.Stop() + + ctxDone := c.Request.Context().Done() + for { + select { + case <-ctxDone: + return + case <-heartbeat.C: + if _, err := io.WriteString(c.Writer, ": ping\n\n"); err != nil { + return + } + flusher.Flush() + case entry := <-ch: + if err := writeOpsLogSSE(c.Writer, &entry); err != nil { + return + } + flusher.Flush() + } + } +} + +// LogStreamRecent returns the broadcaster history without subscribing. +// Useful for one-shot polling when the admin panel cannot keep an open +// SSE connection (e.g. behind a buffering proxy). +// +// GET /api/v1/admin/ops/logs/recent?min_status=400&max=500 +func (h *OpsHandler) LogStreamRecent(c *gin.Context) { + if h.logBroadcaster == nil { + response.Error(c, http.StatusServiceUnavailable, "log broadcaster not configured") + return + } + if h.opsService != nil { + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + } + filter, err := parseOpsLogFilter(c) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + maxN := opsLogStreamRecentMax + if v := strings.TrimSpace(c.Query("max")); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 && n <= 2000 { + maxN = n + } + } + + published, dropped, subs := h.logBroadcaster.Stats() + response.Success(c, gin.H{ + "entries": h.logBroadcaster.Snapshot(filter, maxN), + "published_total": published, + "dropped_total": dropped, + "subscribers": subs, + }) +} + +func parseOpsLogFilter(c *gin.Context) (service.OpsLogFilter, error) { + f := service.OpsLogFilter{} + if v := strings.TrimSpace(c.Query("min_status")); v != "" { + n, err := strconv.Atoi(v) + if err != nil || n < 0 { + return f, fmt.Errorf("invalid min_status") + } + f.MinStatus = n + } + if v := strings.TrimSpace(c.Query("model")); v != "" { + // Cap input to keep an authenticated admin from stuffing huge + // strings into long-lived subscription state. + if len(v) > opsLogStreamModelMaxLen { + return f, fmt.Errorf("model too long (max %d)", opsLogStreamModelMaxLen) + } + f.Model = v + } + if v := strings.TrimSpace(c.Query("account_id")); v != "" { + n, err := strconv.ParseInt(v, 10, 64) + // Reject 0 — matches() treats AccountID==0 as "match all", so a + // param of 0 would silently degrade to no-filter without telling + // the user. Demand a positive id (callers wanting all should omit + // the param). + if err != nil || n <= 0 { + return f, fmt.Errorf("invalid account_id") + } + f.AccountID = n + } + if v := strings.TrimSpace(c.Query("group_id")); v != "" { + n, err := strconv.ParseInt(v, 10, 64) + if err != nil || n <= 0 { + return f, fmt.Errorf("invalid group_id") + } + f.GroupID = n + } + if v := strings.TrimSpace(c.Query("min_latency_ms")); v != "" { + n, err := strconv.ParseInt(v, 10, 64) + if err != nil || n < 0 { + return f, fmt.Errorf("invalid min_latency_ms") + } + f.MinLatencyMs = n + } + return f, nil +} + +// writeOpsLogSSE writes one SSE frame: an `event: log` line followed by a +// single `data:` line containing the JSON-encoded entry, terminated by a +// blank line per the SSE protocol. +// +// This assumes the JSON payload contains no bare LF — which holds because +// every string field in OpsLogEntry passes through encoding/json escaping. +// If a future field is added that emits raw bytes (e.g. a []byte body), +// the marshalled output must be split across multiple `data:` lines. +func writeOpsLogSSE(w io.Writer, e *service.OpsLogEntry) error { + payload, err := json.Marshal(e) + if err != nil { + return err + } + if _, err := io.WriteString(w, "event: log\ndata: "); err != nil { + return err + } + if _, err := w.Write(payload); err != nil { + return err + } + _, err = io.WriteString(w, "\n\n") + return err +} diff --git a/backend/internal/handler/admin/ops_runtime_logging_handler_test.go b/backend/internal/handler/admin/ops_runtime_logging_handler_test.go index 0eede09b..723138c2 100644 --- a/backend/internal/handler/admin/ops_runtime_logging_handler_test.go +++ b/backend/internal/handler/admin/ops_runtime_logging_handler_test.go @@ -116,7 +116,7 @@ func newRuntimeOpsService(t *testing.T) *service.OpsService { } func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) { - h := NewOpsHandler(newRuntimeOpsService(t), nil) + h := NewOpsHandler(newRuntimeOpsService(t), nil, nil) r := newOpsRuntimeRouter(h, false) w := httptest.NewRecorder() @@ -128,7 +128,7 @@ func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) { } func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) { - h := NewOpsHandler(newRuntimeOpsService(t), nil) + h := NewOpsHandler(newRuntimeOpsService(t), nil, nil) r := newOpsRuntimeRouter(h, false) body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}` @@ -142,7 +142,7 @@ func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) { } func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) { - h := NewOpsHandler(newRuntimeOpsService(t), nil) + h := NewOpsHandler(newRuntimeOpsService(t), nil, nil) r := newOpsRuntimeRouter(h, true) payload := map[string]any{ diff --git a/backend/internal/handler/admin/windsurf_handler.go b/backend/internal/handler/admin/windsurf_handler.go index 35ad791c..cb2638b7 100644 --- a/backend/internal/handler/admin/windsurf_handler.go +++ b/backend/internal/handler/admin/windsurf_handler.go @@ -13,20 +13,23 @@ import ( ) type WindsurfHandler struct { - authService *service.WindsurfAuthService - lsService *service.WindsurfLSService - probeService *service.WindsurfProbeService + authService *service.WindsurfAuthService + lsService *service.WindsurfLSService + probeService *service.WindsurfProbeService + tierAccessService *service.WindsurfTierAccessService } func NewWindsurfHandler( authService *service.WindsurfAuthService, lsService *service.WindsurfLSService, probeService *service.WindsurfProbeService, + tierAccessService *service.WindsurfTierAccessService, ) *WindsurfHandler { return &WindsurfHandler{ - authService: authService, - lsService: lsService, - probeService: probeService, + authService: authService, + lsService: lsService, + probeService: probeService, + tierAccessService: tierAccessService, } } @@ -81,6 +84,61 @@ func (h *WindsurfHandler) Login(c *gin.Context) { }) } +func (h *WindsurfHandler) TokenLogin(c *gin.Context) { + var req dto.WindsurfTokenLoginRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + concurrency := 1 + if req.Concurrency != nil && *req.Concurrency > 0 { + concurrency = *req.Concurrency + } + priority := 0 + if req.Priority != nil { + priority = *req.Priority + } + probeAfter := false + if req.ProbeAfter != nil { + probeAfter = *req.ProbeAfter + } + + input := &service.WindsurfTokenLoginInput{ + Token: req.Token, + Email: req.Email, + Name: req.Name, + Notes: req.Notes, + ProxyID: req.ProxyID, + GroupIDs: req.GroupIDs, + Concurrency: concurrency, + Priority: priority, + ProbeAfter: probeAfter, + LSInstanceID: req.LSInstanceID, + } + + output, err := h.authService.TokenLogin(c.Request.Context(), input) + if err != nil { + // ErrorFrom maps typed ApplicationError (BadRequest/Conflict/etc.) + // to its real HTTP code; falls through to 500 for opaque errors. + if !response.ErrorFrom(c, err) { + response.Error(c, http.StatusInternalServerError, err.Error()) + } + return + } + + response.Success(c, dto.WindsurfLoginResponse{ + AccountID: output.AccountID, + Platform: "windsurf", + Type: "windsurf-session", + Email: output.Email, + Tier: output.Tier, + AuthMethod: output.AuthMethod, + APIKeyPresent: output.APIKeyPresent, + RefreshTokenPresent: output.RefreshTokenPresent, + }) +} + func (h *WindsurfHandler) BatchLogin(c *gin.Context) { var req dto.WindsurfBatchLoginRequest if err := c.ShouldBindJSON(&req); err != nil { @@ -309,3 +367,18 @@ func (h *WindsurfHandler) GetRuntime(c *gin.Context) { response.Success(c, result) } + +// GetTierAccess returns per-model account-pool availability for the +// admin dashboard. Backed by a 60s in-memory snapshot cache. +func (h *WindsurfHandler) GetTierAccess(c *gin.Context) { + if h.tierAccessService == nil { + response.Error(c, http.StatusServiceUnavailable, "tier access service not configured") + return + } + snap, err := h.tierAccessService.Snapshot(c.Request.Context()) + if err != nil { + response.Error(c, http.StatusInternalServerError, err.Error()) + return + } + response.Success(c, snap) +} diff --git a/backend/internal/handler/dto/windsurf.go b/backend/internal/handler/dto/windsurf.go index 24864beb..288fd834 100644 --- a/backend/internal/handler/dto/windsurf.go +++ b/backend/internal/handler/dto/windsurf.go @@ -13,6 +13,26 @@ type WindsurfLoginRequest struct { LSInstanceID string `json:"ls_instance_id,omitempty"` } +// WindsurfTokenLoginRequest carries a pre-obtained Windsurf auth token +// (copied by the user from https://windsurf.com/show-auth-token after +// signing in to windsurf.com via Google / GitHub / email). +// +// Token field accepts whatever windsurf.com/show-auth-token displays — +// the backend tries to exchange it directly with Codeium's register_user +// endpoint, mirroring the dwgx/WindsurfAPI reference behaviour. +type WindsurfTokenLoginRequest struct { + Token string `json:"token" binding:"required,max=16384"` + Email string `json:"email" binding:"omitempty,email"` + Name string `json:"name"` + Notes *string `json:"notes,omitempty"` + ProxyID *int64 `json:"proxy_id,omitempty"` + GroupIDs []int64 `json:"group_ids,omitempty"` + Concurrency *int `json:"concurrency,omitempty"` + Priority *int `json:"priority,omitempty"` + ProbeAfter *bool `json:"probe_after,omitempty"` + LSInstanceID string `json:"ls_instance_id,omitempty"` +} + type WindsurfBatchLoginRequest struct { Items []string `json:"items" binding:"required,min=1"` ProxyID *int64 `json:"proxy_id,omitempty"` diff --git a/backend/internal/handler/ops_log_stream_middleware.go b/backend/internal/handler/ops_log_stream_middleware.go new file mode 100644 index 00000000..0d68d991 --- /dev/null +++ b/backend/internal/handler/ops_log_stream_middleware.go @@ -0,0 +1,100 @@ +package handler + +import ( + "strings" + "time" + + "github.com/gin-gonic/gin" + + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// OpsLogStreamMiddleware fans every gateway request out to the in-memory +// OpsLogBroadcaster so admin tools can subscribe to a real-time SSE feed. +// +// This is intentionally separate from OpsErrorLoggerMiddleware: +// - OpsErrorLoggerMiddleware persists 4xx/5xx into the database for audit. +// - This middleware streams every request (success + failure) for live UX. +// +// The broadcaster.Publish call is non-blocking by design (see the +// implementation): a slow/missing subscriber NEVER stalls the request path. +// Empty broadcaster (nil receiver, or no subscribers) is a no-op. +func OpsLogStreamMiddleware(b *service.OpsLogBroadcaster) gin.HandlerFunc { + if b == nil { + return func(c *gin.Context) { c.Next() } + } + return func(c *gin.Context) { + start := time.Now() + c.Next() + + entry := service.OpsLogEntry{ + Time: start, + Method: c.Request.Method, + Path: c.Request.URL.Path, + Status: c.Writer.Status(), + LatencyMs: time.Since(start).Milliseconds(), + } + + if v, ok := c.Get(opsModelKey); ok { + if s, ok := v.(string); ok { + entry.Model = s + } + } + if v, ok := c.Get(opsStreamKey); ok { + if streamFlag, ok := v.(bool); ok { + entry.Stream = streamFlag + } + } + if v, ok := c.Get(opsAccountIDKey); ok { + switch t := v.(type) { + case int64: + entry.AccountID = t + case int: + entry.AccountID = int64(t) + } + } + + // Best-effort api-key + group + user from middleware context. + if apiKey, ok := servermiddleware.GetAPIKeyFromContext(c); ok && apiKey != nil { + entry.APIKeyID = apiKey.ID + if apiKey.GroupID != nil { + entry.GroupID = *apiKey.GroupID + } + entry.UserID = apiKey.UserID + } + + // Pull upstream error context (set by gateway services on retries). + if v, ok := c.Get(service.OpsUpstreamStatusCodeKey); ok { + switch t := v.(type) { + case int: + entry.UpstreamCode = t + case int64: + entry.UpstreamCode = int(t) + } + } + if v, ok := c.Get(service.OpsUpstreamErrorMessageKey); ok { + if s, ok := v.(string); ok { + entry.ErrorMessage = trimForStream(s) + } + } + if v, ok := c.Get(service.OpsUpstreamErrorDetailKey); ok { + if s, ok := v.(string); ok { + entry.ErrorDetail = trimForStream(s) + } + } + + b.Publish(entry) + } +} + +// trimForStream caps long error strings so a single broken upstream cannot +// flood the SSE channel with megabyte-sized error blobs. +func trimForStream(s string) string { + const max = 512 + s = strings.TrimSpace(s) + if len(s) > max { + return s[:max] + "…" + } + return s +} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index d440d115..ce161902 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -80,11 +80,11 @@ func ProvideSystemHandler(updateService *service.UpdateService, lockService *ser } // ProvideWindsurfHandler returns nil when windsurf auth service is disabled. -func ProvideWindsurfHandler(authService *service.WindsurfAuthService, lsService *service.WindsurfLSService, probeService *service.WindsurfProbeService) *admin.WindsurfHandler { +func ProvideWindsurfHandler(authService *service.WindsurfAuthService, lsService *service.WindsurfLSService, probeService *service.WindsurfProbeService, tierAccessService *service.WindsurfTierAccessService) *admin.WindsurfHandler { if authService == nil { return nil } - return admin.NewWindsurfHandler(authService, lsService, probeService) + return admin.NewWindsurfHandler(authService, lsService, probeService, tierAccessService) } // ProvideSettingHandler creates SettingHandler with version from BuildInfo diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index 9c13a3cc..fba358e7 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -486,7 +486,7 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) { reqBody := LoadCodeAssistRequest{} reqBody.Metadata.IDEType = "ANTIGRAVITY" - reqBody.Metadata.IDEVersion = "1.20.6" + reqBody.Metadata.IDEVersion = currentUserAgentVersion() reqBody.Metadata.IDEName = "antigravity" bodyBytes, err := json.Marshal(reqBody) diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index 5bb17183..ce9f0ec4 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -140,13 +140,19 @@ func GetClientCredentials(isEnterprise bool) (ClientCredentials, error) { } // BaseURLsForAccount 根据 isGcpTos 返回有序 URL 列表。 -// 企业账号(isGcpTos=true)优先走 prod;个人账号优先走 daily(与真实 IDE 一致)。 -// sandbox 作为最后兜底,仅在 prod/daily 都不可用时使用。 +// +// - 企业账号(isGcpTos=true):prod → daily → sandbox +// 企业账号拥有 GCP Workspace 权限,可访问真实 daily(daily-cloudcode-pa.googleapis.com)。 +// +// - 个人账号(isGcpTos=false):sandbox → prod +// 个人免费账号无权访问真实 daily,该端点对个人账号会直接返回 429 RESOURCE_EXHAUSTED。 +// sandbox(daily-cloudcode-pa.sandbox.googleapis.com)对个人账号可用,与上游行为一致。 func BaseURLsForAccount(isGcpTos bool) []string { if isGcpTos { return []string{antigravityProdBaseURL, antigravityDailyBaseURL, antigravitySandboxBaseURL} } - return []string{antigravityDailyBaseURL, antigravityProdBaseURL, antigravitySandboxBaseURL} + // 个人账号跳过真实 daily,直接使用 sandbox → prod 顺序(与上游 ForwardBaseURLs 一致) + return []string{antigravitySandboxBaseURL, antigravityProdBaseURL} } func getClientSecret() (string, error) { diff --git a/backend/internal/pkg/windsurf/auth_client.go b/backend/internal/pkg/windsurf/auth_client.go index 70cef08e..b0c356b6 100644 --- a/backend/internal/pkg/windsurf/auth_client.go +++ b/backend/internal/pkg/windsurf/auth_client.go @@ -276,6 +276,12 @@ func (a *AuthClient) loginViaFirebase(ctx context.Context, email, password strin }, nil } +// RegisterWithCodeiumDefault wraps RegisterWithCodeium with a freshly generated +// browser fingerprint, so callers (e.g. Google login) don't need to construct one. +func (a *AuthClient) RegisterWithCodeiumDefault(ctx context.Context, token, proxyURL string) (*RegisterResult, error) { + return a.RegisterWithCodeium(ctx, token, generateFingerprint(), proxyURL) +} + func (a *AuthClient) RegisterWithCodeium(ctx context.Context, token string, fp http.Header, proxyURL string) (*RegisterResult, error) { c := newClient(a.RequestTimeout, proxyURL) body := map[string]string{"firebase_id_token": token} diff --git a/backend/internal/pkg/windsurf/cold_threshold.go b/backend/internal/pkg/windsurf/cold_threshold.go new file mode 100644 index 00000000..9110f21b --- /dev/null +++ b/backend/internal/pkg/windsurf/cold_threshold.go @@ -0,0 +1,108 @@ +package windsurf + +import ( + "os" + "strconv" + "strings" + "sync" + "time" +) + +// ColdThresholdConfig parameterises the adaptive cold-stall timeout. Values +// can be overridden via env vars (see DefaultColdThresholdConfig). +type ColdThresholdConfig struct { + // Base is the timeout for an empty / tiny prompt (e.g. "ping"). + Base time.Duration + // PerKChar adds this much time for every 1000 characters of input. + PerKChar time.Duration + // Max caps the total threshold regardless of input size. Callers may + // further clamp via the runtime maxWait argument. + Max time.Duration +} + +var ( + defaultColdCfg ColdThresholdConfig + defaultColdCfgOnce sync.Once +) + +// DefaultColdThresholdConfig returns the active config. The first call +// resolves env-var overrides: +// +// WINDSURF_COLD_BASE_SECONDS (default 30) +// WINDSURF_COLD_PER_KCHAR_SEC (default 5) +// WINDSURF_COLD_MAX_SECONDS (default 90) +// +// Defaults match dwgx/WindsurfAPI's empirical "long inputs up to 90s" rule +// while preserving the prior 30s base for backward compatibility. +func DefaultColdThresholdConfig() ColdThresholdConfig { + defaultColdCfgOnce.Do(func() { + defaultColdCfg = ColdThresholdConfig{ + Base: envSeconds("WINDSURF_COLD_BASE_SECONDS", 30), + PerKChar: envSeconds("WINDSURF_COLD_PER_KCHAR_SEC", 5), + Max: envSeconds("WINDSURF_COLD_MAX_SECONDS", 90), + } + if defaultColdCfg.Base <= 0 { + defaultColdCfg.Base = 30 * time.Second + } + if defaultColdCfg.Max <= 0 { + defaultColdCfg.Max = 90 * time.Second + } + if defaultColdCfg.Max < defaultColdCfg.Base { + defaultColdCfg.Max = defaultColdCfg.Base + } + }) + return defaultColdCfg +} + +// AdaptiveColdThreshold returns the cold-stall timeout for a given prompt +// size, applying the active ColdThresholdConfig and an absolute upstream +// cap (typically the StreamCascadeChat maxWait constant). +// +// The returned threshold is the minimum of: +// +// Base + PerKChar * (inputChars / 1000) +// Config.Max +// upstreamCap (when > 0) +// +// inputChars < 0 is treated as 0. The result is always >= Base unless +// upstreamCap < Base, in which case upstreamCap wins. +func AdaptiveColdThreshold(inputChars int, upstreamCap time.Duration) time.Duration { + cfg := DefaultColdThresholdConfig() + return ComputeColdThreshold(cfg, inputChars, upstreamCap) +} + +// maxInputCharsForOverflowGuard caps inputChars before multiplication to keep +// the resulting time.Duration (int64 ns) from wrapping. 2^31 chars (~2GB) +// is already absurd for an LLM prompt; anything beyond is a bug or DoS attempt. +const maxInputCharsForOverflowGuard = 1<<31 - 1 + +// ComputeColdThreshold is the pure form used by tests and callers that want +// to inject a custom config without touching the singleton. +func ComputeColdThreshold(cfg ColdThresholdConfig, inputChars int, upstreamCap time.Duration) time.Duration { + if inputChars < 0 { + inputChars = 0 + } + if inputChars > maxInputCharsForOverflowGuard { + inputChars = maxInputCharsForOverflowGuard + } + scaled := cfg.Base + time.Duration(inputChars/1000)*cfg.PerKChar + if cfg.Max > 0 && scaled > cfg.Max { + scaled = cfg.Max + } + if upstreamCap > 0 && scaled > upstreamCap { + scaled = upstreamCap + } + return scaled +} + +func envSeconds(key string, defaultSec int) time.Duration { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + return time.Duration(defaultSec) * time.Second + } + v, err := strconv.Atoi(raw) + if err != nil || v <= 0 { + return time.Duration(defaultSec) * time.Second + } + return time.Duration(v) * time.Second +} diff --git a/backend/internal/pkg/windsurf/cold_threshold_test.go b/backend/internal/pkg/windsurf/cold_threshold_test.go new file mode 100644 index 00000000..62e7a353 --- /dev/null +++ b/backend/internal/pkg/windsurf/cold_threshold_test.go @@ -0,0 +1,87 @@ +package windsurf + +import ( + "testing" + "time" +) + +func TestComputeColdThreshold(t *testing.T) { + cfg := ColdThresholdConfig{ + Base: 15 * time.Second, + PerKChar: 6 * time.Second, + Max: 90 * time.Second, + } + + tests := []struct { + name string + inputChars int + upstreamCap time.Duration + want time.Duration + }{ + {"empty prompt → base", 0, 0, 15 * time.Second}, + {"under 1k → still base", 999, 0, 15 * time.Second}, + {"1k → base + per-k", 1000, 0, 21 * time.Second}, + {"5k → base + 5*per-k", 5000, 0, 45 * time.Second}, + {"50k → base + 50*per-k clamped to max", 50000, 0, 90 * time.Second}, + {"100k → max", 100000, 0, 90 * time.Second}, + {"upstreamCap below max wins", 50000, 60 * time.Second, 60 * time.Second}, + {"upstreamCap above max no-op", 50000, 200 * time.Second, 90 * time.Second}, + {"negative input treated as 0", -100, 0, 15 * time.Second}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := ComputeColdThreshold(cfg, tc.inputChars, tc.upstreamCap) + if got != tc.want { + t.Fatalf("got %v, want %v", got, tc.want) + } + }) + } +} + +func TestAdaptiveColdThreshold_DefaultsHonorEnv(t *testing.T) { + // Note: DefaultColdThresholdConfig caches via sync.Once, so this test + // exercises the public path but cannot assert overridden values without + // resetting the singleton. We verify the function returns a sensible + // value for representative inputs. + got := AdaptiveColdThreshold(5000, 200*time.Second) + if got <= 0 { + t.Fatalf("expected positive threshold, got %v", got) + } + if got > 200*time.Second { + t.Fatalf("expected <= upstreamCap, got %v", got) + } +} + +func TestComputeColdThreshold_PreservesLegacyDefaults(t *testing.T) { + // Regression: pre-PR-2 behaviour used base=30s, perKChar=5s/1500 chars, + // cap=180s. New defaults (base=30s, perKChar=5s/1000 chars, cap=90s) + // must NOT shorten timeouts for sub-12k inputs that already worked. + cfg := ColdThresholdConfig{Base: 30 * time.Second, PerKChar: 5 * time.Second, Max: 90 * time.Second} + for _, chars := range []int{0, 500, 1500, 6000, 12000} { + got := ComputeColdThreshold(cfg, chars, 180*time.Second) + legacy := 30*time.Second + time.Duration(chars/1500)*5*time.Second + if legacy > 180*time.Second { + legacy = 180 * time.Second + } + if got < legacy { + t.Fatalf("chars=%d: new=%v shorter than legacy=%v", chars, got, legacy) + } + } +} + +func TestComputeColdThreshold_OverflowGuard(t *testing.T) { + // Without the overflow guard, inputChars near math.MaxInt would make + // the duration multiplication wrap negative, returning a "stalled + // immediately" timeout. Verify the result is always >= Base and clamped + // to Max, never negative. + cfg := ColdThresholdConfig{Base: 15 * time.Second, PerKChar: 6 * time.Second, Max: 90 * time.Second} + for _, chars := range []int{1 << 31, 1 << 40, 1<<63 - 1} { + got := ComputeColdThreshold(cfg, chars, 200*time.Second) + if got <= 0 { + t.Fatalf("chars=%d: expected positive duration, got %v", chars, got) + } + if got != cfg.Max { + t.Fatalf("chars=%d: expected clamped to Max=%v, got %v", chars, cfg.Max, got) + } + } +} diff --git a/backend/internal/pkg/windsurf/local_ls.go b/backend/internal/pkg/windsurf/local_ls.go index 83a418ac..a0d0e132 100644 --- a/backend/internal/pkg/windsurf/local_ls.go +++ b/backend/internal/pkg/windsurf/local_ls.go @@ -502,10 +502,7 @@ func (l *LocalLSClient) StreamCascadeChat(ctx context.Context, token, modelUID, // Cold stall: active but no text/thinking after threshold elapsed := time.Since(startTime) - coldThreshold := 30*time.Second + time.Duration(inputChars/1500)*5*time.Second - if coldThreshold > maxWait { - coldThreshold = maxWait - } + coldThreshold := AdaptiveColdThreshold(inputChars, maxWait) if elapsed > coldThreshold && sawActive && !sawText && totalThinking == 0 { return nil, &CascadeModelError{Msg: fmt.Sprintf("Cascade planner stalled — no output after %ds", int(coldThreshold.Seconds()))} } diff --git a/backend/internal/pkg/windsurf/models.go b/backend/internal/pkg/windsurf/models.go index 5f797fea..bb482bfe 100644 --- a/backend/internal/pkg/windsurf/models.go +++ b/backend/internal/pkg/windsurf/models.go @@ -12,8 +12,23 @@ type ModelMeta struct { EnumValue int `json:"enum_value"` ModelUID string `json:"model_uid,omitempty"` Credit float64 `json:"credit"` + // EmulationFlavor controls how the chat service interprets tool-call + // output from this model. Values: + // "tool_use" — model emits well-formed tags (Claude family). + // "nlu" — model emits free-form text describing tool intent + // (GLM-4.7, Kimi-K2.5). Engage NLU fallback when no + // structured tags parse. + // "" / "auto" — try tool_use first, fall back to NLU heuristic only + // when the response carries clear NLU signals. + EmulationFlavor string `json:"emulation_flavor,omitempty"` } +const ( + EmulationFlavorAuto = "auto" + EmulationFlavorToolUse = "tool_use" + EmulationFlavorNLU = "nlu" +) + type ModelListEntry struct { ID string `json:"id"` Object string `json:"object"` @@ -263,8 +278,15 @@ func buildLookup() { } } +// ensureLookup builds the lookupMap once. It serializes via cloudModelsMu so +// concurrent first-touch readers don't race with MergeCloudModels (which +// rebuilds lookupMap as part of its write critical section). func ensureLookup() { - lookupOnce.Do(buildLookup) + lookupOnce.Do(func() { + cloudModelsMu.Lock() + defer cloudModelsMu.Unlock() + buildLookup() + }) } func ResolveModel(name string) string { @@ -272,6 +294,8 @@ func ResolveModel(name string) string { return "" } ensureLookup() + cloudModelsMu.RLock() + defer cloudModelsMu.RUnlock() if id, ok := lookupMap[name]; ok { return id } @@ -282,6 +306,8 @@ func ResolveModel(name string) string { } func GetModelInfo(id string) *ModelMeta { + cloudModelsMu.RLock() + defer cloudModelsMu.RUnlock() if m, ok := catalog[id]; ok { return &m } @@ -304,11 +330,37 @@ func GetChatMode(m *ModelMeta, legacyEnumCutoff int) string { return "cascade" } +// ResolveEmulationFlavor returns the emulation flavor for a model, applying +// per-provider defaults when the model entry doesn't override. +// +// Defaults follow dwgx/WindsurfAPI's empirical findings: +// - Anthropic Claude family: tool_use is reliable, no NLU needed. +// - Zhipu GLM, Moonshot Kimi: free-form intent text, NLU fallback required. +// - Everything else: auto (try tool_use first, fall back to NLU on signal). +func ResolveEmulationFlavor(m *ModelMeta) string { + if m == nil { + return EmulationFlavorAuto + } + if m.EmulationFlavor != "" { + return m.EmulationFlavor + } + switch m.Provider { + case "anthropic": + return EmulationFlavorToolUse + case "zhipu", "moonshot": + return EmulationFlavorNLU + default: + return EmulationFlavorAuto + } +} + var freeTierModels = []string{"gpt-4o-mini", "gemini-2.5-flash"} func GetTierModels(tier string) []string { switch tier { case "pro": + cloudModelsMu.RLock() + defer cloudModelsMu.RUnlock() keys := make([]string, 0, len(catalog)) for k := range catalog { keys = append(keys, k) @@ -325,6 +377,8 @@ func GetTierModels(tier string) []string { func ListModelsOpenAI() []ModelListEntry { ts := time.Now().Unix() + cloudModelsMu.RLock() + defer cloudModelsMu.RUnlock() entries := make([]ModelListEntry, 0, len(catalog)) for _, info := range catalog { entries = append(entries, ModelListEntry{ @@ -337,12 +391,19 @@ func ListModelsOpenAI() []ModelListEntry { return entries } -var cloudModelsMu sync.Mutex +// cloudModelsMu protects concurrent reads/writes of the package-level +// catalog/lookupMap state. MergeCloudModels takes the write lock; all +// public readers (GetModelInfo, ListModelsOpenAI, GetTierModels, ResolveModel) +// take the read lock. Without this, the new tier-access hot path would +// race against cloud-model merge on `go test -race`. +var cloudModelsMu sync.RWMutex func MergeCloudModels(configs []ModelInfo) int { cloudModelsMu.Lock() defer cloudModelsMu.Unlock() - ensureLookup() + // Already inside the write lock — call buildLookup directly to avoid + // re-entering cloudModelsMu via ensureLookup → would deadlock. + lookupOnce.Do(buildLookup) providerMap := map[string]string{ "MODEL_PROVIDER_ANTHROPIC": "anthropic", diff --git a/backend/internal/pkg/windsurf/nlu_extractor.go b/backend/internal/pkg/windsurf/nlu_extractor.go new file mode 100644 index 00000000..860c31d4 --- /dev/null +++ b/backend/internal/pkg/windsurf/nlu_extractor.go @@ -0,0 +1,202 @@ +package windsurf + +import ( + "encoding/json" + "fmt" + "regexp" + "sort" + "strings" +) + +// ExtractToolCallsNLU is a best-effort fallback parser used when a model +// (typically GLM-4.7 / Kimi family) emits tool-call intent in free-form +// text instead of well-formed tags. +// +// Strategy: +// 1. Look for "function:NAME" / "tool_call:NAME" / "call NAME" markers. +// 2. Look for the nearest JSON object after the marker as arguments. +// 3. Validate the function name is in the available tool list. +// +// availableTools is the list of tool names the request advertised. If empty, +// the extractor still tries name discovery but is best-effort. Returns nil +// when no plausible tool call is found — callers should treat that as +// "no tools" not "error". +func ExtractToolCallsNLU(text string, availableTools []string) []ToolCall { + if text == "" { + return nil + } + available := make(map[string]struct{}, len(availableTools)) + for _, name := range availableTools { + if n := strings.TrimSpace(name); n != "" { + available[n] = struct{}{} + } + } + + calls := nluFindMarkedCalls(text, available) + if len(calls) > 0 { + return calls + } + if len(available) > 0 { + // Last-resort: some models just say "I'll use edit_file with {...}" + // — try to spot any known tool name followed by a JSON object. + calls = nluFindBareNameCalls(text, available) + } + return calls +} + +// HasNLUSignal reports whether `text` looks like it intended to call a tool +// but malformed the tags. Used to decide whether to spend CPU on the NLU +// extractor when EmulationFlavor=auto. Conservative — false negatives are +// fine, false positives waste a few microseconds. +func HasNLUSignal(text string) bool { + if text == "" { + return false + } + lower := strings.ToLower(text) + for _, kw := range nluSignalKeywords { + if strings.Contains(lower, kw) { + return true + } + } + return false +} + +var nluSignalKeywords = []string{ + "tool_call", + "function_call", + "function:", + "tool:", + "arguments:", + "i'll call", + "i will call", + "calling tool", + "调用工具", + "使用工具", +} + +// nluMarkerRE matches "function: name", "tool_call: name", "call name" +// followed (possibly with delimiters) by a JSON object. The name capture +// stops at whitespace, comma, paren, or brace. +var nluMarkerRE = regexp.MustCompile(`(?i)(?:function|tool_call|tool|call)[\s:=]+([a-zA-Z_][a-zA-Z0-9_]*)`) + +func nluFindMarkedCalls(text string, available map[string]struct{}) []ToolCall { + matches := nluMarkerRE.FindAllStringSubmatchIndex(text, -1) + if len(matches) == 0 { + return nil + } + var calls []ToolCall + seen := make(map[string]struct{}) + for _, m := range matches { + name := text[m[2]:m[3]] + if _, ok := available[name]; len(available) > 0 && !ok { + continue + } + if _, dup := seen[name]; dup { + continue + } + args := nluFindNearestJSONAfter(text, m[1]) + if args == "" { + continue + } + seen[name] = struct{}{} + calls = append(calls, ToolCall{ + ID: nluCallID(name, len(calls)), + Name: name, + ArgumentsJSON: args, + }) + } + return calls +} + +func nluFindBareNameCalls(text string, available map[string]struct{}) []ToolCall { + // Iterate available names in deterministic (alphabetical) order so the + // returned slice is stable across runs and Go map randomization. Without + // this, two identical inputs can yield differently ordered tool-call + // slices, which makes upstream replay/retry behaviour inconsistent. + names := make([]string, 0, len(available)) + for name := range available { + names = append(names, name) + } + sort.Strings(names) + + var calls []ToolCall + seen := make(map[string]struct{}) + for _, name := range names { + idx := strings.Index(text, name) + if idx < 0 { + continue + } + args := nluFindNearestJSONAfter(text, idx+len(name)) + if args == "" { + continue + } + if _, dup := seen[name]; dup { + continue + } + seen[name] = struct{}{} + calls = append(calls, ToolCall{ + ID: nluCallID(name, len(calls)), + Name: name, + ArgumentsJSON: args, + }) + } + return calls +} + +// nluCallID generates a stable, namespaced ID for an NLU-extracted tool +// call. The numeric suffix prevents collisions when the same tool name +// appears in multiple turns within a session. +func nluCallID(name string, idx int) string { + return fmt.Sprintf("nlu_%s_%d", name, idx) +} + +// nluFindNearestJSONAfter scans forward from `start` and returns the first +// JSON object literal it encounters. Empty string when none found within a +// reasonable lookahead (4KB). +func nluFindNearestJSONAfter(text string, start int) string { + const lookahead = 4096 + end := start + lookahead + if end > len(text) { + end = len(text) + } + region := text[start:end] + open := strings.Index(region, "{") + if open < 0 { + return "" + } + depth := 0 + inString := false + escape := false + for i := open; i < len(region); i++ { + ch := region[i] + if escape { + escape = false + continue + } + if ch == '\\' { + escape = true + continue + } + if ch == '"' { + inString = !inString + continue + } + if inString { + continue + } + switch ch { + case '{': + depth++ + case '}': + depth-- + if depth == 0 { + candidate := region[open : i+1] + if json.Valid([]byte(candidate)) { + return candidate + } + return "" + } + } + } + return "" +} diff --git a/backend/internal/pkg/windsurf/nlu_extractor_test.go b/backend/internal/pkg/windsurf/nlu_extractor_test.go new file mode 100644 index 00000000..f18bcda4 --- /dev/null +++ b/backend/internal/pkg/windsurf/nlu_extractor_test.go @@ -0,0 +1,144 @@ +package windsurf + +import ( + "testing" +) + +func TestExtractToolCallsNLU(t *testing.T) { + tools := []string{"edit_file", "read_file", "run_command"} + + tests := []struct { + name string + text string + available []string + wantCount int + wantName string + }{ + { + name: "marker form: function: edit_file with JSON", + text: `I'll use function: edit_file with {"path": "/tmp/x", "content": "abc"}`, + available: tools, + wantCount: 1, + wantName: "edit_file", + }, + { + name: "marker form: tool_call read_file", + text: `tool_call: read_file arguments: {"path": "/etc/hosts"}`, + available: tools, + wantCount: 1, + wantName: "read_file", + }, + { + name: "marker form: nested JSON object", + text: `function: run_command with {"cmd": "ls", "opts": {"long": true}}`, + available: tools, + wantCount: 1, + wantName: "run_command", + }, + { + name: "bare name fallback when no marker", + text: `Sure, I'll edit_file {"path": "/tmp/y"} for you.`, + available: tools, + wantCount: 1, + wantName: "edit_file", + }, + { + name: "unknown tool name rejected when available list is non-empty", + text: `function: delete_universe {"target": "all"}`, + available: tools, + wantCount: 0, + }, + { + name: "no JSON after marker yields no call", + text: `function: edit_file but I'm not sure what arguments to use`, + available: tools, + wantCount: 0, + }, + { + name: "empty text returns nil", + text: "", + available: tools, + wantCount: 0, + }, + { + name: "duplicate names deduplicated", + text: `function: edit_file {"path": "/a"} then function: edit_file {"path": "/b"}`, + available: tools, + wantCount: 1, + wantName: "edit_file", + }, + { + name: "name not in available list is rejected even when JSON valid", + text: `Calling foo with {"x": 1}`, + available: []string{"bar"}, + wantCount: 0, + }, + { + name: "marker with no available list still extracts", + text: `function: my_tool {"x": 1}`, + available: nil, + wantCount: 1, + wantName: "my_tool", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := ExtractToolCallsNLU(tc.text, tc.available) + if len(got) != tc.wantCount { + t.Fatalf("expected %d call(s), got %d: %+v", tc.wantCount, len(got), got) + } + if tc.wantCount > 0 && got[0].Name != tc.wantName { + t.Fatalf("expected name %q, got %q", tc.wantName, got[0].Name) + } + if tc.wantCount > 0 && got[0].ArgumentsJSON == "" { + t.Fatalf("expected non-empty ArgumentsJSON, got %q", got[0].ArgumentsJSON) + } + }) + } +} + +func TestHasNLUSignal(t *testing.T) { + tests := []struct { + text string + want bool + }{ + {"function: edit_file {}", true}, + {"I'll call edit_file", true}, + {"calling tool edit_file", true}, + {"调用工具 edit_file", true}, + {"Hello, just a chat reply.", false}, + {"", false}, + {"foo", false}, + } + for _, tc := range tests { + t.Run(tc.text, func(t *testing.T) { + if got := HasNLUSignal(tc.text); got != tc.want { + t.Fatalf("HasNLUSignal(%q) = %v, want %v", tc.text, got, tc.want) + } + }) + } +} + +func TestResolveEmulationFlavor(t *testing.T) { + tests := []struct { + name string + meta *ModelMeta + want string + }{ + {"nil meta", nil, EmulationFlavorAuto}, + {"explicit override wins", &ModelMeta{Provider: "anthropic", EmulationFlavor: "nlu"}, "nlu"}, + {"anthropic default tool_use", &ModelMeta{Provider: "anthropic"}, EmulationFlavorToolUse}, + {"zhipu default nlu", &ModelMeta{Provider: "zhipu"}, EmulationFlavorNLU}, + {"moonshot default nlu", &ModelMeta{Provider: "moonshot"}, EmulationFlavorNLU}, + {"openai default auto", &ModelMeta{Provider: "openai"}, EmulationFlavorAuto}, + {"unknown provider auto", &ModelMeta{Provider: "xyz"}, EmulationFlavorAuto}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := ResolveEmulationFlavor(tc.meta); got != tc.want { + t.Fatalf("ResolveEmulationFlavor = %q, want %q", got, tc.want) + } + }) + } +} diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index fad4d8e0..01f231e7 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -40,6 +40,7 @@ func ProvideRouter( settingService *service.SettingService, healthService *service.HealthService, redisClient *redis.Client, + opsLogBroadcaster *service.OpsLogBroadcaster, ) *gin.Engine { if cfg.Server.Mode == "release" { gin.SetMode(gin.ReleaseMode) @@ -96,7 +97,7 @@ func ProvideRouter( service.SetWebSearchManager(websearch.NewManager(configs, redisClient)) }) - return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, healthService, cfg, redisClient) + return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, healthService, cfg, redisClient, opsLogBroadcaster) } // ProvideHTTPServer 提供 HTTP 服务器 diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index ac05096c..1c9970cc 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -33,6 +33,7 @@ func SetupRouter( healthService *service.HealthService, cfg *config.Config, redisClient *redis.Client, + opsLogBroadcaster *service.OpsLogBroadcaster, ) *gin.Engine { // 缓存 iframe 页面的 origin 列表,用于动态注入 CSP frame-src var cachedFrameOrigins atomic.Pointer[[]string] @@ -82,7 +83,7 @@ func SetupRouter( } // 注册路由 - registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, healthService, cfg, redisClient) + registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, healthService, cfg, redisClient, opsLogBroadcaster) return r } @@ -101,6 +102,7 @@ func registerRoutes( healthService *service.HealthService, cfg *config.Config, redisClient *redis.Client, + opsLogBroadcaster *service.OpsLogBroadcaster, ) { // 通用路由(健康检查、状态等) routes.RegisterCommonRoutes(r, healthService) @@ -112,10 +114,10 @@ func registerRoutes( routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient, settingService) routes.RegisterUserRoutes(v1, h, jwtAuth, settingService) routes.RegisterAdminRoutes(v1, h, adminAuth) - routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg) + routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, opsLogBroadcaster) // Windsurf gateway routes - routes.RegisterWindsurfGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg) + routes.RegisterWindsurfGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, opsLogBroadcaster) routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService) } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 53734ce8..53047bdf 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -115,6 +115,8 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ops.GET("/user-concurrency", h.Admin.Ops.GetUserConcurrencyStats) ops.GET("/account-availability", h.Admin.Ops.GetAccountAvailability) ops.GET("/realtime-traffic", h.Admin.Ops.GetRealtimeTrafficSummary) + ops.GET("/logs/stream", h.Admin.Ops.LogStream) + ops.GET("/logs/recent", h.Admin.Ops.LogStreamRecent) // Alerts (rules + events) ops.GET("/alert-rules", h.Admin.Ops.ListAlertRules) @@ -588,6 +590,7 @@ func registerWindsurfRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ws := admin.Group("/windsurf") { ws.POST("/accounts/login", h.Admin.Windsurf.Login) + ws.POST("/accounts/token-login", h.Admin.Windsurf.TokenLogin) ws.POST("/accounts/batch-login", h.Admin.Windsurf.BatchLogin) ws.POST("/accounts/:id/probe", h.Admin.Windsurf.Probe) ws.POST("/accounts/batch-probe", h.Admin.Windsurf.BatchProbe) @@ -596,6 +599,7 @@ func registerWindsurfRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ws.GET("/accounts/:id/runtime", h.Admin.Windsurf.GetRuntime) ws.GET("/ls/status", h.Admin.Windsurf.GetLSStatus) ws.GET("/models", h.Admin.Windsurf.ListModels) + ws.GET("/tier-access", h.Admin.Windsurf.GetTierAccess) } } diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index 9541cda1..b773697b 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -21,10 +21,12 @@ func RegisterGatewayRoutes( opsService *service.OpsService, settingService *service.SettingService, cfg *config.Config, + opsLogBroadcaster *service.OpsLogBroadcaster, ) { bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize) clientRequestID := middleware.ClientRequestID() opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService) + opsLogStream := handler.OpsLogStreamMiddleware(opsLogBroadcaster) endpointNorm := handler.InboundEndpointMiddleware() // 未分组 Key 拦截中间件(按协议格式区分错误响应) @@ -36,6 +38,7 @@ func RegisterGatewayRoutes( gateway.Use(bodyLimit) gateway.Use(clientRequestID) gateway.Use(opsErrorLogger) + gateway.Use(opsLogStream) gateway.Use(endpointNorm) gateway.Use(gin.HandlerFunc(apiKeyAuth)) gateway.Use(requireGroupAnthropic) diff --git a/backend/internal/server/routes/windsurf_gateway.go b/backend/internal/server/routes/windsurf_gateway.go index c0cd9bfb..21f3dd7e 100644 --- a/backend/internal/server/routes/windsurf_gateway.go +++ b/backend/internal/server/routes/windsurf_gateway.go @@ -18,6 +18,7 @@ func RegisterWindsurfGatewayRoutes( opsService *service.OpsService, settingService *service.SettingService, cfg *config.Config, + opsLogBroadcaster *service.OpsLogBroadcaster, ) { if h.Gateway == nil { return @@ -26,6 +27,7 @@ func RegisterWindsurfGatewayRoutes( bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize) clientRequestID := middleware.ClientRequestID() opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService) + opsLogStream := handler.OpsLogStreamMiddleware(opsLogBroadcaster) endpointNorm := handler.InboundEndpointMiddleware() requireGroupAnthropic := middleware.RequireGroupAssignment(settingService, middleware.AnthropicErrorWriter) @@ -33,6 +35,7 @@ func RegisterWindsurfGatewayRoutes( windsurfV1.Use(bodyLimit) windsurfV1.Use(clientRequestID) windsurfV1.Use(opsErrorLogger) + windsurfV1.Use(opsLogStream) windsurfV1.Use(endpointNorm) windsurfV1.Use(middleware.ForcePlatform(service.PlatformWindsurf)) windsurfV1.Use(gin.HandlerFunc(apiKeyAuth)) diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index fa8b319f..4d45ab95 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -215,13 +215,20 @@ type antigravityRetryLoopResult struct { } // resolveAntigravityForwardBaseURL 解析转发用 base URL。 -// 根据账号类型选择优先 URL:企业账号(isGcpTos=true)→ prod;个人账号 → daily(与真实 IDE 一致)。 -// 可通过环境变量 GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL=daily 或 =prod 强制覆盖。 +// 根据账号类型选择优先 URL: +// - 企业账号(isGcpTos=true)→ prod 优先,可访问真实 daily +// - 个人账号(isGcpTos=false)→ sandbox 优先(真实 daily 对个人账号返回 429) +// +// 可通过环境变量 GATEWAY_ANTIGRAVITY_FORWARD_BASE_URL=sandbox/daily/prod 强制覆盖。 func resolveAntigravityForwardBaseURL(account *Account) string { mode := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityForwardBaseURLEnv))) if mode == "daily" { + // 注意:真实 daily(daily-cloudcode-pa.googleapis.com)仅对企业账号可用 return "https://daily-cloudcode-pa.googleapis.com" } + if mode == "sandbox" { + return "https://daily-cloudcode-pa.sandbox.googleapis.com" + } if mode == "prod" { return "https://cloudcode-pa.googleapis.com" } diff --git a/backend/internal/service/ops_log_broadcaster.go b/backend/internal/service/ops_log_broadcaster.go new file mode 100644 index 00000000..ae25f780 --- /dev/null +++ b/backend/internal/service/ops_log_broadcaster.go @@ -0,0 +1,237 @@ +// Package service exposes domain services. opslog provides a lightweight +// in-memory pub/sub for streaming admin log events without persisting them. +package service + +import ( + "sync" + "sync/atomic" + "time" +) + +// OpsLogEntry is one streamed log event. All fields are optional except Time +// — any missing data simply renders as empty in the admin UI. +type OpsLogEntry struct { + Time time.Time `json:"time"` + Method string `json:"method,omitempty"` + Path string `json:"path,omitempty"` + Status int `json:"status"` + LatencyMs int64 `json:"latency_ms"` + Model string `json:"model,omitempty"` + Stream bool `json:"stream,omitempty"` + AccountID int64 `json:"account_id,omitempty"` + GroupID int64 `json:"group_id,omitempty"` + APIKeyID int64 `json:"api_key_id,omitempty"` + UserID int64 `json:"user_id,omitempty"` + Turns int `json:"turns,omitempty"` + PromptChars int `json:"prompt_chars,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` + ErrorDetail string `json:"error_detail,omitempty"` + UpstreamCode int `json:"upstream_status,omitempty"` +} + +// OpsLogFilter restricts which entries a subscriber receives. Empty fields +// match everything; non-empty fields are AND-combined. +type OpsLogFilter struct { + MinStatus int + Model string + AccountID int64 + GroupID int64 + MinLatencyMs int64 +} + +// matches reports whether the entry passes the filter. +func (f OpsLogFilter) matches(e *OpsLogEntry) bool { + if f.MinStatus > 0 && e.Status < f.MinStatus { + return false + } + if f.Model != "" && e.Model != f.Model { + return false + } + if f.AccountID > 0 && e.AccountID != f.AccountID { + return false + } + if f.GroupID > 0 && e.GroupID != f.GroupID { + return false + } + if f.MinLatencyMs > 0 && e.LatencyMs < f.MinLatencyMs { + return false + } + return true +} + +// OpsLogBroadcaster is a lock-free-ish fan-out broadcaster with a bounded +// ring buffer for history (so freshly connected clients can prime their UI) +// and per-subscriber non-blocking sends (so a slow client never stalls a +// publish on the hot request path). +type OpsLogBroadcaster struct { + subsMu sync.RWMutex + subscribers map[int64]*opsLogSubscription + nextID atomic.Int64 + + historyMu sync.Mutex + history []OpsLogEntry + histHead int + histLen int + histCap int + + // publishedTotal / droppedTotal expose simple ops counters for an + // admin dashboard cell. Atomic so callers don't need to lock. + publishedTotal atomic.Int64 + droppedTotal atomic.Int64 +} + +type opsLogSubscription struct { + ch chan OpsLogEntry + filter OpsLogFilter + // closed is set atomically by unsubscribe() before any cleanup. Publish + // reads this flag under subsMu.RLock and skips closed subscriptions + // instead of attempting a send-on-closed-channel (which would panic). + closed atomic.Bool +} + +// NewOpsLogBroadcaster constructs a broadcaster. historyCap controls how +// many recent entries are kept for newly connected clients; 1000 is a sane +// default. Pass historyCap<=0 to disable the buffer entirely. +func NewOpsLogBroadcaster(historyCap int) *OpsLogBroadcaster { + if historyCap < 0 { + historyCap = 0 + } + b := &OpsLogBroadcaster{ + subscribers: make(map[int64]*opsLogSubscription), + histCap: historyCap, + } + if historyCap > 0 { + b.history = make([]OpsLogEntry, historyCap) + } + return b +} + +// Publish fans the entry out to every matching subscriber and appends it to +// the history buffer. Never blocks: if a subscriber's channel is full, the +// entry is dropped for that subscriber and the broadcaster's drop counter +// is incremented. Hot-path safe. +// +// The same entry value is delivered (by value) to all subscribers and to +// the ring buffer — no shared mutable pointer is leaked, so subscribers +// holding references to past entries cannot observe later mutations. +func (b *OpsLogBroadcaster) Publish(entry OpsLogEntry) { + if entry.Time.IsZero() { + entry.Time = time.Now() + } + b.publishedTotal.Add(1) + + b.appendHistory(entry) + + b.subsMu.RLock() + subs := make([]*opsLogSubscription, 0, len(b.subscribers)) + for _, s := range b.subscribers { + subs = append(subs, s) + } + b.subsMu.RUnlock() + + for _, s := range subs { + // Skip subscriptions that have been unsubscribed since we snapped + // the list. Without this check, a concurrent unsubscribe → close(ch) + // would race the send below and panic on send-to-closed-channel. + if s.closed.Load() { + continue + } + if !s.filter.matches(&entry) { + continue + } + select { + case s.ch <- entry: + default: + b.droppedTotal.Add(1) + } + } +} + +// Subscribe registers a new listener. The returned channel is buffered; +// callers MUST drain it. Cancel by calling the returned unsubscribe func +// (idempotent and safe from any goroutine). +// +// IMPORTANT: unsubscribe does NOT close the channel. Closing it would race +// with concurrent Publish goroutines that may already be holding a snapshot +// of the subscription pointer (causing a panic on send-to-closed-channel). +// Instead, unsubscribe (a) sets the closed atomic flag — Publish skips sends +// to flagged subs — and (b) removes from the subscriber map. Any in-flight +// send that slips past the flag check still proceeds harmlessly into the +// channel buffer and is garbage-collected with the channel once the caller +// drops its reference. Subscribers that need to know the broadcaster is +// done with them should rely on the parent ctx, not channel close. +func (b *OpsLogBroadcaster) Subscribe(filter OpsLogFilter, bufSize int) (<-chan OpsLogEntry, func()) { + if bufSize <= 0 { + bufSize = 1024 + } + id := b.nextID.Add(1) + sub := &opsLogSubscription{ + ch: make(chan OpsLogEntry, bufSize), + filter: filter, + } + b.subsMu.Lock() + b.subscribers[id] = sub + b.subsMu.Unlock() + + var unsubOnce sync.Once + unsubscribe := func() { + unsubOnce.Do(func() { + sub.closed.Store(true) + b.subsMu.Lock() + delete(b.subscribers, id) + b.subsMu.Unlock() + }) + } + return sub.ch, unsubscribe +} + +// Snapshot returns a copy of the recent history (oldest → newest), filtered +// by the given filter. Used by /admin/ops/logs/recent to prime newly opened +// dashboards before live events arrive. +func (b *OpsLogBroadcaster) Snapshot(filter OpsLogFilter, maxEntries int) []OpsLogEntry { + b.historyMu.Lock() + defer b.historyMu.Unlock() + + if b.histLen == 0 { + return nil + } + out := make([]OpsLogEntry, 0, b.histLen) + start := b.histHead - b.histLen + if start < 0 { + start += b.histCap + } + for i := 0; i < b.histLen; i++ { + idx := (start + i) % b.histCap + e := b.history[idx] + if !filter.matches(&e) { + continue + } + out = append(out, e) + } + if maxEntries > 0 && len(out) > maxEntries { + out = out[len(out)-maxEntries:] + } + return out +} + +// Stats reports cumulative publish/drop counts and the current subscriber +// count for diagnostics. +func (b *OpsLogBroadcaster) Stats() (published, dropped int64, subscribers int) { + b.subsMu.RLock() + subscribers = len(b.subscribers) + b.subsMu.RUnlock() + return b.publishedTotal.Load(), b.droppedTotal.Load(), subscribers +} + +func (b *OpsLogBroadcaster) appendHistory(e OpsLogEntry) { + if b.histCap == 0 { + return + } + b.historyMu.Lock() + defer b.historyMu.Unlock() + b.history[b.histHead] = e + b.histHead = (b.histHead + 1) % b.histCap + if b.histLen < b.histCap { + b.histLen++ + } +} diff --git a/backend/internal/service/ops_log_broadcaster_test.go b/backend/internal/service/ops_log_broadcaster_test.go new file mode 100644 index 00000000..c6ff261b --- /dev/null +++ b/backend/internal/service/ops_log_broadcaster_test.go @@ -0,0 +1,267 @@ +package service + +import ( + "sync" + "testing" + "time" +) + +func TestOpsLogBroadcaster_FanOutDeliversToMatchingSubscribers(t *testing.T) { + b := NewOpsLogBroadcaster(16) + + chHigh, unHigh := b.Subscribe(OpsLogFilter{MinStatus: 500}, 8) + defer unHigh() + chAll, unAll := b.Subscribe(OpsLogFilter{}, 8) + defer unAll() + + b.Publish(OpsLogEntry{Status: 200, Model: "claude-sonnet-4.6"}) + b.Publish(OpsLogEntry{Status: 503, Model: "kimi-k2.5"}) + + got200 := receiveOrTimeout(t, chAll, 200*time.Millisecond) + got503 := receiveOrTimeout(t, chAll, 200*time.Millisecond) + if got200.Status != 200 || got503.Status != 503 { + t.Fatalf("unexpected fan-out: %d / %d", got200.Status, got503.Status) + } + + gotHigh := receiveOrTimeout(t, chHigh, 200*time.Millisecond) + if gotHigh.Status != 503 { + t.Fatalf("filter MinStatus=500 should drop 200, got %d", gotHigh.Status) + } + expectNoMessage(t, chHigh, 50*time.Millisecond) +} + +func TestOpsLogBroadcaster_FilterByModelAndAccount(t *testing.T) { + b := NewOpsLogBroadcaster(0) + + chKimi, unKimi := b.Subscribe(OpsLogFilter{Model: "kimi-k2.5"}, 4) + defer unKimi() + chAcct, unAcct := b.Subscribe(OpsLogFilter{AccountID: 42}, 4) + defer unAcct() + + b.Publish(OpsLogEntry{Status: 200, Model: "claude-sonnet-4.6", AccountID: 1}) + b.Publish(OpsLogEntry{Status: 200, Model: "kimi-k2.5", AccountID: 42}) + + got := receiveOrTimeout(t, chKimi, 200*time.Millisecond) + if got.Model != "kimi-k2.5" { + t.Fatalf("expected kimi entry, got %+v", got) + } + expectNoMessage(t, chKimi, 50*time.Millisecond) + + gotA := receiveOrTimeout(t, chAcct, 200*time.Millisecond) + if gotA.AccountID != 42 { + t.Fatalf("expected account 42, got %d", gotA.AccountID) + } +} + +func TestOpsLogBroadcaster_NeverBlocksOnSlowSubscriber(t *testing.T) { + b := NewOpsLogBroadcaster(0) + + // Subscriber with buffer=1, never reads. After the second Publish, + // the entry must be dropped instead of blocking the publisher. + _, unsub := b.Subscribe(OpsLogFilter{}, 1) + defer unsub() + + done := make(chan struct{}) + go func() { + for i := 0; i < 100; i++ { + b.Publish(OpsLogEntry{Status: 200}) + } + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("publisher blocked on slow subscriber") + } + + _, dropped, _ := b.Stats() + if dropped == 0 { + t.Fatal("expected dropped count > 0 when subscriber buffer overflows") + } +} + +func TestOpsLogBroadcaster_HistorySnapshot(t *testing.T) { + b := NewOpsLogBroadcaster(3) + + for i := 1; i <= 5; i++ { + b.Publish(OpsLogEntry{Status: 200 + i}) + } + + got := b.Snapshot(OpsLogFilter{}, 0) + if len(got) != 3 { + t.Fatalf("expected ring of 3, got %d", len(got)) + } + // Oldest → newest + if got[0].Status != 203 || got[1].Status != 204 || got[2].Status != 205 { + t.Fatalf("expected 203/204/205, got %d/%d/%d", got[0].Status, got[1].Status, got[2].Status) + } +} + +func TestOpsLogBroadcaster_HistoryAppliesFilter(t *testing.T) { + b := NewOpsLogBroadcaster(8) + + b.Publish(OpsLogEntry{Status: 200}) + b.Publish(OpsLogEntry{Status: 500}) + b.Publish(OpsLogEntry{Status: 503}) + + got := b.Snapshot(OpsLogFilter{MinStatus: 500}, 0) + if len(got) != 2 { + t.Fatalf("expected 2 high-status entries, got %d: %+v", len(got), got) + } +} + +func TestOpsLogBroadcaster_UnsubscribeIdempotent(t *testing.T) { + b := NewOpsLogBroadcaster(0) + _, unsub := b.Subscribe(OpsLogFilter{}, 1) + unsub() + unsub() // second call must not panic + unsub() // and a third +} + +func TestOpsLogBroadcaster_ZeroTimeFilledIn(t *testing.T) { + b := NewOpsLogBroadcaster(2) + + b.Publish(OpsLogEntry{Status: 200}) // Time intentionally zero + got := b.Snapshot(OpsLogFilter{}, 0) + if len(got) != 1 { + t.Fatalf("expected 1 entry, got %d", len(got)) + } + if got[0].Time.IsZero() { + t.Fatal("Publish should populate zero Time with time.Now()") + } +} + +func TestOpsLogBroadcaster_ConcurrentSafe(t *testing.T) { + b := NewOpsLogBroadcaster(64) + + // Spin a few subscribers and producers; rely on -race to surface + // any concurrency bugs in the fan-out path. + var wg sync.WaitGroup + for i := 0; i < 4; i++ { + ch, unsub := b.Subscribe(OpsLogFilter{}, 32) + wg.Add(1) + go func() { + defer wg.Done() + defer unsub() + deadline := time.Now().Add(200 * time.Millisecond) + for time.Now().Before(deadline) { + select { + case <-ch: + case <-time.After(5 * time.Millisecond): + } + } + }() + } + + for i := 0; i < 4; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 200; j++ { + b.Publish(OpsLogEntry{Status: 200, Model: "x"}) + } + }() + } + + wg.Wait() + pub, _, _ := b.Stats() + if pub != 800 { + t.Fatalf("expected 800 publishes, got %d", pub) + } +} + +// TestOpsLogBroadcaster_ConcurrentUnsubscribeNoPanic exercises the exact +// race the audit identified: a Publish goroutine has snapped a subscription +// pointer while another goroutine unsubscribes (close(ch)) the moment before +// the send. Without the closed-flag guard in Publish, this races into +// "send on closed channel" and panics. With the guard, Publish observes +// closed=true and skips the send. Run with -race. +func TestOpsLogBroadcaster_ConcurrentUnsubscribeNoPanic(t *testing.T) { + b := NewOpsLogBroadcaster(0) + + var wg sync.WaitGroup + for i := 0; i < 8; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 200; j++ { + ch, unsub := b.Subscribe(OpsLogFilter{}, 1) + // Drain non-blockingly until the next publish lands. + done := make(chan struct{}) + go func() { + defer close(done) + timer := time.NewTimer(50 * time.Millisecond) + defer timer.Stop() + for { + select { + case <-ch: + case <-timer.C: + return + } + } + }() + b.Publish(OpsLogEntry{Status: 200}) + unsub() + b.Publish(OpsLogEntry{Status: 200}) // post-unsub publish must not panic + <-done + } + }() + } + wg.Wait() +} + +// TestOpsLogBroadcaster_SnapshotConcurrentWithPublish ensures Snapshot is +// safe under concurrent Publish (verifies subsMu vs historyMu coexistence). +func TestOpsLogBroadcaster_SnapshotConcurrentWithPublish(t *testing.T) { + b := NewOpsLogBroadcaster(32) + + stop := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + for { + select { + case <-stop: + return + default: + b.Publish(OpsLogEntry{Status: 200}) + } + } + }() + + go func() { + defer wg.Done() + for i := 0; i < 200; i++ { + _ = b.Snapshot(OpsLogFilter{}, 0) + } + }() + + time.Sleep(50 * time.Millisecond) + close(stop) + wg.Wait() +} + +// helpers -------------------------------------------------------------- + +func receiveOrTimeout(t *testing.T, ch <-chan OpsLogEntry, d time.Duration) OpsLogEntry { + t.Helper() + select { + case e := <-ch: + return e + case <-time.After(d): + t.Fatalf("timeout waiting for entry after %s", d) + } + return OpsLogEntry{} +} + +func expectNoMessage(t *testing.T, ch <-chan OpsLogEntry, d time.Duration) { + t.Helper() + select { + case e := <-ch: + t.Fatalf("unexpected message: %+v", e) + case <-time.After(d): + } +} diff --git a/backend/internal/service/windsurf_chat_service.go b/backend/internal/service/windsurf_chat_service.go index 970110dd..0a20f39e 100644 --- a/backend/internal/service/windsurf_chat_service.go +++ b/backend/internal/service/windsurf_chat_service.go @@ -6,7 +6,9 @@ import ( "encoding/hex" "fmt" "log/slog" + "os" "strings" + "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -238,6 +240,22 @@ func (s *WindsurfChatService) chatCascade( } } + toolCalls := result.ToolCalls + if len(toolCalls) == 0 && shouldRunNLUFallback(meta, result.Text) { + if nluCalls := windsurf.ExtractToolCallsNLU(result.Text, availableToolNames(toolPreamble)); len(nluCalls) > 0 { + slog.Info("windsurf_cascade_nlu_fallback", + "model", modelKey, "calls", len(nluCalls), "text_chars", len(result.Text)) + toolCalls = make([]windsurf.NativeToolCall, 0, len(nluCalls)) + for _, c := range nluCalls { + toolCalls = append(toolCalls, windsurf.NativeToolCall{ + ID: c.ID, + Name: c.Name, + ArgumentsJSON: c.ArgumentsJSON, + }) + } + } + } + return &WindsurfChatResponse{ Text: result.Text, Thinking: result.Thinking, @@ -245,10 +263,119 @@ func (s *WindsurfChatService) chatCascade( Mode: "cascade", Usage: result.Usage, FirstTextAt: result.FirstTextAt, - ToolCalls: result.ToolCalls, + ToolCalls: toolCalls, }, nil } +// shouldRunNLUFallback decides whether to spend CPU on the NLU extractor. +// Two cases run the fallback: (a) model is explicitly NLU-flavored — +// always extract, even when no obvious signal, since these models routinely +// emit half-broken intent; (b) flavor is auto and the text shows NLU +// signals. tool_use flavored models (Claude family) are NEVER run through +// NLU because their tags are reliable and an erroneous fallback would +// invent calls the user did not request. Setting WINDSURF_NLU_FALLBACK_DISABLED=1 +// short-circuits all of the above. +func shouldRunNLUFallback(meta *windsurf.ModelMeta, text string) bool { + if text == "" { + return false + } + if isNLUFallbackDisabled() { + return false + } + flavor := windsurf.ResolveEmulationFlavor(meta) + switch flavor { + case windsurf.EmulationFlavorNLU: + return true + case windsurf.EmulationFlavorAuto: + return windsurf.HasNLUSignal(text) + default: + return false + } +} + +// availableToolNames extracts tool names from the cascade tool preamble for +// validation in the NLU extractor. Format heuristic: lines like +// "- TOOL_NAME(...)" or "name: TOOL_NAME". Returns nil when nothing parses +// — extractor still works in best-effort mode. +func availableToolNames(preamble string) []string { + if preamble == "" { + return nil + } + var names []string + seen := make(map[string]struct{}) + for _, raw := range strings.Split(preamble, "\n") { + line := strings.TrimSpace(raw) + if line == "" { + continue + } + if name := extractToolNameFromPreambleLine(line); name != "" { + if _, dup := seen[name]; !dup { + seen[name] = struct{}{} + names = append(names, name) + } + } + } + return names +} + +func extractToolNameFromPreambleLine(line string) string { + trim := strings.TrimLeft(line, "-* \t") + if trim == "" { + return "" + } + // "name: foo" form + if strings.HasPrefix(strings.ToLower(trim), "name:") { + return strings.TrimSpace(trim[len("name:"):]) + } + // "foo(args)" form — take identifier before "(". + if idx := strings.IndexByte(trim, '('); idx > 0 { + candidate := strings.TrimSpace(trim[:idx]) + if isIdentifier(candidate) { + return candidate + } + } + return "" +} + +func isIdentifier(s string) bool { + if s == "" { + return false + } + for i, r := range s { + switch { + case r == '_': + case r >= 'a' && r <= 'z': + case r >= 'A' && r <= 'Z': + case r >= '0' && r <= '9' && i > 0: + default: + return false + } + } + return true +} + +// nluFallbackDisabledFn resolves the NLU-fallback kill switch. Tests may +// replace it via withNLUFallbackDisabledFn to assert behaviour. Production +// uses readNLUFallbackDisabledEnvOnce, which reads the env var exactly once. +var nluFallbackDisabledFn = readNLUFallbackDisabledEnvOnce + +func isNLUFallbackDisabled() bool { + return nluFallbackDisabledFn() +} + +var ( + nluFallbackDisabledCachedOnce sync.Once + nluFallbackDisabledCached bool +) + +func readNLUFallbackDisabledEnvOnce() bool { + nluFallbackDisabledCachedOnce.Do(func() { + v := strings.ToLower(strings.TrimSpace(os.Getenv("WINDSURF_NLU_FALLBACK_DISABLED"))) + nluFallbackDisabledCached = v == "1" || v == "true" || v == "yes" || v == "on" + }) + return nluFallbackDisabledCached +} + // buildCascadeCacheKey 构造 Cascade 复用 cache 的 key。 // 任一组件变化(账号、模型、LS 实例、会话、system prompt)都会自动 cache miss。 func buildCascadeCacheKey(groupID, accountID int64, modelUID, lsEndpoint, sessionHash, sysHash string) string { diff --git a/backend/internal/service/windsurf_google_login_test.go b/backend/internal/service/windsurf_google_login_test.go new file mode 100644 index 00000000..de0b6647 --- /dev/null +++ b/backend/internal/service/windsurf_google_login_test.go @@ -0,0 +1,237 @@ +package service + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/domain" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/windsurf" +) + +// tokenLoginRepoStub is a minimal AccountRepository stub used by +// TestWindsurfAuthService_TokenLogin_*. It implements just FindByCredentialField +// (the only repo method TokenLogin reaches before the validation short-circuits). +// All other methods panic so accidental calls are loud. +type tokenLoginRepoStub struct { + existing []Account + findErr error +} + +func (s *tokenLoginRepoStub) FindByCredentialField(_ context.Context, _, _, _ string) ([]Account, error) { + return s.existing, s.findErr +} + +func (*tokenLoginRepoStub) Create(context.Context, *Account) error { panic("unexpected") } +func (*tokenLoginRepoStub) GetByID(context.Context, int64) (*Account, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) GetByIDs(context.Context, []int64) ([]*Account, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) ExistsByID(context.Context, int64) (bool, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) GetByCRSAccountID(context.Context, string) (*Account, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) FindByExtraField(context.Context, string, any) ([]Account, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) ListCRSAccountIDs(context.Context) (map[string]int64, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) Update(context.Context, *Account) error { panic("unexpected") } +func (*tokenLoginRepoStub) Delete(context.Context, int64) error { panic("unexpected") } +func (*tokenLoginRepoStub) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]Account, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) ListByGroup(context.Context, int64) ([]Account, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) ListActive(context.Context) ([]Account, error) { panic("unexpected") } +func (*tokenLoginRepoStub) ListByPlatform(context.Context, string) ([]Account, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) UpdateLastUsed(context.Context, int64) error { panic("unexpected") } +func (*tokenLoginRepoStub) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error { + panic("unexpected") +} +func (*tokenLoginRepoStub) SetError(context.Context, int64, string) error { panic("unexpected") } +func (*tokenLoginRepoStub) ClearError(context.Context, int64) error { panic("unexpected") } +func (*tokenLoginRepoStub) SetSchedulable(context.Context, int64, bool) error { + panic("unexpected") +} +func (*tokenLoginRepoStub) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) BindGroups(context.Context, int64, []int64) error { + panic("unexpected") +} +func (*tokenLoginRepoStub) ListSchedulable(context.Context) ([]Account, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) ListSchedulableByGroupID(context.Context, int64) ([]Account, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) ListSchedulableByPlatform(context.Context, string) ([]Account, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]Account, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) ListSchedulableByPlatforms(context.Context, []string) ([]Account, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]Account, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) ListSchedulableUngroupedByPlatform(context.Context, string) ([]Account, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) ListSchedulableUngroupedByPlatforms(context.Context, []string) ([]Account, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) SetRateLimited(context.Context, int64, time.Time) error { + panic("unexpected") +} +func (*tokenLoginRepoStub) SetModelRateLimit(context.Context, int64, string, time.Time) error { + panic("unexpected") +} +func (*tokenLoginRepoStub) SetOverloaded(context.Context, int64, time.Time) error { + panic("unexpected") +} +func (*tokenLoginRepoStub) SetTempUnschedulable(context.Context, int64, time.Time, string) error { + panic("unexpected") +} +func (*tokenLoginRepoStub) ClearTempUnschedulable(context.Context, int64) error { + panic("unexpected") +} +func (*tokenLoginRepoStub) ClearRateLimit(context.Context, int64) error { panic("unexpected") } +func (*tokenLoginRepoStub) ClearAntigravityQuotaScopes(context.Context, int64) error { + panic("unexpected") +} +func (*tokenLoginRepoStub) ClearModelRateLimits(context.Context, int64) error { + panic("unexpected") +} +func (*tokenLoginRepoStub) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error { + panic("unexpected") +} +func (*tokenLoginRepoStub) UpdateExtra(context.Context, int64, map[string]any) error { + panic("unexpected") +} +func (*tokenLoginRepoStub) BulkUpdate(context.Context, []int64, AccountBulkUpdate) (int64, error) { + panic("unexpected") +} +func (*tokenLoginRepoStub) IncrementQuotaUsed(context.Context, int64, float64) error { + panic("unexpected") +} +func (*tokenLoginRepoStub) ResetQuotaUsed(context.Context, int64) error { panic("unexpected") } + +// TestWindsurfAuthService_TokenLogin_Validation exercises input validation and +// dedup short-circuits in TokenLogin (these run before any external dependency +// is touched). +func TestWindsurfAuthService_TokenLogin_Validation(t *testing.T) { + tests := []struct { + name string + input *WindsurfTokenLoginInput + repo *tokenLoginRepoStub + wantErr string + }{ + { + name: "empty token rejected", + input: &WindsurfTokenLoginInput{Email: "user@example.com"}, + repo: &tokenLoginRepoStub{}, + wantErr: "token required", + }, + { + name: "duplicate email rejected with conflict error", + input: &WindsurfTokenLoginInput{Token: "fake-token", Email: "dup@example.com"}, + repo: &tokenLoginRepoStub{existing: []Account{{ID: 42}}}, + wantErr: "already exists", + }, + { + name: "find error propagated", + input: &WindsurfTokenLoginInput{Token: "fake-token", Email: "boom@example.com"}, + repo: &tokenLoginRepoStub{findErr: errors.New("db down")}, + wantErr: "check existing account", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + svc := &WindsurfAuthService{ + accountRepo: tc.repo, + authClient: &windsurf.AuthClient{}, + } + + _, err := svc.TokenLogin(context.Background(), tc.input) + if err == nil { + t.Fatalf("expected error containing %q, got nil", tc.wantErr) + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("expected error containing %q, got %q", tc.wantErr, err.Error()) + } + }) + } +} + +// TestWindsurfAuthService_TokenLogin_PlatformConst guards against accidental +// drift in the platform/type constants used to persist the account. +func TestWindsurfAuthService_TokenLogin_PlatformConst(t *testing.T) { + if domain.PlatformWindsurf == "" { + t.Fatal("PlatformWindsurf constant is empty") + } + if domain.AccountTypeWindsurfSession == "" { + t.Fatal("AccountTypeWindsurfSession constant is empty") + } +} + +// TestWindsurfAuthService_TokenLogin_TypedErrors verifies that validation +// failures surface as ApplicationError with the right HTTP code, so the +// handler maps them to 4xx instead of 500. +func TestWindsurfAuthService_TokenLogin_TypedErrors(t *testing.T) { + cases := []struct { + name string + input *WindsurfTokenLoginInput + repo *tokenLoginRepoStub + wantCode int + }{ + { + name: "missing token returns 400", + input: &WindsurfTokenLoginInput{Email: "x@y.z"}, + repo: &tokenLoginRepoStub{}, + wantCode: 400, + }, + { + name: "duplicate email returns 409", + input: &WindsurfTokenLoginInput{Token: "tok", Email: "dup@example.com"}, + repo: &tokenLoginRepoStub{existing: []Account{{ID: 1}}}, + wantCode: 409, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + svc := &WindsurfAuthService{ + accountRepo: tc.repo, + authClient: &windsurf.AuthClient{}, + } + + _, err := svc.TokenLogin(context.Background(), tc.input) + if err == nil { + t.Fatalf("expected error, got nil") + } + if got := infraerrors.Code(err); got != tc.wantCode { + t.Fatalf("expected HTTP code %d, got %d (err=%v)", tc.wantCode, got, err) + } + }) + } +} diff --git a/backend/internal/service/windsurf_nlu_fallback_test.go b/backend/internal/service/windsurf_nlu_fallback_test.go new file mode 100644 index 00000000..59289e09 --- /dev/null +++ b/backend/internal/service/windsurf_nlu_fallback_test.go @@ -0,0 +1,123 @@ +package service + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/windsurf" +) + +func TestShouldRunNLUFallback(t *testing.T) { + // Force the fallback-enabled path; tests below cover the disabled path + // via withNLUFallbackDisabledFn directly. The production env-Once design + // makes t.Setenv ineffective once init has fired in a prior test. + prev := nluFallbackDisabledFn + nluFallbackDisabledFn = func() bool { return false } + t.Cleanup(func() { nluFallbackDisabledFn = prev }) + + tests := []struct { + name string + meta *windsurf.ModelMeta + text string + want bool + }{ + {"empty text never runs", &windsurf.ModelMeta{Provider: "zhipu"}, "", false}, + {"explicit nlu always runs", &windsurf.ModelMeta{Provider: "zhipu"}, "Sure, I helped.", true}, + {"explicit override nlu wins over provider", &windsurf.ModelMeta{Provider: "anthropic", EmulationFlavor: "nlu"}, "any text", true}, + {"explicit tool_use never runs", &windsurf.ModelMeta{Provider: "anthropic"}, "function: edit_file {}", false}, + {"auto with no signal skips", &windsurf.ModelMeta{Provider: "openai"}, "Just a chat reply.", false}, + {"auto with signal runs", &windsurf.ModelMeta{Provider: "openai"}, `function: edit_file {"x":1}`, true}, + {"nil meta auto, signal", nil, "function: edit_file {}", true}, + {"nil meta auto, no signal", nil, "Hello world", false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := shouldRunNLUFallback(tc.meta, tc.text); got != tc.want { + t.Fatalf("shouldRunNLUFallback = %v, want %v", got, tc.want) + } + }) + } +} + +func TestShouldRunNLUFallback_DisabledByOverride(t *testing.T) { + prev := nluFallbackDisabledFn + nluFallbackDisabledFn = func() bool { return true } + t.Cleanup(func() { nluFallbackDisabledFn = prev }) + + got := shouldRunNLUFallback(&windsurf.ModelMeta{Provider: "zhipu"}, "function: x {}") + if got { + t.Fatal("expected disabled by override, got enabled") + } +} + +func TestAvailableToolNames(t *testing.T) { + tests := []struct { + name string + preamble string + want []string + }{ + {"empty preamble", "", nil}, + { + name: "function-call format with parens", + preamble: `Tools: +- edit_file(path, content) +- read_file(path) +- run_command(cmd)`, + want: []string{"edit_file", "read_file", "run_command"}, + }, + { + name: "name: form", + preamble: `tools: + - name: foo + - name: bar`, + want: []string{"foo", "bar"}, + }, + { + name: "deduplicates", + preamble: `- edit_file(p) +- edit_file(p)`, + want: []string{"edit_file"}, + }, + { + name: "ignores non-identifier lines", + preamble: `Use the following tools:`, + want: nil, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := availableToolNames(tc.preamble) + if len(got) != len(tc.want) { + t.Fatalf("expected %v, got %v", tc.want, got) + } + for i := range got { + if got[i] != tc.want[i] { + t.Fatalf("expected %v, got %v", tc.want, got) + } + } + }) + } +} + +func TestIsIdentifier(t *testing.T) { + cases := []struct { + s string + want bool + }{ + {"", false}, + {"foo", true}, + {"foo_bar", true}, + {"FooBar", true}, + {"foo123", true}, + {"123foo", false}, + {"foo bar", false}, + {"foo-bar", false}, + {"foo.bar", false}, + } + for _, tc := range cases { + t.Run(tc.s, func(t *testing.T) { + if got := isIdentifier(tc.s); got != tc.want { + t.Fatalf("isIdentifier(%q) = %v, want %v", tc.s, got, tc.want) + } + }) + } +} diff --git a/backend/internal/service/windsurf_services.go b/backend/internal/service/windsurf_services.go index 4f77e667..c1d70823 100644 --- a/backend/internal/service/windsurf_services.go +++ b/backend/internal/service/windsurf_services.go @@ -8,6 +8,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/domain" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/windsurf" ) @@ -244,6 +245,124 @@ func (s *WindsurfAuthService) Login(ctx context.Context, input *WindsurfLoginInp }, nil } +type WindsurfTokenLoginInput struct { + Token string + Email string + Name string + Notes *string + ProxyID *int64 + GroupIDs []int64 + Concurrency int + Priority int + ProbeAfter bool + LSInstanceID string +} + +// TokenLogin registers a Windsurf account by exchanging a token obtained from +// https://windsurf.com/show-auth-token (after the user signed in on +// windsurf.com via Google / GitHub / email) with Codeium's register_user +// endpoint. Because the OAuth round-trip happens entirely on windsurf.com, +// no Firebase Referer-restricted requests originate from our own domain — +// this is the only flow that works for self-hosted deployments. +func (s *WindsurfAuthService) TokenLogin(ctx context.Context, input *WindsurfTokenLoginInput) (*WindsurfLoginOutput, error) { + if input.Token == "" { + return nil, infraerrors.BadRequest("WINDSURF_TOKEN_REQUIRED", "token required") + } + + if input.Email != "" { + existing, err := s.accountRepo.FindByCredentialField(ctx, domain.PlatformWindsurf, "email", input.Email) + if err != nil { + return nil, fmt.Errorf("check existing account: %w", err) + } + if len(existing) > 0 { + return nil, infraerrors.Conflict( + "WINDSURF_ACCOUNT_EMAIL_EXISTS", + "windsurf account with this email already exists", + ) + } + } + + proxyURL := "" + if input.ProxyID != nil { + proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID) + if err != nil { + return nil, fmt.Errorf("get proxy: %w", err) + } + proxyURL = proxy.URL() + } + + reg, err := s.authClient.RegisterWithCodeiumDefault(ctx, input.Token, proxyURL) + if err != nil { + return nil, fmt.Errorf("codeium register (token): %w", err) + } + + emailForRecord := input.Email + if emailForRecord == "" { + emailForRecord = reg.Name // best-effort label when caller didn't supply one + } + + creds := WindsurfCredentials{ + Email: emailForRecord, + APIKey: reg.APIKey, + AuthMethod: "token", + APIServerURL: reg.APIServerURL, + RegisteredAt: time.Now().Format(time.RFC3339), + } + credMap := StoreWindsurfCredentials(creds) + + extra := WindsurfExtra{ + Profile: WindsurfProfileSnapshot{TierSource: "login"}, + Refresh: WindsurfRefreshState{}, + } + if input.LSInstanceID != "" { + extra.LSBinding = WindsurfLSBinding{ContainerID: input.LSInstanceID} + } + extraMap := StoreWindsurfExtra(extra) + + name := input.Name + if name == "" { + name = reg.Name + } + if name == "" { + name = emailForRecord + } + if name == "" { + name = "Windsurf Account" + } + + concurrency := input.Concurrency + if concurrency <= 0 { + concurrency = 1 + } + + createInput := &CreateAccountInput{ + Name: name, + Notes: input.Notes, + Platform: domain.PlatformWindsurf, + Type: domain.AccountTypeWindsurfSession, + Credentials: credMap, + Extra: extraMap, + ProxyID: input.ProxyID, + Concurrency: concurrency, + Priority: input.Priority, + GroupIDs: input.GroupIDs, + } + + account, err := s.adminSvc.CreateAccount(ctx, createInput) + if err != nil { + return nil, fmt.Errorf("create account: %w", err) + } + + return &WindsurfLoginOutput{ + AccountID: account.ID, + Email: emailForRecord, + Tier: "unknown", + AuthMethod: "token", + APIKeyPresent: reg.APIKey != "", + RefreshTokenPresent: false, + }, nil +} + func (s *WindsurfAuthService) BatchLogin(ctx context.Context, items []string, proxyID *int64, groupIDs []int64, concurrency, priority int, probeAfter bool) ([]WindsurfBatchResult, error) { results := make([]WindsurfBatchResult, 0, len(items)) diff --git a/backend/internal/service/windsurf_tier_access_service.go b/backend/internal/service/windsurf_tier_access_service.go new file mode 100644 index 00000000..2e954225 --- /dev/null +++ b/backend/internal/service/windsurf_tier_access_service.go @@ -0,0 +1,216 @@ +package service + +import ( + "context" + "fmt" + "sort" + "sync" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/domain" + "github.com/Wei-Shaw/sub2api/internal/pkg/windsurf" +) + +// WindsurfTierAccessRow describes account-pool availability for one model. +// +// Counts are exclusive: free + pro + trial sums to total schedulable accounts +// that can serve the model; blocked excludes accounts whose schedulable=false +// or capability check failed. Total = free + pro + trial. +type WindsurfTierAccessRow struct { + Model string `json:"model"` + Provider string `json:"provider"` + EmulationFlavor string `json:"emulation_flavor"` + Free int `json:"free"` + Pro int `json:"pro"` + Trial int `json:"trial"` + Blocked int `json:"blocked"` + Total int `json:"total"` +} + +// WindsurfTierAccessSnapshot is the cacheable result of a tier-access scan. +type WindsurfTierAccessSnapshot struct { + GeneratedAt time.Time `json:"generated_at"` + Accounts int `json:"accounts_considered"` + Rows []WindsurfTierAccessRow `json:"rows"` +} + +// WindsurfTierAccessService aggregates per-model availability from the +// account pool. The Snapshot result is cached for cacheTTL to keep this +// cheap when called from a busy admin dashboard. +type WindsurfTierAccessService struct { + accountRepo AccountRepository + cacheTTL time.Duration + + cache atomic.Pointer[WindsurfTierAccessSnapshot] + mu sync.Mutex // guards rebuild +} + +// NewWindsurfTierAccessService creates a service with a default 60s cache. +func NewWindsurfTierAccessService(accountRepo AccountRepository) *WindsurfTierAccessService { + return &WindsurfTierAccessService{ + accountRepo: accountRepo, + cacheTTL: 60 * time.Second, + } +} + +// Snapshot returns the latest tier-access snapshot, rebuilding from the +// repository when the cache is stale or absent. Concurrent callers during +// a rebuild get the freshly generated snapshot; only one rebuild runs at a +// time. +func (s *WindsurfTierAccessService) Snapshot(ctx context.Context) (*WindsurfTierAccessSnapshot, error) { + if cached := s.cache.Load(); cached != nil && time.Since(cached.GeneratedAt) < s.cacheTTL { + return cached, nil + } + + s.mu.Lock() + defer s.mu.Unlock() + // Re-check after acquiring lock — another goroutine may have rebuilt. + if cached := s.cache.Load(); cached != nil && time.Since(cached.GeneratedAt) < s.cacheTTL { + return cached, nil + } + + snap, err := s.build(ctx) + if err != nil { + return nil, err + } + s.cache.Store(snap) + return snap, nil +} + +func (s *WindsurfTierAccessService) build(ctx context.Context) (*WindsurfTierAccessSnapshot, error) { + accounts, err := s.accountRepo.ListByPlatform(ctx, domain.PlatformWindsurf) + if err != nil { + return nil, fmt.Errorf("list windsurf accounts: %w", err) + } + + byModel := make(map[string]*tierCounter) + getCounter := func(model string) *tierCounter { + c, ok := byModel[model] + if !ok { + c = &tierCounter{} + byModel[model] = c + } + return c + } + + considered := 0 + for i := range accounts { + acct := &accounts[i] + creds := LoadWindsurfCredentials(acct.Credentials) + extra := LoadWindsurfExtra(acct.Extra) + if creds.APIKey == "" { + continue // un-registered account; cannot serve traffic + } + considered++ + + tierBucket := tierBucketFor(creds.Tier) + schedulable := acct.IsSchedulable() + + // 1) Walk the account's allowedModels (authoritative when present). + seen := make(map[string]struct{}) + for _, am := range extra.UserStatus.AllowedModels { + model := am.ModelKey + if model == "" { + model = am.Alias + } + if model == "" { + continue + } + seen[model] = struct{}{} + c := getCounter(model) + if !schedulable { + c.blocked++ + continue + } + capCheck := extra.Capabilities[model] + if !capabilityOK(capCheck) { + c.blocked++ + continue + } + incTier(c, tierBucket) + } + + // 2) Fall back to capability map for any model not already counted + // (older accounts may have probe data without allowedModels). + for model, capCheck := range extra.Capabilities { + if _, ok := seen[model]; ok { + continue + } + c := getCounter(model) + if !schedulable || !capabilityOK(capCheck) { + c.blocked++ + continue + } + incTier(c, tierBucket) + } + } + + rows := make([]WindsurfTierAccessRow, 0, len(byModel)) + for model, c := range byModel { + meta := windsurf.GetModelInfo(model) + row := WindsurfTierAccessRow{ + Model: model, + Free: c.free, + Pro: c.pro, + Trial: c.trial, + Blocked: c.blocked, + Total: c.free + c.pro + c.trial, + } + if meta != nil { + row.Provider = meta.Provider + row.EmulationFlavor = windsurf.ResolveEmulationFlavor(meta) + } + rows = append(rows, row) + } + sort.Slice(rows, func(i, j int) bool { + if rows[i].Total != rows[j].Total { + return rows[i].Total > rows[j].Total + } + return rows[i].Model < rows[j].Model + }) + + return &WindsurfTierAccessSnapshot{ + GeneratedAt: time.Now(), + Accounts: considered, + Rows: rows, + }, nil +} + +// tierCounter is the per-model tally used during a snapshot build. +type tierCounter struct { + free, pro, trial, blocked int +} + +func capabilityOK(c WindsurfModelCapability) bool { + if c.Reason == "not_entitled" { + return false + } + return c.Available +} + +func tierBucketFor(tier string) string { + switch tier { + case "pro": + return "pro" + case "trial": + return "trial" + case "free": + return "free" + default: + // Unknown tiers (legacy / pre-probe accounts) bucket as free for + // display purposes — they're typically free until probed. + return "free" + } +} + +func incTier(c *tierCounter, bucket string) { + switch bucket { + case "pro": + c.pro++ + case "trial": + c.trial++ + default: + c.free++ + } +} diff --git a/backend/internal/service/windsurf_tier_access_service_test.go b/backend/internal/service/windsurf_tier_access_service_test.go new file mode 100644 index 00000000..f1e6ee93 --- /dev/null +++ b/backend/internal/service/windsurf_tier_access_service_test.go @@ -0,0 +1,266 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/domain" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +// tierAccessRepoStub satisfies AccountRepository with a hand-rolled +// ListByPlatform; every other method panics so accidental calls are loud. +type tierAccessRepoStub struct { + accounts []Account + err error +} + +func (s *tierAccessRepoStub) ListByPlatform(_ context.Context, _ string) ([]Account, error) { + return s.accounts, s.err +} + +func (*tierAccessRepoStub) Create(context.Context, *Account) error { panic("unexpected") } +func (*tierAccessRepoStub) GetByID(context.Context, int64) (*Account, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) GetByIDs(context.Context, []int64) ([]*Account, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) ExistsByID(context.Context, int64) (bool, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) GetByCRSAccountID(context.Context, string) (*Account, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) FindByExtraField(context.Context, string, any) ([]Account, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) FindByCredentialField(context.Context, string, string, string) ([]Account, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) ListCRSAccountIDs(context.Context) (map[string]int64, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) Update(context.Context, *Account) error { panic("unexpected") } +func (*tierAccessRepoStub) Delete(context.Context, int64) error { panic("unexpected") } +func (*tierAccessRepoStub) List(context.Context, pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64, string) ([]Account, *pagination.PaginationResult, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) ListByGroup(context.Context, int64) ([]Account, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) ListActive(context.Context) ([]Account, error) { panic("unexpected") } +func (*tierAccessRepoStub) UpdateLastUsed(context.Context, int64) error { panic("unexpected") } +func (*tierAccessRepoStub) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error { + panic("unexpected") +} +func (*tierAccessRepoStub) SetError(context.Context, int64, string) error { panic("unexpected") } +func (*tierAccessRepoStub) ClearError(context.Context, int64) error { panic("unexpected") } +func (*tierAccessRepoStub) SetSchedulable(context.Context, int64, bool) error { + panic("unexpected") +} +func (*tierAccessRepoStub) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) BindGroups(context.Context, int64, []int64) error { + panic("unexpected") +} +func (*tierAccessRepoStub) ListSchedulable(context.Context) ([]Account, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) ListSchedulableByGroupID(context.Context, int64) ([]Account, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) ListSchedulableByPlatform(context.Context, string) ([]Account, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]Account, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) ListSchedulableByPlatforms(context.Context, []string) ([]Account, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]Account, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) ListSchedulableUngroupedByPlatform(context.Context, string) ([]Account, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) ListSchedulableUngroupedByPlatforms(context.Context, []string) ([]Account, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) SetRateLimited(context.Context, int64, time.Time) error { + panic("unexpected") +} +func (*tierAccessRepoStub) SetModelRateLimit(context.Context, int64, string, time.Time) error { + panic("unexpected") +} +func (*tierAccessRepoStub) SetOverloaded(context.Context, int64, time.Time) error { + panic("unexpected") +} +func (*tierAccessRepoStub) SetTempUnschedulable(context.Context, int64, time.Time, string) error { + panic("unexpected") +} +func (*tierAccessRepoStub) ClearTempUnschedulable(context.Context, int64) error { + panic("unexpected") +} +func (*tierAccessRepoStub) ClearRateLimit(context.Context, int64) error { panic("unexpected") } +func (*tierAccessRepoStub) ClearAntigravityQuotaScopes(context.Context, int64) error { + panic("unexpected") +} +func (*tierAccessRepoStub) ClearModelRateLimits(context.Context, int64) error { + panic("unexpected") +} +func (*tierAccessRepoStub) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error { + panic("unexpected") +} +func (*tierAccessRepoStub) UpdateExtra(context.Context, int64, map[string]any) error { + panic("unexpected") +} +func (*tierAccessRepoStub) BulkUpdate(context.Context, []int64, AccountBulkUpdate) (int64, error) { + panic("unexpected") +} +func (*tierAccessRepoStub) IncrementQuotaUsed(context.Context, int64, float64) error { + panic("unexpected") +} +func (*tierAccessRepoStub) ResetQuotaUsed(context.Context, int64) error { panic("unexpected") } + +func mkAccount(id int64, tier string, status string, allowed []WindsurfAllowedModel, caps map[string]WindsurfModelCapability) Account { + creds := WindsurfCredentials{ + APIKey: "key-" + tier, + Tier: tier, + } + extra := WindsurfExtra{ + UserStatus: WindsurfUserStatusSnapshot{AllowedModels: allowed}, + Capabilities: caps, + } + return Account{ + ID: id, + Platform: domain.PlatformWindsurf, + Status: status, + Schedulable: status == StatusActive, + Credentials: StoreWindsurfCredentials(creds), + Extra: StoreWindsurfExtra(extra), + } +} + +func TestWindsurfTierAccessService_Snapshot_HappyPath(t *testing.T) { + repo := &tierAccessRepoStub{ + accounts: []Account{ + mkAccount(1, "free", StatusActive, + []WindsurfAllowedModel{{ModelKey: "gemini-2.5-flash"}, {ModelKey: "kimi-k2"}}, + nil), + mkAccount(2, "pro", StatusActive, + []WindsurfAllowedModel{{ModelKey: "gemini-2.5-flash"}, {ModelKey: "claude-sonnet-4.6"}}, + nil), + mkAccount(3, "trial", StatusActive, + []WindsurfAllowedModel{{ModelKey: "claude-sonnet-4.6"}}, + nil), + }, + } + svc := NewWindsurfTierAccessService(repo) + snap, err := svc.Snapshot(context.Background()) + if err != nil { + t.Fatalf("snapshot: %v", err) + } + if snap.Accounts != 3 { + t.Fatalf("expected 3 accounts considered, got %d", snap.Accounts) + } + rowsByModel := make(map[string]WindsurfTierAccessRow) + for _, r := range snap.Rows { + rowsByModel[r.Model] = r + } + if got := rowsByModel["gemini-2.5-flash"]; got.Free != 1 || got.Pro != 1 || got.Trial != 0 || got.Total != 2 { + t.Fatalf("gemini-2.5-flash unexpected counts: %+v", got) + } + if got := rowsByModel["claude-sonnet-4.6"]; got.Free != 0 || got.Pro != 1 || got.Trial != 1 || got.Total != 2 { + t.Fatalf("claude-sonnet-4.6 unexpected counts: %+v", got) + } + if got := rowsByModel["kimi-k2"]; got.Free != 1 { + t.Fatalf("kimi-k2 unexpected counts: %+v", got) + } +} + +func TestWindsurfTierAccessService_Snapshot_BlockedAccountsCounted(t *testing.T) { + caps := map[string]WindsurfModelCapability{ + "gemini-2.5-flash": {Available: false, Reason: "not_entitled"}, + } + repo := &tierAccessRepoStub{ + accounts: []Account{ + mkAccount(1, "free", StatusActive, nil, caps), + mkAccount(2, "free", "paused", []WindsurfAllowedModel{{ModelKey: "gemini-2.5-flash"}}, nil), + }, + } + svc := NewWindsurfTierAccessService(repo) + snap, _ := svc.Snapshot(context.Background()) + row := findTierRow(snap, "gemini-2.5-flash") + if row == nil { + t.Fatal("expected gemini-2.5-flash row") + } + if row.Blocked != 2 || row.Total != 0 { + t.Fatalf("expected blocked=2 total=0, got %+v", row) + } +} + +func TestWindsurfTierAccessService_Snapshot_SkipsUnregisteredAccounts(t *testing.T) { + acct := Account{ + ID: 1, + Platform: domain.PlatformWindsurf, + Status: StatusActive, + Schedulable: true, + Credentials: StoreWindsurfCredentials(WindsurfCredentials{Email: "a@b.c"}), // no APIKey + } + svc := NewWindsurfTierAccessService(&tierAccessRepoStub{accounts: []Account{acct}}) + snap, _ := svc.Snapshot(context.Background()) + if snap.Accounts != 0 { + t.Fatalf("expected accounts considered=0, got %d", snap.Accounts) + } + if len(snap.Rows) != 0 { + t.Fatalf("expected no rows, got %+v", snap.Rows) + } +} + +func TestWindsurfTierAccessService_Snapshot_PropagatesRepoError(t *testing.T) { + svc := NewWindsurfTierAccessService(&tierAccessRepoStub{err: errors.New("db down")}) + if _, err := svc.Snapshot(context.Background()); err == nil { + t.Fatal("expected error") + } +} + +func TestWindsurfTierAccessService_Snapshot_CachesWithinTTL(t *testing.T) { + repo := &tierAccessRepoStub{ + accounts: []Account{ + mkAccount(1, "free", StatusActive, []WindsurfAllowedModel{{ModelKey: "x"}}, nil), + }, + } + svc := NewWindsurfTierAccessService(repo) + first, _ := svc.Snapshot(context.Background()) + + // Pointer equality is the cache-hit signal: atomic.Pointer.Store fires + // only on rebuild, so a returned pointer identical to the prior call + // proves the build() path was skipped. We mutate the underlying repo + // to a state that would yield a different snapshot if the rebuild + // actually ran — the assertion below catches a regression in either + // direction (TTL gate broken, or sync.Once style misuse). The 60s + // default TTL is large enough that this test never sees an expiry + // during normal CI runs. + repo.accounts = nil + second, _ := svc.Snapshot(context.Background()) + if first != second { + t.Fatal("expected cached pointer reuse") + } +} + +func findTierRow(snap *WindsurfTierAccessSnapshot, model string) *WindsurfTierAccessRow { + for i := range snap.Rows { + if snap.Rows[i].Model == model { + return &snap.Rows[i] + } + } + return nil +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index dfafa94e..699307e4 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -528,6 +528,8 @@ var ProviderSet = wire.NewSet( ProvideWindsurfTokenProvider, ProvideWindsurfRefreshService, ProvideWindsurfProbeService, + ProvideWindsurfTierAccessService, + ProvideOpsLogBroadcaster, ProvideChannelMonitorService, ProvideChannelMonitorRunner, NewChannelMonitorRequestTemplateService, @@ -546,6 +548,26 @@ func ProvideWindsurfAuthService(cfg *config.Config, accountRepo AccountRepositor return NewWindsurfAuthService(cfg.Windsurf, accountRepo, proxyRepo, adminSvc) } +// ProvideWindsurfTierAccessService creates the tier-access aggregator +// (nil when windsurf is disabled). +func ProvideWindsurfTierAccessService(cfg *config.Config, accountRepo AccountRepository) *WindsurfTierAccessService { + if !cfg.Windsurf.Enabled { + return nil + } + return NewWindsurfTierAccessService(accountRepo) +} + +// ProvideOpsLogBroadcaster builds the in-memory ops log fan-out. +// +// Always returns a non-nil broadcaster — middleware/handler call sites +// rely on a stable receiver and gracefully no-op when the feature is +// effectively disabled (e.g. monitoring globally turned off via OpsService). +// History capacity is fixed at 1000 entries to bound memory. +func ProvideOpsLogBroadcaster() *OpsLogBroadcaster { + const historyCap = 1000 + return NewOpsLogBroadcaster(historyCap) +} + // ProvideWindsurfLSService creates WindsurfLSService (nil when windsurf is disabled). func ProvideWindsurfLSService(cfg *config.Config) *WindsurfLSService { if !cfg.Windsurf.Enabled { diff --git a/frontend/src/api/admin/ops.ts b/frontend/src/api/admin/ops.ts index ac58eff4..0890fc4d 100644 --- a/frontend/src/api/admin/ops.ts +++ b/frontend/src/api/admin/ops.ts @@ -1418,7 +1418,190 @@ export const opsAPI = { updateMetricThresholds, listSystemLogs, cleanupSystemLogs, - getSystemLogSinkHealth + getSystemLogSinkHealth, + getRecentOpsLogs, + subscribeOpsLogStream } export default opsAPI + +// ===== Real-time ops log stream ===================================== + +export interface OpsLogEntry { + time: string + method?: string + path?: string + status: number + latency_ms: number + model?: string + stream?: boolean + account_id?: number + group_id?: number + api_key_id?: number + user_id?: number + turns?: number + prompt_chars?: number + error_message?: string + error_detail?: string + upstream_status?: number +} + +export interface OpsLogFilter { + min_status?: number + model?: string + account_id?: number + group_id?: number + min_latency_ms?: number +} + +export interface OpsLogRecentResponse { + entries: OpsLogEntry[] + published_total: number + dropped_total: number + subscribers: number +} + +function buildLogQuery(filter: OpsLogFilter): string { + const params = new URLSearchParams() + if (filter.min_status && filter.min_status > 0) params.set('min_status', String(filter.min_status)) + if (filter.model) params.set('model', filter.model) + if (filter.account_id && filter.account_id > 0) params.set('account_id', String(filter.account_id)) + if (filter.group_id && filter.group_id > 0) params.set('group_id', String(filter.group_id)) + if (filter.min_latency_ms && filter.min_latency_ms > 0) params.set('min_latency_ms', String(filter.min_latency_ms)) + const qs = params.toString() + return qs ? `?${qs}` : '' +} + +export async function getRecentOpsLogs( + filter: OpsLogFilter = {}, + max?: number +): Promise { + const params: Record = {} + if (filter.min_status) params.min_status = filter.min_status + if (filter.model) params.model = filter.model + if (filter.account_id) params.account_id = filter.account_id + if (filter.group_id) params.group_id = filter.group_id + if (filter.min_latency_ms) params.min_latency_ms = filter.min_latency_ms + if (max) params.max = max + + const { data } = await apiClient.get('/admin/ops/logs/recent', { params }) + return data +} + +export interface SubscribeLogsOptions { + onEntry: (entry: OpsLogEntry) => void + onStatus?: (status: 'connecting' | 'live' | 'closed' | 'error') => void + onError?: (err: Error) => void +} + +export type LogStreamHandle = { + close: () => void +} + +/** + * Subscribe to /admin/ops/logs/stream via fetch + ReadableStream. + * + * EventSource doesn't allow custom headers, so we cannot use it for our + * Bearer-token authenticated SSE endpoint. fetch with `accept: text/event-stream` + * gets us the same wire protocol with Authorization support. + * + * The returned handle is idempotent — calling close() multiple times is safe. + */ +export function subscribeOpsLogStream( + filter: OpsLogFilter, + opts: SubscribeLogsOptions +): LogStreamHandle { + const ctrl = new AbortController() + let closed = false + + const close = () => { + if (closed) return + closed = true + ctrl.abort() + opts.onStatus?.('closed') + } + + const run = async () => { + opts.onStatus?.('connecting') + const baseURL = (apiClient.defaults.baseURL ?? '/api/v1').replace(/\/+$/, '') + const url = `${baseURL}/admin/ops/logs/stream${buildLogQuery(filter)}` + const token = localStorage.getItem('auth_token') ?? '' + + let resp: Response + try { + resp = await fetch(url, { + method: 'GET', + signal: ctrl.signal, + credentials: 'include', + headers: { + accept: 'text/event-stream', + ...(token ? { Authorization: `Bearer ${token}` } : {}) + } + }) + } catch (e: any) { + if (closed || ctrl.signal.aborted) return + opts.onStatus?.('error') + opts.onError?.(e instanceof Error ? e : new Error(String(e))) + return + } + + if (!resp.ok || !resp.body) { + opts.onStatus?.('error') + opts.onError?.(new Error(`SSE ${resp.status} ${resp.statusText}`)) + return + } + + opts.onStatus?.('live') + const reader = resp.body.getReader() + const decoder = new TextDecoder('utf-8') + let buffer = '' + + try { + while (!closed) { + const { done, value } = await reader.read() + if (done) break + // Normalize CRLF and bare CR to LF before scanning. The SSE wire + // format permits \r\n line endings; without this, a trailing \r + // would leak into the JSON payload and silently break JSON.parse. + buffer += decoder.decode(value, { stream: true }).replace(/\r\n?/g, '\n') + + // SSE events are separated by a blank line. + let sep: number + while ((sep = buffer.indexOf('\n\n')) !== -1) { + const rawEvent = buffer.slice(0, sep) + buffer = buffer.slice(sep + 2) + + let dataLine = '' + for (const line of rawEvent.split('\n')) { + if (line.startsWith(':')) continue // comment / heartbeat + if (line.startsWith('data:')) { + // Per SSE spec, multiple `data:` lines in one event are + // joined by '\n', not concatenated. Our JSON-encoded entries + // never contain unescaped LF, but we follow the spec to + // future-proof against a field that emits raw bytes. + const piece = line.slice(5).replace(/^ /, '') + dataLine += dataLine ? '\n' + piece : piece + } + } + if (!dataLine) continue + try { + const parsed = JSON.parse(dataLine) as OpsLogEntry + opts.onEntry(parsed) + } catch { + // skip malformed payload — server uses well-formed JSON, this is defensive + } + } + } + } catch (e: any) { + if (!closed) { + opts.onStatus?.('error') + opts.onError?.(e instanceof Error ? e : new Error(String(e))) + } + } finally { + if (!closed) opts.onStatus?.('closed') + } + } + + void run() + return { close } +} diff --git a/frontend/src/api/admin/windsurf.ts b/frontend/src/api/admin/windsurf.ts index a87be534..e1ef35d7 100644 --- a/frontend/src/api/admin/windsurf.ts +++ b/frontend/src/api/admin/windsurf.ts @@ -4,9 +4,11 @@ import type { WindsurfLoginResponse, WindsurfBatchLoginRequest, WindsurfBatchLoginResponse, + WindsurfTokenLoginRequest, WindsurfRefreshTokenResponse, WindsurfLSStatusResponse, - WindsurfRuntimeResponse + WindsurfRuntimeResponse, + WindsurfTierAccessSnapshot } from '@/types' export async function login(req: WindsurfLoginRequest): Promise { @@ -14,6 +16,14 @@ export async function login(req: WindsurfLoginRequest): Promise { + const { data } = await apiClient.post( + '/admin/windsurf/accounts/token-login', + req + ) + return data +} + export async function batchLogin(req: WindsurfBatchLoginRequest): Promise { const { data } = await apiClient.post( '/admin/windsurf/accounts/batch-login', @@ -62,14 +72,21 @@ export async function getRuntime(accountId: number): Promise { + const { data } = await apiClient.get('/admin/windsurf/tier-access') + return data +} + export const windsurfAPI = { login, + tokenLogin, batchLogin, refreshToken, batchRefreshTokens, getLSStatus, listModels, - getRuntime + getRuntime, + getTierAccess } export default windsurfAPI diff --git a/frontend/src/components/account/WindsurfLoginModal.vue b/frontend/src/components/account/WindsurfLoginModal.vue index 6a9bf9a1..2b0c4f7b 100644 --- a/frontend/src/components/account/WindsurfLoginModal.vue +++ b/frontend/src/components/account/WindsurfLoginModal.vue @@ -9,8 +9,20 @@ {{ t('admin.windsurf.loginDesc') }}

- +
+