sub2api/backend/internal/pkg/lspool/mock_extension_server.go
win 0cda0e0b96
Some checks failed
CI / test (push) Failing after 8s
CI / golangci-lint (push) Failing after 5s
Security Scan / backend-security (push) Failing after 7s
Security Scan / frontend-security (push) Failing after 6s
feat: add dockerized antigravity ls worker mode
2026-03-30 23:57:25 +08:00

909 lines
28 KiB
Go

// Package lspool provides a mock Extension Server that the LS binary connects
// to at startup. The real IDE's extension.js runs a ConnectRPC HTTP/1.1 server
// using connectNodeAdapter. We replicate that protocol here.
//
// Protocol details (from extension.js source):
// - Transport: HTTP/1.1 on 127.0.0.1 (no TLS)
// - Auth: x-codeium-csrf-token header on every request
// - Unary request Content-Type: application/proto (binary protobuf, no envelope)
// OR application/connect+proto (with 5-byte envelope)
// - Unary response Content-Type: application/proto (raw binary protobuf, no envelope)
// - Stream request Content-Type: application/connect+proto (with 5-byte envelope)
// - Stream response Content-Type: application/connect+proto (envelope-framed messages)
//
// The LS sends requests with content-type "application/connect+proto" for BOTH
// unary and streaming RPCs. ConnectRPC's content-type regex:
//
// /^application\/(connect\+)?(?:(json)(?:; ?charset=utf-?8)?|(proto))$/i
//
// If "connect+" prefix is present → stream mode; otherwise → unary mode.
// However the LS Go client uses the Connect protocol client which always sends
// "application/proto" for unary and "application/connect+proto" for streaming.
//
// We detect the RPC kind from the URL path and respond accordingly.
package lspool
import (
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"strings"
"sync"
"time"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
// ============================================================
// Proto helpers — hand-encode minimal proto messages so we don't
// need to import the full generated proto package.
// ============================================================
// encodeProtoString writes a proto string field (wire type 2) to a byte slice.
func encodeProtoString(fieldNum int, val string) []byte {
tag := encodeVarint(uint64(fieldNum<<3 | 2))
length := encodeVarint(uint64(len(val)))
out := make([]byte, 0, len(tag)+len(length)+len(val))
out = append(out, tag...)
out = append(out, length...)
out = append(out, []byte(val)...)
return out
}
// encodeProtoBytes writes a proto bytes/message field (wire type 2).
func encodeProtoBytes(fieldNum int, val []byte) []byte {
tag := encodeVarint(uint64(fieldNum<<3 | 2))
length := encodeVarint(uint64(len(val)))
out := make([]byte, 0, len(tag)+len(length)+len(val))
out = append(out, tag...)
out = append(out, length...)
out = append(out, val...)
return out
}
// encodeProtoVarint writes a proto varint field (wire type 0).
func encodeProtoVarint(fieldNum int, val uint64) []byte {
tag := encodeVarint(uint64(fieldNum<<3 | 0))
v := encodeVarint(val)
out := make([]byte, 0, len(tag)+len(v))
out = append(out, tag...)
out = append(out, v...)
return out
}
// encodeProtoBool writes a proto bool field.
func encodeProtoBool(fieldNum int, val bool) []byte {
v := uint64(0)
if val {
v = 1
}
return encodeProtoVarint(fieldNum, v)
}
func encodeVarint(v uint64) []byte {
buf := make([]byte, binary.MaxVarintLen64)
n := binary.PutUvarint(buf, v)
return buf[:n]
}
// decodeProtoString extracts a string field from raw proto bytes.
func decodeProtoString(data []byte, targetField int) string {
i := 0
for i < len(data) {
if i >= len(data) {
break
}
tag, n := binary.Uvarint(data[i:])
if n <= 0 {
break
}
i += n
fieldNum := int(tag >> 3)
wireType := tag & 0x7
switch wireType {
case 0: // varint
_, n = binary.Uvarint(data[i:])
if n <= 0 {
return ""
}
i += n
case 2: // length-delimited
length, n := binary.Uvarint(data[i:])
if n <= 0 {
return ""
}
i += n
if fieldNum == targetField {
end := i + int(length)
if end > len(data) {
return ""
}
return string(data[i:end])
}
i += int(length)
case 1: // 64-bit
i += 8
case 5: // 32-bit
i += 4
default:
return ""
}
}
return ""
}
// ============================================================
// ConnectRPC envelope helpers
// ============================================================
// connectEnvelope wraps a proto payload in a ConnectRPC streaming envelope:
// 1 byte flags + 4 byte big-endian length + payload
func connectEnvelope(flags byte, payload []byte) []byte {
frame := make([]byte, 5+len(payload))
frame[0] = flags
binary.BigEndian.PutUint32(frame[1:5], uint32(len(payload)))
copy(frame[5:], payload)
return frame
}
// connectEndOfStream returns the end-of-stream trailer frame for ConnectRPC.
// flags=0x02 signals end of stream. The payload is a JSON object with empty metadata.
func connectEndOfStream() []byte {
trailer := []byte("{}")
return connectEnvelope(0x02, trailer)
}
// unwrapConnectEnvelope strips the 5-byte envelope header from a ConnectRPC message.
// Returns the raw proto payload. If the input is shorter than 5 bytes, returns as-is.
func unwrapConnectEnvelope(body []byte) []byte {
if len(body) < 5 {
return body
}
// Check if it looks like an envelope: first byte should be 0x00 or 0x01
if body[0] > 0x02 {
return body // Not envelope-framed, return raw
}
plen := binary.BigEndian.Uint32(body[1:5])
if int(plen)+5 > len(body) {
return body // Length mismatch, return raw
}
return body[5 : 5+plen]
}
// ============================================================
// OAuthTokenInfo proto builder
// ============================================================
// buildOAuthTokenInfoBinary creates binary-encoded OAuthTokenInfo proto.
//
// message OAuthTokenInfo {
// string access_token = 1;
// string token_type = 2;
// string refresh_token = 3;
// google.protobuf.Timestamp expiry = 4;
// bool is_gcp_tos = 6;
// }
func buildOAuthTokenInfoBinary(accessToken, refreshToken string, expiresAt time.Time) []byte {
var buf []byte
buf = append(buf, encodeProtoString(1, accessToken)...)
buf = append(buf, encodeProtoString(2, "Bearer")...)
if refreshToken != "" {
buf = append(buf, encodeProtoString(3, refreshToken)...)
}
// Use real expiry if provided, otherwise default to 1 hour from now
expiry := expiresAt
if expiry.IsZero() {
expiry = time.Now().Add(1 * time.Hour)
}
ts := &timestamppb.Timestamp{
Seconds: expiry.Unix(),
}
tsBytes, _ := proto.Marshal(ts)
buf = append(buf, encodeProtoBytes(4, tsBytes)...)
buf = append(buf, encodeProtoBool(6, true)...)
return buf
}
// buildUSSTopicWithOAuth creates a USS Topic proto with the OAuth token.
//
// message Topic { map<string, Row> data = 1; }
// message Row { string value = 1; int64 e_tag = 2; }
//
// The key in the map is "oauthTokenInfoSentinelKey" and the Row.value is
// base64(toBinary(OAuthTokenInfo)).
func buildUSSTopicWithOAuth(accessToken, refreshToken string, expiresAt time.Time) []byte {
tokenBin := buildOAuthTokenInfoBinary(accessToken, refreshToken, expiresAt)
tokenB64 := base64.StdEncoding.EncodeToString(tokenBin)
// Row: value=tokenB64 (field 1), e_tag=1 (field 2)
var row []byte
row = append(row, encodeProtoString(1, tokenB64)...)
row = append(row, encodeProtoVarint(2, 1)...)
// Map entry: key="oauthTokenInfoSentinelKey" (field 1), value=row (field 2)
var entry []byte
entry = append(entry, encodeProtoString(1, "oauthTokenInfoSentinelKey")...)
entry = append(entry, encodeProtoBytes(2, row)...)
// Topic: data map entries use field 1
var topic []byte
topic = append(topic, encodeProtoBytes(1, entry)...)
return topic
}
func buildPrimitiveBoolBinary(val bool) []byte {
// Primitive.bool_value is field 13 in the proto definition
return encodeProtoBool(13, val)
}
func buildPrimitiveInt32Binary(val int32) []byte {
// Primitive.int32_value is field 3 in the proto definition
return encodeProtoVarint(3, uint64(uint32(val)))
}
func buildUSSTopicRow(key string, value string) []byte {
row := buildUSSRowBinary(value)
var entry []byte
entry = append(entry, encodeProtoString(1, key)...)
entry = append(entry, encodeProtoBytes(2, row)...)
return entry
}
func buildUSSRowBinary(value string) []byte {
var row []byte
row = append(row, encodeProtoString(1, value)...)
row = append(row, encodeProtoVarint(2, 1)...)
return row
}
func buildUSSTopicWithModelCredits(info *ModelCreditsInfo) []byte {
if info == nil {
info = &ModelCreditsInfo{}
}
minimum := defaultMinimumCreditAmountForUsage
if info.MinimumCreditAmountForUsage != nil {
minimum = *info.MinimumCreditAmountForUsage
}
entries := make([][]byte, 0, 3)
entries = append(entries, buildUSSTopicRow(
useAICreditsSentinelKey,
string(buildPrimitiveBoolBinary(info.UseAICredits)),
))
// JS protocol: useAICreditsSentinelKey carries the toggle state.
// availableCreditsSentinelKey is only present when credits are enabled.
if info.UseAICredits {
credits := int32(9999)
if info.AvailableCredits != nil {
credits = *info.AvailableCredits
}
entries = append(entries, buildUSSTopicRow(availableCreditsSentinelKey, string(buildPrimitiveInt32Binary(credits))))
}
entries = append(entries, buildUSSTopicRow(minimumCreditAmountForUsageKey, string(buildPrimitiveInt32Binary(minimum))))
var topic []byte
for _, entry := range entries {
topic = append(topic, encodeProtoBytes(1, entry)...)
}
return topic
}
// buildEmptyTopic returns an empty USS Topic proto (for non-oauth topics).
func buildEmptyTopic() []byte {
return []byte{} // Empty message = no map entries
}
// ============================================================
// UnifiedStateSyncUpdate builder
// ============================================================
// buildInitialStateUpdate creates a UnifiedStateSyncUpdate with initial_state set.
//
// message UnifiedStateSyncUpdate {
// oneof update_type {
// Topic initial_state = 1;
// AppliedUpdate applied_update = 2;
// }
// }
func buildInitialStateUpdate(topicData []byte) []byte {
return encodeProtoBytes(1, topicData)
}
func buildAppliedUpdate(key string, row []byte) []byte {
var applied []byte
applied = append(applied, encodeProtoString(1, key)...)
if len(row) > 0 {
applied = append(applied, encodeProtoBytes(2, row)...)
}
return encodeProtoBytes(2, applied)
}
// ============================================================
// MockExtensionServer
// ============================================================
// MockExtensionServer provides a ConnectRPC-compatible HTTP server that the
// Language Server binary connects to. It implements just enough of the
// ExtensionServerService to keep the LS operational.
type MockExtensionServer struct {
listener net.Listener
server *http.Server
port int
csrf string
mu sync.RWMutex
tokens map[string]*TokenInfo // account_id -> token info
credits map[string]*ModelCreditsInfo // account_id -> model credits info
subscribers map[string]map[int]*stateSubscriber
nextSubID int
lastAccountID string
logger *slog.Logger
// Trajectory callback — when LS pushes trajectory updates, we forward them
onTrajectoryUpdate func(topic, key string, data []byte)
}
// TokenInfo holds OAuth token details for an account.
type TokenInfo struct {
AccessToken string
RefreshToken string
ExpiresAt time.Time // zero value means unknown; defaults to now+1h
}
// ModelCreditsInfo mirrors the JS uss-modelCredits topic state.
type ModelCreditsInfo struct {
UseAICredits bool
AvailableCredits *int32
MinimumCreditAmountForUsage *int32
}
type stateSubscriber struct {
id int
accountID string
topic string
updates chan []byte
}
const (
useAICreditsSentinelKey = "useAICreditsSentinelKey"
availableCreditsSentinelKey = "availableCreditsSentinelKey"
minimumCreditAmountForUsageKey = "minimumCreditAmountForUsageKey"
defaultMinimumCreditAmountForUsage = int32(50)
)
// NewMockExtensionServer creates a mock extension server with proper ConnectRPC handling.
func NewMockExtensionServer(csrf string) (*MockExtensionServer, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, fmt.Errorf("listen: %w", err)
}
m := &MockExtensionServer{
listener: listener,
port: listener.Addr().(*net.TCPAddr).Port,
csrf: csrf,
tokens: make(map[string]*TokenInfo),
credits: make(map[string]*ModelCreditsInfo),
subscribers: make(map[string]map[int]*stateSubscriber),
logger: slog.Default().With("component", "mock-ext-server"),
}
mux := http.NewServeMux()
extService := "/exa.extension_server_pb.ExtensionServerService/"
// Register all RPCs the LS calls on the Extension Server.
// Unary RPCs — return application/proto
mux.HandleFunc(extService+"LanguageServerStarted", m.handleUnary(m.onLanguageServerStarted))
mux.HandleFunc(extService+"Heartbeat", m.handleUnary(m.onHeartbeat))
mux.HandleFunc(extService+"GetSecretValue", m.handleUnary(m.onGetSecretValue))
mux.HandleFunc(extService+"StoreSecretValue", m.handleUnary(m.onStoreSecretValue))
mux.HandleFunc(extService+"IsAgentManagerEnabled", m.handleUnary(m.onIsAgentManagerEnabled))
mux.HandleFunc(extService+"PushUnifiedStateSyncUpdate", m.handleUnary(m.onPushUnifiedStateSyncUpdate))
mux.HandleFunc(extService+"RecordError", m.handleUnary(m.onRecordError))
mux.HandleFunc(extService+"LogEvent", m.handleUnary(m.onLogEvent))
mux.HandleFunc(extService+"UpdateCascadeTrajectorySummaries", m.handleUnary(m.onUpdateTrajectorySummaries))
mux.HandleFunc(extService+"BroadcastConversationDeletion", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"WriteCascadeEdit", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"OpenDiffZones", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"HandleAsyncPostMessage", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"OpenFilePointer", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"OpenVirtualFile", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"SaveDocument", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"RestartUserStatusUpdater", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"FocusIDEWindow", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"SmartFocusConversation", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"RunExtensionCode", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"UpdateDetailedViewWithCascadeInput", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"FindAllReferences", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"GetDefinition", m.handleUnary(m.onDefault))
mux.HandleFunc(extService+"GetLintErrors", m.handleUnary(m.onDefault))
// Server-streaming RPCs — return application/connect+proto
mux.HandleFunc(extService+"SubscribeToUnifiedStateSyncTopic", m.handleStream(m.onSubscribeStateSyncTopic))
mux.HandleFunc(extService+"ExecuteCommand", m.handleStream(m.onExecuteCommand))
// Catch-all for any unregistered RPCs
mux.HandleFunc("/", m.handleCatchAll)
m.server = &http.Server{Handler: mux}
go func() {
if err := m.server.Serve(listener); err != http.ErrServerClosed {
m.logger.Error("extension server error", "err", err)
}
}()
m.logger.Info("mock extension server started", "port", m.port, "csrf_len", len(csrf))
return m, nil
}
// Port returns the listening port.
func (m *MockExtensionServer) Port() int {
return m.port
}
// SetToken sets the OAuth token for an account.
func (m *MockExtensionServer) SetToken(accountID string, info *TokenInfo) {
m.mu.Lock()
m.tokens[accountID] = info
m.lastAccountID = accountID
subscribers := m.snapshotSubscribersLocked("uss-oauth", accountID)
m.mu.Unlock()
if info == nil {
return
}
tokenBin := buildOAuthTokenInfoBinary(info.AccessToken, info.RefreshToken, info.ExpiresAt)
tokenB64 := base64.StdEncoding.EncodeToString(tokenBin)
m.publishTopicUpdate(subscribers, buildAppliedUpdate("oauthTokenInfoSentinelKey", buildUSSRowBinary(tokenB64)))
}
// SetModelCredits sets the uss-modelCredits state for an account.
func (m *MockExtensionServer) SetModelCredits(accountID string, info *ModelCreditsInfo) {
if info == nil {
info = &ModelCreditsInfo{}
}
copyInfo := *info
m.mu.Lock()
m.credits[accountID] = &copyInfo
m.lastAccountID = accountID
subscribers := m.snapshotSubscribersLocked("uss-modelCredits", accountID)
m.mu.Unlock()
m.publishTopicUpdate(subscribers, buildModelCreditsAppliedUpdates(&copyInfo)...)
}
// SetTrajectoryCallback registers a callback for when the LS pushes trajectory data.
func (m *MockExtensionServer) SetTrajectoryCallback(fn func(topic, key string, data []byte)) {
m.onTrajectoryUpdate = fn
}
func (m *MockExtensionServer) currentTokenLocked() *TokenInfo {
if m.lastAccountID != "" {
if info := m.tokens[m.lastAccountID]; info != nil {
return info
}
}
for _, info := range m.tokens {
return info
}
return nil
}
func (m *MockExtensionServer) currentModelCreditsLocked() *ModelCreditsInfo {
if m.lastAccountID != "" {
if info := m.credits[m.lastAccountID]; info != nil {
return info
}
}
for _, info := range m.credits {
return info
}
return nil
}
func (m *MockExtensionServer) tokenForAccountLocked(accountID string) *TokenInfo {
if accountID != "" {
if info := m.tokens[accountID]; info != nil {
return info
}
}
return m.currentTokenLocked()
}
func (m *MockExtensionServer) creditsForAccountLocked(accountID string) *ModelCreditsInfo {
if accountID != "" {
if info := m.credits[accountID]; info != nil {
return info
}
}
return m.currentModelCreditsLocked()
}
func (m *MockExtensionServer) snapshotSubscribersLocked(topic, accountID string) []*stateSubscriber {
topicSubs := m.subscribers[topic]
if len(topicSubs) == 0 {
return nil
}
out := make([]*stateSubscriber, 0, len(topicSubs))
for _, sub := range topicSubs {
if sub == nil {
continue
}
if accountID != "" && sub.accountID != "" && sub.accountID != accountID {
continue
}
out = append(out, sub)
}
return out
}
func (m *MockExtensionServer) publishTopicUpdate(subscribers []*stateSubscriber, updates ...[]byte) {
for _, sub := range subscribers {
if sub == nil {
continue
}
for _, update := range updates {
if len(update) == 0 {
continue
}
payload := append([]byte(nil), update...)
select {
case sub.updates <- payload:
default:
m.logger.Warn("dropping USS update", "topic", sub.topic, "account", sub.accountID)
}
}
}
}
func buildModelCreditsAppliedUpdates(info *ModelCreditsInfo) [][]byte {
if info == nil {
info = &ModelCreditsInfo{}
}
minimum := defaultMinimumCreditAmountForUsage
if info.MinimumCreditAmountForUsage != nil {
minimum = *info.MinimumCreditAmountForUsage
}
updates := make([][]byte, 0, 3)
updates = append(updates, buildAppliedUpdate(
useAICreditsSentinelKey,
buildUSSRowBinary(string(buildPrimitiveBoolBinary(info.UseAICredits))),
))
if info.UseAICredits {
credits := int32(9999)
if info.AvailableCredits != nil {
credits = *info.AvailableCredits
}
updates = append(updates, buildAppliedUpdate(
availableCreditsSentinelKey,
buildUSSRowBinary(string(buildPrimitiveInt32Binary(credits))),
))
} else {
updates = append(updates, buildAppliedUpdate(availableCreditsSentinelKey, nil))
}
updates = append(updates, buildAppliedUpdate(
minimumCreditAmountForUsageKey,
buildUSSRowBinary(string(buildPrimitiveInt32Binary(minimum))),
))
return updates
}
// Close shuts down the server.
func (m *MockExtensionServer) Close() error {
return m.server.Close()
}
// ============================================================
// Middleware
// ============================================================
type unaryHandler func(body []byte) []byte
type streamHandler func(body []byte, w http.ResponseWriter, r *http.Request)
// handleUnary wraps a unary RPC handler with CSRF check and proper content-type.
func (m *MockExtensionServer) handleUnary(handler unaryHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// CSRF check
if !m.checkCSRF(w, r) {
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
m.logger.Error("read body", "err", err, "path", r.URL.Path)
w.Header().Set("Content-Type", "application/proto")
w.WriteHeader(200)
return
}
// The LS might send with envelope framing (application/connect+proto)
// or without (application/proto). Detect and unwrap.
ct := r.Header.Get("Content-Type")
protoBody := body
if strings.Contains(ct, "connect+proto") && len(body) >= 5 {
protoBody = unwrapConnectEnvelope(body)
}
m.logger.Debug("unary RPC", "path", r.URL.Path, "body_len", len(protoBody), "content_type", ct)
responseProto := handler(protoBody)
// Respond with proper unary ConnectRPC content-type.
// If the request used "connect+proto", the response should be "application/proto"
// for unary RPCs (ConnectRPC spec: unary uses application/proto, not connect+proto).
w.Header().Set("Content-Type", "application/proto")
w.WriteHeader(200)
if len(responseProto) > 0 {
w.Write(responseProto)
}
}
}
// handleStream wraps a server-streaming RPC handler with CSRF and content-type.
func (m *MockExtensionServer) handleStream(handler streamHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !m.checkCSRF(w, r) {
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
m.logger.Error("read body", "err", err, "path", r.URL.Path)
return
}
// Unwrap envelope from request
ct := r.Header.Get("Content-Type")
if strings.Contains(ct, "connect+proto") || strings.Contains(ct, "connect+json") {
body = unwrapConnectEnvelope(body)
}
m.logger.Debug("stream RPC", "path", r.URL.Path, "body_len", len(body))
// Set streaming response content-type
w.Header().Set("Content-Type", "application/connect+proto")
w.WriteHeader(200)
handler(body, w, r)
}
}
func (m *MockExtensionServer) checkCSRF(w http.ResponseWriter, r *http.Request) bool {
token := r.Header.Get("x-codeium-csrf-token")
if m.csrf != "" && token != m.csrf {
m.logger.Warn("CSRF mismatch", "path", r.URL.Path, "got", token[:min(8, len(token))])
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(403)
w.Write([]byte("Invalid CSRF token"))
return false
}
return true
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
// ============================================================
// Unary RPC Handlers — each receives raw proto request body,
// returns raw proto response body.
// ============================================================
func (m *MockExtensionServer) onLanguageServerStarted(body []byte) []byte {
// LanguageServerStartedRequest has: https_port(1), http_port(2), lsp_port(3), csrf_token(4)
// We just log the ports — they're informational.
m.logger.Info("LanguageServerStarted",
"body_len", len(body))
// Return empty LanguageServerStartedResponse
return nil
}
func (m *MockExtensionServer) onHeartbeat(body []byte) []byte {
// Return empty HeartbeatResponse
return nil
}
func (m *MockExtensionServer) onGetSecretValue(body []byte) []byte {
// GetSecretValueRequest: key = field 1
key := decodeProtoString(body, 1)
m.logger.Debug("GetSecretValue", "key", key)
m.mu.RLock()
var token string
if info := m.currentTokenLocked(); info != nil {
token = info.AccessToken
}
m.mu.RUnlock()
// GetSecretValueResponse: value = field 1
if token != "" {
return encodeProtoString(1, token)
}
return nil
}
func (m *MockExtensionServer) onStoreSecretValue(body []byte) []byte {
key := decodeProtoString(body, 1)
m.logger.Debug("StoreSecretValue", "key", key)
return nil
}
func (m *MockExtensionServer) onIsAgentManagerEnabled(body []byte) []byte {
// IsAgentManagerEnabledResponse: enabled = field 1 (bool)
return encodeProtoBool(1, false)
}
func (m *MockExtensionServer) onPushUnifiedStateSyncUpdate(body []byte) []byte {
// PushUnifiedStateSyncUpdateRequest: update = field 1 (UpdateRequest message)
// UpdateRequest: topic_name = field 1, applied_update = field 5, key = field 2
m.logger.Debug("PushUnifiedStateSyncUpdate", "body_len", len(body))
// Extract topic name from the embedded UpdateRequest
// The body is PushUnifiedStateSyncUpdateRequest, field 1 is the UpdateRequest
// We need to dig into the nested message to get topic_name
if m.onTrajectoryUpdate != nil {
// For now, just notify that an update was pushed
m.onTrajectoryUpdate("", "", body)
}
// Return empty PushUnifiedStateSyncUpdateResponse
return nil
}
func (m *MockExtensionServer) onRecordError(body []byte) []byte {
m.logger.Debug("RecordError", "body_len", len(body))
return nil
}
func (m *MockExtensionServer) onLogEvent(body []byte) []byte {
return nil
}
func (m *MockExtensionServer) onUpdateTrajectorySummaries(body []byte) []byte {
m.logger.Debug("UpdateCascadeTrajectorySummaries", "body_len", len(body))
return nil
}
func (m *MockExtensionServer) onDefault(body []byte) []byte {
return nil
}
// ============================================================
// Streaming RPC Handlers
// ============================================================
func (m *MockExtensionServer) onSubscribeStateSyncTopic(body []byte, w http.ResponseWriter, r *http.Request) {
// SubscribeToUnifiedStateSyncTopicRequest: topic = field 1
topic := decodeProtoString(body, 1)
m.logger.Info("SubscribeToUnifiedStateSyncTopic", "topic", topic)
flusher, ok := w.(http.Flusher)
if !ok {
m.logger.Error("ResponseWriter does not support Flush")
return
}
m.mu.Lock()
accountID := m.lastAccountID
subID := m.nextSubID
m.nextSubID++
sub := &stateSubscriber{
id: subID,
accountID: accountID,
topic: topic,
updates: make(chan []byte, 16),
}
if m.subscribers[topic] == nil {
m.subscribers[topic] = make(map[int]*stateSubscriber)
}
m.subscribers[topic][subID] = sub
// Build initial state based on topic
var topicData []byte
switch topic {
case "uss-oauth":
tokenInfo := m.tokenForAccountLocked(accountID)
if tokenInfo != nil {
topicData = buildUSSTopicWithOAuth(tokenInfo.AccessToken, tokenInfo.RefreshToken, tokenInfo.ExpiresAt)
} else {
topicData = buildEmptyTopic()
}
case "uss-modelCredits":
creditsInfo := m.creditsForAccountLocked(accountID)
if creditsInfo != nil {
topicData = buildUSSTopicWithModelCredits(creditsInfo)
} else {
topicData = buildEmptyTopic()
}
default:
// For all other topics (browserPreferences, enterprisePreferences, etc.),
// return empty topic data.
topicData = buildEmptyTopic()
}
m.mu.Unlock()
defer func() {
m.mu.Lock()
if topicSubs := m.subscribers[topic]; topicSubs != nil {
delete(topicSubs, subID)
if len(topicSubs) == 0 {
delete(m.subscribers, topic)
}
}
m.mu.Unlock()
}()
// Send initial state as envelope-framed message
initialUpdate := buildInitialStateUpdate(topicData)
frame := connectEnvelope(0x00, initialUpdate)
w.Write(frame)
flusher.Flush()
for {
select {
case <-r.Context().Done():
m.logger.Debug("SubscribeToUnifiedStateSyncTopic stream closed", "topic", topic)
return
case update := <-sub.updates:
if len(update) == 0 {
continue
}
if _, err := w.Write(connectEnvelope(0x00, update)); err != nil {
m.logger.Debug("SubscribeToUnifiedStateSyncTopic write failed", "topic", topic, "err", err)
return
}
flusher.Flush()
}
}
}
func (m *MockExtensionServer) onExecuteCommand(body []byte, w http.ResponseWriter, r *http.Request) {
m.logger.Debug("ExecuteCommand (mock)", "body_len", len(body))
// Send end-of-stream immediately — we don't execute commands
flusher, ok := w.(http.Flusher)
if !ok {
return
}
w.Write(connectEndOfStream())
flusher.Flush()
}
// ============================================================
// Catch-all handler
// ============================================================
func (m *MockExtensionServer) handleCatchAll(w http.ResponseWriter, r *http.Request) {
if !m.checkCSRF(w, r) {
return
}
m.logger.Debug("unhandled RPC (returning empty proto)", "path", r.URL.Path, "method", r.Method)
// Drain request body
io.ReadAll(r.Body)
// Determine if this is likely a unary or streaming request based on content-type.
ct := r.Header.Get("Content-Type")
if strings.Contains(ct, "connect+") {
// Could be streaming — respond with unary proto to be safe
// (unary Connect requests can also use connect+ prefix in some client impls)
w.Header().Set("Content-Type", "application/proto")
} else {
w.Header().Set("Content-Type", "application/proto")
}
w.WriteHeader(200)
}