921 lines
28 KiB
Go
921 lines
28 KiB
Go
// Package lspool provides a mock Extension Server that the LS binary connects
|
|
// to at startup. The real IDE's extension.js runs a ConnectRPC HTTP/1.1 server
|
|
// using connectNodeAdapter. We replicate that protocol here.
|
|
//
|
|
// Protocol details (from extension.js source):
|
|
// - Transport: HTTP/1.1 on 127.0.0.1 (no TLS)
|
|
// - Auth: x-codeium-csrf-token header on every request
|
|
// - Unary request Content-Type: application/proto (binary protobuf, no envelope)
|
|
// OR application/connect+proto (with 5-byte envelope)
|
|
// - Unary response Content-Type: application/proto (raw binary protobuf, no envelope)
|
|
// - Stream request Content-Type: application/connect+proto (with 5-byte envelope)
|
|
// - Stream response Content-Type: application/connect+proto (envelope-framed messages)
|
|
//
|
|
// The LS sends requests with content-type "application/connect+proto" for BOTH
|
|
// unary and streaming RPCs. ConnectRPC's content-type regex:
|
|
//
|
|
// /^application\/(connect\+)?(?:(json)(?:; ?charset=utf-?8)?|(proto))$/i
|
|
//
|
|
// If "connect+" prefix is present → stream mode; otherwise → unary mode.
|
|
// However the LS Go client uses the Connect protocol client which always sends
|
|
// "application/proto" for unary and "application/connect+proto" for streaming.
|
|
//
|
|
// We detect the RPC kind from the URL path and respond accordingly.
|
|
package lspool
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
)
|
|
|
|
// ============================================================
|
|
// Proto helpers — hand-encode minimal proto messages so we don't
|
|
// need to import the full generated proto package.
|
|
// ============================================================
|
|
|
|
// encodeProtoString writes a proto string field (wire type 2) to a byte slice.
|
|
func encodeProtoString(fieldNum int, val string) []byte {
|
|
tag := encodeVarint(uint64(fieldNum<<3 | 2))
|
|
length := encodeVarint(uint64(len(val)))
|
|
out := make([]byte, 0, len(tag)+len(length)+len(val))
|
|
out = append(out, tag...)
|
|
out = append(out, length...)
|
|
out = append(out, []byte(val)...)
|
|
return out
|
|
}
|
|
|
|
// encodeProtoBytes writes a proto bytes/message field (wire type 2).
|
|
func encodeProtoBytes(fieldNum int, val []byte) []byte {
|
|
tag := encodeVarint(uint64(fieldNum<<3 | 2))
|
|
length := encodeVarint(uint64(len(val)))
|
|
out := make([]byte, 0, len(tag)+len(length)+len(val))
|
|
out = append(out, tag...)
|
|
out = append(out, length...)
|
|
out = append(out, val...)
|
|
return out
|
|
}
|
|
|
|
// encodeProtoVarint writes a proto varint field (wire type 0).
|
|
func encodeProtoVarint(fieldNum int, val uint64) []byte {
|
|
tag := encodeVarint(uint64(fieldNum<<3 | 0))
|
|
v := encodeVarint(val)
|
|
out := make([]byte, 0, len(tag)+len(v))
|
|
out = append(out, tag...)
|
|
out = append(out, v...)
|
|
return out
|
|
}
|
|
|
|
// encodeProtoBool writes a proto bool field.
|
|
func encodeProtoBool(fieldNum int, val bool) []byte {
|
|
v := uint64(0)
|
|
if val {
|
|
v = 1
|
|
}
|
|
return encodeProtoVarint(fieldNum, v)
|
|
}
|
|
|
|
func encodeVarint(v uint64) []byte {
|
|
buf := make([]byte, binary.MaxVarintLen64)
|
|
n := binary.PutUvarint(buf, v)
|
|
return buf[:n]
|
|
}
|
|
|
|
// decodeProtoString extracts a string field from raw proto bytes.
|
|
func decodeProtoString(data []byte, targetField int) string {
|
|
i := 0
|
|
for i < len(data) {
|
|
if i >= len(data) {
|
|
break
|
|
}
|
|
tag, n := binary.Uvarint(data[i:])
|
|
if n <= 0 {
|
|
break
|
|
}
|
|
i += n
|
|
fieldNum := int(tag >> 3)
|
|
wireType := tag & 0x7
|
|
|
|
switch wireType {
|
|
case 0: // varint
|
|
_, n = binary.Uvarint(data[i:])
|
|
if n <= 0 {
|
|
return ""
|
|
}
|
|
i += n
|
|
case 2: // length-delimited
|
|
length, n := binary.Uvarint(data[i:])
|
|
if n <= 0 {
|
|
return ""
|
|
}
|
|
i += n
|
|
if fieldNum == targetField {
|
|
end := i + int(length)
|
|
if end > len(data) {
|
|
return ""
|
|
}
|
|
return string(data[i:end])
|
|
}
|
|
i += int(length)
|
|
case 1: // 64-bit
|
|
i += 8
|
|
case 5: // 32-bit
|
|
i += 4
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// ============================================================
|
|
// ConnectRPC envelope helpers
|
|
// ============================================================
|
|
|
|
// connectEnvelope wraps a proto payload in a ConnectRPC streaming envelope:
|
|
// 1 byte flags + 4 byte big-endian length + payload
|
|
func connectEnvelope(flags byte, payload []byte) []byte {
|
|
frame := make([]byte, 5+len(payload))
|
|
frame[0] = flags
|
|
binary.BigEndian.PutUint32(frame[1:5], uint32(len(payload)))
|
|
copy(frame[5:], payload)
|
|
return frame
|
|
}
|
|
|
|
// connectEndOfStream returns the end-of-stream trailer frame for ConnectRPC.
|
|
// flags=0x02 signals end of stream. The payload is a JSON object with empty metadata.
|
|
func connectEndOfStream() []byte {
|
|
trailer := []byte("{}")
|
|
return connectEnvelope(0x02, trailer)
|
|
}
|
|
|
|
// unwrapConnectEnvelope strips the 5-byte envelope header from a ConnectRPC message.
|
|
// Returns the raw proto payload. If the input is shorter than 5 bytes, returns as-is.
|
|
func unwrapConnectEnvelope(body []byte) []byte {
|
|
if len(body) < 5 {
|
|
return body
|
|
}
|
|
// Check if it looks like an envelope: first byte should be 0x00 or 0x01
|
|
if body[0] > 0x02 {
|
|
return body // Not envelope-framed, return raw
|
|
}
|
|
plen := binary.BigEndian.Uint32(body[1:5])
|
|
if int(plen)+5 > len(body) {
|
|
return body // Length mismatch, return raw
|
|
}
|
|
return body[5 : 5+plen]
|
|
}
|
|
|
|
// ============================================================
|
|
// OAuthTokenInfo proto builder
|
|
// ============================================================
|
|
|
|
// buildOAuthTokenInfoBinary creates binary-encoded OAuthTokenInfo proto.
|
|
//
|
|
// message OAuthTokenInfo {
|
|
// string access_token = 1;
|
|
// string token_type = 2;
|
|
// string refresh_token = 3;
|
|
// google.protobuf.Timestamp expiry = 4;
|
|
// bool is_gcp_tos = 6;
|
|
// }
|
|
func buildOAuthTokenInfoBinary(accessToken, refreshToken string, expiresAt time.Time) []byte {
|
|
var buf []byte
|
|
buf = append(buf, encodeProtoString(1, accessToken)...)
|
|
buf = append(buf, encodeProtoString(2, "Bearer")...)
|
|
if refreshToken != "" {
|
|
buf = append(buf, encodeProtoString(3, refreshToken)...)
|
|
}
|
|
// Use real expiry if provided, otherwise default to 1 hour from now
|
|
expiry := expiresAt
|
|
if expiry.IsZero() {
|
|
expiry = time.Now().Add(1 * time.Hour)
|
|
}
|
|
ts := ×tamppb.Timestamp{
|
|
Seconds: expiry.Unix(),
|
|
}
|
|
tsBytes, _ := proto.Marshal(ts)
|
|
buf = append(buf, encodeProtoBytes(4, tsBytes)...)
|
|
buf = append(buf, encodeProtoBool(6, true)...)
|
|
return buf
|
|
}
|
|
|
|
// buildUSSTopicWithOAuth creates a USS Topic proto with the OAuth token.
|
|
//
|
|
// message Topic { map<string, Row> data = 1; }
|
|
// message Row { string value = 1; int64 e_tag = 2; }
|
|
//
|
|
// The key in the map is "oauthTokenInfoSentinelKey" and the Row.value is
|
|
// base64(toBinary(OAuthTokenInfo)).
|
|
func buildUSSTopicWithOAuth(accessToken, refreshToken string, expiresAt time.Time) []byte {
|
|
tokenBin := buildOAuthTokenInfoBinary(accessToken, refreshToken, expiresAt)
|
|
tokenB64 := base64.StdEncoding.EncodeToString(tokenBin)
|
|
|
|
// Row: value=tokenB64 (field 1), e_tag=1 (field 2)
|
|
var row []byte
|
|
row = append(row, encodeProtoString(1, tokenB64)...)
|
|
row = append(row, encodeProtoVarint(2, 1)...)
|
|
|
|
// Map entry: key="oauthTokenInfoSentinelKey" (field 1), value=row (field 2)
|
|
var entry []byte
|
|
entry = append(entry, encodeProtoString(1, "oauthTokenInfoSentinelKey")...)
|
|
entry = append(entry, encodeProtoBytes(2, row)...)
|
|
|
|
// Topic: data map entries use field 1
|
|
var topic []byte
|
|
topic = append(topic, encodeProtoBytes(1, entry)...)
|
|
|
|
return topic
|
|
}
|
|
|
|
func buildPrimitiveBoolBinary(val bool) []byte {
|
|
// Primitive.bool_value is field 13 in the proto definition
|
|
return encodeProtoBool(13, val)
|
|
}
|
|
|
|
func buildPrimitiveInt32Binary(val int32) []byte {
|
|
// Primitive.int32_value is field 3 in the proto definition
|
|
return encodeProtoVarint(3, uint64(uint32(val)))
|
|
}
|
|
|
|
func encodeUSSBinaryValue(value []byte) string {
|
|
return base64.StdEncoding.EncodeToString(value)
|
|
}
|
|
|
|
func encodeUSSPrimitiveBoolValue(val bool) string {
|
|
return encodeUSSBinaryValue(buildPrimitiveBoolBinary(val))
|
|
}
|
|
|
|
func encodeUSSPrimitiveInt32Value(val int32) string {
|
|
return encodeUSSBinaryValue(buildPrimitiveInt32Binary(val))
|
|
}
|
|
|
|
func buildUSSTopicRow(key string, value string) []byte {
|
|
row := buildUSSRowBinary(value)
|
|
|
|
var entry []byte
|
|
entry = append(entry, encodeProtoString(1, key)...)
|
|
entry = append(entry, encodeProtoBytes(2, row)...)
|
|
return entry
|
|
}
|
|
|
|
func buildUSSRowBinary(value string) []byte {
|
|
var row []byte
|
|
row = append(row, encodeProtoString(1, value)...)
|
|
row = append(row, encodeProtoVarint(2, 1)...)
|
|
return row
|
|
}
|
|
|
|
func buildUSSTopicWithModelCredits(info *ModelCreditsInfo) []byte {
|
|
if info == nil {
|
|
info = &ModelCreditsInfo{}
|
|
}
|
|
|
|
minimum := defaultMinimumCreditAmountForUsage
|
|
if info.MinimumCreditAmountForUsage != nil {
|
|
minimum = *info.MinimumCreditAmountForUsage
|
|
}
|
|
|
|
entries := make([][]byte, 0, 3)
|
|
entries = append(entries, buildUSSTopicRow(
|
|
useAICreditsSentinelKey,
|
|
encodeUSSPrimitiveBoolValue(info.UseAICredits),
|
|
))
|
|
// JS protocol: useAICreditsSentinelKey carries the toggle state.
|
|
// availableCreditsSentinelKey is only present when credits are enabled.
|
|
if info.UseAICredits {
|
|
credits := int32(9999)
|
|
if info.AvailableCredits != nil {
|
|
credits = *info.AvailableCredits
|
|
}
|
|
entries = append(entries, buildUSSTopicRow(availableCreditsSentinelKey, encodeUSSPrimitiveInt32Value(credits)))
|
|
}
|
|
entries = append(entries, buildUSSTopicRow(minimumCreditAmountForUsageKey, encodeUSSPrimitiveInt32Value(minimum)))
|
|
|
|
var topic []byte
|
|
for _, entry := range entries {
|
|
topic = append(topic, encodeProtoBytes(1, entry)...)
|
|
}
|
|
return topic
|
|
}
|
|
|
|
// buildEmptyTopic returns an empty USS Topic proto (for non-oauth topics).
|
|
func buildEmptyTopic() []byte {
|
|
return []byte{} // Empty message = no map entries
|
|
}
|
|
|
|
// ============================================================
|
|
// UnifiedStateSyncUpdate builder
|
|
// ============================================================
|
|
|
|
// buildInitialStateUpdate creates a UnifiedStateSyncUpdate with initial_state set.
|
|
//
|
|
// message UnifiedStateSyncUpdate {
|
|
// oneof update_type {
|
|
// Topic initial_state = 1;
|
|
// AppliedUpdate applied_update = 2;
|
|
// }
|
|
// }
|
|
func buildInitialStateUpdate(topicData []byte) []byte {
|
|
return encodeProtoBytes(1, topicData)
|
|
}
|
|
|
|
func buildAppliedUpdate(key string, row []byte) []byte {
|
|
var applied []byte
|
|
applied = append(applied, encodeProtoString(1, key)...)
|
|
if len(row) > 0 {
|
|
applied = append(applied, encodeProtoBytes(2, row)...)
|
|
}
|
|
return encodeProtoBytes(2, applied)
|
|
}
|
|
|
|
// ============================================================
|
|
// MockExtensionServer
|
|
// ============================================================
|
|
|
|
// MockExtensionServer provides a ConnectRPC-compatible HTTP server that the
|
|
// Language Server binary connects to. It implements just enough of the
|
|
// ExtensionServerService to keep the LS operational.
|
|
type MockExtensionServer struct {
|
|
listener net.Listener
|
|
server *http.Server
|
|
port int
|
|
csrf string
|
|
mu sync.RWMutex
|
|
tokens map[string]*TokenInfo // account_id -> token info
|
|
credits map[string]*ModelCreditsInfo // account_id -> model credits info
|
|
subscribers map[string]map[int]*stateSubscriber
|
|
nextSubID int
|
|
lastAccountID string
|
|
logger *slog.Logger
|
|
|
|
// Trajectory callback — when LS pushes trajectory updates, we forward them
|
|
onTrajectoryUpdate func(topic, key string, data []byte)
|
|
}
|
|
|
|
// TokenInfo holds OAuth token details for an account.
|
|
type TokenInfo struct {
|
|
AccessToken string
|
|
RefreshToken string
|
|
ExpiresAt time.Time // zero value means unknown; defaults to now+1h
|
|
}
|
|
|
|
// ModelCreditsInfo mirrors the JS uss-modelCredits topic state.
|
|
type ModelCreditsInfo struct {
|
|
UseAICredits bool
|
|
AvailableCredits *int32
|
|
MinimumCreditAmountForUsage *int32
|
|
}
|
|
|
|
type stateSubscriber struct {
|
|
id int
|
|
accountID string
|
|
topic string
|
|
updates chan []byte
|
|
}
|
|
|
|
const (
|
|
useAICreditsSentinelKey = "useAICreditsSentinelKey"
|
|
availableCreditsSentinelKey = "availableCreditsSentinelKey"
|
|
minimumCreditAmountForUsageKey = "minimumCreditAmountForUsageKey"
|
|
defaultMinimumCreditAmountForUsage = int32(50)
|
|
)
|
|
|
|
// NewMockExtensionServer creates a mock extension server with proper ConnectRPC handling.
|
|
func NewMockExtensionServer(csrf string) (*MockExtensionServer, error) {
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("listen: %w", err)
|
|
}
|
|
|
|
m := &MockExtensionServer{
|
|
listener: listener,
|
|
port: listener.Addr().(*net.TCPAddr).Port,
|
|
csrf: csrf,
|
|
tokens: make(map[string]*TokenInfo),
|
|
credits: make(map[string]*ModelCreditsInfo),
|
|
subscribers: make(map[string]map[int]*stateSubscriber),
|
|
logger: slog.Default().With("component", "mock-ext-server"),
|
|
}
|
|
|
|
mux := http.NewServeMux()
|
|
extService := "/exa.extension_server_pb.ExtensionServerService/"
|
|
|
|
// Register all RPCs the LS calls on the Extension Server.
|
|
// Unary RPCs — return application/proto
|
|
mux.HandleFunc(extService+"LanguageServerStarted", m.handleUnary(m.onLanguageServerStarted))
|
|
mux.HandleFunc(extService+"Heartbeat", m.handleUnary(m.onHeartbeat))
|
|
mux.HandleFunc(extService+"GetSecretValue", m.handleUnary(m.onGetSecretValue))
|
|
mux.HandleFunc(extService+"StoreSecretValue", m.handleUnary(m.onStoreSecretValue))
|
|
mux.HandleFunc(extService+"IsAgentManagerEnabled", m.handleUnary(m.onIsAgentManagerEnabled))
|
|
mux.HandleFunc(extService+"PushUnifiedStateSyncUpdate", m.handleUnary(m.onPushUnifiedStateSyncUpdate))
|
|
mux.HandleFunc(extService+"RecordError", m.handleUnary(m.onRecordError))
|
|
mux.HandleFunc(extService+"LogEvent", m.handleUnary(m.onLogEvent))
|
|
mux.HandleFunc(extService+"UpdateCascadeTrajectorySummaries", m.handleUnary(m.onUpdateTrajectorySummaries))
|
|
mux.HandleFunc(extService+"BroadcastConversationDeletion", m.handleUnary(m.onDefault))
|
|
mux.HandleFunc(extService+"WriteCascadeEdit", m.handleUnary(m.onDefault))
|
|
mux.HandleFunc(extService+"OpenDiffZones", m.handleUnary(m.onDefault))
|
|
mux.HandleFunc(extService+"HandleAsyncPostMessage", m.handleUnary(m.onDefault))
|
|
mux.HandleFunc(extService+"OpenFilePointer", m.handleUnary(m.onDefault))
|
|
mux.HandleFunc(extService+"OpenVirtualFile", m.handleUnary(m.onDefault))
|
|
mux.HandleFunc(extService+"SaveDocument", m.handleUnary(m.onDefault))
|
|
mux.HandleFunc(extService+"RestartUserStatusUpdater", m.handleUnary(m.onDefault))
|
|
mux.HandleFunc(extService+"FocusIDEWindow", m.handleUnary(m.onDefault))
|
|
mux.HandleFunc(extService+"SmartFocusConversation", m.handleUnary(m.onDefault))
|
|
mux.HandleFunc(extService+"RunExtensionCode", m.handleUnary(m.onDefault))
|
|
mux.HandleFunc(extService+"UpdateDetailedViewWithCascadeInput", m.handleUnary(m.onDefault))
|
|
mux.HandleFunc(extService+"FindAllReferences", m.handleUnary(m.onDefault))
|
|
mux.HandleFunc(extService+"GetDefinition", m.handleUnary(m.onDefault))
|
|
mux.HandleFunc(extService+"GetLintErrors", m.handleUnary(m.onDefault))
|
|
|
|
// Server-streaming RPCs — return application/connect+proto
|
|
mux.HandleFunc(extService+"SubscribeToUnifiedStateSyncTopic", m.handleStream(m.onSubscribeStateSyncTopic))
|
|
mux.HandleFunc(extService+"ExecuteCommand", m.handleStream(m.onExecuteCommand))
|
|
|
|
// Catch-all for any unregistered RPCs
|
|
mux.HandleFunc("/", m.handleCatchAll)
|
|
|
|
m.server = &http.Server{Handler: mux}
|
|
|
|
go func() {
|
|
if err := m.server.Serve(listener); err != http.ErrServerClosed {
|
|
m.logger.Error("extension server error", "err", err)
|
|
}
|
|
}()
|
|
|
|
m.logger.Info("mock extension server started", "port", m.port, "csrf_len", len(csrf))
|
|
return m, nil
|
|
}
|
|
|
|
// Port returns the listening port.
|
|
func (m *MockExtensionServer) Port() int {
|
|
return m.port
|
|
}
|
|
|
|
// SetToken sets the OAuth token for an account.
|
|
func (m *MockExtensionServer) SetToken(accountID string, info *TokenInfo) {
|
|
m.mu.Lock()
|
|
m.tokens[accountID] = info
|
|
m.lastAccountID = accountID
|
|
subscribers := m.snapshotSubscribersLocked("uss-oauth", accountID)
|
|
m.mu.Unlock()
|
|
|
|
if info == nil {
|
|
return
|
|
}
|
|
tokenBin := buildOAuthTokenInfoBinary(info.AccessToken, info.RefreshToken, info.ExpiresAt)
|
|
tokenB64 := base64.StdEncoding.EncodeToString(tokenBin)
|
|
m.publishTopicUpdate(subscribers, buildAppliedUpdate("oauthTokenInfoSentinelKey", buildUSSRowBinary(tokenB64)))
|
|
}
|
|
|
|
// SetModelCredits sets the uss-modelCredits state for an account.
|
|
func (m *MockExtensionServer) SetModelCredits(accountID string, info *ModelCreditsInfo) {
|
|
if info == nil {
|
|
info = &ModelCreditsInfo{}
|
|
}
|
|
copyInfo := *info
|
|
m.mu.Lock()
|
|
m.credits[accountID] = ©Info
|
|
m.lastAccountID = accountID
|
|
subscribers := m.snapshotSubscribersLocked("uss-modelCredits", accountID)
|
|
m.mu.Unlock()
|
|
|
|
m.publishTopicUpdate(subscribers, buildModelCreditsAppliedUpdates(©Info)...)
|
|
}
|
|
|
|
// SetTrajectoryCallback registers a callback for when the LS pushes trajectory data.
|
|
func (m *MockExtensionServer) SetTrajectoryCallback(fn func(topic, key string, data []byte)) {
|
|
m.onTrajectoryUpdate = fn
|
|
}
|
|
|
|
func (m *MockExtensionServer) currentTokenLocked() *TokenInfo {
|
|
if m.lastAccountID != "" {
|
|
if info := m.tokens[m.lastAccountID]; info != nil {
|
|
return info
|
|
}
|
|
}
|
|
for _, info := range m.tokens {
|
|
return info
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *MockExtensionServer) currentModelCreditsLocked() *ModelCreditsInfo {
|
|
if m.lastAccountID != "" {
|
|
if info := m.credits[m.lastAccountID]; info != nil {
|
|
return info
|
|
}
|
|
}
|
|
for _, info := range m.credits {
|
|
return info
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *MockExtensionServer) tokenForAccountLocked(accountID string) *TokenInfo {
|
|
if accountID != "" {
|
|
if info := m.tokens[accountID]; info != nil {
|
|
return info
|
|
}
|
|
}
|
|
return m.currentTokenLocked()
|
|
}
|
|
|
|
func (m *MockExtensionServer) creditsForAccountLocked(accountID string) *ModelCreditsInfo {
|
|
if accountID != "" {
|
|
if info := m.credits[accountID]; info != nil {
|
|
return info
|
|
}
|
|
}
|
|
return m.currentModelCreditsLocked()
|
|
}
|
|
|
|
func (m *MockExtensionServer) snapshotSubscribersLocked(topic, accountID string) []*stateSubscriber {
|
|
topicSubs := m.subscribers[topic]
|
|
if len(topicSubs) == 0 {
|
|
return nil
|
|
}
|
|
out := make([]*stateSubscriber, 0, len(topicSubs))
|
|
for _, sub := range topicSubs {
|
|
if sub == nil {
|
|
continue
|
|
}
|
|
if accountID != "" && sub.accountID != "" && sub.accountID != accountID {
|
|
continue
|
|
}
|
|
out = append(out, sub)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (m *MockExtensionServer) publishTopicUpdate(subscribers []*stateSubscriber, updates ...[]byte) {
|
|
for _, sub := range subscribers {
|
|
if sub == nil {
|
|
continue
|
|
}
|
|
for _, update := range updates {
|
|
if len(update) == 0 {
|
|
continue
|
|
}
|
|
payload := append([]byte(nil), update...)
|
|
select {
|
|
case sub.updates <- payload:
|
|
default:
|
|
m.logger.Warn("dropping USS update", "topic", sub.topic, "account", sub.accountID)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func buildModelCreditsAppliedUpdates(info *ModelCreditsInfo) [][]byte {
|
|
if info == nil {
|
|
info = &ModelCreditsInfo{}
|
|
}
|
|
minimum := defaultMinimumCreditAmountForUsage
|
|
if info.MinimumCreditAmountForUsage != nil {
|
|
minimum = *info.MinimumCreditAmountForUsage
|
|
}
|
|
|
|
updates := make([][]byte, 0, 3)
|
|
updates = append(updates, buildAppliedUpdate(
|
|
useAICreditsSentinelKey,
|
|
buildUSSRowBinary(encodeUSSPrimitiveBoolValue(info.UseAICredits)),
|
|
))
|
|
|
|
if info.UseAICredits {
|
|
credits := int32(9999)
|
|
if info.AvailableCredits != nil {
|
|
credits = *info.AvailableCredits
|
|
}
|
|
updates = append(updates, buildAppliedUpdate(
|
|
availableCreditsSentinelKey,
|
|
buildUSSRowBinary(encodeUSSPrimitiveInt32Value(credits)),
|
|
))
|
|
} else {
|
|
updates = append(updates, buildAppliedUpdate(availableCreditsSentinelKey, nil))
|
|
}
|
|
updates = append(updates, buildAppliedUpdate(
|
|
minimumCreditAmountForUsageKey,
|
|
buildUSSRowBinary(encodeUSSPrimitiveInt32Value(minimum)),
|
|
))
|
|
|
|
return updates
|
|
}
|
|
|
|
// Close shuts down the server.
|
|
func (m *MockExtensionServer) Close() error {
|
|
return m.server.Close()
|
|
}
|
|
|
|
// ============================================================
|
|
// Middleware
|
|
// ============================================================
|
|
|
|
type unaryHandler func(body []byte) []byte
|
|
type streamHandler func(body []byte, w http.ResponseWriter, r *http.Request)
|
|
|
|
// handleUnary wraps a unary RPC handler with CSRF check and proper content-type.
|
|
func (m *MockExtensionServer) handleUnary(handler unaryHandler) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// CSRF check
|
|
if !m.checkCSRF(w, r) {
|
|
return
|
|
}
|
|
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
m.logger.Error("read body", "err", err, "path", r.URL.Path)
|
|
w.Header().Set("Content-Type", "application/proto")
|
|
w.WriteHeader(200)
|
|
return
|
|
}
|
|
|
|
// The LS might send with envelope framing (application/connect+proto)
|
|
// or without (application/proto). Detect and unwrap.
|
|
ct := r.Header.Get("Content-Type")
|
|
protoBody := body
|
|
if strings.Contains(ct, "connect+proto") && len(body) >= 5 {
|
|
protoBody = unwrapConnectEnvelope(body)
|
|
}
|
|
|
|
m.logger.Debug("unary RPC", "path", r.URL.Path, "body_len", len(protoBody), "content_type", ct)
|
|
|
|
responseProto := handler(protoBody)
|
|
|
|
// Respond with proper unary ConnectRPC content-type.
|
|
// If the request used "connect+proto", the response should be "application/proto"
|
|
// for unary RPCs (ConnectRPC spec: unary uses application/proto, not connect+proto).
|
|
w.Header().Set("Content-Type", "application/proto")
|
|
w.WriteHeader(200)
|
|
if len(responseProto) > 0 {
|
|
w.Write(responseProto)
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleStream wraps a server-streaming RPC handler with CSRF and content-type.
|
|
func (m *MockExtensionServer) handleStream(handler streamHandler) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
if !m.checkCSRF(w, r) {
|
|
return
|
|
}
|
|
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
m.logger.Error("read body", "err", err, "path", r.URL.Path)
|
|
return
|
|
}
|
|
|
|
// Unwrap envelope from request
|
|
ct := r.Header.Get("Content-Type")
|
|
if strings.Contains(ct, "connect+proto") || strings.Contains(ct, "connect+json") {
|
|
body = unwrapConnectEnvelope(body)
|
|
}
|
|
|
|
m.logger.Debug("stream RPC", "path", r.URL.Path, "body_len", len(body))
|
|
|
|
// Set streaming response content-type
|
|
w.Header().Set("Content-Type", "application/connect+proto")
|
|
w.WriteHeader(200)
|
|
|
|
handler(body, w, r)
|
|
}
|
|
}
|
|
|
|
func (m *MockExtensionServer) checkCSRF(w http.ResponseWriter, r *http.Request) bool {
|
|
token := r.Header.Get("x-codeium-csrf-token")
|
|
if m.csrf != "" && token != m.csrf {
|
|
m.logger.Warn("CSRF mismatch", "path", r.URL.Path, "got", token[:min(8, len(token))])
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
w.WriteHeader(403)
|
|
w.Write([]byte("Invalid CSRF token"))
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func min(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
// ============================================================
|
|
// Unary RPC Handlers — each receives raw proto request body,
|
|
// returns raw proto response body.
|
|
// ============================================================
|
|
|
|
func (m *MockExtensionServer) onLanguageServerStarted(body []byte) []byte {
|
|
// LanguageServerStartedRequest has: https_port(1), http_port(2), lsp_port(3), csrf_token(4)
|
|
// We just log the ports — they're informational.
|
|
m.logger.Info("LanguageServerStarted",
|
|
"body_len", len(body))
|
|
// Return empty LanguageServerStartedResponse
|
|
return nil
|
|
}
|
|
|
|
func (m *MockExtensionServer) onHeartbeat(body []byte) []byte {
|
|
// Return empty HeartbeatResponse
|
|
return nil
|
|
}
|
|
|
|
func (m *MockExtensionServer) onGetSecretValue(body []byte) []byte {
|
|
// GetSecretValueRequest: key = field 1
|
|
key := decodeProtoString(body, 1)
|
|
m.logger.Debug("GetSecretValue", "key", key)
|
|
|
|
m.mu.RLock()
|
|
var token string
|
|
if info := m.currentTokenLocked(); info != nil {
|
|
token = info.AccessToken
|
|
}
|
|
m.mu.RUnlock()
|
|
|
|
// GetSecretValueResponse: value = field 1
|
|
if token != "" {
|
|
return encodeProtoString(1, token)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *MockExtensionServer) onStoreSecretValue(body []byte) []byte {
|
|
key := decodeProtoString(body, 1)
|
|
m.logger.Debug("StoreSecretValue", "key", key)
|
|
return nil
|
|
}
|
|
|
|
func (m *MockExtensionServer) onIsAgentManagerEnabled(body []byte) []byte {
|
|
// IsAgentManagerEnabledResponse: enabled = field 1 (bool)
|
|
return encodeProtoBool(1, false)
|
|
}
|
|
|
|
func (m *MockExtensionServer) onPushUnifiedStateSyncUpdate(body []byte) []byte {
|
|
// PushUnifiedStateSyncUpdateRequest: update = field 1 (UpdateRequest message)
|
|
// UpdateRequest: topic_name = field 1, applied_update = field 5, key = field 2
|
|
m.logger.Debug("PushUnifiedStateSyncUpdate", "body_len", len(body))
|
|
|
|
// Extract topic name from the embedded UpdateRequest
|
|
// The body is PushUnifiedStateSyncUpdateRequest, field 1 is the UpdateRequest
|
|
// We need to dig into the nested message to get topic_name
|
|
if m.onTrajectoryUpdate != nil {
|
|
// For now, just notify that an update was pushed
|
|
m.onTrajectoryUpdate("", "", body)
|
|
}
|
|
|
|
// Return empty PushUnifiedStateSyncUpdateResponse
|
|
return nil
|
|
}
|
|
|
|
func (m *MockExtensionServer) onRecordError(body []byte) []byte {
|
|
m.logger.Debug("RecordError", "body_len", len(body))
|
|
return nil
|
|
}
|
|
|
|
func (m *MockExtensionServer) onLogEvent(body []byte) []byte {
|
|
return nil
|
|
}
|
|
|
|
func (m *MockExtensionServer) onUpdateTrajectorySummaries(body []byte) []byte {
|
|
m.logger.Debug("UpdateCascadeTrajectorySummaries", "body_len", len(body))
|
|
return nil
|
|
}
|
|
|
|
func (m *MockExtensionServer) onDefault(body []byte) []byte {
|
|
return nil
|
|
}
|
|
|
|
// ============================================================
|
|
// Streaming RPC Handlers
|
|
// ============================================================
|
|
|
|
func (m *MockExtensionServer) onSubscribeStateSyncTopic(body []byte, w http.ResponseWriter, r *http.Request) {
|
|
// SubscribeToUnifiedStateSyncTopicRequest: topic = field 1
|
|
topic := decodeProtoString(body, 1)
|
|
m.logger.Info("SubscribeToUnifiedStateSyncTopic", "topic", topic)
|
|
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
m.logger.Error("ResponseWriter does not support Flush")
|
|
return
|
|
}
|
|
|
|
m.mu.Lock()
|
|
accountID := m.lastAccountID
|
|
subID := m.nextSubID
|
|
m.nextSubID++
|
|
sub := &stateSubscriber{
|
|
id: subID,
|
|
accountID: accountID,
|
|
topic: topic,
|
|
updates: make(chan []byte, 16),
|
|
}
|
|
if m.subscribers[topic] == nil {
|
|
m.subscribers[topic] = make(map[int]*stateSubscriber)
|
|
}
|
|
m.subscribers[topic][subID] = sub
|
|
|
|
// Build initial state based on topic
|
|
var topicData []byte
|
|
switch topic {
|
|
case "uss-oauth":
|
|
tokenInfo := m.tokenForAccountLocked(accountID)
|
|
if tokenInfo != nil {
|
|
topicData = buildUSSTopicWithOAuth(tokenInfo.AccessToken, tokenInfo.RefreshToken, tokenInfo.ExpiresAt)
|
|
} else {
|
|
topicData = buildEmptyTopic()
|
|
}
|
|
case "uss-modelCredits":
|
|
creditsInfo := m.creditsForAccountLocked(accountID)
|
|
if creditsInfo != nil {
|
|
topicData = buildUSSTopicWithModelCredits(creditsInfo)
|
|
} else {
|
|
topicData = buildEmptyTopic()
|
|
}
|
|
default:
|
|
// For all other topics (browserPreferences, enterprisePreferences, etc.),
|
|
// return empty topic data.
|
|
topicData = buildEmptyTopic()
|
|
}
|
|
m.mu.Unlock()
|
|
defer func() {
|
|
m.mu.Lock()
|
|
if topicSubs := m.subscribers[topic]; topicSubs != nil {
|
|
delete(topicSubs, subID)
|
|
if len(topicSubs) == 0 {
|
|
delete(m.subscribers, topic)
|
|
}
|
|
}
|
|
m.mu.Unlock()
|
|
}()
|
|
|
|
// Send initial state as envelope-framed message
|
|
initialUpdate := buildInitialStateUpdate(topicData)
|
|
frame := connectEnvelope(0x00, initialUpdate)
|
|
w.Write(frame)
|
|
flusher.Flush()
|
|
|
|
for {
|
|
select {
|
|
case <-r.Context().Done():
|
|
m.logger.Debug("SubscribeToUnifiedStateSyncTopic stream closed", "topic", topic)
|
|
return
|
|
case update := <-sub.updates:
|
|
if len(update) == 0 {
|
|
continue
|
|
}
|
|
if _, err := w.Write(connectEnvelope(0x00, update)); err != nil {
|
|
m.logger.Debug("SubscribeToUnifiedStateSyncTopic write failed", "topic", topic, "err", err)
|
|
return
|
|
}
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *MockExtensionServer) onExecuteCommand(body []byte, w http.ResponseWriter, r *http.Request) {
|
|
m.logger.Debug("ExecuteCommand (mock)", "body_len", len(body))
|
|
// Send end-of-stream immediately — we don't execute commands
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
return
|
|
}
|
|
w.Write(connectEndOfStream())
|
|
flusher.Flush()
|
|
}
|
|
|
|
// ============================================================
|
|
// Catch-all handler
|
|
// ============================================================
|
|
|
|
func (m *MockExtensionServer) handleCatchAll(w http.ResponseWriter, r *http.Request) {
|
|
if !m.checkCSRF(w, r) {
|
|
return
|
|
}
|
|
m.logger.Debug("unhandled RPC (returning empty proto)", "path", r.URL.Path, "method", r.Method)
|
|
|
|
// Drain request body
|
|
io.ReadAll(r.Body)
|
|
|
|
// Determine if this is likely a unary or streaming request based on content-type.
|
|
ct := r.Header.Get("Content-Type")
|
|
if strings.Contains(ct, "connect+") {
|
|
// Could be streaming — respond with unary proto to be safe
|
|
// (unary Connect requests can also use connect+ prefix in some client impls)
|
|
w.Header().Set("Content-Type", "application/proto")
|
|
} else {
|
|
w.Header().Set("Content-Type", "application/proto")
|
|
}
|
|
w.WriteHeader(200)
|
|
}
|