swarm repositories / source
summaryrefslogtreecommitdiff
path: root/crates/ra-mcp-engine/src/lsp_transport.rs
diff options
context:
space:
mode:
authormain <main@swarm.moe>2026-03-19 15:49:41 -0400
committermain <main@swarm.moe>2026-03-19 15:49:41 -0400
commitfa1bd32800b65aab31ea732dd240261b4047522c (patch)
tree2fd08af6f36b8beb3c7c941990becc1a0a091d62 /crates/ra-mcp-engine/src/lsp_transport.rs
downloadadequate-rust-mcp-fa1bd32800b65aab31ea732dd240261b4047522c.zip
Release adequate-rust-mcp 1.0.0v1.0.0
Diffstat (limited to 'crates/ra-mcp-engine/src/lsp_transport.rs')
-rw-r--r--crates/ra-mcp-engine/src/lsp_transport.rs717
1 files changed, 717 insertions, 0 deletions
diff --git a/crates/ra-mcp-engine/src/lsp_transport.rs b/crates/ra-mcp-engine/src/lsp_transport.rs
new file mode 100644
index 0000000..c47d4f2
--- /dev/null
+++ b/crates/ra-mcp-engine/src/lsp_transport.rs
@@ -0,0 +1,717 @@
+use crate::config::EngineConfig;
+use ra_mcp_domain::{
+ fault::{Fault, FaultClass, FaultCode, FaultDetail},
+ types::Generation,
+};
+use serde::{Deserialize, Serialize};
+use serde_json::{Value, json};
+use std::{
+ collections::HashMap,
+ io,
+ process::Stdio,
+ sync::{
+ Arc,
+ atomic::{AtomicU64, Ordering},
+ },
+ time::Duration,
+};
+use tokio::{
+ io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
+ process::{Child, ChildStdin, ChildStdout, Command},
+ sync::{Mutex, oneshot, watch},
+ task::JoinHandle,
+};
+use tracing::{debug, warn};
+use url::Url;
+
+#[derive(Debug, Clone)]
+pub(crate) struct WorkerHandle {
+ generation: Generation,
+ child: Arc<Mutex<Child>>,
+ writer: Arc<Mutex<ChildStdin>>,
+ pending: Arc<Mutex<HashMap<u64, oneshot::Sender<PendingOutcome>>>>,
+ next_request_id: Arc<AtomicU64>,
+ terminal_fault_rx: watch::Receiver<Option<Fault>>,
+ reader_task: Arc<Mutex<Option<JoinHandle<()>>>>,
+ stderr_task: Arc<Mutex<Option<JoinHandle<()>>>>,
+}
+
+#[derive(Debug)]
+enum PendingOutcome {
+ Result(Value),
+ ResponseError(RpcErrorPayload),
+ TransportFault(Fault),
+}
+
+#[derive(Debug, Clone, Deserialize)]
+pub(crate) struct RpcErrorPayload {
+ pub(crate) code: i64,
+ pub(crate) message: String,
+ pub(crate) data: Option<Value>,
+}
+
+#[derive(Debug)]
+pub(crate) enum WorkerRequestError {
+ Fault(Fault),
+ Response(RpcErrorPayload),
+}
+
+impl WorkerHandle {
+ pub(crate) fn terminal_fault(&self) -> Option<Fault> {
+ self.terminal_fault_rx.borrow().clone()
+ }
+
+ pub(crate) async fn send_notification(
+ &self,
+ method: &'static str,
+ params: &impl Serialize,
+ ) -> Result<(), Fault> {
+ let payload = json!({
+ "jsonrpc": "2.0",
+ "method": method,
+ "params": params,
+ });
+ let mut writer = self.writer.lock().await;
+ write_frame(&mut writer, &payload).await.map_err(|error| {
+ classify_io_fault(
+ self.generation,
+ FaultClass::Transport,
+ "failed to write notification",
+ error,
+ )
+ })
+ }
+
+ pub(crate) async fn send_request(
+ &self,
+ method: &'static str,
+ params: &impl Serialize,
+ timeout: Duration,
+ ) -> Result<Value, WorkerRequestError> {
+ let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
+ let (sender, receiver) = oneshot::channel::<PendingOutcome>();
+ {
+ let mut pending = self.pending.lock().await;
+ let previous = pending.insert(request_id, sender);
+ if let Some(previous_sender) = previous {
+ drop(previous_sender);
+ }
+ }
+
+ let payload = json!({
+ "jsonrpc": "2.0",
+ "id": request_id,
+ "method": method,
+ "params": params,
+ });
+
+ let write_result = {
+ let mut writer = self.writer.lock().await;
+ write_frame(&mut writer, &payload).await
+ };
+
+ if let Err(error) = write_result {
+ let mut pending = self.pending.lock().await;
+ let removed = pending.remove(&request_id);
+ if let Some(sender) = removed {
+ drop(sender);
+ }
+ return Err(WorkerRequestError::Fault(classify_io_fault(
+ self.generation,
+ FaultClass::Transport,
+ "failed to write request",
+ error,
+ )));
+ }
+
+ match tokio::time::timeout(timeout, receiver).await {
+ Ok(Ok(PendingOutcome::Result(value))) => Ok(value),
+ Ok(Ok(PendingOutcome::ResponseError(error))) => {
+ Err(WorkerRequestError::Response(error))
+ }
+ Ok(Ok(PendingOutcome::TransportFault(fault))) => Err(WorkerRequestError::Fault(fault)),
+ Ok(Err(_closed)) => Err(WorkerRequestError::Fault(Fault::new(
+ self.generation,
+ FaultClass::Transport,
+ FaultCode::UnexpectedEof,
+ FaultDetail::new("response channel closed before result"),
+ ))),
+ Err(_elapsed) => {
+ let mut pending = self.pending.lock().await;
+ let removed = pending.remove(&request_id);
+ if let Some(sender) = removed {
+ drop(sender);
+ }
+ Err(WorkerRequestError::Fault(Fault::new(
+ self.generation,
+ FaultClass::Timeout,
+ FaultCode::RequestTimedOut,
+ FaultDetail::new(format!("request timed out for method {method}")),
+ )))
+ }
+ }
+ }
+
+ pub(crate) async fn terminate(&self) {
+ let mut child = self.child.lock().await;
+ if child.id().is_some()
+ && let Err(error) = child.kill().await
+ {
+ debug!(
+ generation = self.generation.get(),
+ "failed to kill rust-analyzer process cleanly: {error}"
+ );
+ }
+ if let Err(error) = child.wait().await {
+ debug!(
+ generation = self.generation.get(),
+ "failed to wait rust-analyzer process cleanly: {error}"
+ );
+ }
+
+ if let Some(task) = self.reader_task.lock().await.take() {
+ task.abort();
+ }
+ if let Some(task) = self.stderr_task.lock().await.take() {
+ task.abort();
+ }
+ }
+}
+
+pub(crate) async fn spawn_worker(
+ config: &EngineConfig,
+ generation: Generation,
+) -> Result<WorkerHandle, Fault> {
+ let mut command = Command::new(&config.rust_analyzer_binary);
+ let _args = command.args(&config.rust_analyzer_args);
+ let _envs = command.envs(config.rust_analyzer_env.iter().cloned());
+ let _configured = command
+ .current_dir(config.workspace_root.as_path())
+ .stdin(Stdio::piped())
+ .stdout(Stdio::piped())
+ .stderr(Stdio::piped());
+
+ let mut child = command.spawn().map_err(|error| {
+ classify_io_fault(
+ generation,
+ FaultClass::Process,
+ "failed to spawn rust-analyzer",
+ error,
+ )
+ })?;
+
+ let stdin = child.stdin.take().ok_or_else(|| {
+ Fault::new(
+ generation,
+ FaultClass::Process,
+ FaultCode::SpawnFailed,
+ FaultDetail::new("missing stdin pipe from rust-analyzer process"),
+ )
+ })?;
+ let stdout = child.stdout.take().ok_or_else(|| {
+ Fault::new(
+ generation,
+ FaultClass::Process,
+ FaultCode::SpawnFailed,
+ FaultDetail::new("missing stdout pipe from rust-analyzer process"),
+ )
+ })?;
+ let stderr = child.stderr.take().ok_or_else(|| {
+ Fault::new(
+ generation,
+ FaultClass::Process,
+ FaultCode::SpawnFailed,
+ FaultDetail::new("missing stderr pipe from rust-analyzer process"),
+ )
+ })?;
+
+ let child = Arc::new(Mutex::new(child));
+ let writer = Arc::new(Mutex::new(stdin));
+ let pending = Arc::new(Mutex::new(
+ HashMap::<u64, oneshot::Sender<PendingOutcome>>::new(),
+ ));
+ let next_request_id = Arc::new(AtomicU64::new(1));
+ let (terminal_fault_tx, terminal_fault_rx) = watch::channel(None::<Fault>);
+
+ let reader_task = {
+ let pending = Arc::clone(&pending);
+ let terminal_fault_tx = terminal_fault_tx.clone();
+ tokio::spawn(async move {
+ read_stdout_loop(generation, stdout, pending, terminal_fault_tx).await;
+ })
+ };
+
+ let stderr_task = tokio::spawn(async move {
+ stream_stderr(generation, stderr).await;
+ });
+
+ let handle = WorkerHandle {
+ generation,
+ child,
+ writer,
+ pending,
+ next_request_id,
+ terminal_fault_rx,
+ reader_task: Arc::new(Mutex::new(Some(reader_task))),
+ stderr_task: Arc::new(Mutex::new(Some(stderr_task))),
+ };
+
+ let initialize_params = build_initialize_params(config)?;
+ let startup = handle
+ .send_request("initialize", &initialize_params, config.startup_timeout)
+ .await;
+ if let Err(error) = startup {
+ handle.terminate().await;
+ return Err(map_worker_request_error(generation, error));
+ }
+
+ let initialized_params = json!({});
+ let initialized_result = handle
+ .send_notification("initialized", &initialized_params)
+ .await
+ .map_err(|fault| {
+ handle_fault_notification(generation, "initialized notification failed", fault)
+ });
+ if let Err(fault) = initialized_result {
+ handle.terminate().await;
+ return Err(fault);
+ }
+
+ Ok(handle)
+}
+
+fn map_worker_request_error(generation: Generation, error: WorkerRequestError) -> Fault {
+ match error {
+ WorkerRequestError::Fault(fault) => fault,
+ WorkerRequestError::Response(response) => Fault::new(
+ generation,
+ FaultClass::Protocol,
+ FaultCode::InvalidFrame,
+ FaultDetail::new(format!(
+ "initialize returned LSP error {}: {}",
+ response.code, response.message
+ )),
+ ),
+ }
+}
+
+fn handle_fault_notification(generation: Generation, context: &'static str, fault: Fault) -> Fault {
+ let detail = FaultDetail::new(format!("{context}: {}", fault.detail.message));
+ Fault::new(generation, fault.class, fault.code, detail)
+}
+
+fn build_initialize_params(config: &EngineConfig) -> Result<Value, Fault> {
+ let root_uri = Url::from_directory_path(config.workspace_root.as_path()).map_err(|()| {
+ Fault::new(
+ Generation::genesis(),
+ FaultClass::Protocol,
+ FaultCode::InvalidFrame,
+ FaultDetail::new("workspace root cannot be represented as file URI"),
+ )
+ })?;
+ let folder_name = config
+ .workspace_root
+ .as_path()
+ .file_name()
+ .and_then(|value| value.to_str())
+ .unwrap_or("workspace")
+ .to_owned();
+ let root_uri_string = root_uri.to_string();
+ Ok(json!({
+ "processId": std::process::id(),
+ "rootUri": root_uri_string.clone(),
+ "capabilities": build_client_capabilities(),
+ "workspaceFolders": [{
+ "uri": root_uri_string,
+ "name": folder_name,
+ }],
+ "trace": "off",
+ "clientInfo": {
+ "name": "adequate-rust-mcp",
+ "version": env!("CARGO_PKG_VERSION"),
+ }
+ }))
+}
+
+fn build_client_capabilities() -> Value {
+ let symbol_kind_values = (1_u32..=26).collect::<Vec<_>>();
+ json!({
+ "workspace": {
+ "applyEdit": true,
+ "workspaceEdit": {
+ "documentChanges": true,
+ "resourceOperations": ["create", "rename", "delete"],
+ },
+ "symbol": {
+ "dynamicRegistration": false,
+ "resolveSupport": {
+ "properties": ["location.range", "containerName"],
+ },
+ },
+ "diagnostics": {
+ "refreshSupport": true,
+ },
+ "executeCommand": {
+ "dynamicRegistration": false,
+ },
+ "workspaceFolders": true,
+ "configuration": true,
+ },
+ "textDocument": {
+ "synchronization": {
+ "dynamicRegistration": false,
+ "willSave": false,
+ "didSave": true,
+ "willSaveWaitUntil": false,
+ },
+ "hover": {
+ "dynamicRegistration": false,
+ "contentFormat": ["markdown", "plaintext"],
+ },
+ "definition": {
+ "dynamicRegistration": false,
+ "linkSupport": true,
+ },
+ "declaration": {
+ "dynamicRegistration": false,
+ "linkSupport": true,
+ },
+ "typeDefinition": {
+ "dynamicRegistration": false,
+ "linkSupport": true,
+ },
+ "implementation": {
+ "dynamicRegistration": false,
+ "linkSupport": true,
+ },
+ "references": {
+ "dynamicRegistration": false,
+ },
+ "documentHighlight": {
+ "dynamicRegistration": false,
+ },
+ "documentSymbol": {
+ "dynamicRegistration": false,
+ "hierarchicalDocumentSymbolSupport": true,
+ "symbolKind": {
+ "valueSet": symbol_kind_values,
+ },
+ },
+ "completion": {
+ "dynamicRegistration": false,
+ "contextSupport": true,
+ "completionItem": {
+ "snippetSupport": true,
+ "documentationFormat": ["markdown", "plaintext"],
+ "resolveSupport": {
+ "properties": ["documentation", "detail", "additionalTextEdits"],
+ },
+ },
+ },
+ "signatureHelp": {
+ "dynamicRegistration": false,
+ },
+ "codeAction": {
+ "dynamicRegistration": false,
+ "isPreferredSupport": true,
+ "codeActionLiteralSupport": {
+ "codeActionKind": {
+ "valueSet": [
+ "",
+ "quickfix",
+ "refactor",
+ "refactor.extract",
+ "refactor.inline",
+ "refactor.rewrite",
+ "source",
+ "source.organizeImports",
+ ],
+ },
+ },
+ },
+ "codeLens": {
+ "dynamicRegistration": false,
+ },
+ "documentLink": {
+ "dynamicRegistration": false,
+ "tooltipSupport": true,
+ },
+ "colorProvider": {
+ "dynamicRegistration": false,
+ },
+ "linkedEditingRange": {
+ "dynamicRegistration": false,
+ },
+ "rename": {
+ "dynamicRegistration": false,
+ "prepareSupport": true,
+ },
+ "typeHierarchy": {
+ "dynamicRegistration": false,
+ },
+ "inlineValue": {
+ "dynamicRegistration": false,
+ },
+ "moniker": {
+ "dynamicRegistration": false,
+ },
+ "diagnostic": {
+ "dynamicRegistration": false,
+ },
+ "documentFormatting": {
+ "dynamicRegistration": false,
+ },
+ "documentRangeFormatting": {
+ "dynamicRegistration": false,
+ },
+ "documentOnTypeFormatting": {
+ "dynamicRegistration": false,
+ },
+ "foldingRange": {
+ "dynamicRegistration": false,
+ },
+ "selectionRange": {
+ "dynamicRegistration": false,
+ },
+ "inlayHint": {
+ "dynamicRegistration": false,
+ "resolveSupport": {
+ "properties": ["tooltip", "textEdits", "label.tooltip", "label.location", "label.command"],
+ },
+ },
+ "semanticTokens": {
+ "dynamicRegistration": false,
+ "requests": {
+ "full": {
+ "delta": true,
+ },
+ "range": true,
+ },
+ "tokenTypes": [
+ "namespace", "type", "class", "enum", "interface", "struct", "typeParameter",
+ "parameter", "variable", "property", "enumMember", "event", "function",
+ "method", "macro", "keyword", "modifier", "comment", "string", "number",
+ "regexp", "operator",
+ ],
+ "tokenModifiers": [
+ "declaration", "definition", "readonly", "static", "deprecated", "abstract",
+ "async", "modification", "documentation", "defaultLibrary",
+ ],
+ "formats": ["relative"],
+ "multilineTokenSupport": true,
+ "overlappingTokenSupport": true,
+ },
+ "publishDiagnostics": {
+ "relatedInformation": true,
+ },
+ },
+ "window": {
+ "workDoneProgress": true,
+ },
+ "general": {
+ "positionEncodings": ["utf-8", "utf-16"],
+ },
+ })
+}
+
+async fn stream_stderr(generation: Generation, stderr: tokio::process::ChildStderr) {
+ let mut reader = BufReader::new(stderr).lines();
+ loop {
+ match reader.next_line().await {
+ Ok(Some(line)) => {
+ debug!(
+ generation = generation.get(),
+ "rust-analyzer stderr: {line}"
+ );
+ }
+ Ok(None) => break,
+ Err(error) => {
+ debug!(
+ generation = generation.get(),
+ "rust-analyzer stderr stream failed: {error}"
+ );
+ break;
+ }
+ }
+ }
+}
+
+async fn read_stdout_loop(
+ generation: Generation,
+ stdout: ChildStdout,
+ pending: Arc<Mutex<HashMap<u64, oneshot::Sender<PendingOutcome>>>>,
+ terminal_fault_tx: watch::Sender<Option<Fault>>,
+) {
+ let mut reader = BufReader::new(stdout);
+ loop {
+ match read_frame(&mut reader).await {
+ Ok(frame) => {
+ if let Err(fault) = dispatch_frame(generation, &pending, &frame).await {
+ emit_terminal_fault(&terminal_fault_tx, &pending, fault).await;
+ break;
+ }
+ }
+ Err(error) => {
+ let fault = classify_io_fault(
+ generation,
+ FaultClass::Transport,
+ "failed to read frame",
+ error,
+ );
+ emit_terminal_fault(&terminal_fault_tx, &pending, fault).await;
+ break;
+ }
+ }
+ }
+}
+
+async fn emit_terminal_fault(
+ terminal_fault_tx: &watch::Sender<Option<Fault>>,
+ pending: &Arc<Mutex<HashMap<u64, oneshot::Sender<PendingOutcome>>>>,
+ fault: Fault,
+) {
+ if let Err(error) = terminal_fault_tx.send(Some(fault.clone())) {
+ warn!("failed to publish terminal fault: {error}");
+ }
+ let mut pending_guard = pending.lock().await;
+ for sender in pending_guard.drain().map(|(_id, sender)| sender) {
+ if let Err(outcome) = sender.send(PendingOutcome::TransportFault(fault.clone())) {
+ drop(outcome);
+ }
+ }
+}
+
+async fn dispatch_frame(
+ generation: Generation,
+ pending: &Arc<Mutex<HashMap<u64, oneshot::Sender<PendingOutcome>>>>,
+ frame: &[u8],
+) -> Result<(), Fault> {
+ let value: Value = serde_json::from_slice(frame).map_err(|error| {
+ Fault::new(
+ generation,
+ FaultClass::Protocol,
+ FaultCode::InvalidJson,
+ FaultDetail::new(format!("failed to deserialize JSON-RPC frame: {error}")),
+ )
+ })?;
+
+ let response_id = value.get("id").and_then(Value::as_u64);
+ let Some(response_id) = response_id else {
+ return Ok(());
+ };
+
+ let mut pending_guard = pending.lock().await;
+ let Some(sender) = pending_guard.remove(&response_id) else {
+ warn!(
+ generation = generation.get(),
+ response_id, "received response for unknown request id"
+ );
+ return Ok(());
+ };
+ drop(pending_guard);
+
+ if let Some(result) = value.get("result") {
+ if let Err(outcome) = sender.send(PendingOutcome::Result(result.clone())) {
+ drop(outcome);
+ }
+ return Ok(());
+ }
+
+ if let Some(error_value) = value.get("error") {
+ let error: RpcErrorPayload =
+ serde_json::from_value(error_value.clone()).map_err(|error| {
+ Fault::new(
+ generation,
+ FaultClass::Protocol,
+ FaultCode::InvalidJson,
+ FaultDetail::new(format!(
+ "failed to deserialize JSON-RPC error payload: {error}"
+ )),
+ )
+ })?;
+ if let Err(outcome) = sender.send(PendingOutcome::ResponseError(error)) {
+ drop(outcome);
+ }
+ return Ok(());
+ }
+
+ Err(Fault::new(
+ generation,
+ FaultClass::Protocol,
+ FaultCode::InvalidFrame,
+ FaultDetail::new("response frame missing both result and error"),
+ ))
+}
+
+async fn read_frame(reader: &mut BufReader<ChildStdout>) -> Result<Vec<u8>, io::Error> {
+ let mut content_length = None::<usize>;
+ loop {
+ let mut header_line = String::new();
+ let bytes_read = reader.read_line(&mut header_line).await?;
+ if bytes_read == 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ "EOF while reading headers",
+ ));
+ }
+
+ if header_line == "\r\n" || header_line == "\n" {
+ break;
+ }
+
+ let trimmed = header_line.trim_end_matches(['\r', '\n']);
+ if let Some(length) = trimmed.strip_prefix("Content-Length:") {
+ let parsed = length.trim().parse::<usize>().map_err(|parse_error| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("invalid Content-Length header: {parse_error}"),
+ )
+ })?;
+ content_length = Some(parsed);
+ }
+ }
+
+ let length = content_length.ok_or_else(|| {
+ io::Error::new(io::ErrorKind::InvalidData, "missing Content-Length header")
+ })?;
+
+ let mut payload = vec![0_u8; length];
+ let _bytes_read = reader.read_exact(&mut payload).await?;
+ Ok(payload)
+}
+
+async fn write_frame(writer: &mut ChildStdin, value: &Value) -> Result<(), io::Error> {
+ let payload = serde_json::to_vec(value).map_err(|error| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("failed to serialize JSON-RPC payload: {error}"),
+ )
+ })?;
+ let header = format!("Content-Length: {}\r\n\r\n", payload.len());
+ writer.write_all(header.as_bytes()).await?;
+ writer.write_all(&payload).await?;
+ writer.flush().await?;
+ Ok(())
+}
+
+fn classify_io_fault(
+ generation: Generation,
+ class: FaultClass,
+ context: &'static str,
+ error: io::Error,
+) -> Fault {
+ let code = match error.kind() {
+ io::ErrorKind::BrokenPipe => FaultCode::BrokenPipe,
+ io::ErrorKind::UnexpectedEof => FaultCode::UnexpectedEof,
+ _ => match class {
+ FaultClass::Process => FaultCode::SpawnFailed,
+ _ => FaultCode::InvalidFrame,
+ },
+ };
+ Fault::new(
+ generation,
+ class,
+ code,
+ FaultDetail::new(format!("{context}: {error}")),
+ )
+}