use crate::config::EngineConfig; use ra_mcp_domain::{ fault::{Fault, FaultClass, FaultCode, FaultDetail}, types::Generation, }; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use std::{ collections::HashMap, io, process::Stdio, sync::{ Arc, atomic::{AtomicU64, Ordering}, }, time::Duration, }; use tokio::{ io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}, process::{Child, ChildStdin, ChildStdout, Command}, sync::{Mutex, oneshot, watch}, task::JoinHandle, }; use tracing::{debug, warn}; use url::Url; #[derive(Debug, Clone)] pub(crate) struct WorkerHandle { generation: Generation, child: Arc>, writer: Arc>, pending: Arc>>>, next_request_id: Arc, terminal_fault_rx: watch::Receiver>, reader_task: Arc>>>, stderr_task: Arc>>>, } #[derive(Debug)] enum PendingOutcome { Result(Value), ResponseError(RpcErrorPayload), TransportFault(Fault), } #[derive(Debug, Clone, Deserialize)] pub(crate) struct RpcErrorPayload { pub(crate) code: i64, pub(crate) message: String, pub(crate) data: Option, } #[derive(Debug)] pub(crate) enum WorkerRequestError { Fault(Fault), Response(RpcErrorPayload), } impl WorkerHandle { pub(crate) fn terminal_fault(&self) -> Option { self.terminal_fault_rx.borrow().clone() } pub(crate) async fn send_notification( &self, method: &'static str, params: &impl Serialize, ) -> Result<(), Fault> { let payload = json!({ "jsonrpc": "2.0", "method": method, "params": params, }); let mut writer = self.writer.lock().await; write_frame(&mut writer, &payload).await.map_err(|error| { classify_io_fault( self.generation, FaultClass::Transport, "failed to write notification", error, ) }) } pub(crate) async fn send_request( &self, method: &'static str, params: &impl Serialize, timeout: Duration, ) -> Result { let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed); let (sender, receiver) = oneshot::channel::(); { let mut pending = self.pending.lock().await; let previous = pending.insert(request_id, sender); if let Some(previous_sender) = previous { drop(previous_sender); } } let payload = json!({ "jsonrpc": "2.0", "id": request_id, "method": method, "params": params, }); let write_result = { let mut writer = self.writer.lock().await; write_frame(&mut writer, &payload).await }; if let Err(error) = write_result { let mut pending = self.pending.lock().await; let removed = pending.remove(&request_id); if let Some(sender) = removed { drop(sender); } return Err(WorkerRequestError::Fault(classify_io_fault( self.generation, FaultClass::Transport, "failed to write request", error, ))); } match tokio::time::timeout(timeout, receiver).await { Ok(Ok(PendingOutcome::Result(value))) => Ok(value), Ok(Ok(PendingOutcome::ResponseError(error))) => { Err(WorkerRequestError::Response(error)) } Ok(Ok(PendingOutcome::TransportFault(fault))) => Err(WorkerRequestError::Fault(fault)), Ok(Err(_closed)) => Err(WorkerRequestError::Fault(Fault::new( self.generation, FaultClass::Transport, FaultCode::UnexpectedEof, FaultDetail::new("response channel closed before result"), ))), Err(_elapsed) => { let mut pending = self.pending.lock().await; let removed = pending.remove(&request_id); if let Some(sender) = removed { drop(sender); } Err(WorkerRequestError::Fault(Fault::new( self.generation, FaultClass::Timeout, FaultCode::RequestTimedOut, FaultDetail::new(format!("request timed out for method {method}")), ))) } } } pub(crate) async fn terminate(&self) { let mut child = self.child.lock().await; if child.id().is_some() && let Err(error) = child.kill().await { debug!( generation = self.generation.get(), "failed to kill rust-analyzer process cleanly: {error}" ); } if let Err(error) = child.wait().await { debug!( generation = self.generation.get(), "failed to wait rust-analyzer process cleanly: {error}" ); } if let Some(task) = self.reader_task.lock().await.take() { task.abort(); } if let Some(task) = self.stderr_task.lock().await.take() { task.abort(); } } } pub(crate) async fn spawn_worker( config: &EngineConfig, generation: Generation, ) -> Result { let mut command = Command::new(&config.rust_analyzer_binary); let _args = command.args(&config.rust_analyzer_args); let _envs = command.envs(config.rust_analyzer_env.iter().cloned()); let _configured = command .current_dir(config.workspace_root.as_path()) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()); let mut child = command.spawn().map_err(|error| { classify_io_fault( generation, FaultClass::Process, "failed to spawn rust-analyzer", error, ) })?; let stdin = child.stdin.take().ok_or_else(|| { Fault::new( generation, FaultClass::Process, FaultCode::SpawnFailed, FaultDetail::new("missing stdin pipe from rust-analyzer process"), ) })?; let stdout = child.stdout.take().ok_or_else(|| { Fault::new( generation, FaultClass::Process, FaultCode::SpawnFailed, FaultDetail::new("missing stdout pipe from rust-analyzer process"), ) })?; let stderr = child.stderr.take().ok_or_else(|| { Fault::new( generation, FaultClass::Process, FaultCode::SpawnFailed, FaultDetail::new("missing stderr pipe from rust-analyzer process"), ) })?; let child = Arc::new(Mutex::new(child)); let writer = Arc::new(Mutex::new(stdin)); let pending = Arc::new(Mutex::new( HashMap::>::new(), )); let next_request_id = Arc::new(AtomicU64::new(1)); let (terminal_fault_tx, terminal_fault_rx) = watch::channel(None::); let reader_task = { let pending = Arc::clone(&pending); let terminal_fault_tx = terminal_fault_tx.clone(); tokio::spawn(async move { read_stdout_loop(generation, stdout, pending, terminal_fault_tx).await; }) }; let stderr_task = tokio::spawn(async move { stream_stderr(generation, stderr).await; }); let handle = WorkerHandle { generation, child, writer, pending, next_request_id, terminal_fault_rx, reader_task: Arc::new(Mutex::new(Some(reader_task))), stderr_task: Arc::new(Mutex::new(Some(stderr_task))), }; let initialize_params = build_initialize_params(config)?; let startup = handle .send_request("initialize", &initialize_params, config.startup_timeout) .await; if let Err(error) = startup { handle.terminate().await; return Err(map_worker_request_error(generation, error)); } let initialized_params = json!({}); let initialized_result = handle .send_notification("initialized", &initialized_params) .await .map_err(|fault| { handle_fault_notification(generation, "initialized notification failed", fault) }); if let Err(fault) = initialized_result { handle.terminate().await; return Err(fault); } Ok(handle) } fn map_worker_request_error(generation: Generation, error: WorkerRequestError) -> Fault { match error { WorkerRequestError::Fault(fault) => fault, WorkerRequestError::Response(response) => Fault::new( generation, FaultClass::Protocol, FaultCode::InvalidFrame, FaultDetail::new(format!( "initialize returned LSP error {}: {}", response.code, response.message )), ), } } fn handle_fault_notification(generation: Generation, context: &'static str, fault: Fault) -> Fault { let detail = FaultDetail::new(format!("{context}: {}", fault.detail.message)); Fault::new(generation, fault.class, fault.code, detail) } fn build_initialize_params(config: &EngineConfig) -> Result { let root_uri = Url::from_directory_path(config.workspace_root.as_path()).map_err(|()| { Fault::new( Generation::genesis(), FaultClass::Protocol, FaultCode::InvalidFrame, FaultDetail::new("workspace root cannot be represented as file URI"), ) })?; let folder_name = config .workspace_root .as_path() .file_name() .and_then(|value| value.to_str()) .unwrap_or("workspace") .to_owned(); let root_uri_string = root_uri.to_string(); Ok(json!({ "processId": std::process::id(), "rootUri": root_uri_string.clone(), "capabilities": build_client_capabilities(), "workspaceFolders": [{ "uri": root_uri_string, "name": folder_name, }], "trace": "off", "clientInfo": { "name": "adequate-rust-mcp", "version": env!("CARGO_PKG_VERSION"), } })) } fn build_client_capabilities() -> Value { let symbol_kind_values = (1_u32..=26).collect::>(); json!({ "workspace": { "applyEdit": true, "workspaceEdit": { "documentChanges": true, "resourceOperations": ["create", "rename", "delete"], }, "symbol": { "dynamicRegistration": false, "resolveSupport": { "properties": ["location.range", "containerName"], }, }, "diagnostics": { "refreshSupport": true, }, "executeCommand": { "dynamicRegistration": false, }, "workspaceFolders": true, "configuration": true, }, "textDocument": { "synchronization": { "dynamicRegistration": false, "willSave": false, "didSave": true, "willSaveWaitUntil": false, }, "hover": { "dynamicRegistration": false, "contentFormat": ["markdown", "plaintext"], }, "definition": { "dynamicRegistration": false, "linkSupport": true, }, "declaration": { "dynamicRegistration": false, "linkSupport": true, }, "typeDefinition": { "dynamicRegistration": false, "linkSupport": true, }, "implementation": { "dynamicRegistration": false, "linkSupport": true, }, "references": { "dynamicRegistration": false, }, "documentHighlight": { "dynamicRegistration": false, }, "documentSymbol": { "dynamicRegistration": false, "hierarchicalDocumentSymbolSupport": true, "symbolKind": { "valueSet": symbol_kind_values, }, }, "completion": { "dynamicRegistration": false, "contextSupport": true, "completionItem": { "snippetSupport": true, "documentationFormat": ["markdown", "plaintext"], "resolveSupport": { "properties": ["documentation", "detail", "additionalTextEdits"], }, }, }, "signatureHelp": { "dynamicRegistration": false, }, "codeAction": { "dynamicRegistration": false, "isPreferredSupport": true, "codeActionLiteralSupport": { "codeActionKind": { "valueSet": [ "", "quickfix", "refactor", "refactor.extract", "refactor.inline", "refactor.rewrite", "source", "source.organizeImports", ], }, }, }, "codeLens": { "dynamicRegistration": false, }, "documentLink": { "dynamicRegistration": false, "tooltipSupport": true, }, "colorProvider": { "dynamicRegistration": false, }, "linkedEditingRange": { "dynamicRegistration": false, }, "rename": { "dynamicRegistration": false, "prepareSupport": true, }, "typeHierarchy": { "dynamicRegistration": false, }, "inlineValue": { "dynamicRegistration": false, }, "moniker": { "dynamicRegistration": false, }, "diagnostic": { "dynamicRegistration": false, }, "documentFormatting": { "dynamicRegistration": false, }, "documentRangeFormatting": { "dynamicRegistration": false, }, "documentOnTypeFormatting": { "dynamicRegistration": false, }, "foldingRange": { "dynamicRegistration": false, }, "selectionRange": { "dynamicRegistration": false, }, "inlayHint": { "dynamicRegistration": false, "resolveSupport": { "properties": ["tooltip", "textEdits", "label.tooltip", "label.location", "label.command"], }, }, "semanticTokens": { "dynamicRegistration": false, "requests": { "full": { "delta": true, }, "range": true, }, "tokenTypes": [ "namespace", "type", "class", "enum", "interface", "struct", "typeParameter", "parameter", "variable", "property", "enumMember", "event", "function", "method", "macro", "keyword", "modifier", "comment", "string", "number", "regexp", "operator", ], "tokenModifiers": [ "declaration", "definition", "readonly", "static", "deprecated", "abstract", "async", "modification", "documentation", "defaultLibrary", ], "formats": ["relative"], "multilineTokenSupport": true, "overlappingTokenSupport": true, }, "publishDiagnostics": { "relatedInformation": true, }, }, "window": { "workDoneProgress": true, }, "general": { "positionEncodings": ["utf-8", "utf-16"], }, }) } async fn stream_stderr(generation: Generation, stderr: tokio::process::ChildStderr) { let mut reader = BufReader::new(stderr).lines(); loop { match reader.next_line().await { Ok(Some(line)) => { debug!( generation = generation.get(), "rust-analyzer stderr: {line}" ); } Ok(None) => break, Err(error) => { debug!( generation = generation.get(), "rust-analyzer stderr stream failed: {error}" ); break; } } } } async fn read_stdout_loop( generation: Generation, stdout: ChildStdout, pending: Arc>>>, terminal_fault_tx: watch::Sender>, ) { let mut reader = BufReader::new(stdout); loop { match read_frame(&mut reader).await { Ok(frame) => { if let Err(fault) = dispatch_frame(generation, &pending, &frame).await { emit_terminal_fault(&terminal_fault_tx, &pending, fault).await; break; } } Err(error) => { let fault = classify_io_fault( generation, FaultClass::Transport, "failed to read frame", error, ); emit_terminal_fault(&terminal_fault_tx, &pending, fault).await; break; } } } } async fn emit_terminal_fault( terminal_fault_tx: &watch::Sender>, pending: &Arc>>>, fault: Fault, ) { if let Err(error) = terminal_fault_tx.send(Some(fault.clone())) { warn!("failed to publish terminal fault: {error}"); } let mut pending_guard = pending.lock().await; for sender in pending_guard.drain().map(|(_id, sender)| sender) { if let Err(outcome) = sender.send(PendingOutcome::TransportFault(fault.clone())) { drop(outcome); } } } async fn dispatch_frame( generation: Generation, pending: &Arc>>>, frame: &[u8], ) -> Result<(), Fault> { let value: Value = serde_json::from_slice(frame).map_err(|error| { Fault::new( generation, FaultClass::Protocol, FaultCode::InvalidJson, FaultDetail::new(format!("failed to deserialize JSON-RPC frame: {error}")), ) })?; let response_id = value.get("id").and_then(Value::as_u64); let Some(response_id) = response_id else { return Ok(()); }; let mut pending_guard = pending.lock().await; let Some(sender) = pending_guard.remove(&response_id) else { warn!( generation = generation.get(), response_id, "received response for unknown request id" ); return Ok(()); }; drop(pending_guard); if let Some(result) = value.get("result") { if let Err(outcome) = sender.send(PendingOutcome::Result(result.clone())) { drop(outcome); } return Ok(()); } if let Some(error_value) = value.get("error") { let error: RpcErrorPayload = serde_json::from_value(error_value.clone()).map_err(|error| { Fault::new( generation, FaultClass::Protocol, FaultCode::InvalidJson, FaultDetail::new(format!( "failed to deserialize JSON-RPC error payload: {error}" )), ) })?; if let Err(outcome) = sender.send(PendingOutcome::ResponseError(error)) { drop(outcome); } return Ok(()); } Err(Fault::new( generation, FaultClass::Protocol, FaultCode::InvalidFrame, FaultDetail::new("response frame missing both result and error"), )) } async fn read_frame(reader: &mut BufReader) -> Result, io::Error> { let mut content_length = None::; loop { let mut header_line = String::new(); let bytes_read = reader.read_line(&mut header_line).await?; if bytes_read == 0 { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "EOF while reading headers", )); } if header_line == "\r\n" || header_line == "\n" { break; } let trimmed = header_line.trim_end_matches(['\r', '\n']); if let Some(length) = trimmed.strip_prefix("Content-Length:") { let parsed = length.trim().parse::().map_err(|parse_error| { io::Error::new( io::ErrorKind::InvalidData, format!("invalid Content-Length header: {parse_error}"), ) })?; content_length = Some(parsed); } } let length = content_length.ok_or_else(|| { io::Error::new(io::ErrorKind::InvalidData, "missing Content-Length header") })?; let mut payload = vec![0_u8; length]; let _bytes_read = reader.read_exact(&mut payload).await?; Ok(payload) } async fn write_frame(writer: &mut ChildStdin, value: &Value) -> Result<(), io::Error> { let payload = serde_json::to_vec(value).map_err(|error| { io::Error::new( io::ErrorKind::InvalidData, format!("failed to serialize JSON-RPC payload: {error}"), ) })?; let header = format!("Content-Length: {}\r\n\r\n", payload.len()); writer.write_all(header.as_bytes()).await?; writer.write_all(&payload).await?; writer.flush().await?; Ok(()) } fn classify_io_fault( generation: Generation, class: FaultClass, context: &'static str, error: io::Error, ) -> Fault { let code = match error.kind() { io::ErrorKind::BrokenPipe => FaultCode::BrokenPipe, io::ErrorKind::UnexpectedEof => FaultCode::UnexpectedEof, _ => match class { FaultClass::Process => FaultCode::SpawnFailed, _ => FaultCode::InvalidFrame, }, }; Fault::new( generation, class, code, FaultDetail::new(format!("{context}: {error}")), ) }