feat(viewer): add real-time request stream WebSocket endpoint

Adds GET /api/v1/admin/ops/ws/requests — a fan-out WebSocket that pushes
per-request metadata (method, path, model, account_id, status, latency_ms)
to all connected admin clients the moment each gateway dispatch completes.

- service/request_event_bus.go: lock-free pub/sub with non-blocking drop
  when per-subscriber buffer (64 slots) is full; nil-safe Publish
- service/request_event_bus_test.go: 6 tests (basic, fanout, drop, nil, close)
- GatewayHandler: records reqStartTime at entry; defer emits RequestEvent on
  every return; sets status success/error/rate_limited in both Gemini and
  Anthropic dispatch paths
- OpsHandler: accepts *RequestEventBus; wires it to RequestStreamWSHandler
- ops_ws_requests_handler.go: subscribes to bus, pushes JSON per event,
  reuses existing upgrader/conn-limit/ping-pong infrastructure
- Route: ws.GET("/requests", ...) alongside existing /ws/qps
- wire_gen.go: requestEventBus shared between OpsHandler and GatewayHandler
This commit is contained in:
win 2026-04-29 01:48:15 +08:00
parent d535688bfd
commit d1e2d39c26
10 changed files with 435 additions and 23 deletions

View File

@ -193,7 +193,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
soraQuotaService := service.NewSoraQuotaService(userRepository, groupRepository, settingService)
soraGenerationService := service.NewSoraGenerationService(soraGenerationRepository, soraS3Storage, soraQuotaService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, soraS3Storage)
opsHandler := admin.NewOpsHandler(opsService)
requestEventBus := service.NewRequestEventBus()
opsHandler := admin.NewOpsHandler(opsService, requestEventBus)
updateCache := repository.NewUpdateCache(redisClient)
gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig)
serviceBuildInfo := provideServiceBuildInfo(buildInfo)
@ -223,7 +224,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService, requestEventBus)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)

View File

@ -16,7 +16,8 @@ import (
)
type OpsHandler struct {
opsService *service.OpsService
opsService *service.OpsService
requestEventBus *service.RequestEventBus
}
// GetErrorLogByID returns ops error log detail.
@ -70,8 +71,8 @@ func parseOpsViewParam(c *gin.Context) string {
}
}
func NewOpsHandler(opsService *service.OpsService) *OpsHandler {
return &OpsHandler{opsService: opsService}
func NewOpsHandler(opsService *service.OpsService, requestEventBus *service.RequestEventBus) *OpsHandler {
return &OpsHandler{opsService: opsService, requestEventBus: requestEventBus}
}
// GetErrorLogs lists ops error logs.

View File

@ -116,7 +116,7 @@ func newRuntimeOpsService(t *testing.T) *service.OpsService {
}
func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
h := NewOpsHandler(newRuntimeOpsService(t), nil)
r := newOpsRuntimeRouter(h, false)
w := httptest.NewRecorder()
@ -128,7 +128,7 @@ func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) {
}
func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
h := NewOpsHandler(newRuntimeOpsService(t), nil)
r := newOpsRuntimeRouter(h, false)
body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}`
@ -142,7 +142,7 @@ func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) {
}
func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) {
h := NewOpsHandler(newRuntimeOpsService(t))
h := NewOpsHandler(newRuntimeOpsService(t), nil)
r := newOpsRuntimeRouter(h, true)
payload := map[string]any{

View File

@ -35,7 +35,7 @@ func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine {
}
func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
h := NewOpsHandler(nil)
h := NewOpsHandler(nil, nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
@ -48,7 +48,7 @@ func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) {
func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
@ -61,7 +61,7 @@ func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) {
func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
@ -76,7 +76,7 @@ func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
@ -89,7 +89,7 @@ func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) {
func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
@ -110,7 +110,7 @@ func TestOpsSystemLogHandler_ListSuccess(t *testing.T) {
func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
@ -124,7 +124,7 @@ func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) {
func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
@ -138,7 +138,7 @@ func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) {
func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
@ -152,7 +152,7 @@ func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) {
func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
@ -166,7 +166,7 @@ func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) {
func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) {
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
@ -182,7 +182,7 @@ func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h := NewOpsHandler(svc)
h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, true)
w := httptest.NewRecorder()
@ -197,7 +197,7 @@ func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) {
func TestOpsSystemLogHandler_Health(t *testing.T) {
sink := service.NewOpsSystemLogSink(nil)
svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink)
h := NewOpsHandler(svc)
h := NewOpsHandler(svc, nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
@ -209,7 +209,7 @@ func TestOpsSystemLogHandler_Health(t *testing.T) {
}
func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) {
h := NewOpsHandler(nil)
h := NewOpsHandler(nil, nil)
r := newOpsSystemLogTestRouter(h, false)
w := httptest.NewRecorder()
@ -222,7 +222,7 @@ func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T
svc := service.NewOpsService(nil, nil, &config.Config{
Ops: config.OpsConfig{Enabled: false},
}, nil, nil, nil, nil, nil, nil, nil, nil)
h = NewOpsHandler(svc)
h = NewOpsHandler(svc, nil)
r = newOpsSystemLogTestRouter(h, false)
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/logs/health", nil)

View File

@ -0,0 +1,198 @@
package admin
import (
"context"
"encoding/json"
"net/http"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
type requestStreamWSMessage struct {
Type string `json:"type"`
Data service.RequestEvent `json:"data"`
}
// RequestStreamWSHandler streams real-time request events to WebSocket clients.
// GET /api/v1/admin/ops/ws/requests
//
// Each connected client receives a JSON message per gateway dispatch:
//
// {"type":"request_event","data":{"timestamp":...,"method":"POST","path":"/v1/messages",
// "model":"claude-3-5-sonnet-20241022","account_id":42,"status":"success","latency_ms":1230}}
func (h *OpsHandler) RequestStreamWSHandler(c *gin.Context) {
clientIP := requestClientIP(c.Request)
if h == nil || h.opsService == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "ops service not initialized"})
return
}
if h.requestEventBus == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "request event bus not initialized"})
return
}
if !h.opsService.IsRealtimeMonitoringEnabled(c.Request.Context()) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "ops realtime monitoring is disabled"})
return
}
closeWS(conn, opsWSCloseRealtimeDisabled, "realtime_disabled")
return
}
if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) {
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return
}
defer func() {
if wsConnCount.Add(-1) == 0 {
scheduleQPSWSIdleStop()
}
}()
if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" {
if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) {
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] per-ip limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP)
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"})
return
}
defer releaseOpsWSIPSlot(clientIP)
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] upgrade failed: %v", err)
return
}
defer func() { _ = conn.Close() }()
handleRequestStreamWebSocket(c.Request.Context(), conn, h.requestEventBus)
}
func handleRequestStreamWebSocket(parentCtx context.Context, conn *websocket.Conn, bus *service.RequestEventBus) {
if conn == nil || bus == nil {
return
}
ctx, cancel := context.WithCancel(parentCtx)
defer cancel()
subID, eventCh := bus.Subscribe()
defer bus.Unsubscribe(subID)
var closeOnce sync.Once
closeConn := func() {
closeOnce.Do(func() { _ = conn.Close() })
}
closeFrameCh := make(chan []byte, 1)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
conn.SetReadLimit(qpsWSMaxReadBytes)
if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil {
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] set read deadline failed: %v", err)
return
}
conn.SetPongHandler(func(string) error {
return conn.SetReadDeadline(time.Now().Add(qpsWSPongWait))
})
conn.SetCloseHandler(func(code int, text string) error {
select {
case closeFrameCh <- websocket.FormatCloseMessage(code, text):
default:
}
cancel()
return nil
})
for {
_, _, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] read failed: %v", err)
}
return
}
}
}()
pingTicker := time.NewTicker(qpsWSPingInterval)
defer pingTicker.Stop()
writeWithTimeout := func(messageType int, data []byte) error {
if err := conn.SetWriteDeadline(time.Now().Add(qpsWSWriteTimeout)); err != nil {
return err
}
return conn.WriteMessage(messageType, data)
}
sendClose := func(closeFrame []byte) {
if closeFrame == nil {
closeFrame = websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
}
_ = writeWithTimeout(websocket.CloseMessage, closeFrame)
}
for {
select {
case evt, ok := <-eventCh:
if !ok {
// channel closed by Unsubscribe
sendClose(nil)
closeConn()
wg.Wait()
return
}
msg, err := json.Marshal(requestStreamWSMessage{Type: "request_event", Data: evt})
if err != nil {
continue
}
if err := writeWithTimeout(websocket.TextMessage, msg); err != nil {
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] write failed: %v", err)
cancel()
closeConn()
wg.Wait()
return
}
case <-pingTicker.C:
if err := writeWithTimeout(websocket.PingMessage, nil); err != nil {
logger.LegacyPrintf("handler.admin.ops_ws_requests", "[OpsWSReq] ping failed: %v", err)
cancel()
closeConn()
wg.Wait()
return
}
case closeFrame := <-closeFrameCh:
sendClose(closeFrame)
closeConn()
wg.Wait()
return
case <-ctx.Done():
var closeFrame []byte
select {
case closeFrame = <-closeFrameCh:
default:
}
sendClose(closeFrame)
closeConn()
wg.Wait()
return
}
}
}

View File

@ -47,6 +47,7 @@ type GatewayHandler struct {
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
userMsgQueueHelper *UserMsgQueueHelper
requestEventBus *service.RequestEventBus
maxAccountSwitches int
maxAccountSwitchesGemini int
cfg *config.Config
@ -68,6 +69,7 @@ func NewGatewayHandler(
userMsgQueueService *service.UserMessageQueueService,
cfg *config.Config,
settingService *service.SettingService,
requestEventBus *service.RequestEventBus,
) *GatewayHandler {
pingInterval := time.Duration(0)
maxAccountSwitches := 10
@ -100,6 +102,7 @@ func NewGatewayHandler(
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval),
userMsgQueueHelper: umqHelper,
requestEventBus: requestEventBus,
maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
cfg: cfg,
@ -110,6 +113,7 @@ func NewGatewayHandler(
// Messages handles Claude API compatible messages endpoint
// POST /v1/messages
func (h *GatewayHandler) Messages(c *gin.Context) {
reqStartTime := time.Now()
// 从context获取apiKey和userApiKeyAuth中间件已设置
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
@ -158,6 +162,25 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqStream := parsedReq.Stream
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
// 实时请求查看器:记录每次请求的结果(账号、模型、状态、延迟)
var (
reqEventAccountID int64
reqEventStatus = "error"
)
defer func() {
if h.requestEventBus != nil {
h.requestEventBus.Publish(service.RequestEvent{
Timestamp: reqStartTime,
Method: c.Request.Method,
Path: c.FullPath(),
Model: reqModel,
AccountID: reqEventAccountID,
Status: reqEventStatus,
LatencyMS: time.Since(reqStartTime).Milliseconds(),
})
}
}()
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
@ -393,6 +416,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
reqEventAccountID = account.ID
reqEventStatus = "rate_limited"
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted)
return
}
@ -458,6 +483,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 实时请求查看器:标记 Gemini 路径成功
reqEventAccountID = account.ID
reqEventStatus = "success"
// RPM 计数递增Forward 成功后)
// 注意TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
// 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
@ -630,6 +659,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if accountReleaseFunc != nil {
accountReleaseFunc()
}
reqEventAccountID = account.ID
reqEventStatus = "rate_limited"
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "RPM rate limit exceeded, please retry later", streamStarted)
return
}
@ -805,6 +836,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 实时请求查看器:标记 Anthropic 路径成功
reqEventAccountID = account.ID
reqEventStatus = "success"
// RPM 计数递增Forward 成功后)
// 注意TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
// 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。

View File

@ -141,10 +141,11 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
settings.PUT("/metric-thresholds", h.Admin.Ops.UpdateMetricThresholds)
}
// WebSocket realtime (QPS/TPS)
// WebSocket realtime (QPS/TPS and request stream)
ws := ops.Group("/ws")
{
ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
ws.GET("/requests", h.Admin.Ops.RequestStreamWSHandler)
}
// Error logs (legacy)

View File

@ -0,0 +1,75 @@
package service
import (
"sync"
"sync/atomic"
"time"
)
const requestEventBufSize = 64
// RequestEvent is published for every gateway dispatch completion.
type RequestEvent struct {
Timestamp time.Time `json:"timestamp"`
Method string `json:"method"`
Path string `json:"path"`
Model string `json:"model"`
AccountID int64 `json:"account_id"`
// Status is "success", "error", or "rate_limited".
Status string `json:"status"`
LatencyMS int64 `json:"latency_ms"`
}
// RequestEventBus is a fan-out hub for real-time request events.
// Publishers call Publish; subscribers call Subscribe/Unsubscribe.
// Each subscriber gets its own buffered channel. If the buffer is full
// the event is dropped for that subscriber (non-blocking publish).
type RequestEventBus struct {
mu sync.RWMutex
subscribers map[uint64]chan RequestEvent
nextID atomic.Uint64
}
func NewRequestEventBus() *RequestEventBus {
return &RequestEventBus{
subscribers: make(map[uint64]chan RequestEvent),
}
}
// Subscribe registers a new subscriber and returns its ID and a receive-only channel.
func (b *RequestEventBus) Subscribe() (uint64, <-chan RequestEvent) {
id := b.nextID.Add(1)
ch := make(chan RequestEvent, requestEventBufSize)
b.mu.Lock()
b.subscribers[id] = ch
b.mu.Unlock()
return id, ch
}
// Unsubscribe removes a subscriber and closes its channel.
func (b *RequestEventBus) Unsubscribe(id uint64) {
b.mu.Lock()
ch, ok := b.subscribers[id]
if ok {
delete(b.subscribers, id)
}
b.mu.Unlock()
if ok {
close(ch)
}
}
// Publish sends an event to all current subscribers without blocking.
func (b *RequestEventBus) Publish(e RequestEvent) {
if b == nil {
return
}
b.mu.RLock()
defer b.mu.RUnlock()
for _, ch := range b.subscribers {
select {
case ch <- e:
default:
}
}
}

View File

@ -0,0 +1,100 @@
package service
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRequestEventBus_PublishToSubscriber(t *testing.T) {
bus := NewRequestEventBus()
id, ch := bus.Subscribe()
defer bus.Unsubscribe(id)
evt := RequestEvent{Model: "claude-3", Status: "success", LatencyMS: 100}
bus.Publish(evt)
select {
case got := <-ch:
assert.Equal(t, evt, got)
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
}
}
func TestRequestEventBus_MultipleSubscribers(t *testing.T) {
bus := NewRequestEventBus()
id1, ch1 := bus.Subscribe()
id2, ch2 := bus.Subscribe()
defer bus.Unsubscribe(id1)
defer bus.Unsubscribe(id2)
evt := RequestEvent{Model: "claude-3", Status: "error"}
bus.Publish(evt)
for _, ch := range []<-chan RequestEvent{ch1, ch2} {
select {
case got := <-ch:
assert.Equal(t, evt, got)
case <-time.After(time.Second):
t.Fatal("timed out waiting for event on one subscriber")
}
}
}
func TestRequestEventBus_UnsubscribeClosesChannel(t *testing.T) {
bus := NewRequestEventBus()
id, ch := bus.Subscribe()
bus.Unsubscribe(id)
// Channel should be closed.
_, ok := <-ch
assert.False(t, ok, "channel should be closed after Unsubscribe")
}
func TestRequestEventBus_UnsubscribedMissesEvents(t *testing.T) {
bus := NewRequestEventBus()
id, _ := bus.Subscribe()
bus.Unsubscribe(id)
// Publish after unsubscribe should not panic.
require.NotPanics(t, func() {
bus.Publish(RequestEvent{Model: "test"})
})
}
func TestRequestEventBus_DropWhenFull(t *testing.T) {
bus := NewRequestEventBus()
id, ch := bus.Subscribe()
defer bus.Unsubscribe(id)
// Fill the buffer then publish one more — should drop, not block.
evt := RequestEvent{Model: "model", Status: "success"}
for i := 0; i < requestEventBufSize; i++ {
bus.Publish(evt)
}
// This publish should return immediately (dropped).
done := make(chan struct{})
go func() {
bus.Publish(evt)
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("Publish blocked when buffer was full")
}
assert.Len(t, ch, requestEventBufSize)
}
func TestRequestEventBus_NilSafePublish(t *testing.T) {
var bus *RequestEventBus
require.NotPanics(t, func() {
bus.Publish(RequestEvent{Model: "test"})
})
}

View File

@ -425,6 +425,7 @@ var ProviderSet = wire.NewSet(
NewAnnouncementService,
NewAdminService,
NewRPMTokenBucketService,
NewRequestEventBus,
NewGatewayService,
ProvideSoraMediaStorage,
ProvideSoraMediaCleanupService,