diff options
Diffstat (limited to 'crates/ra-mcp-engine/src/lsp_transport.rs')
| -rw-r--r-- | crates/ra-mcp-engine/src/lsp_transport.rs | 717 |
1 files changed, 717 insertions, 0 deletions
diff --git a/crates/ra-mcp-engine/src/lsp_transport.rs b/crates/ra-mcp-engine/src/lsp_transport.rs new file mode 100644 index 0000000..c47d4f2 --- /dev/null +++ b/crates/ra-mcp-engine/src/lsp_transport.rs @@ -0,0 +1,717 @@ +use crate::config::EngineConfig; +use ra_mcp_domain::{ + fault::{Fault, FaultClass, FaultCode, FaultDetail}, + types::Generation, +}; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; +use std::{ + collections::HashMap, + io, + process::Stdio, + sync::{ + Arc, + atomic::{AtomicU64, Ordering}, + }, + time::Duration, +}; +use tokio::{ + io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}, + process::{Child, ChildStdin, ChildStdout, Command}, + sync::{Mutex, oneshot, watch}, + task::JoinHandle, +}; +use tracing::{debug, warn}; +use url::Url; + +#[derive(Debug, Clone)] +pub(crate) struct WorkerHandle { + generation: Generation, + child: Arc<Mutex<Child>>, + writer: Arc<Mutex<ChildStdin>>, + pending: Arc<Mutex<HashMap<u64, oneshot::Sender<PendingOutcome>>>>, + next_request_id: Arc<AtomicU64>, + terminal_fault_rx: watch::Receiver<Option<Fault>>, + reader_task: Arc<Mutex<Option<JoinHandle<()>>>>, + stderr_task: Arc<Mutex<Option<JoinHandle<()>>>>, +} + +#[derive(Debug)] +enum PendingOutcome { + Result(Value), + ResponseError(RpcErrorPayload), + TransportFault(Fault), +} + +#[derive(Debug, Clone, Deserialize)] +pub(crate) struct RpcErrorPayload { + pub(crate) code: i64, + pub(crate) message: String, + pub(crate) data: Option<Value>, +} + +#[derive(Debug)] +pub(crate) enum WorkerRequestError { + Fault(Fault), + Response(RpcErrorPayload), +} + +impl WorkerHandle { + pub(crate) fn terminal_fault(&self) -> Option<Fault> { + self.terminal_fault_rx.borrow().clone() + } + + pub(crate) async fn send_notification( + &self, + method: &'static str, + params: &impl Serialize, + ) -> Result<(), Fault> { + let payload = json!({ + "jsonrpc": "2.0", + "method": method, + "params": params, + }); + let mut writer = self.writer.lock().await; + write_frame(&mut writer, &payload).await.map_err(|error| { + classify_io_fault( + self.generation, + FaultClass::Transport, + "failed to write notification", + error, + ) + }) + } + + pub(crate) async fn send_request( + &self, + method: &'static str, + params: &impl Serialize, + timeout: Duration, + ) -> Result<Value, WorkerRequestError> { + let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed); + let (sender, receiver) = oneshot::channel::<PendingOutcome>(); + { + let mut pending = self.pending.lock().await; + let previous = pending.insert(request_id, sender); + if let Some(previous_sender) = previous { + drop(previous_sender); + } + } + + let payload = json!({ + "jsonrpc": "2.0", + "id": request_id, + "method": method, + "params": params, + }); + + let write_result = { + let mut writer = self.writer.lock().await; + write_frame(&mut writer, &payload).await + }; + + if let Err(error) = write_result { + let mut pending = self.pending.lock().await; + let removed = pending.remove(&request_id); + if let Some(sender) = removed { + drop(sender); + } + return Err(WorkerRequestError::Fault(classify_io_fault( + self.generation, + FaultClass::Transport, + "failed to write request", + error, + ))); + } + + match tokio::time::timeout(timeout, receiver).await { + Ok(Ok(PendingOutcome::Result(value))) => Ok(value), + Ok(Ok(PendingOutcome::ResponseError(error))) => { + Err(WorkerRequestError::Response(error)) + } + Ok(Ok(PendingOutcome::TransportFault(fault))) => Err(WorkerRequestError::Fault(fault)), + Ok(Err(_closed)) => Err(WorkerRequestError::Fault(Fault::new( + self.generation, + FaultClass::Transport, + FaultCode::UnexpectedEof, + FaultDetail::new("response channel closed before result"), + ))), + Err(_elapsed) => { + let mut pending = self.pending.lock().await; + let removed = pending.remove(&request_id); + if let Some(sender) = removed { + drop(sender); + } + Err(WorkerRequestError::Fault(Fault::new( + self.generation, + FaultClass::Timeout, + FaultCode::RequestTimedOut, + FaultDetail::new(format!("request timed out for method {method}")), + ))) + } + } + } + + pub(crate) async fn terminate(&self) { + let mut child = self.child.lock().await; + if child.id().is_some() + && let Err(error) = child.kill().await + { + debug!( + generation = self.generation.get(), + "failed to kill rust-analyzer process cleanly: {error}" + ); + } + if let Err(error) = child.wait().await { + debug!( + generation = self.generation.get(), + "failed to wait rust-analyzer process cleanly: {error}" + ); + } + + if let Some(task) = self.reader_task.lock().await.take() { + task.abort(); + } + if let Some(task) = self.stderr_task.lock().await.take() { + task.abort(); + } + } +} + +pub(crate) async fn spawn_worker( + config: &EngineConfig, + generation: Generation, +) -> Result<WorkerHandle, Fault> { + let mut command = Command::new(&config.rust_analyzer_binary); + let _args = command.args(&config.rust_analyzer_args); + let _envs = command.envs(config.rust_analyzer_env.iter().cloned()); + let _configured = command + .current_dir(config.workspace_root.as_path()) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + let mut child = command.spawn().map_err(|error| { + classify_io_fault( + generation, + FaultClass::Process, + "failed to spawn rust-analyzer", + error, + ) + })?; + + let stdin = child.stdin.take().ok_or_else(|| { + Fault::new( + generation, + FaultClass::Process, + FaultCode::SpawnFailed, + FaultDetail::new("missing stdin pipe from rust-analyzer process"), + ) + })?; + let stdout = child.stdout.take().ok_or_else(|| { + Fault::new( + generation, + FaultClass::Process, + FaultCode::SpawnFailed, + FaultDetail::new("missing stdout pipe from rust-analyzer process"), + ) + })?; + let stderr = child.stderr.take().ok_or_else(|| { + Fault::new( + generation, + FaultClass::Process, + FaultCode::SpawnFailed, + FaultDetail::new("missing stderr pipe from rust-analyzer process"), + ) + })?; + + let child = Arc::new(Mutex::new(child)); + let writer = Arc::new(Mutex::new(stdin)); + let pending = Arc::new(Mutex::new( + HashMap::<u64, oneshot::Sender<PendingOutcome>>::new(), + )); + let next_request_id = Arc::new(AtomicU64::new(1)); + let (terminal_fault_tx, terminal_fault_rx) = watch::channel(None::<Fault>); + + let reader_task = { + let pending = Arc::clone(&pending); + let terminal_fault_tx = terminal_fault_tx.clone(); + tokio::spawn(async move { + read_stdout_loop(generation, stdout, pending, terminal_fault_tx).await; + }) + }; + + let stderr_task = tokio::spawn(async move { + stream_stderr(generation, stderr).await; + }); + + let handle = WorkerHandle { + generation, + child, + writer, + pending, + next_request_id, + terminal_fault_rx, + reader_task: Arc::new(Mutex::new(Some(reader_task))), + stderr_task: Arc::new(Mutex::new(Some(stderr_task))), + }; + + let initialize_params = build_initialize_params(config)?; + let startup = handle + .send_request("initialize", &initialize_params, config.startup_timeout) + .await; + if let Err(error) = startup { + handle.terminate().await; + return Err(map_worker_request_error(generation, error)); + } + + let initialized_params = json!({}); + let initialized_result = handle + .send_notification("initialized", &initialized_params) + .await + .map_err(|fault| { + handle_fault_notification(generation, "initialized notification failed", fault) + }); + if let Err(fault) = initialized_result { + handle.terminate().await; + return Err(fault); + } + + Ok(handle) +} + +fn map_worker_request_error(generation: Generation, error: WorkerRequestError) -> Fault { + match error { + WorkerRequestError::Fault(fault) => fault, + WorkerRequestError::Response(response) => Fault::new( + generation, + FaultClass::Protocol, + FaultCode::InvalidFrame, + FaultDetail::new(format!( + "initialize returned LSP error {}: {}", + response.code, response.message + )), + ), + } +} + +fn handle_fault_notification(generation: Generation, context: &'static str, fault: Fault) -> Fault { + let detail = FaultDetail::new(format!("{context}: {}", fault.detail.message)); + Fault::new(generation, fault.class, fault.code, detail) +} + +fn build_initialize_params(config: &EngineConfig) -> Result<Value, Fault> { + let root_uri = Url::from_directory_path(config.workspace_root.as_path()).map_err(|()| { + Fault::new( + Generation::genesis(), + FaultClass::Protocol, + FaultCode::InvalidFrame, + FaultDetail::new("workspace root cannot be represented as file URI"), + ) + })?; + let folder_name = config + .workspace_root + .as_path() + .file_name() + .and_then(|value| value.to_str()) + .unwrap_or("workspace") + .to_owned(); + let root_uri_string = root_uri.to_string(); + Ok(json!({ + "processId": std::process::id(), + "rootUri": root_uri_string.clone(), + "capabilities": build_client_capabilities(), + "workspaceFolders": [{ + "uri": root_uri_string, + "name": folder_name, + }], + "trace": "off", + "clientInfo": { + "name": "adequate-rust-mcp", + "version": env!("CARGO_PKG_VERSION"), + } + })) +} + +fn build_client_capabilities() -> Value { + let symbol_kind_values = (1_u32..=26).collect::<Vec<_>>(); + json!({ + "workspace": { + "applyEdit": true, + "workspaceEdit": { + "documentChanges": true, + "resourceOperations": ["create", "rename", "delete"], + }, + "symbol": { + "dynamicRegistration": false, + "resolveSupport": { + "properties": ["location.range", "containerName"], + }, + }, + "diagnostics": { + "refreshSupport": true, + }, + "executeCommand": { + "dynamicRegistration": false, + }, + "workspaceFolders": true, + "configuration": true, + }, + "textDocument": { + "synchronization": { + "dynamicRegistration": false, + "willSave": false, + "didSave": true, + "willSaveWaitUntil": false, + }, + "hover": { + "dynamicRegistration": false, + "contentFormat": ["markdown", "plaintext"], + }, + "definition": { + "dynamicRegistration": false, + "linkSupport": true, + }, + "declaration": { + "dynamicRegistration": false, + "linkSupport": true, + }, + "typeDefinition": { + "dynamicRegistration": false, + "linkSupport": true, + }, + "implementation": { + "dynamicRegistration": false, + "linkSupport": true, + }, + "references": { + "dynamicRegistration": false, + }, + "documentHighlight": { + "dynamicRegistration": false, + }, + "documentSymbol": { + "dynamicRegistration": false, + "hierarchicalDocumentSymbolSupport": true, + "symbolKind": { + "valueSet": symbol_kind_values, + }, + }, + "completion": { + "dynamicRegistration": false, + "contextSupport": true, + "completionItem": { + "snippetSupport": true, + "documentationFormat": ["markdown", "plaintext"], + "resolveSupport": { + "properties": ["documentation", "detail", "additionalTextEdits"], + }, + }, + }, + "signatureHelp": { + "dynamicRegistration": false, + }, + "codeAction": { + "dynamicRegistration": false, + "isPreferredSupport": true, + "codeActionLiteralSupport": { + "codeActionKind": { + "valueSet": [ + "", + "quickfix", + "refactor", + "refactor.extract", + "refactor.inline", + "refactor.rewrite", + "source", + "source.organizeImports", + ], + }, + }, + }, + "codeLens": { + "dynamicRegistration": false, + }, + "documentLink": { + "dynamicRegistration": false, + "tooltipSupport": true, + }, + "colorProvider": { + "dynamicRegistration": false, + }, + "linkedEditingRange": { + "dynamicRegistration": false, + }, + "rename": { + "dynamicRegistration": false, + "prepareSupport": true, + }, + "typeHierarchy": { + "dynamicRegistration": false, + }, + "inlineValue": { + "dynamicRegistration": false, + }, + "moniker": { + "dynamicRegistration": false, + }, + "diagnostic": { + "dynamicRegistration": false, + }, + "documentFormatting": { + "dynamicRegistration": false, + }, + "documentRangeFormatting": { + "dynamicRegistration": false, + }, + "documentOnTypeFormatting": { + "dynamicRegistration": false, + }, + "foldingRange": { + "dynamicRegistration": false, + }, + "selectionRange": { + "dynamicRegistration": false, + }, + "inlayHint": { + "dynamicRegistration": false, + "resolveSupport": { + "properties": ["tooltip", "textEdits", "label.tooltip", "label.location", "label.command"], + }, + }, + "semanticTokens": { + "dynamicRegistration": false, + "requests": { + "full": { + "delta": true, + }, + "range": true, + }, + "tokenTypes": [ + "namespace", "type", "class", "enum", "interface", "struct", "typeParameter", + "parameter", "variable", "property", "enumMember", "event", "function", + "method", "macro", "keyword", "modifier", "comment", "string", "number", + "regexp", "operator", + ], + "tokenModifiers": [ + "declaration", "definition", "readonly", "static", "deprecated", "abstract", + "async", "modification", "documentation", "defaultLibrary", + ], + "formats": ["relative"], + "multilineTokenSupport": true, + "overlappingTokenSupport": true, + }, + "publishDiagnostics": { + "relatedInformation": true, + }, + }, + "window": { + "workDoneProgress": true, + }, + "general": { + "positionEncodings": ["utf-8", "utf-16"], + }, + }) +} + +async fn stream_stderr(generation: Generation, stderr: tokio::process::ChildStderr) { + let mut reader = BufReader::new(stderr).lines(); + loop { + match reader.next_line().await { + Ok(Some(line)) => { + debug!( + generation = generation.get(), + "rust-analyzer stderr: {line}" + ); + } + Ok(None) => break, + Err(error) => { + debug!( + generation = generation.get(), + "rust-analyzer stderr stream failed: {error}" + ); + break; + } + } + } +} + +async fn read_stdout_loop( + generation: Generation, + stdout: ChildStdout, + pending: Arc<Mutex<HashMap<u64, oneshot::Sender<PendingOutcome>>>>, + terminal_fault_tx: watch::Sender<Option<Fault>>, +) { + let mut reader = BufReader::new(stdout); + loop { + match read_frame(&mut reader).await { + Ok(frame) => { + if let Err(fault) = dispatch_frame(generation, &pending, &frame).await { + emit_terminal_fault(&terminal_fault_tx, &pending, fault).await; + break; + } + } + Err(error) => { + let fault = classify_io_fault( + generation, + FaultClass::Transport, + "failed to read frame", + error, + ); + emit_terminal_fault(&terminal_fault_tx, &pending, fault).await; + break; + } + } + } +} + +async fn emit_terminal_fault( + terminal_fault_tx: &watch::Sender<Option<Fault>>, + pending: &Arc<Mutex<HashMap<u64, oneshot::Sender<PendingOutcome>>>>, + fault: Fault, +) { + if let Err(error) = terminal_fault_tx.send(Some(fault.clone())) { + warn!("failed to publish terminal fault: {error}"); + } + let mut pending_guard = pending.lock().await; + for sender in pending_guard.drain().map(|(_id, sender)| sender) { + if let Err(outcome) = sender.send(PendingOutcome::TransportFault(fault.clone())) { + drop(outcome); + } + } +} + +async fn dispatch_frame( + generation: Generation, + pending: &Arc<Mutex<HashMap<u64, oneshot::Sender<PendingOutcome>>>>, + frame: &[u8], +) -> Result<(), Fault> { + let value: Value = serde_json::from_slice(frame).map_err(|error| { + Fault::new( + generation, + FaultClass::Protocol, + FaultCode::InvalidJson, + FaultDetail::new(format!("failed to deserialize JSON-RPC frame: {error}")), + ) + })?; + + let response_id = value.get("id").and_then(Value::as_u64); + let Some(response_id) = response_id else { + return Ok(()); + }; + + let mut pending_guard = pending.lock().await; + let Some(sender) = pending_guard.remove(&response_id) else { + warn!( + generation = generation.get(), + response_id, "received response for unknown request id" + ); + return Ok(()); + }; + drop(pending_guard); + + if let Some(result) = value.get("result") { + if let Err(outcome) = sender.send(PendingOutcome::Result(result.clone())) { + drop(outcome); + } + return Ok(()); + } + + if let Some(error_value) = value.get("error") { + let error: RpcErrorPayload = + serde_json::from_value(error_value.clone()).map_err(|error| { + Fault::new( + generation, + FaultClass::Protocol, + FaultCode::InvalidJson, + FaultDetail::new(format!( + "failed to deserialize JSON-RPC error payload: {error}" + )), + ) + })?; + if let Err(outcome) = sender.send(PendingOutcome::ResponseError(error)) { + drop(outcome); + } + return Ok(()); + } + + Err(Fault::new( + generation, + FaultClass::Protocol, + FaultCode::InvalidFrame, + FaultDetail::new("response frame missing both result and error"), + )) +} + +async fn read_frame(reader: &mut BufReader<ChildStdout>) -> Result<Vec<u8>, io::Error> { + let mut content_length = None::<usize>; + loop { + let mut header_line = String::new(); + let bytes_read = reader.read_line(&mut header_line).await?; + if bytes_read == 0 { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "EOF while reading headers", + )); + } + + if header_line == "\r\n" || header_line == "\n" { + break; + } + + let trimmed = header_line.trim_end_matches(['\r', '\n']); + if let Some(length) = trimmed.strip_prefix("Content-Length:") { + let parsed = length.trim().parse::<usize>().map_err(|parse_error| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("invalid Content-Length header: {parse_error}"), + ) + })?; + content_length = Some(parsed); + } + } + + let length = content_length.ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "missing Content-Length header") + })?; + + let mut payload = vec![0_u8; length]; + let _bytes_read = reader.read_exact(&mut payload).await?; + Ok(payload) +} + +async fn write_frame(writer: &mut ChildStdin, value: &Value) -> Result<(), io::Error> { + let payload = serde_json::to_vec(value).map_err(|error| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("failed to serialize JSON-RPC payload: {error}"), + ) + })?; + let header = format!("Content-Length: {}\r\n\r\n", payload.len()); + writer.write_all(header.as_bytes()).await?; + writer.write_all(&payload).await?; + writer.flush().await?; + Ok(()) +} + +fn classify_io_fault( + generation: Generation, + class: FaultClass, + context: &'static str, + error: io::Error, +) -> Fault { + let code = match error.kind() { + io::ErrorKind::BrokenPipe => FaultCode::BrokenPipe, + io::ErrorKind::UnexpectedEof => FaultCode::UnexpectedEof, + _ => match class { + FaultClass::Process => FaultCode::SpawnFailed, + _ => FaultCode::InvalidFrame, + }, + }; + Fault::new( + generation, + class, + code, + FaultDetail::new(format!("{context}: {error}")), + ) +} |