// Package lspool provides a mock Extension Server that the LS binary connects // to at startup. The real IDE's extension.js runs a ConnectRPC HTTP/1.1 server // using connectNodeAdapter. We replicate that protocol here. // // Protocol details (from extension.js source): // - Transport: HTTP/1.1 on 127.0.0.1 (no TLS) // - Auth: x-codeium-csrf-token header on every request // - Unary request Content-Type: application/proto (binary protobuf, no envelope) // OR application/connect+proto (with 5-byte envelope) // - Unary response Content-Type: application/proto (raw binary protobuf, no envelope) // - Stream request Content-Type: application/connect+proto (with 5-byte envelope) // - Stream response Content-Type: application/connect+proto (envelope-framed messages) // // The LS sends requests with content-type "application/connect+proto" for BOTH // unary and streaming RPCs. ConnectRPC's content-type regex: // // /^application\/(connect\+)?(?:(json)(?:; ?charset=utf-?8)?|(proto))$/i // // If "connect+" prefix is present → stream mode; otherwise → unary mode. // However the LS Go client uses the Connect protocol client which always sends // "application/proto" for unary and "application/connect+proto" for streaming. // // We detect the RPC kind from the URL path and respond accordingly. package lspool import ( "encoding/base64" "encoding/binary" "fmt" "io" "log/slog" "net" "net/http" "strings" "sync" "time" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" ) // ============================================================ // Proto helpers — hand-encode minimal proto messages so we don't // need to import the full generated proto package. // ============================================================ // encodeProtoString writes a proto string field (wire type 2) to a byte slice. func encodeProtoString(fieldNum int, val string) []byte { tag := encodeVarint(uint64(fieldNum<<3 | 2)) length := encodeVarint(uint64(len(val))) out := make([]byte, 0, len(tag)+len(length)+len(val)) out = append(out, tag...) out = append(out, length...) out = append(out, []byte(val)...) return out } // encodeProtoBytes writes a proto bytes/message field (wire type 2). func encodeProtoBytes(fieldNum int, val []byte) []byte { tag := encodeVarint(uint64(fieldNum<<3 | 2)) length := encodeVarint(uint64(len(val))) out := make([]byte, 0, len(tag)+len(length)+len(val)) out = append(out, tag...) out = append(out, length...) out = append(out, val...) return out } // encodeProtoVarint writes a proto varint field (wire type 0). func encodeProtoVarint(fieldNum int, val uint64) []byte { tag := encodeVarint(uint64(fieldNum<<3 | 0)) v := encodeVarint(val) out := make([]byte, 0, len(tag)+len(v)) out = append(out, tag...) out = append(out, v...) return out } // encodeProtoBool writes a proto bool field. func encodeProtoBool(fieldNum int, val bool) []byte { v := uint64(0) if val { v = 1 } return encodeProtoVarint(fieldNum, v) } func encodeVarint(v uint64) []byte { buf := make([]byte, binary.MaxVarintLen64) n := binary.PutUvarint(buf, v) return buf[:n] } // decodeProtoString extracts a string field from raw proto bytes. func decodeProtoString(data []byte, targetField int) string { i := 0 for i < len(data) { if i >= len(data) { break } tag, n := binary.Uvarint(data[i:]) if n <= 0 { break } i += n fieldNum := int(tag >> 3) wireType := tag & 0x7 switch wireType { case 0: // varint _, n = binary.Uvarint(data[i:]) if n <= 0 { return "" } i += n case 2: // length-delimited length, n := binary.Uvarint(data[i:]) if n <= 0 { return "" } i += n if fieldNum == targetField { end := i + int(length) if end > len(data) { return "" } return string(data[i:end]) } i += int(length) case 1: // 64-bit i += 8 case 5: // 32-bit i += 4 default: return "" } } return "" } // ============================================================ // ConnectRPC envelope helpers // ============================================================ // connectEnvelope wraps a proto payload in a ConnectRPC streaming envelope: // 1 byte flags + 4 byte big-endian length + payload func connectEnvelope(flags byte, payload []byte) []byte { frame := make([]byte, 5+len(payload)) frame[0] = flags binary.BigEndian.PutUint32(frame[1:5], uint32(len(payload))) copy(frame[5:], payload) return frame } // connectEndOfStream returns the end-of-stream trailer frame for ConnectRPC. // flags=0x02 signals end of stream. The payload is a JSON object with empty metadata. func connectEndOfStream() []byte { trailer := []byte("{}") return connectEnvelope(0x02, trailer) } // unwrapConnectEnvelope strips the 5-byte envelope header from a ConnectRPC message. // Returns the raw proto payload. If the input is shorter than 5 bytes, returns as-is. func unwrapConnectEnvelope(body []byte) []byte { if len(body) < 5 { return body } // Check if it looks like an envelope: first byte should be 0x00 or 0x01 if body[0] > 0x02 { return body // Not envelope-framed, return raw } plen := binary.BigEndian.Uint32(body[1:5]) if int(plen)+5 > len(body) { return body // Length mismatch, return raw } return body[5 : 5+plen] } // ============================================================ // OAuthTokenInfo proto builder // ============================================================ // buildOAuthTokenInfoBinary creates binary-encoded OAuthTokenInfo proto. // // message OAuthTokenInfo { // string access_token = 1; // string token_type = 2; // string refresh_token = 3; // google.protobuf.Timestamp expiry = 4; // bool is_gcp_tos = 6; // } func buildOAuthTokenInfoBinary(accessToken, refreshToken string, expiresAt time.Time) []byte { var buf []byte buf = append(buf, encodeProtoString(1, accessToken)...) buf = append(buf, encodeProtoString(2, "Bearer")...) if refreshToken != "" { buf = append(buf, encodeProtoString(3, refreshToken)...) } // Use real expiry if provided, otherwise default to 1 hour from now expiry := expiresAt if expiry.IsZero() { expiry = time.Now().Add(1 * time.Hour) } ts := ×tamppb.Timestamp{ Seconds: expiry.Unix(), } tsBytes, _ := proto.Marshal(ts) buf = append(buf, encodeProtoBytes(4, tsBytes)...) buf = append(buf, encodeProtoBool(6, true)...) return buf } // buildUSSTopicWithOAuth creates a USS Topic proto with the OAuth token. // // message Topic { map data = 1; } // message Row { string value = 1; int64 e_tag = 2; } // // The key in the map is "oauthTokenInfoSentinelKey" and the Row.value is // base64(toBinary(OAuthTokenInfo)). func buildUSSTopicWithOAuth(accessToken, refreshToken string, expiresAt time.Time) []byte { tokenBin := buildOAuthTokenInfoBinary(accessToken, refreshToken, expiresAt) tokenB64 := base64.StdEncoding.EncodeToString(tokenBin) // Row: value=tokenB64 (field 1), e_tag=1 (field 2) var row []byte row = append(row, encodeProtoString(1, tokenB64)...) row = append(row, encodeProtoVarint(2, 1)...) // Map entry: key="oauthTokenInfoSentinelKey" (field 1), value=row (field 2) var entry []byte entry = append(entry, encodeProtoString(1, "oauthTokenInfoSentinelKey")...) entry = append(entry, encodeProtoBytes(2, row)...) // Topic: data map entries use field 1 var topic []byte topic = append(topic, encodeProtoBytes(1, entry)...) return topic } func buildPrimitiveBoolBinary(val bool) []byte { // Primitive.bool_value is field 13 in the proto definition return encodeProtoBool(13, val) } func buildPrimitiveInt32Binary(val int32) []byte { // Primitive.int32_value is field 3 in the proto definition return encodeProtoVarint(3, uint64(uint32(val))) } func encodeUSSBinaryValue(value []byte) string { return base64.StdEncoding.EncodeToString(value) } func encodeUSSPrimitiveBoolValue(val bool) string { return encodeUSSBinaryValue(buildPrimitiveBoolBinary(val)) } func encodeUSSPrimitiveInt32Value(val int32) string { return encodeUSSBinaryValue(buildPrimitiveInt32Binary(val)) } func buildUSSTopicRow(key string, value string) []byte { row := buildUSSRowBinary(value) var entry []byte entry = append(entry, encodeProtoString(1, key)...) entry = append(entry, encodeProtoBytes(2, row)...) return entry } func buildUSSRowBinary(value string) []byte { var row []byte row = append(row, encodeProtoString(1, value)...) row = append(row, encodeProtoVarint(2, 1)...) return row } func buildUSSTopicWithModelCredits(info *ModelCreditsInfo) []byte { if info == nil { info = &ModelCreditsInfo{} } minimum := defaultMinimumCreditAmountForUsage if info.MinimumCreditAmountForUsage != nil { minimum = *info.MinimumCreditAmountForUsage } entries := make([][]byte, 0, 3) entries = append(entries, buildUSSTopicRow( useAICreditsSentinelKey, encodeUSSPrimitiveBoolValue(info.UseAICredits), )) // JS protocol: useAICreditsSentinelKey carries the toggle state. // availableCreditsSentinelKey is only present when credits are enabled. if info.UseAICredits { credits := int32(9999) if info.AvailableCredits != nil { credits = *info.AvailableCredits } entries = append(entries, buildUSSTopicRow(availableCreditsSentinelKey, encodeUSSPrimitiveInt32Value(credits))) } entries = append(entries, buildUSSTopicRow(minimumCreditAmountForUsageKey, encodeUSSPrimitiveInt32Value(minimum))) var topic []byte for _, entry := range entries { topic = append(topic, encodeProtoBytes(1, entry)...) } return topic } // buildEmptyTopic returns an empty USS Topic proto (for non-oauth topics). func buildEmptyTopic() []byte { return []byte{} // Empty message = no map entries } // ============================================================ // UnifiedStateSyncUpdate builder // ============================================================ // buildInitialStateUpdate creates a UnifiedStateSyncUpdate with initial_state set. // // message UnifiedStateSyncUpdate { // oneof update_type { // Topic initial_state = 1; // AppliedUpdate applied_update = 2; // } // } func buildInitialStateUpdate(topicData []byte) []byte { return encodeProtoBytes(1, topicData) } func buildAppliedUpdate(key string, row []byte) []byte { var applied []byte applied = append(applied, encodeProtoString(1, key)...) if len(row) > 0 { applied = append(applied, encodeProtoBytes(2, row)...) } return encodeProtoBytes(2, applied) } // ============================================================ // MockExtensionServer // ============================================================ // MockExtensionServer provides a ConnectRPC-compatible HTTP server that the // Language Server binary connects to. It implements just enough of the // ExtensionServerService to keep the LS operational. type MockExtensionServer struct { listener net.Listener server *http.Server port int csrf string mu sync.RWMutex tokens map[string]*TokenInfo // account_id -> token info credits map[string]*ModelCreditsInfo // account_id -> model credits info subscribers map[string]map[int]*stateSubscriber nextSubID int lastAccountID string logger *slog.Logger // Trajectory callback — when LS pushes trajectory updates, we forward them onTrajectoryUpdate func(topic, key string, data []byte) } // TokenInfo holds OAuth token details for an account. type TokenInfo struct { AccessToken string RefreshToken string ExpiresAt time.Time // zero value means unknown; defaults to now+1h } // ModelCreditsInfo mirrors the JS uss-modelCredits topic state. type ModelCreditsInfo struct { UseAICredits bool AvailableCredits *int32 MinimumCreditAmountForUsage *int32 } type stateSubscriber struct { id int accountID string topic string updates chan []byte } const ( useAICreditsSentinelKey = "useAICreditsSentinelKey" availableCreditsSentinelKey = "availableCreditsSentinelKey" minimumCreditAmountForUsageKey = "minimumCreditAmountForUsageKey" defaultMinimumCreditAmountForUsage = int32(50) ) // NewMockExtensionServer creates a mock extension server with proper ConnectRPC handling. func NewMockExtensionServer(csrf string) (*MockExtensionServer, error) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return nil, fmt.Errorf("listen: %w", err) } m := &MockExtensionServer{ listener: listener, port: listener.Addr().(*net.TCPAddr).Port, csrf: csrf, tokens: make(map[string]*TokenInfo), credits: make(map[string]*ModelCreditsInfo), subscribers: make(map[string]map[int]*stateSubscriber), logger: slog.Default().With("component", "mock-ext-server"), } mux := http.NewServeMux() extService := "/exa.extension_server_pb.ExtensionServerService/" // Register all RPCs the LS calls on the Extension Server. // Unary RPCs — return application/proto mux.HandleFunc(extService+"LanguageServerStarted", m.handleUnary(m.onLanguageServerStarted)) mux.HandleFunc(extService+"Heartbeat", m.handleUnary(m.onHeartbeat)) mux.HandleFunc(extService+"GetSecretValue", m.handleUnary(m.onGetSecretValue)) mux.HandleFunc(extService+"StoreSecretValue", m.handleUnary(m.onStoreSecretValue)) mux.HandleFunc(extService+"IsAgentManagerEnabled", m.handleUnary(m.onIsAgentManagerEnabled)) mux.HandleFunc(extService+"PushUnifiedStateSyncUpdate", m.handleUnary(m.onPushUnifiedStateSyncUpdate)) mux.HandleFunc(extService+"RecordError", m.handleUnary(m.onRecordError)) mux.HandleFunc(extService+"LogEvent", m.handleUnary(m.onLogEvent)) mux.HandleFunc(extService+"UpdateCascadeTrajectorySummaries", m.handleUnary(m.onUpdateTrajectorySummaries)) mux.HandleFunc(extService+"BroadcastConversationDeletion", m.handleUnary(m.onDefault)) mux.HandleFunc(extService+"WriteCascadeEdit", m.handleUnary(m.onDefault)) mux.HandleFunc(extService+"OpenDiffZones", m.handleUnary(m.onDefault)) mux.HandleFunc(extService+"HandleAsyncPostMessage", m.handleUnary(m.onDefault)) mux.HandleFunc(extService+"OpenFilePointer", m.handleUnary(m.onDefault)) mux.HandleFunc(extService+"OpenVirtualFile", m.handleUnary(m.onDefault)) mux.HandleFunc(extService+"SaveDocument", m.handleUnary(m.onDefault)) mux.HandleFunc(extService+"RestartUserStatusUpdater", m.handleUnary(m.onDefault)) mux.HandleFunc(extService+"FocusIDEWindow", m.handleUnary(m.onDefault)) mux.HandleFunc(extService+"SmartFocusConversation", m.handleUnary(m.onDefault)) mux.HandleFunc(extService+"RunExtensionCode", m.handleUnary(m.onDefault)) mux.HandleFunc(extService+"UpdateDetailedViewWithCascadeInput", m.handleUnary(m.onDefault)) mux.HandleFunc(extService+"FindAllReferences", m.handleUnary(m.onDefault)) mux.HandleFunc(extService+"GetDefinition", m.handleUnary(m.onDefault)) mux.HandleFunc(extService+"GetLintErrors", m.handleUnary(m.onDefault)) // Server-streaming RPCs — return application/connect+proto mux.HandleFunc(extService+"SubscribeToUnifiedStateSyncTopic", m.handleStream(m.onSubscribeStateSyncTopic)) mux.HandleFunc(extService+"ExecuteCommand", m.handleStream(m.onExecuteCommand)) // Catch-all for any unregistered RPCs mux.HandleFunc("/", m.handleCatchAll) m.server = &http.Server{Handler: mux} go func() { if err := m.server.Serve(listener); err != http.ErrServerClosed { m.logger.Error("extension server error", "err", err) } }() m.logger.Info("mock extension server started", "port", m.port, "csrf_len", len(csrf)) return m, nil } // Port returns the listening port. func (m *MockExtensionServer) Port() int { return m.port } // SetToken sets the OAuth token for an account. func (m *MockExtensionServer) SetToken(accountID string, info *TokenInfo) { m.mu.Lock() m.tokens[accountID] = info m.lastAccountID = accountID subscribers := m.snapshotSubscribersLocked("uss-oauth", accountID) m.mu.Unlock() if info == nil { return } tokenBin := buildOAuthTokenInfoBinary(info.AccessToken, info.RefreshToken, info.ExpiresAt) tokenB64 := base64.StdEncoding.EncodeToString(tokenBin) m.publishTopicUpdate(subscribers, buildAppliedUpdate("oauthTokenInfoSentinelKey", buildUSSRowBinary(tokenB64))) } // SetModelCredits sets the uss-modelCredits state for an account. func (m *MockExtensionServer) SetModelCredits(accountID string, info *ModelCreditsInfo) { if info == nil { info = &ModelCreditsInfo{} } copyInfo := *info m.mu.Lock() m.credits[accountID] = ©Info m.lastAccountID = accountID subscribers := m.snapshotSubscribersLocked("uss-modelCredits", accountID) m.mu.Unlock() m.publishTopicUpdate(subscribers, buildModelCreditsAppliedUpdates(©Info)...) } // SetTrajectoryCallback registers a callback for when the LS pushes trajectory data. func (m *MockExtensionServer) SetTrajectoryCallback(fn func(topic, key string, data []byte)) { m.onTrajectoryUpdate = fn } func (m *MockExtensionServer) currentTokenLocked() *TokenInfo { if m.lastAccountID != "" { if info := m.tokens[m.lastAccountID]; info != nil { return info } } for _, info := range m.tokens { return info } return nil } func (m *MockExtensionServer) currentModelCreditsLocked() *ModelCreditsInfo { if m.lastAccountID != "" { if info := m.credits[m.lastAccountID]; info != nil { return info } } for _, info := range m.credits { return info } return nil } func (m *MockExtensionServer) tokenForAccountLocked(accountID string) *TokenInfo { if accountID != "" { if info := m.tokens[accountID]; info != nil { return info } } return m.currentTokenLocked() } func (m *MockExtensionServer) creditsForAccountLocked(accountID string) *ModelCreditsInfo { if accountID != "" { if info := m.credits[accountID]; info != nil { return info } } return m.currentModelCreditsLocked() } func (m *MockExtensionServer) snapshotSubscribersLocked(topic, accountID string) []*stateSubscriber { topicSubs := m.subscribers[topic] if len(topicSubs) == 0 { return nil } out := make([]*stateSubscriber, 0, len(topicSubs)) for _, sub := range topicSubs { if sub == nil { continue } if accountID != "" && sub.accountID != "" && sub.accountID != accountID { continue } out = append(out, sub) } return out } func (m *MockExtensionServer) publishTopicUpdate(subscribers []*stateSubscriber, updates ...[]byte) { for _, sub := range subscribers { if sub == nil { continue } for _, update := range updates { if len(update) == 0 { continue } payload := append([]byte(nil), update...) select { case sub.updates <- payload: default: m.logger.Warn("dropping USS update", "topic", sub.topic, "account", sub.accountID) } } } } func buildModelCreditsAppliedUpdates(info *ModelCreditsInfo) [][]byte { if info == nil { info = &ModelCreditsInfo{} } minimum := defaultMinimumCreditAmountForUsage if info.MinimumCreditAmountForUsage != nil { minimum = *info.MinimumCreditAmountForUsage } updates := make([][]byte, 0, 3) updates = append(updates, buildAppliedUpdate( useAICreditsSentinelKey, buildUSSRowBinary(encodeUSSPrimitiveBoolValue(info.UseAICredits)), )) if info.UseAICredits { credits := int32(9999) if info.AvailableCredits != nil { credits = *info.AvailableCredits } updates = append(updates, buildAppliedUpdate( availableCreditsSentinelKey, buildUSSRowBinary(encodeUSSPrimitiveInt32Value(credits)), )) } else { updates = append(updates, buildAppliedUpdate(availableCreditsSentinelKey, nil)) } updates = append(updates, buildAppliedUpdate( minimumCreditAmountForUsageKey, buildUSSRowBinary(encodeUSSPrimitiveInt32Value(minimum)), )) return updates } // Close shuts down the server. func (m *MockExtensionServer) Close() error { return m.server.Close() } // ============================================================ // Middleware // ============================================================ type unaryHandler func(body []byte) []byte type streamHandler func(body []byte, w http.ResponseWriter, r *http.Request) // handleUnary wraps a unary RPC handler with CSRF check and proper content-type. func (m *MockExtensionServer) handleUnary(handler unaryHandler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // CSRF check if !m.checkCSRF(w, r) { return } body, err := io.ReadAll(r.Body) if err != nil { m.logger.Error("read body", "err", err, "path", r.URL.Path) w.Header().Set("Content-Type", "application/proto") w.WriteHeader(200) return } // The LS might send with envelope framing (application/connect+proto) // or without (application/proto). Detect and unwrap. ct := r.Header.Get("Content-Type") protoBody := body if strings.Contains(ct, "connect+proto") && len(body) >= 5 { protoBody = unwrapConnectEnvelope(body) } m.logger.Debug("unary RPC", "path", r.URL.Path, "body_len", len(protoBody), "content_type", ct) responseProto := handler(protoBody) // Respond with proper unary ConnectRPC content-type. // If the request used "connect+proto", the response should be "application/proto" // for unary RPCs (ConnectRPC spec: unary uses application/proto, not connect+proto). w.Header().Set("Content-Type", "application/proto") w.WriteHeader(200) if len(responseProto) > 0 { w.Write(responseProto) } } } // handleStream wraps a server-streaming RPC handler with CSRF and content-type. func (m *MockExtensionServer) handleStream(handler streamHandler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if !m.checkCSRF(w, r) { return } body, err := io.ReadAll(r.Body) if err != nil { m.logger.Error("read body", "err", err, "path", r.URL.Path) return } // Unwrap envelope from request ct := r.Header.Get("Content-Type") if strings.Contains(ct, "connect+proto") || strings.Contains(ct, "connect+json") { body = unwrapConnectEnvelope(body) } m.logger.Debug("stream RPC", "path", r.URL.Path, "body_len", len(body)) // Set streaming response content-type w.Header().Set("Content-Type", "application/connect+proto") w.WriteHeader(200) handler(body, w, r) } } func (m *MockExtensionServer) checkCSRF(w http.ResponseWriter, r *http.Request) bool { token := r.Header.Get("x-codeium-csrf-token") if m.csrf != "" && token != m.csrf { m.logger.Warn("CSRF mismatch", "path", r.URL.Path, "got", token[:min(8, len(token))]) w.Header().Set("Content-Type", "text/plain") w.WriteHeader(403) w.Write([]byte("Invalid CSRF token")) return false } return true } func min(a, b int) int { if a < b { return a } return b } // ============================================================ // Unary RPC Handlers — each receives raw proto request body, // returns raw proto response body. // ============================================================ func (m *MockExtensionServer) onLanguageServerStarted(body []byte) []byte { // LanguageServerStartedRequest has: https_port(1), http_port(2), lsp_port(3), csrf_token(4) // We just log the ports — they're informational. m.logger.Info("LanguageServerStarted", "body_len", len(body)) // Return empty LanguageServerStartedResponse return nil } func (m *MockExtensionServer) onHeartbeat(body []byte) []byte { // Return empty HeartbeatResponse return nil } func (m *MockExtensionServer) onGetSecretValue(body []byte) []byte { // GetSecretValueRequest: key = field 1 key := decodeProtoString(body, 1) m.logger.Debug("GetSecretValue", "key", key) m.mu.RLock() var token string if info := m.currentTokenLocked(); info != nil { token = info.AccessToken } m.mu.RUnlock() // GetSecretValueResponse: value = field 1 if token != "" { return encodeProtoString(1, token) } return nil } func (m *MockExtensionServer) onStoreSecretValue(body []byte) []byte { key := decodeProtoString(body, 1) m.logger.Debug("StoreSecretValue", "key", key) return nil } func (m *MockExtensionServer) onIsAgentManagerEnabled(body []byte) []byte { // IsAgentManagerEnabledResponse: enabled = field 1 (bool) return encodeProtoBool(1, false) } func (m *MockExtensionServer) onPushUnifiedStateSyncUpdate(body []byte) []byte { // PushUnifiedStateSyncUpdateRequest: update = field 1 (UpdateRequest message) // UpdateRequest: topic_name = field 1, applied_update = field 5, key = field 2 m.logger.Debug("PushUnifiedStateSyncUpdate", "body_len", len(body)) // Extract topic name from the embedded UpdateRequest // The body is PushUnifiedStateSyncUpdateRequest, field 1 is the UpdateRequest // We need to dig into the nested message to get topic_name if m.onTrajectoryUpdate != nil { // For now, just notify that an update was pushed m.onTrajectoryUpdate("", "", body) } // Return empty PushUnifiedStateSyncUpdateResponse return nil } func (m *MockExtensionServer) onRecordError(body []byte) []byte { m.logger.Debug("RecordError", "body_len", len(body)) return nil } func (m *MockExtensionServer) onLogEvent(body []byte) []byte { return nil } func (m *MockExtensionServer) onUpdateTrajectorySummaries(body []byte) []byte { m.logger.Debug("UpdateCascadeTrajectorySummaries", "body_len", len(body)) return nil } func (m *MockExtensionServer) onDefault(body []byte) []byte { return nil } // ============================================================ // Streaming RPC Handlers // ============================================================ func (m *MockExtensionServer) onSubscribeStateSyncTopic(body []byte, w http.ResponseWriter, r *http.Request) { // SubscribeToUnifiedStateSyncTopicRequest: topic = field 1 topic := decodeProtoString(body, 1) m.logger.Info("SubscribeToUnifiedStateSyncTopic", "topic", topic) flusher, ok := w.(http.Flusher) if !ok { m.logger.Error("ResponseWriter does not support Flush") return } m.mu.Lock() accountID := m.lastAccountID subID := m.nextSubID m.nextSubID++ sub := &stateSubscriber{ id: subID, accountID: accountID, topic: topic, updates: make(chan []byte, 16), } if m.subscribers[topic] == nil { m.subscribers[topic] = make(map[int]*stateSubscriber) } m.subscribers[topic][subID] = sub // Build initial state based on topic var topicData []byte switch topic { case "uss-oauth": tokenInfo := m.tokenForAccountLocked(accountID) if tokenInfo != nil { topicData = buildUSSTopicWithOAuth(tokenInfo.AccessToken, tokenInfo.RefreshToken, tokenInfo.ExpiresAt) } else { topicData = buildEmptyTopic() } case "uss-modelCredits": creditsInfo := m.creditsForAccountLocked(accountID) if creditsInfo != nil { topicData = buildUSSTopicWithModelCredits(creditsInfo) } else { topicData = buildEmptyTopic() } default: // For all other topics (browserPreferences, enterprisePreferences, etc.), // return empty topic data. topicData = buildEmptyTopic() } m.mu.Unlock() defer func() { m.mu.Lock() if topicSubs := m.subscribers[topic]; topicSubs != nil { delete(topicSubs, subID) if len(topicSubs) == 0 { delete(m.subscribers, topic) } } m.mu.Unlock() }() // Send initial state as envelope-framed message initialUpdate := buildInitialStateUpdate(topicData) frame := connectEnvelope(0x00, initialUpdate) w.Write(frame) flusher.Flush() for { select { case <-r.Context().Done(): m.logger.Debug("SubscribeToUnifiedStateSyncTopic stream closed", "topic", topic) return case update := <-sub.updates: if len(update) == 0 { continue } if _, err := w.Write(connectEnvelope(0x00, update)); err != nil { m.logger.Debug("SubscribeToUnifiedStateSyncTopic write failed", "topic", topic, "err", err) return } flusher.Flush() } } } func (m *MockExtensionServer) onExecuteCommand(body []byte, w http.ResponseWriter, r *http.Request) { m.logger.Debug("ExecuteCommand (mock)", "body_len", len(body)) // Send end-of-stream immediately — we don't execute commands flusher, ok := w.(http.Flusher) if !ok { return } w.Write(connectEndOfStream()) flusher.Flush() } // ============================================================ // Catch-all handler // ============================================================ func (m *MockExtensionServer) handleCatchAll(w http.ResponseWriter, r *http.Request) { if !m.checkCSRF(w, r) { return } m.logger.Debug("unhandled RPC (returning empty proto)", "path", r.URL.Path, "method", r.Method) // Drain request body io.ReadAll(r.Body) // Determine if this is likely a unary or streaming request based on content-type. ct := r.Header.Get("Content-Type") if strings.Contains(ct, "connect+") { // Could be streaming — respond with unary proto to be safe // (unary Connect requests can also use connect+ prefix in some client impls) w.Header().Set("Content-Type", "application/proto") } else { w.Header().Set("Content-Type", "application/proto") } w.WriteHeader(200) }