diff options
| author | main <main@swarm.moe> | 2026-03-19 15:49:41 -0400 |
|---|---|---|
| committer | main <main@swarm.moe> | 2026-03-19 15:49:41 -0400 |
| commit | fa1bd32800b65aab31ea732dd240261b4047522c (patch) | |
| tree | 2fd08af6f36b8beb3c7c941990becc1a0a091d62 /crates/ra-mcp-engine | |
| download | adequate-rust-mcp-fa1bd32800b65aab31ea732dd240261b4047522c.zip | |
Release adequate-rust-mcp 1.0.0v1.0.0
Diffstat (limited to 'crates/ra-mcp-engine')
| -rw-r--r-- | crates/ra-mcp-engine/.gitignore | 1 | ||||
| -rw-r--r-- | crates/ra-mcp-engine/Cargo.toml | 28 | ||||
| -rw-r--r-- | crates/ra-mcp-engine/src/bin/fake-rust-analyzer.rs | 467 | ||||
| -rw-r--r-- | crates/ra-mcp-engine/src/config.rs | 79 | ||||
| -rw-r--r-- | crates/ra-mcp-engine/src/error.rs | 77 | ||||
| -rw-r--r-- | crates/ra-mcp-engine/src/lib.rs | 20 | ||||
| -rw-r--r-- | crates/ra-mcp-engine/src/lsp_transport.rs | 717 | ||||
| -rw-r--r-- | crates/ra-mcp-engine/src/supervisor.rs | 1257 | ||||
| -rw-r--r-- | crates/ra-mcp-engine/tests/engine_recovery.rs | 353 |
9 files changed, 2999 insertions, 0 deletions
diff --git a/crates/ra-mcp-engine/.gitignore b/crates/ra-mcp-engine/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/crates/ra-mcp-engine/.gitignore @@ -0,0 +1 @@ +/target diff --git a/crates/ra-mcp-engine/Cargo.toml b/crates/ra-mcp-engine/Cargo.toml new file mode 100644 index 0000000..d5d870d --- /dev/null +++ b/crates/ra-mcp-engine/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "ra-mcp-engine" +categories.workspace = true +description = "Resilient rust-analyzer transport and worker-supervision engine used by adequate-rust-mcp." +edition.workspace = true +keywords.workspace = true +license.workspace = true +readme.workspace = true +repository.workspace = true +rust-version.workspace = true +version.workspace = true + +[dependencies] +lsp-types.workspace = true +ra-mcp-domain = { path = "../ra-mcp-domain" } +serde.workspace = true +serde_json.workspace = true +thiserror.workspace = true +tokio.workspace = true +tracing.workspace = true +url.workspace = true + +[dev-dependencies] +serial_test.workspace = true +tempfile.workspace = true + +[lints] +workspace = true diff --git a/crates/ra-mcp-engine/src/bin/fake-rust-analyzer.rs b/crates/ra-mcp-engine/src/bin/fake-rust-analyzer.rs new file mode 100644 index 0000000..c64b68b --- /dev/null +++ b/crates/ra-mcp-engine/src/bin/fake-rust-analyzer.rs @@ -0,0 +1,467 @@ +//! Fault-injectable fake rust-analyzer used by integration tests. + +use lsp_types as _; +use ra_mcp_domain as _; +use ra_mcp_engine as _; +use serde as _; +use serde_json::{Value, json}; +#[cfg(test)] +use serial_test as _; +use std::{ + fs, + io::{self, BufRead, BufReader, Read, Write}, + path::{Path, PathBuf}, + time::Duration, +}; +#[cfg(test)] +use tempfile as _; +use thiserror as _; +use tokio as _; +use tracing as _; +use url as _; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Mode { + Stable, + CrashOnFirstHover, +} + +fn main() -> Result<(), Box<dyn std::error::Error>> { + run().map_err(|error| Box::new(error) as Box<dyn std::error::Error>) +} + +fn run() -> io::Result<()> { + let mut mode = Mode::Stable; + let mut marker = None::<PathBuf>; + let mut hover_delay = Duration::ZERO; + let mut execute_command_delay = Duration::ZERO; + let mut execute_command_log = None::<PathBuf>; + let mut diagnostic_warmup_count = 0_u8; + let mut diagnostic_cancel_count = 0_u8; + let mut strict_root_match = false; + let mut workspace_root = None::<PathBuf>; + let mut args = std::env::args().skip(1); + loop { + let argument = args.next(); + let Some(argument) = argument else { + break; + }; + match argument.as_str() { + "--mode" => { + if let Some(value) = args.next() { + mode = parse_mode(&value).unwrap_or(Mode::Stable); + } + } + "--crash-marker" => { + if let Some(value) = args.next() { + marker = Some(PathBuf::from(value)); + } + } + "--hover-delay-ms" => { + if let Some(value) = args.next() { + let parsed = value.parse::<u64>().ok(); + if let Some(delay_ms) = parsed { + hover_delay = Duration::from_millis(delay_ms); + } + } + } + "--execute-command-delay-ms" => { + if let Some(value) = args.next() { + let parsed = value.parse::<u64>().ok(); + if let Some(delay_ms) = parsed { + execute_command_delay = Duration::from_millis(delay_ms); + } + } + } + "--execute-command-log" => { + if let Some(value) = args.next() { + execute_command_log = Some(PathBuf::from(value)); + } + } + "--diagnostic-warmup-count" => { + if let Some(value) = args.next() { + let parsed = value.parse::<u8>().ok(); + if let Some(count) = parsed { + diagnostic_warmup_count = count; + } + } + } + "--diagnostic-cancel-count" => { + if let Some(value) = args.next() { + let parsed = value.parse::<u8>().ok(); + if let Some(count) = parsed { + diagnostic_cancel_count = count; + } + } + } + "--strict-root-match" => { + strict_root_match = true; + } + _ => {} + } + } + + let stdin = io::stdin(); + let stdout = io::stdout(); + let mut reader = BufReader::new(stdin.lock()); + let mut writer = stdout.lock(); + + loop { + let frame = match read_frame(&mut reader) { + Ok(frame) => frame, + Err(error) if error.kind() == io::ErrorKind::UnexpectedEof => break, + Err(error) => return Err(error), + }; + let message: Value = serde_json::from_slice(&frame) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string()))?; + if let Some(method) = message.get("method").and_then(Value::as_str) { + if method == "initialized" { + continue; + } + + let request_id = message.get("id").cloned(); + let Some(request_id) = request_id else { + continue; + }; + if method == "initialize" { + workspace_root = initialized_workspace_root(&message); + } + + if mode == Mode::CrashOnFirstHover + && method == "textDocument/hover" + && should_crash(&marker)? + { + std::process::exit(0); + } + if method == "textDocument/hover" && !hover_delay.is_zero() { + std::thread::sleep(hover_delay); + } + if method == "workspace/executeCommand" { + if let Some(path) = execute_command_log.as_ref() { + log_execute_command_effect(path, &message)?; + } + if !execute_command_delay.is_zero() { + std::thread::sleep(execute_command_delay); + } + } + + let response = if strict_root_match + && request_targets_outside_workspace(&message, workspace_root.as_deref()) + { + strict_root_mismatch_response(method, request_id, &message) + } else if method == "textDocument/diagnostic" && diagnostic_cancel_count > 0 { + diagnostic_cancel_count = diagnostic_cancel_count.saturating_sub(1); + server_cancelled_response(request_id) + } else if method == "textDocument/diagnostic" && diagnostic_warmup_count > 0 { + diagnostic_warmup_count = diagnostic_warmup_count.saturating_sub(1); + warmup_unlinked_diagnostic_response(request_id) + } else { + make_response(method, request_id, &message) + }; + write_frame(&mut writer, &response)?; + } + } + + Ok(()) +} + +fn parse_mode(raw: &str) -> Option<Mode> { + match raw { + "stable" => Some(Mode::Stable), + "crash_on_first_hover" => Some(Mode::CrashOnFirstHover), + _ => None, + } +} + +fn should_crash(marker: &Option<PathBuf>) -> io::Result<bool> { + let Some(marker) = marker else { + return Ok(true); + }; + if marker.exists() { + return Ok(false); + } + fs::write(marker, b"crashed")?; + Ok(true) +} + +fn log_execute_command_effect(path: &PathBuf, request: &Value) -> io::Result<()> { + let command = request + .get("params") + .and_then(|params| params.get("command")) + .and_then(Value::as_str) + .unwrap_or("<missing-command>"); + let mut file = fs::OpenOptions::new() + .create(true) + .append(true) + .open(path)?; + writeln!(file, "{command}")?; + Ok(()) +} + +fn initialized_workspace_root(request: &Value) -> Option<PathBuf> { + let root_uri = request + .get("params") + .and_then(|params| params.get("rootUri")) + .and_then(Value::as_str)?; + let root_url = url::Url::parse(root_uri).ok()?; + root_url.to_file_path().ok() +} + +fn request_targets_outside_workspace(request: &Value, workspace_root: Option<&Path>) -> bool { + let Some(workspace_root) = workspace_root else { + return false; + }; + let file_path = request_document_path(request); + let Some(file_path) = file_path else { + return false; + }; + !file_path.starts_with(workspace_root) +} + +fn request_document_path(request: &Value) -> Option<PathBuf> { + let uri = request + .get("params") + .and_then(|params| params.get("textDocument")) + .and_then(|doc| doc.get("uri")) + .and_then(Value::as_str)?; + let url = url::Url::parse(uri).ok()?; + url.to_file_path().ok() +} + +fn strict_root_mismatch_response(method: &str, request_id: Value, request: &Value) -> Value { + match method { + "textDocument/hover" => json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": Value::Null + }), + "textDocument/definition" => json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": Value::Null + }), + "textDocument/references" => json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": Value::Null + }), + "textDocument/rename" => { + let uri = request + .get("params") + .and_then(|params| params.get("textDocument")) + .and_then(|doc| doc.get("uri")) + .and_then(Value::as_str) + .unwrap_or("file:///tmp/fallback.rs") + .to_owned(); + json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": { + "changes": { + uri: [] + } + } + }) + } + "textDocument/diagnostic" => warmup_unlinked_diagnostic_response(request_id), + _ => make_response(method, request_id, request), + } +} + +fn make_response(method: &str, request_id: Value, request: &Value) -> Value { + match method { + "initialize" => json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": { + "capabilities": {} + } + }), + "textDocument/hover" => json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": { + "contents": { + "kind": "markdown", + "value": "hover::ok" + } + } + }), + "textDocument/definition" => { + let uri = request + .get("params") + .and_then(|params| params.get("textDocument")) + .and_then(|doc| doc.get("uri")) + .cloned() + .unwrap_or(Value::String("file:///tmp/fallback.rs".to_owned())); + json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": [{ + "uri": uri, + "range": { + "start": { "line": 2, "character": 3 }, + "end": { "line": 2, "character": 8 } + } + }] + }) + } + "textDocument/references" => { + let uri = request + .get("params") + .and_then(|params| params.get("textDocument")) + .and_then(|doc| doc.get("uri")) + .cloned() + .unwrap_or(Value::String("file:///tmp/fallback.rs".to_owned())); + json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": [{ + "uri": uri, + "range": { + "start": { "line": 4, "character": 1 }, + "end": { "line": 4, "character": 5 } + } + }] + }) + } + "textDocument/rename" => { + let uri = request + .get("params") + .and_then(|params| params.get("textDocument")) + .and_then(|doc| doc.get("uri")) + .and_then(Value::as_str) + .unwrap_or("file:///tmp/fallback.rs") + .to_owned(); + json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": { + "changes": { + uri: [ + { + "range": { + "start": { "line": 1, "character": 1 }, + "end": { "line": 1, "character": 4 } + }, + "newText": "renamed_symbol" + } + ] + } + } + }) + } + "textDocument/diagnostic" => json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": { + "kind": "full", + "items": [{ + "range": { + "start": { "line": 0, "character": 0 }, + "end": { "line": 0, "character": 3 } + }, + "severity": 1, + "message": "fake diagnostic" + }] + } + }), + "workspace/executeCommand" => { + let command = request + .get("params") + .and_then(|params| params.get("command")) + .cloned() + .unwrap_or(Value::Null); + json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": { + "ack": "ok", + "command": command + } + }) + } + _ => json!({ + "jsonrpc": "2.0", + "id": request_id, + "error": { + "code": -32601, + "message": format!("method not found: {method}") + } + }), + } +} + +fn warmup_unlinked_diagnostic_response(request_id: Value) -> Value { + json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": { + "kind": "full", + "items": [{ + "range": { + "start": { "line": 0, "character": 0 }, + "end": { "line": 0, "character": 0 } + }, + "severity": 2, + "code": "unlinked-file", + "message": "This file is not part of any crate, so rust-analyzer can't offer IDE services." + }] + } + }) +} + +fn server_cancelled_response(request_id: Value) -> Value { + json!({ + "jsonrpc": "2.0", + "id": request_id, + "error": { + "code": -32802, + "message": "server cancelled request during workspace reload" + } + }) +} + +fn read_frame(reader: &mut BufReader<impl Read>) -> io::Result<Vec<u8>> { + let mut content_length = None::<usize>; + loop { + let mut line = String::new(); + let bytes = reader.read_line(&mut line)?; + if bytes == 0 { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "EOF while reading headers", + )); + } + if line == "\r\n" || line == "\n" { + break; + } + let trimmed = line.trim_end_matches(['\r', '\n']); + if let Some(raw_length) = trimmed.strip_prefix("Content-Length:") { + let parsed = raw_length.trim().parse::<usize>().map_err(|error| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("invalid Content-Length header: {error}"), + ) + })?; + content_length = Some(parsed); + } + } + + let length = content_length.ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "missing Content-Length header") + })?; + let mut payload = vec![0_u8; length]; + reader.read_exact(&mut payload)?; + Ok(payload) +} + +fn write_frame(writer: &mut impl Write, payload: &Value) -> io::Result<()> { + let serialized = serde_json::to_vec(payload) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string()))?; + let header = format!("Content-Length: {}\r\n\r\n", serialized.len()); + writer.write_all(header.as_bytes())?; + writer.write_all(&serialized)?; + writer.flush()?; + Ok(()) +} diff --git a/crates/ra-mcp-engine/src/config.rs b/crates/ra-mcp-engine/src/config.rs new file mode 100644 index 0000000..8d116d5 --- /dev/null +++ b/crates/ra-mcp-engine/src/config.rs @@ -0,0 +1,79 @@ +use ra_mcp_domain::types::{InvariantViolation, WorkspaceRoot}; +use std::{path::PathBuf, time::Duration}; + +/// Exponential backoff policy for worker restart attempts. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct BackoffPolicy { + /// Minimum delay between restart attempts. + pub floor: Duration, + /// Maximum delay between restart attempts. + pub ceiling: Duration, +} + +impl BackoffPolicy { + /// Builds a validated backoff policy. + pub fn try_new(floor: Duration, ceiling: Duration) -> Result<Self, InvariantViolation> { + if floor.is_zero() { + return Err(InvariantViolation::new("backoff floor must be non-zero")); + } + if ceiling < floor { + return Err(InvariantViolation::new( + "backoff ceiling must be greater than or equal to floor", + )); + } + Ok(Self { floor, ceiling }) + } +} + +/// Runtime engine configuration. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EngineConfig { + /// Absolute workspace root used for rust-analyzer process cwd/root URI. + pub workspace_root: WorkspaceRoot, + /// rust-analyzer executable path. + pub rust_analyzer_binary: PathBuf, + /// Additional rust-analyzer process arguments. + pub rust_analyzer_args: Vec<String>, + /// Additional rust-analyzer process environment variables. + pub rust_analyzer_env: Vec<(String, String)>, + /// Startup handshake timeout. + pub startup_timeout: Duration, + /// Timeout for ordinary requests. + pub request_timeout: Duration, + /// Restart backoff policy. + pub backoff_policy: BackoffPolicy, +} + +impl EngineConfig { + /// Builds validated engine configuration. + pub fn try_new( + workspace_root: WorkspaceRoot, + rust_analyzer_binary: PathBuf, + rust_analyzer_args: Vec<String>, + rust_analyzer_env: Vec<(String, String)>, + startup_timeout: Duration, + request_timeout: Duration, + backoff_policy: BackoffPolicy, + ) -> Result<Self, InvariantViolation> { + if rust_analyzer_binary.as_os_str().is_empty() { + return Err(InvariantViolation::new( + "rust-analyzer binary path must be non-empty", + )); + } + if startup_timeout.is_zero() { + return Err(InvariantViolation::new("startup timeout must be non-zero")); + } + if request_timeout.is_zero() { + return Err(InvariantViolation::new("request timeout must be non-zero")); + } + Ok(Self { + workspace_root, + rust_analyzer_binary, + rust_analyzer_args, + rust_analyzer_env, + startup_timeout, + request_timeout, + backoff_policy, + }) + } +} diff --git a/crates/ra-mcp-engine/src/error.rs b/crates/ra-mcp-engine/src/error.rs new file mode 100644 index 0000000..f40e1ae --- /dev/null +++ b/crates/ra-mcp-engine/src/error.rs @@ -0,0 +1,77 @@ +use crate::lsp_transport::RpcErrorPayload; +use ra_mcp_domain::{fault::Fault, types::InvariantViolation}; +use serde_json::Value; +use thiserror::Error; + +/// Engine result type. +pub type EngineResult<T> = Result<T, EngineError>; + +/// Structured rust-analyzer response error. +#[derive(Debug, Clone, Error)] +#[error("lsp response error: code={code}, message={message}")] +pub struct LspResponseError { + /// LSP JSON-RPC error code. + pub code: i64, + /// LSP JSON-RPC error message. + pub message: String, + /// Optional JSON-RPC error data payload. + pub data: Option<Value>, +} + +/// Engine failure type. +#[derive(Debug, Error)] +pub enum EngineError { + /// I/O failure while syncing source documents. + #[error("io error: {0}")] + Io(#[from] std::io::Error), + /// Domain invariant was violated. + #[error(transparent)] + Invariant(#[from] InvariantViolation), + /// Transport/process/protocol fault. + #[error("engine fault: {0:?}")] + Fault(Fault), + /// rust-analyzer returned a JSON-RPC error object. + #[error(transparent)] + LspResponse(#[from] LspResponseError), + /// Response payload could not be deserialized into expected type. + #[error("invalid lsp payload for method {method}: {message}")] + InvalidPayload { + /// Request method. + method: &'static str, + /// Decoder error detail. + message: String, + }, + /// Request params could not be serialized into JSON. + #[error("invalid lsp request payload for method {method}: {message}")] + InvalidRequest { + /// Request method. + method: &'static str, + /// Encoder error detail. + message: String, + }, + /// Path to URL conversion failed. + #[error("path cannot be represented as file URL")] + InvalidFileUrl, +} + +impl From<Fault> for EngineError { + fn from(value: Fault) -> Self { + Self::Fault(value) + } +} + +impl From<RpcErrorPayload> for LspResponseError { + fn from(value: RpcErrorPayload) -> Self { + Self { + code: value.code, + message: value.message, + data: value.data, + } + } +} + +impl From<RpcErrorPayload> for EngineError { + fn from(value: RpcErrorPayload) -> Self { + Self::LspResponse(value.into()) + } +} diff --git a/crates/ra-mcp-engine/src/lib.rs b/crates/ra-mcp-engine/src/lib.rs new file mode 100644 index 0000000..3d36a5b --- /dev/null +++ b/crates/ra-mcp-engine/src/lib.rs @@ -0,0 +1,20 @@ +#![recursion_limit = "512"] + +//! Resilient rust-analyzer execution engine and typed LSP façade. + +#[cfg(test)] +use serial_test as _; +#[cfg(test)] +use tempfile as _; + +mod config; +mod error; +mod lsp_transport; +mod supervisor; + +pub use config::{BackoffPolicy, EngineConfig}; +pub use error::{EngineError, EngineResult}; +pub use supervisor::{ + DiagnosticEntry, DiagnosticLevel, DiagnosticsReport, Engine, HoverPayload, + MethodTelemetrySnapshot, RenameReport, TelemetrySnapshot, TelemetryTotals, +}; diff --git a/crates/ra-mcp-engine/src/lsp_transport.rs b/crates/ra-mcp-engine/src/lsp_transport.rs new file mode 100644 index 0000000..c47d4f2 --- /dev/null +++ b/crates/ra-mcp-engine/src/lsp_transport.rs @@ -0,0 +1,717 @@ +use crate::config::EngineConfig; +use ra_mcp_domain::{ + fault::{Fault, FaultClass, FaultCode, FaultDetail}, + types::Generation, +}; +use serde::{Deserialize, Serialize}; +use serde_json::{Value, json}; +use std::{ + collections::HashMap, + io, + process::Stdio, + sync::{ + Arc, + atomic::{AtomicU64, Ordering}, + }, + time::Duration, +}; +use tokio::{ + io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}, + process::{Child, ChildStdin, ChildStdout, Command}, + sync::{Mutex, oneshot, watch}, + task::JoinHandle, +}; +use tracing::{debug, warn}; +use url::Url; + +#[derive(Debug, Clone)] +pub(crate) struct WorkerHandle { + generation: Generation, + child: Arc<Mutex<Child>>, + writer: Arc<Mutex<ChildStdin>>, + pending: Arc<Mutex<HashMap<u64, oneshot::Sender<PendingOutcome>>>>, + next_request_id: Arc<AtomicU64>, + terminal_fault_rx: watch::Receiver<Option<Fault>>, + reader_task: Arc<Mutex<Option<JoinHandle<()>>>>, + stderr_task: Arc<Mutex<Option<JoinHandle<()>>>>, +} + +#[derive(Debug)] +enum PendingOutcome { + Result(Value), + ResponseError(RpcErrorPayload), + TransportFault(Fault), +} + +#[derive(Debug, Clone, Deserialize)] +pub(crate) struct RpcErrorPayload { + pub(crate) code: i64, + pub(crate) message: String, + pub(crate) data: Option<Value>, +} + +#[derive(Debug)] +pub(crate) enum WorkerRequestError { + Fault(Fault), + Response(RpcErrorPayload), +} + +impl WorkerHandle { + pub(crate) fn terminal_fault(&self) -> Option<Fault> { + self.terminal_fault_rx.borrow().clone() + } + + pub(crate) async fn send_notification( + &self, + method: &'static str, + params: &impl Serialize, + ) -> Result<(), Fault> { + let payload = json!({ + "jsonrpc": "2.0", + "method": method, + "params": params, + }); + let mut writer = self.writer.lock().await; + write_frame(&mut writer, &payload).await.map_err(|error| { + classify_io_fault( + self.generation, + FaultClass::Transport, + "failed to write notification", + error, + ) + }) + } + + pub(crate) async fn send_request( + &self, + method: &'static str, + params: &impl Serialize, + timeout: Duration, + ) -> Result<Value, WorkerRequestError> { + let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed); + let (sender, receiver) = oneshot::channel::<PendingOutcome>(); + { + let mut pending = self.pending.lock().await; + let previous = pending.insert(request_id, sender); + if let Some(previous_sender) = previous { + drop(previous_sender); + } + } + + let payload = json!({ + "jsonrpc": "2.0", + "id": request_id, + "method": method, + "params": params, + }); + + let write_result = { + let mut writer = self.writer.lock().await; + write_frame(&mut writer, &payload).await + }; + + if let Err(error) = write_result { + let mut pending = self.pending.lock().await; + let removed = pending.remove(&request_id); + if let Some(sender) = removed { + drop(sender); + } + return Err(WorkerRequestError::Fault(classify_io_fault( + self.generation, + FaultClass::Transport, + "failed to write request", + error, + ))); + } + + match tokio::time::timeout(timeout, receiver).await { + Ok(Ok(PendingOutcome::Result(value))) => Ok(value), + Ok(Ok(PendingOutcome::ResponseError(error))) => { + Err(WorkerRequestError::Response(error)) + } + Ok(Ok(PendingOutcome::TransportFault(fault))) => Err(WorkerRequestError::Fault(fault)), + Ok(Err(_closed)) => Err(WorkerRequestError::Fault(Fault::new( + self.generation, + FaultClass::Transport, + FaultCode::UnexpectedEof, + FaultDetail::new("response channel closed before result"), + ))), + Err(_elapsed) => { + let mut pending = self.pending.lock().await; + let removed = pending.remove(&request_id); + if let Some(sender) = removed { + drop(sender); + } + Err(WorkerRequestError::Fault(Fault::new( + self.generation, + FaultClass::Timeout, + FaultCode::RequestTimedOut, + FaultDetail::new(format!("request timed out for method {method}")), + ))) + } + } + } + + pub(crate) async fn terminate(&self) { + let mut child = self.child.lock().await; + if child.id().is_some() + && let Err(error) = child.kill().await + { + debug!( + generation = self.generation.get(), + "failed to kill rust-analyzer process cleanly: {error}" + ); + } + if let Err(error) = child.wait().await { + debug!( + generation = self.generation.get(), + "failed to wait rust-analyzer process cleanly: {error}" + ); + } + + if let Some(task) = self.reader_task.lock().await.take() { + task.abort(); + } + if let Some(task) = self.stderr_task.lock().await.take() { + task.abort(); + } + } +} + +pub(crate) async fn spawn_worker( + config: &EngineConfig, + generation: Generation, +) -> Result<WorkerHandle, Fault> { + let mut command = Command::new(&config.rust_analyzer_binary); + let _args = command.args(&config.rust_analyzer_args); + let _envs = command.envs(config.rust_analyzer_env.iter().cloned()); + let _configured = command + .current_dir(config.workspace_root.as_path()) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + let mut child = command.spawn().map_err(|error| { + classify_io_fault( + generation, + FaultClass::Process, + "failed to spawn rust-analyzer", + error, + ) + })?; + + let stdin = child.stdin.take().ok_or_else(|| { + Fault::new( + generation, + FaultClass::Process, + FaultCode::SpawnFailed, + FaultDetail::new("missing stdin pipe from rust-analyzer process"), + ) + })?; + let stdout = child.stdout.take().ok_or_else(|| { + Fault::new( + generation, + FaultClass::Process, + FaultCode::SpawnFailed, + FaultDetail::new("missing stdout pipe from rust-analyzer process"), + ) + })?; + let stderr = child.stderr.take().ok_or_else(|| { + Fault::new( + generation, + FaultClass::Process, + FaultCode::SpawnFailed, + FaultDetail::new("missing stderr pipe from rust-analyzer process"), + ) + })?; + + let child = Arc::new(Mutex::new(child)); + let writer = Arc::new(Mutex::new(stdin)); + let pending = Arc::new(Mutex::new( + HashMap::<u64, oneshot::Sender<PendingOutcome>>::new(), + )); + let next_request_id = Arc::new(AtomicU64::new(1)); + let (terminal_fault_tx, terminal_fault_rx) = watch::channel(None::<Fault>); + + let reader_task = { + let pending = Arc::clone(&pending); + let terminal_fault_tx = terminal_fault_tx.clone(); + tokio::spawn(async move { + read_stdout_loop(generation, stdout, pending, terminal_fault_tx).await; + }) + }; + + let stderr_task = tokio::spawn(async move { + stream_stderr(generation, stderr).await; + }); + + let handle = WorkerHandle { + generation, + child, + writer, + pending, + next_request_id, + terminal_fault_rx, + reader_task: Arc::new(Mutex::new(Some(reader_task))), + stderr_task: Arc::new(Mutex::new(Some(stderr_task))), + }; + + let initialize_params = build_initialize_params(config)?; + let startup = handle + .send_request("initialize", &initialize_params, config.startup_timeout) + .await; + if let Err(error) = startup { + handle.terminate().await; + return Err(map_worker_request_error(generation, error)); + } + + let initialized_params = json!({}); + let initialized_result = handle + .send_notification("initialized", &initialized_params) + .await + .map_err(|fault| { + handle_fault_notification(generation, "initialized notification failed", fault) + }); + if let Err(fault) = initialized_result { + handle.terminate().await; + return Err(fault); + } + + Ok(handle) +} + +fn map_worker_request_error(generation: Generation, error: WorkerRequestError) -> Fault { + match error { + WorkerRequestError::Fault(fault) => fault, + WorkerRequestError::Response(response) => Fault::new( + generation, + FaultClass::Protocol, + FaultCode::InvalidFrame, + FaultDetail::new(format!( + "initialize returned LSP error {}: {}", + response.code, response.message + )), + ), + } +} + +fn handle_fault_notification(generation: Generation, context: &'static str, fault: Fault) -> Fault { + let detail = FaultDetail::new(format!("{context}: {}", fault.detail.message)); + Fault::new(generation, fault.class, fault.code, detail) +} + +fn build_initialize_params(config: &EngineConfig) -> Result<Value, Fault> { + let root_uri = Url::from_directory_path(config.workspace_root.as_path()).map_err(|()| { + Fault::new( + Generation::genesis(), + FaultClass::Protocol, + FaultCode::InvalidFrame, + FaultDetail::new("workspace root cannot be represented as file URI"), + ) + })?; + let folder_name = config + .workspace_root + .as_path() + .file_name() + .and_then(|value| value.to_str()) + .unwrap_or("workspace") + .to_owned(); + let root_uri_string = root_uri.to_string(); + Ok(json!({ + "processId": std::process::id(), + "rootUri": root_uri_string.clone(), + "capabilities": build_client_capabilities(), + "workspaceFolders": [{ + "uri": root_uri_string, + "name": folder_name, + }], + "trace": "off", + "clientInfo": { + "name": "adequate-rust-mcp", + "version": env!("CARGO_PKG_VERSION"), + } + })) +} + +fn build_client_capabilities() -> Value { + let symbol_kind_values = (1_u32..=26).collect::<Vec<_>>(); + json!({ + "workspace": { + "applyEdit": true, + "workspaceEdit": { + "documentChanges": true, + "resourceOperations": ["create", "rename", "delete"], + }, + "symbol": { + "dynamicRegistration": false, + "resolveSupport": { + "properties": ["location.range", "containerName"], + }, + }, + "diagnostics": { + "refreshSupport": true, + }, + "executeCommand": { + "dynamicRegistration": false, + }, + "workspaceFolders": true, + "configuration": true, + }, + "textDocument": { + "synchronization": { + "dynamicRegistration": false, + "willSave": false, + "didSave": true, + "willSaveWaitUntil": false, + }, + "hover": { + "dynamicRegistration": false, + "contentFormat": ["markdown", "plaintext"], + }, + "definition": { + "dynamicRegistration": false, + "linkSupport": true, + }, + "declaration": { + "dynamicRegistration": false, + "linkSupport": true, + }, + "typeDefinition": { + "dynamicRegistration": false, + "linkSupport": true, + }, + "implementation": { + "dynamicRegistration": false, + "linkSupport": true, + }, + "references": { + "dynamicRegistration": false, + }, + "documentHighlight": { + "dynamicRegistration": false, + }, + "documentSymbol": { + "dynamicRegistration": false, + "hierarchicalDocumentSymbolSupport": true, + "symbolKind": { + "valueSet": symbol_kind_values, + }, + }, + "completion": { + "dynamicRegistration": false, + "contextSupport": true, + "completionItem": { + "snippetSupport": true, + "documentationFormat": ["markdown", "plaintext"], + "resolveSupport": { + "properties": ["documentation", "detail", "additionalTextEdits"], + }, + }, + }, + "signatureHelp": { + "dynamicRegistration": false, + }, + "codeAction": { + "dynamicRegistration": false, + "isPreferredSupport": true, + "codeActionLiteralSupport": { + "codeActionKind": { + "valueSet": [ + "", + "quickfix", + "refactor", + "refactor.extract", + "refactor.inline", + "refactor.rewrite", + "source", + "source.organizeImports", + ], + }, + }, + }, + "codeLens": { + "dynamicRegistration": false, + }, + "documentLink": { + "dynamicRegistration": false, + "tooltipSupport": true, + }, + "colorProvider": { + "dynamicRegistration": false, + }, + "linkedEditingRange": { + "dynamicRegistration": false, + }, + "rename": { + "dynamicRegistration": false, + "prepareSupport": true, + }, + "typeHierarchy": { + "dynamicRegistration": false, + }, + "inlineValue": { + "dynamicRegistration": false, + }, + "moniker": { + "dynamicRegistration": false, + }, + "diagnostic": { + "dynamicRegistration": false, + }, + "documentFormatting": { + "dynamicRegistration": false, + }, + "documentRangeFormatting": { + "dynamicRegistration": false, + }, + "documentOnTypeFormatting": { + "dynamicRegistration": false, + }, + "foldingRange": { + "dynamicRegistration": false, + }, + "selectionRange": { + "dynamicRegistration": false, + }, + "inlayHint": { + "dynamicRegistration": false, + "resolveSupport": { + "properties": ["tooltip", "textEdits", "label.tooltip", "label.location", "label.command"], + }, + }, + "semanticTokens": { + "dynamicRegistration": false, + "requests": { + "full": { + "delta": true, + }, + "range": true, + }, + "tokenTypes": [ + "namespace", "type", "class", "enum", "interface", "struct", "typeParameter", + "parameter", "variable", "property", "enumMember", "event", "function", + "method", "macro", "keyword", "modifier", "comment", "string", "number", + "regexp", "operator", + ], + "tokenModifiers": [ + "declaration", "definition", "readonly", "static", "deprecated", "abstract", + "async", "modification", "documentation", "defaultLibrary", + ], + "formats": ["relative"], + "multilineTokenSupport": true, + "overlappingTokenSupport": true, + }, + "publishDiagnostics": { + "relatedInformation": true, + }, + }, + "window": { + "workDoneProgress": true, + }, + "general": { + "positionEncodings": ["utf-8", "utf-16"], + }, + }) +} + +async fn stream_stderr(generation: Generation, stderr: tokio::process::ChildStderr) { + let mut reader = BufReader::new(stderr).lines(); + loop { + match reader.next_line().await { + Ok(Some(line)) => { + debug!( + generation = generation.get(), + "rust-analyzer stderr: {line}" + ); + } + Ok(None) => break, + Err(error) => { + debug!( + generation = generation.get(), + "rust-analyzer stderr stream failed: {error}" + ); + break; + } + } + } +} + +async fn read_stdout_loop( + generation: Generation, + stdout: ChildStdout, + pending: Arc<Mutex<HashMap<u64, oneshot::Sender<PendingOutcome>>>>, + terminal_fault_tx: watch::Sender<Option<Fault>>, +) { + let mut reader = BufReader::new(stdout); + loop { + match read_frame(&mut reader).await { + Ok(frame) => { + if let Err(fault) = dispatch_frame(generation, &pending, &frame).await { + emit_terminal_fault(&terminal_fault_tx, &pending, fault).await; + break; + } + } + Err(error) => { + let fault = classify_io_fault( + generation, + FaultClass::Transport, + "failed to read frame", + error, + ); + emit_terminal_fault(&terminal_fault_tx, &pending, fault).await; + break; + } + } + } +} + +async fn emit_terminal_fault( + terminal_fault_tx: &watch::Sender<Option<Fault>>, + pending: &Arc<Mutex<HashMap<u64, oneshot::Sender<PendingOutcome>>>>, + fault: Fault, +) { + if let Err(error) = terminal_fault_tx.send(Some(fault.clone())) { + warn!("failed to publish terminal fault: {error}"); + } + let mut pending_guard = pending.lock().await; + for sender in pending_guard.drain().map(|(_id, sender)| sender) { + if let Err(outcome) = sender.send(PendingOutcome::TransportFault(fault.clone())) { + drop(outcome); + } + } +} + +async fn dispatch_frame( + generation: Generation, + pending: &Arc<Mutex<HashMap<u64, oneshot::Sender<PendingOutcome>>>>, + frame: &[u8], +) -> Result<(), Fault> { + let value: Value = serde_json::from_slice(frame).map_err(|error| { + Fault::new( + generation, + FaultClass::Protocol, + FaultCode::InvalidJson, + FaultDetail::new(format!("failed to deserialize JSON-RPC frame: {error}")), + ) + })?; + + let response_id = value.get("id").and_then(Value::as_u64); + let Some(response_id) = response_id else { + return Ok(()); + }; + + let mut pending_guard = pending.lock().await; + let Some(sender) = pending_guard.remove(&response_id) else { + warn!( + generation = generation.get(), + response_id, "received response for unknown request id" + ); + return Ok(()); + }; + drop(pending_guard); + + if let Some(result) = value.get("result") { + if let Err(outcome) = sender.send(PendingOutcome::Result(result.clone())) { + drop(outcome); + } + return Ok(()); + } + + if let Some(error_value) = value.get("error") { + let error: RpcErrorPayload = + serde_json::from_value(error_value.clone()).map_err(|error| { + Fault::new( + generation, + FaultClass::Protocol, + FaultCode::InvalidJson, + FaultDetail::new(format!( + "failed to deserialize JSON-RPC error payload: {error}" + )), + ) + })?; + if let Err(outcome) = sender.send(PendingOutcome::ResponseError(error)) { + drop(outcome); + } + return Ok(()); + } + + Err(Fault::new( + generation, + FaultClass::Protocol, + FaultCode::InvalidFrame, + FaultDetail::new("response frame missing both result and error"), + )) +} + +async fn read_frame(reader: &mut BufReader<ChildStdout>) -> Result<Vec<u8>, io::Error> { + let mut content_length = None::<usize>; + loop { + let mut header_line = String::new(); + let bytes_read = reader.read_line(&mut header_line).await?; + if bytes_read == 0 { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "EOF while reading headers", + )); + } + + if header_line == "\r\n" || header_line == "\n" { + break; + } + + let trimmed = header_line.trim_end_matches(['\r', '\n']); + if let Some(length) = trimmed.strip_prefix("Content-Length:") { + let parsed = length.trim().parse::<usize>().map_err(|parse_error| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("invalid Content-Length header: {parse_error}"), + ) + })?; + content_length = Some(parsed); + } + } + + let length = content_length.ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "missing Content-Length header") + })?; + + let mut payload = vec![0_u8; length]; + let _bytes_read = reader.read_exact(&mut payload).await?; + Ok(payload) +} + +async fn write_frame(writer: &mut ChildStdin, value: &Value) -> Result<(), io::Error> { + let payload = serde_json::to_vec(value).map_err(|error| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("failed to serialize JSON-RPC payload: {error}"), + ) + })?; + let header = format!("Content-Length: {}\r\n\r\n", payload.len()); + writer.write_all(header.as_bytes()).await?; + writer.write_all(&payload).await?; + writer.flush().await?; + Ok(()) +} + +fn classify_io_fault( + generation: Generation, + class: FaultClass, + context: &'static str, + error: io::Error, +) -> Fault { + let code = match error.kind() { + io::ErrorKind::BrokenPipe => FaultCode::BrokenPipe, + io::ErrorKind::UnexpectedEof => FaultCode::UnexpectedEof, + _ => match class { + FaultClass::Process => FaultCode::SpawnFailed, + _ => FaultCode::InvalidFrame, + }, + }; + Fault::new( + generation, + class, + code, + FaultDetail::new(format!("{context}: {error}")), + ) +} diff --git a/crates/ra-mcp-engine/src/supervisor.rs b/crates/ra-mcp-engine/src/supervisor.rs new file mode 100644 index 0000000..f0c7ea6 --- /dev/null +++ b/crates/ra-mcp-engine/src/supervisor.rs @@ -0,0 +1,1257 @@ +use crate::{ + config::EngineConfig, + error::{EngineError, EngineResult}, + lsp_transport::{WorkerHandle, WorkerRequestError, spawn_worker}, +}; +use lsp_types::{ + DiagnosticSeverity, GotoDefinitionResponse, Hover, HoverContents, Location, LocationLink, + MarkedString, Position, Range, Uri, WorkspaceEdit, +}; +use ra_mcp_domain::{ + fault::{Fault, RecoveryDirective}, + lifecycle::{DynamicLifecycle, LifecycleSnapshot}, + types::{ + InvariantViolation, OneIndexedColumn, OneIndexedLine, SourceFilePath, SourceLocation, + SourcePoint, SourcePosition, SourceRange, + }, +}; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use serde_json::Value; +use std::{ + cmp::min, + collections::HashMap, + fs, + sync::Arc, + time::{Duration, Instant, SystemTime}, +}; +use tokio::{sync::Mutex, time::sleep}; +use tracing::{debug, warn}; +use url::Url; + +/// Hover response payload. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct HoverPayload { + /// Rendered markdown/text content, if available. + pub rendered: Option<String>, + /// Symbol range, if rust-analyzer provided one. + pub range: Option<SourceRange>, +} + +/// Diagnostic severity level. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum DiagnosticLevel { + /// Error severity. + Error, + /// Warning severity. + Warning, + /// Informational severity. + Information, + /// Hint severity. + Hint, +} + +/// One diagnostic record. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct DiagnosticEntry { + /// Affected range. + pub range: SourceRange, + /// Severity. + pub level: DiagnosticLevel, + /// Optional diagnostic code. + pub code: Option<String>, + /// User-facing diagnostic message. + pub message: String, +} + +/// Diagnostics report for a single file. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct DiagnosticsReport { + /// Entries returned by rust-analyzer. + pub diagnostics: Vec<DiagnosticEntry>, +} + +/// Summary of rename operation impact. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct RenameReport { + /// Number of files touched by the edit. + pub files_touched: u64, + /// Number of text edits in total. + pub edits_applied: u64, +} + +/// Aggregate runtime telemetry snapshot for engine behavior. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct TelemetrySnapshot { + /// Process uptime in milliseconds. + pub uptime_ms: u64, + /// Current lifecycle snapshot. + pub lifecycle: LifecycleSnapshot, + /// Number of consecutive failures currently tracked by supervisor. + pub consecutive_failures: u32, + /// Number of worker restarts performed. + pub restart_count: u64, + /// Global counters across all requests. + pub totals: TelemetryTotals, + /// Per-method counters and latency aggregates. + pub methods: Vec<MethodTelemetrySnapshot>, + /// Last fault that triggered worker restart, if any. + pub last_fault: Option<Fault>, +} + +/// Total request/fault counters. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct TelemetryTotals { + /// Total request attempts issued to rust-analyzer. + pub request_count: u64, + /// Successful request attempts. + pub success_count: u64, + /// LSP response error attempts. + pub response_error_count: u64, + /// Transport/protocol fault attempts. + pub transport_fault_count: u64, + /// Retry attempts performed. + pub retry_count: u64, +} + +/// Per-method telemetry aggregate. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct MethodTelemetrySnapshot { + /// LSP method name. + pub method: String, + /// Total request attempts for this method. + pub request_count: u64, + /// Successful attempts. + pub success_count: u64, + /// LSP response error attempts. + pub response_error_count: u64, + /// Transport/protocol fault attempts. + pub transport_fault_count: u64, + /// Retry attempts for this method. + pub retry_count: u64, + /// Last observed attempt latency in milliseconds. + pub last_latency_ms: Option<u64>, + /// Maximum observed attempt latency in milliseconds. + pub max_latency_ms: u64, + /// Average attempt latency in milliseconds. + pub avg_latency_ms: u64, + /// Last error detail for this method, if any. + pub last_error: Option<String>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum RequestMethod { + Hover, + Definition, + References, + Rename, + DocumentDiagnostic, + Raw(&'static str), +} + +impl RequestMethod { + const fn as_lsp_method(self) -> &'static str { + match self { + Self::Hover => "textDocument/hover", + Self::Definition => "textDocument/definition", + Self::References => "textDocument/references", + Self::Rename => "textDocument/rename", + Self::DocumentDiagnostic => "textDocument/diagnostic", + Self::Raw(method) => method, + } + } + + fn retry_delay(self, payload: &crate::lsp_transport::RpcErrorPayload) -> Option<Duration> { + if self.supports_transient_response_retry() + && is_transient_response_error(payload.code, payload.message.as_str()) + { + return Some(self.transient_response_retry_delay()); + } + let retryable_method = matches!( + self.as_lsp_method(), + "textDocument/rename" + | "textDocument/prepareRename" + | "textDocument/definition" + | "textDocument/references" + ); + if !retryable_method + || payload.code != -32602 + || !payload.message.contains("No references found at position") + { + return None; + } + match self.as_lsp_method() { + "textDocument/rename" | "textDocument/prepareRename" => { + Some(Duration::from_millis(1500)) + } + _ => Some(Duration::from_millis(250)), + } + } + + const fn supports_transient_response_retry(self) -> bool { + matches!( + self, + Self::Hover + | Self::Definition + | Self::References + | Self::Rename + | Self::DocumentDiagnostic + ) + } + + fn transient_response_retry_delay(self) -> Duration { + match self { + Self::DocumentDiagnostic => Duration::from_millis(250), + Self::Rename => Duration::from_millis(350), + Self::Hover | Self::Definition | Self::References => Duration::from_millis(150), + Self::Raw(_) => Duration::from_millis(0), + } + } +} + +fn is_transient_response_error(code: i64, message: &str) -> bool { + let normalized = message.to_ascii_lowercase(); + code == -32801 + || code == -32802 + || normalized.contains("content modified") + || normalized.contains("document changed") + || normalized.contains("server cancelled") + || normalized.contains("request cancelled") + || normalized.contains("request canceled") +} + +#[derive(Debug, Clone, Serialize)] +struct TextDocumentIdentifierWire { + uri: String, +} + +#[derive(Debug, Clone, Copy, Serialize)] +struct PositionWire { + line: u32, + character: u32, +} + +impl From<SourcePoint> for PositionWire { + fn from(value: SourcePoint) -> Self { + Self { + line: value.line().to_zero_indexed(), + character: value.column().to_zero_indexed(), + } + } +} + +#[derive(Debug, Clone, Serialize)] +struct TextDocumentPositionParamsWire { + #[serde(rename = "textDocument")] + text_document: TextDocumentIdentifierWire, + position: PositionWire, +} + +#[derive(Debug, Clone, Serialize)] +struct ReferencesContextWire { + #[serde(rename = "includeDeclaration")] + include_declaration: bool, +} + +#[derive(Debug, Clone, Serialize)] +struct ReferencesParamsWire { + #[serde(rename = "textDocument")] + text_document: TextDocumentIdentifierWire, + position: PositionWire, + context: ReferencesContextWire, +} + +#[derive(Debug, Clone, Serialize)] +struct RenameParamsWire { + #[serde(rename = "textDocument")] + text_document: TextDocumentIdentifierWire, + position: PositionWire, + #[serde(rename = "newName")] + new_name: String, +} + +#[derive(Debug, Clone, Serialize)] +struct DocumentDiagnosticParamsWire { + #[serde(rename = "textDocument")] + text_document: TextDocumentIdentifierWire, +} + +#[derive(Debug, Clone, Serialize)] +struct VersionedTextDocumentIdentifierWire { + uri: String, + version: i32, +} + +#[derive(Debug, Clone, Serialize)] +struct TextDocumentContentChangeEventWire { + text: String, +} + +#[derive(Debug, Clone, Serialize)] +struct DidChangeTextDocumentParamsWire { + #[serde(rename = "textDocument")] + text_document: VersionedTextDocumentIdentifierWire, + #[serde(rename = "contentChanges")] + content_changes: Vec<TextDocumentContentChangeEventWire>, +} + +#[derive(Debug, Clone, Serialize)] +struct TextDocumentItemWire { + uri: String, + #[serde(rename = "languageId")] + language_id: &'static str, + version: i32, + text: String, +} + +#[derive(Debug, Clone, Serialize)] +struct DidOpenTextDocumentParamsWire { + #[serde(rename = "textDocument")] + text_document: TextDocumentItemWire, +} + +/// Resilient engine façade. +#[derive(Clone)] +pub struct Engine { + supervisor: Arc<Mutex<Supervisor>>, +} + +struct Supervisor { + config: EngineConfig, + lifecycle: DynamicLifecycle, + worker: Option<WorkerHandle>, + consecutive_failures: u32, + open_documents: HashMap<SourceFilePath, OpenDocumentState>, + telemetry: TelemetryState, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct OpenDocumentState { + version: i32, + fingerprint: SourceFileFingerprint, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct SourceFileFingerprint { + byte_len: u64, + modified_nanos_since_epoch: u128, +} + +#[derive(Debug)] +struct TelemetryState { + started_at: Instant, + totals: TelemetryTotalsState, + methods: HashMap<&'static str, MethodTelemetryState>, + restart_count: u64, + last_fault: Option<Fault>, +} + +#[derive(Debug, Default)] +struct TelemetryTotalsState { + request_count: u64, + success_count: u64, + response_error_count: u64, + transport_fault_count: u64, + retry_count: u64, +} + +#[derive(Debug, Default)] +struct MethodTelemetryState { + request_count: u64, + success_count: u64, + response_error_count: u64, + transport_fault_count: u64, + retry_count: u64, + total_latency_ms: u128, + last_latency_ms: Option<u64>, + max_latency_ms: u64, + last_error: Option<String>, +} + +impl Engine { + /// Creates a new engine. + #[must_use] + pub fn new(config: EngineConfig) -> Self { + Self { + supervisor: Arc::new(Mutex::new(Supervisor::new(config))), + } + } + + /// Returns current lifecycle snapshot. + pub async fn lifecycle_snapshot(&self) -> LifecycleSnapshot { + let supervisor = self.supervisor.lock().await; + supervisor.snapshot() + } + + /// Returns aggregate request/fault telemetry snapshot. + pub async fn telemetry_snapshot(&self) -> TelemetrySnapshot { + let supervisor = self.supervisor.lock().await; + supervisor.telemetry_snapshot() + } + + /// Executes hover request. + pub async fn hover(&self, position: SourcePosition) -> EngineResult<HoverPayload> { + let document_hint = Some(position.file_path().clone()); + let request = text_document_position_params(&position)?; + let hover = self + .issue_typed_request::<_, Option<Hover>>(RequestMethod::Hover, &request, document_hint) + .await?; + let payload = hover + .map(|hover| -> Result<HoverPayload, EngineError> { + let range = hover + .range + .map(|range| range_to_source_range(position.file_path(), range)) + .transpose()?; + Ok(HoverPayload { + rendered: Some(render_hover_contents(hover.contents)), + range, + }) + }) + .transpose()? + .unwrap_or(HoverPayload { + rendered: None, + range: None, + }); + Ok(payload) + } + + /// Executes definition request. + pub async fn definition(&self, position: SourcePosition) -> EngineResult<Vec<SourceLocation>> { + let document_hint = Some(position.file_path().clone()); + let request = text_document_position_params(&position)?; + let parsed = self + .issue_typed_request::<_, Option<GotoDefinitionResponse>>( + RequestMethod::Definition, + &request, + document_hint, + ) + .await?; + let locations = match parsed { + None => Vec::new(), + Some(GotoDefinitionResponse::Scalar(location)) => { + vec![source_location_from_lsp_location(location)?] + } + Some(GotoDefinitionResponse::Array(locations)) => locations + .into_iter() + .map(source_location_from_lsp_location) + .collect::<Result<Vec<_>, _>>()?, + Some(GotoDefinitionResponse::Link(links)) => links + .into_iter() + .map(source_location_from_lsp_link) + .collect::<Result<Vec<_>, _>>()?, + }; + Ok(locations) + } + + /// Executes references request. + pub async fn references(&self, position: SourcePosition) -> EngineResult<Vec<SourceLocation>> { + let request = ReferencesParamsWire { + text_document: text_document_identifier(position.file_path())?, + position: PositionWire::from(position.point()), + context: ReferencesContextWire { + include_declaration: true, + }, + }; + let parsed = self + .issue_typed_request::<_, Option<Vec<Location>>>( + RequestMethod::References, + &request, + Some(position.file_path().clone()), + ) + .await?; + parsed + .unwrap_or_default() + .into_iter() + .map(source_location_from_lsp_location) + .collect::<Result<Vec<_>, _>>() + } + + /// Executes rename request. + pub async fn rename_symbol( + &self, + position: SourcePosition, + new_name: String, + ) -> EngineResult<RenameReport> { + let request = RenameParamsWire { + text_document: text_document_identifier(position.file_path())?, + position: PositionWire::from(position.point()), + new_name, + }; + let edit = self + .issue_typed_request::<_, WorkspaceEdit>( + RequestMethod::Rename, + &request, + Some(position.file_path().clone()), + ) + .await?; + Ok(summarize_workspace_edit(edit)) + } + + /// Executes document diagnostics request. + pub async fn diagnostics(&self, file_path: SourceFilePath) -> EngineResult<DiagnosticsReport> { + let request = DocumentDiagnosticParamsWire { + text_document: text_document_identifier(&file_path)?, + }; + let response = self + .issue_request( + RequestMethod::DocumentDiagnostic, + &request, + Some(file_path.clone()), + ) + .await?; + parse_diagnostics_report(&file_path, response) + } + + /// Executes an arbitrary typed LSP request and returns raw JSON payload. + pub async fn raw_lsp_request( + &self, + method: &'static str, + params: Value, + ) -> EngineResult<Value> { + let document_hint = source_file_path_hint_from_request_params(¶ms)?; + self.issue_request(RequestMethod::Raw(method), ¶ms, document_hint) + .await + } + + async fn issue_typed_request<P, R>( + &self, + method: RequestMethod, + params: &P, + document_hint: Option<SourceFilePath>, + ) -> EngineResult<R> + where + P: Serialize, + R: DeserializeOwned, + { + let response = self.issue_request(method, params, document_hint).await?; + serde_json::from_value::<R>(response).map_err(|error| EngineError::InvalidPayload { + method: method.as_lsp_method(), + message: error.to_string(), + }) + } + + async fn issue_request<P>( + &self, + method: RequestMethod, + params: &P, + document_hint: Option<SourceFilePath>, + ) -> EngineResult<Value> + where + P: Serialize, + { + let max_attempts = 2_u8; + let mut attempt = 0_u8; + while attempt < max_attempts { + attempt = attempt.saturating_add(1); + let (worker, request_timeout) = { + let mut supervisor = self.supervisor.lock().await; + let worker = supervisor.ensure_worker().await?; + if let Some(file_path) = document_hint.as_ref() { + supervisor.synchronize_document(&worker, file_path).await?; + } + (worker, supervisor.request_timeout()) + }; + + let attempt_started_at = Instant::now(); + let result = worker + .send_request(method.as_lsp_method(), params, request_timeout) + .await; + let latency = attempt_started_at.elapsed(); + match result { + Ok(value) => { + let mut supervisor = self.supervisor.lock().await; + supervisor.record_success(method.as_lsp_method(), latency); + return Ok(value); + } + Err(WorkerRequestError::Response(payload)) => { + let retry_delay = (attempt < max_attempts) + .then(|| method.retry_delay(&payload)) + .flatten(); + let should_retry = retry_delay.is_some(); + { + let mut supervisor = self.supervisor.lock().await; + supervisor.record_response_error( + method.as_lsp_method(), + latency, + payload.code, + format_lsp_response_error_detail(&payload), + should_retry, + ); + } + + if let Some(retry_delay) = retry_delay { + debug!( + attempt, + method = method.as_lsp_method(), + code = payload.code, + delay_ms = retry_delay.as_millis(), + "retrying request after transient lsp response error" + ); + sleep(retry_delay).await; + continue; + } + return Err(EngineError::from(payload)); + } + Err(WorkerRequestError::Fault(fault)) => { + let directive = fault.directive(); + let will_retry = matches!( + directive, + RecoveryDirective::RetryInPlace | RecoveryDirective::RestartAndReplay + ) && attempt < max_attempts; + { + let mut supervisor = self.supervisor.lock().await; + supervisor.record_transport_fault( + method.as_lsp_method(), + latency, + fault.detail.message.clone(), + will_retry, + ); + } + + match directive { + RecoveryDirective::RetryInPlace => { + debug!( + attempt, + method = method.as_lsp_method(), + "retrying request in-place after fault" + ); + if attempt >= max_attempts { + return Err(EngineError::Fault(fault)); + } + } + RecoveryDirective::RestartAndReplay => { + let mut supervisor = self.supervisor.lock().await; + supervisor.record_fault(fault.clone()).await?; + if attempt >= max_attempts { + return Err(EngineError::Fault(fault)); + } + debug!( + attempt, + method = method.as_lsp_method(), + "restarting worker and replaying request" + ); + } + RecoveryDirective::AbortRequest => { + let mut supervisor = self.supervisor.lock().await; + supervisor.record_fault(fault.clone()).await?; + return Err(EngineError::Fault(fault)); + } + } + } + } + } + Err(EngineError::Fault(Fault::new( + self.lifecycle_generation().await, + ra_mcp_domain::fault::FaultClass::Resource, + ra_mcp_domain::fault::FaultCode::RequestTimedOut, + ra_mcp_domain::fault::FaultDetail::new(format!( + "exhausted retries for method {}", + method.as_lsp_method() + )), + ))) + } + + async fn lifecycle_generation(&self) -> ra_mcp_domain::types::Generation { + let supervisor = self.supervisor.lock().await; + supervisor.generation() + } +} + +impl TelemetryState { + fn new() -> Self { + Self { + started_at: Instant::now(), + totals: TelemetryTotalsState::default(), + methods: HashMap::new(), + restart_count: 0, + last_fault: None, + } + } + + fn record_success(&mut self, method: &'static str, latency: Duration) { + self.totals.request_count = self.totals.request_count.saturating_add(1); + self.totals.success_count = self.totals.success_count.saturating_add(1); + let entry = self.methods.entry(method).or_default(); + entry.request_count = entry.request_count.saturating_add(1); + entry.success_count = entry.success_count.saturating_add(1); + entry.record_latency(latency); + entry.last_error = None; + } + + fn record_response_error( + &mut self, + method: &'static str, + latency: Duration, + detail: String, + retry_performed: bool, + ) { + self.totals.request_count = self.totals.request_count.saturating_add(1); + self.totals.response_error_count = self.totals.response_error_count.saturating_add(1); + if retry_performed { + self.totals.retry_count = self.totals.retry_count.saturating_add(1); + } + + let entry = self.methods.entry(method).or_default(); + entry.request_count = entry.request_count.saturating_add(1); + entry.response_error_count = entry.response_error_count.saturating_add(1); + if retry_performed { + entry.retry_count = entry.retry_count.saturating_add(1); + } + entry.record_latency(latency); + entry.last_error = Some(detail); + } + + fn record_transport_fault( + &mut self, + method: &'static str, + latency: Duration, + detail: String, + retry_performed: bool, + ) { + self.totals.request_count = self.totals.request_count.saturating_add(1); + self.totals.transport_fault_count = self.totals.transport_fault_count.saturating_add(1); + if retry_performed { + self.totals.retry_count = self.totals.retry_count.saturating_add(1); + } + + let entry = self.methods.entry(method).or_default(); + entry.request_count = entry.request_count.saturating_add(1); + entry.transport_fault_count = entry.transport_fault_count.saturating_add(1); + if retry_performed { + entry.retry_count = entry.retry_count.saturating_add(1); + } + entry.record_latency(latency); + entry.last_error = Some(detail); + } + + fn record_restart(&mut self, fault: Fault) { + self.restart_count = self.restart_count.saturating_add(1); + self.last_fault = Some(fault); + } + + fn snapshot( + &self, + lifecycle: LifecycleSnapshot, + consecutive_failures: u32, + ) -> TelemetrySnapshot { + let mut methods = self + .methods + .iter() + .map(|(method, entry)| MethodTelemetrySnapshot { + method: (*method).to_owned(), + request_count: entry.request_count, + success_count: entry.success_count, + response_error_count: entry.response_error_count, + transport_fault_count: entry.transport_fault_count, + retry_count: entry.retry_count, + last_latency_ms: entry.last_latency_ms, + max_latency_ms: entry.max_latency_ms, + avg_latency_ms: entry.average_latency_ms(), + last_error: entry.last_error.clone(), + }) + .collect::<Vec<_>>(); + methods.sort_by(|left, right| left.method.cmp(&right.method)); + + let uptime_ms = duration_millis_u64(self.started_at.elapsed()); + TelemetrySnapshot { + uptime_ms, + lifecycle, + consecutive_failures, + restart_count: self.restart_count, + totals: TelemetryTotals { + request_count: self.totals.request_count, + success_count: self.totals.success_count, + response_error_count: self.totals.response_error_count, + transport_fault_count: self.totals.transport_fault_count, + retry_count: self.totals.retry_count, + }, + methods, + last_fault: self.last_fault.clone(), + } + } +} + +impl MethodTelemetryState { + fn record_latency(&mut self, latency: Duration) { + let latency_ms = duration_millis_u64(latency); + self.last_latency_ms = Some(latency_ms); + self.max_latency_ms = self.max_latency_ms.max(latency_ms); + self.total_latency_ms = self.total_latency_ms.saturating_add(latency_ms as u128); + } + + fn average_latency_ms(&self) -> u64 { + if self.request_count == 0 { + return 0; + } + let avg = self.total_latency_ms / u128::from(self.request_count); + if avg > u128::from(u64::MAX) { + u64::MAX + } else { + avg as u64 + } + } +} + +fn duration_millis_u64(duration: Duration) -> u64 { + let millis = duration.as_millis(); + if millis > u128::from(u64::MAX) { + u64::MAX + } else { + millis as u64 + } +} + +impl Supervisor { + fn new(config: EngineConfig) -> Self { + Self { + config, + lifecycle: DynamicLifecycle::cold(), + worker: None, + consecutive_failures: 0, + open_documents: HashMap::new(), + telemetry: TelemetryState::new(), + } + } + + fn request_timeout(&self) -> Duration { + self.config.request_timeout + } + + async fn synchronize_document( + &mut self, + worker: &WorkerHandle, + file_path: &SourceFilePath, + ) -> EngineResult<()> { + let fingerprint = capture_source_file_fingerprint(file_path)?; + if let Some(existing) = self.open_documents.get_mut(file_path) { + if existing.fingerprint == fingerprint { + return Ok(()); + } + let text = fs::read_to_string(file_path.as_path())?; + let next_version = existing.version.saturating_add(1); + let params = DidChangeTextDocumentParamsWire { + text_document: VersionedTextDocumentIdentifierWire { + uri: file_uri_string_from_source_path(file_path)?, + version: next_version, + }, + content_changes: vec![TextDocumentContentChangeEventWire { text }], + }; + worker + .send_notification("textDocument/didChange", ¶ms) + .await + .map_err(EngineError::from)?; + existing.version = next_version; + existing.fingerprint = fingerprint; + return Ok(()); + } + + let text = fs::read_to_string(file_path.as_path())?; + let params = DidOpenTextDocumentParamsWire { + text_document: TextDocumentItemWire { + uri: file_uri_string_from_source_path(file_path)?, + language_id: "rust", + version: 1, + text, + }, + }; + worker + .send_notification("textDocument/didOpen", ¶ms) + .await + .map_err(EngineError::from)?; + let _previous = self.open_documents.insert( + file_path.clone(), + OpenDocumentState { + version: 1, + fingerprint, + }, + ); + Ok(()) + } + + fn snapshot(&self) -> LifecycleSnapshot { + self.lifecycle.snapshot() + } + + fn telemetry_snapshot(&self) -> TelemetrySnapshot { + let lifecycle = self.snapshot(); + self.telemetry + .snapshot(lifecycle, self.consecutive_failures) + } + + fn generation(&self) -> ra_mcp_domain::types::Generation { + let snapshot = self.snapshot(); + match snapshot { + LifecycleSnapshot::Cold { generation } + | LifecycleSnapshot::Starting { generation } + | LifecycleSnapshot::Ready { generation } + | LifecycleSnapshot::Recovering { generation, .. } => generation, + } + } + + async fn ensure_worker(&mut self) -> EngineResult<WorkerHandle> { + if let Some(worker) = self.worker.clone() { + if let Some(fault) = worker.terminal_fault() { + warn!( + generation = fault.generation.get(), + "worker marked terminal, recycling" + ); + self.record_fault(fault).await?; + } else { + return Ok(worker); + } + } + self.spawn_worker().await + } + + async fn spawn_worker(&mut self) -> EngineResult<WorkerHandle> { + self.lifecycle = self.lifecycle.clone().begin_startup()?; + let generation = self.generation(); + let started = spawn_worker(&self.config, generation).await; + match started { + Ok(worker) => { + self.lifecycle = self.lifecycle.clone().complete_startup()?; + self.worker = Some(worker.clone()); + self.consecutive_failures = 0; + self.open_documents.clear(); + Ok(worker) + } + Err(fault) => { + self.record_fault(fault.clone()).await?; + Err(EngineError::Fault(fault)) + } + } + } + + async fn record_fault(&mut self, fault: Fault) -> EngineResult<()> { + self.lifecycle = fracture_or_force_recovery(self.lifecycle.clone(), fault.clone())?; + self.consecutive_failures = self.consecutive_failures.saturating_add(1); + self.telemetry.record_restart(fault.clone()); + + if let Some(worker) = self.worker.take() { + worker.terminate().await; + } + self.open_documents.clear(); + + let delay = self.next_backoff_delay(); + debug!( + failures = self.consecutive_failures, + delay_ms = delay.as_millis(), + "applying restart backoff delay" + ); + sleep(delay).await; + Ok(()) + } + + fn record_success(&mut self, method: &'static str, latency: Duration) { + self.consecutive_failures = 0; + self.telemetry.record_success(method, latency); + } + + fn record_response_error( + &mut self, + method: &'static str, + latency: Duration, + code: i64, + message: String, + retry_performed: bool, + ) { + let detail = format!("code={code} message={message}"); + self.telemetry + .record_response_error(method, latency, detail, retry_performed); + } + + fn record_transport_fault( + &mut self, + method: &'static str, + latency: Duration, + detail: String, + retry_performed: bool, + ) { + self.telemetry + .record_transport_fault(method, latency, detail, retry_performed); + } + + fn next_backoff_delay(&self) -> Duration { + let exponent = self.consecutive_failures.saturating_sub(1); + let multiplier = if exponent >= 31 { + u32::MAX + } else { + 1_u32 << exponent + }; + let scaled = self.config.backoff_policy.floor.saturating_mul(multiplier); + min(scaled, self.config.backoff_policy.ceiling) + } +} + +fn fracture_or_force_recovery( + lifecycle: DynamicLifecycle, + fault: Fault, +) -> EngineResult<DynamicLifecycle> { + match lifecycle.clone().fracture(fault.clone()) { + Ok(next) => Ok(next), + Err(_error) => { + let started = lifecycle.begin_startup()?; + started.fracture(fault).map_err(EngineError::from) + } + } +} + +fn text_document_identifier( + file_path: &SourceFilePath, +) -> EngineResult<TextDocumentIdentifierWire> { + Ok(TextDocumentIdentifierWire { + uri: file_uri_string_from_source_path(file_path)?, + }) +} + +fn text_document_position_params( + position: &SourcePosition, +) -> EngineResult<TextDocumentPositionParamsWire> { + Ok(TextDocumentPositionParamsWire { + text_document: text_document_identifier(position.file_path())?, + position: PositionWire::from(position.point()), + }) +} + +fn format_lsp_response_error_detail(payload: &crate::lsp_transport::RpcErrorPayload) -> String { + let crate::lsp_transport::RpcErrorPayload { + code, + message, + data, + } = payload; + match data { + Some(data) => format!("code={code} message={message} data={data}"), + None => format!("code={code} message={message}"), + } +} + +fn file_uri_string_from_source_path(file_path: &SourceFilePath) -> EngineResult<String> { + let file_url = + Url::from_file_path(file_path.as_path()).map_err(|()| EngineError::InvalidFileUrl)?; + Ok(file_url.to_string()) +} + +fn source_file_path_hint_from_request_params( + params: &Value, +) -> EngineResult<Option<SourceFilePath>> { + let maybe_uri = params + .get("textDocument") + .and_then(Value::as_object) + .and_then(|document| document.get("uri")) + .and_then(Value::as_str); + let Some(uri) = maybe_uri else { + return Ok(None); + }; + let file_path = source_file_path_from_file_uri_str(uri)?; + Ok(Some(file_path)) +} + +fn source_file_path_from_file_uri_str(uri: &str) -> EngineResult<SourceFilePath> { + let file_url = Url::parse(uri).map_err(|_error| EngineError::InvalidFileUrl)?; + let file_path = file_url + .to_file_path() + .map_err(|()| EngineError::InvalidFileUrl)?; + SourceFilePath::try_new(file_path).map_err(EngineError::from) +} + +fn capture_source_file_fingerprint( + file_path: &SourceFilePath, +) -> EngineResult<SourceFileFingerprint> { + let metadata = fs::metadata(file_path.as_path())?; + let modified = metadata.modified().unwrap_or(SystemTime::UNIX_EPOCH); + let modified_nanos_since_epoch = modified + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap_or(Duration::ZERO) + .as_nanos(); + Ok(SourceFileFingerprint { + byte_len: metadata.len(), + modified_nanos_since_epoch, + }) +} + +fn source_location_from_lsp_link(link: LocationLink) -> EngineResult<SourceLocation> { + let uri = link.target_uri; + let range = link.target_selection_range; + source_location_from_uri_and_position(uri, range.start) +} + +fn source_location_from_lsp_location(location: Location) -> EngineResult<SourceLocation> { + source_location_from_uri_and_position(location.uri, location.range.start) +} + +fn source_location_from_uri_and_position( + uri: Uri, + position: Position, +) -> EngineResult<SourceLocation> { + let file_url = Url::parse(uri.as_str()).map_err(|_error| EngineError::InvalidFileUrl)?; + let path = file_url + .to_file_path() + .map_err(|()| EngineError::InvalidFileUrl)?; + let file_path = SourceFilePath::try_new(path)?; + let point = SourcePoint::new( + OneIndexedLine::try_new(u64::from(position.line).saturating_add(1))?, + OneIndexedColumn::try_new(u64::from(position.character).saturating_add(1))?, + ); + Ok(SourceLocation::new(file_path, point)) +} + +fn range_to_source_range( + file_path: &SourceFilePath, + range: Range, +) -> Result<SourceRange, InvariantViolation> { + let start = SourcePoint::new( + OneIndexedLine::try_new(u64::from(range.start.line).saturating_add(1))?, + OneIndexedColumn::try_new(u64::from(range.start.character).saturating_add(1))?, + ); + let end = SourcePoint::new( + OneIndexedLine::try_new(u64::from(range.end.line).saturating_add(1))?, + OneIndexedColumn::try_new(u64::from(range.end.character).saturating_add(1))?, + ); + SourceRange::try_new(file_path.clone(), start, end) +} + +fn render_hover_contents(contents: HoverContents) -> String { + match contents { + HoverContents::Scalar(marked_string) => marked_string_to_string(marked_string), + HoverContents::Array(items) => items + .into_iter() + .map(marked_string_to_string) + .collect::<Vec<_>>() + .join("\n"), + HoverContents::Markup(markup) => markup.value, + } +} + +fn marked_string_to_string(marked: MarkedString) -> String { + match marked { + MarkedString::String(value) => value, + MarkedString::LanguageString(language_string) => { + format!( + "```{}\n{}\n```", + language_string.language, language_string.value + ) + } + } +} + +fn summarize_workspace_edit(edit: WorkspaceEdit) -> RenameReport { + let mut touched = HashMap::<String, u64>::new(); + let mut edits_applied = 0_u64; + + if let Some(changes) = edit.changes { + for (uri, edits) in changes { + let edit_count = u64::try_from(edits.len()).unwrap_or(u64::MAX); + let _previous = touched.insert(uri.as_str().to_owned(), edit_count); + edits_applied = edits_applied.saturating_add(edit_count); + } + } + + if let Some(document_changes) = edit.document_changes { + match document_changes { + lsp_types::DocumentChanges::Edits(edits) => { + for document_edit in edits { + let uri = document_edit.text_document.uri; + let edit_count = u64::try_from(document_edit.edits.len()).unwrap_or(u64::MAX); + let _entry = touched + .entry(uri.as_str().to_owned()) + .and_modify(|count| *count = count.saturating_add(edit_count)) + .or_insert(edit_count); + edits_applied = edits_applied.saturating_add(edit_count); + } + } + lsp_types::DocumentChanges::Operations(operations) => { + edits_applied = edits_applied + .saturating_add(u64::try_from(operations.len()).unwrap_or(u64::MAX)); + for operation in operations { + match operation { + lsp_types::DocumentChangeOperation::Op(operation) => match operation { + lsp_types::ResourceOp::Create(create) => { + let _entry = + touched.entry(create.uri.as_str().to_owned()).or_insert(0); + } + lsp_types::ResourceOp::Rename(rename) => { + let _entry = touched + .entry(rename.new_uri.as_str().to_owned()) + .or_insert(0); + } + lsp_types::ResourceOp::Delete(delete) => { + let _entry = + touched.entry(delete.uri.as_str().to_owned()).or_insert(0); + } + }, + lsp_types::DocumentChangeOperation::Edit(edit) => { + let edit_count = u64::try_from(edit.edits.len()).unwrap_or(u64::MAX); + let _entry = touched + .entry(edit.text_document.uri.as_str().to_owned()) + .and_modify(|count| *count = count.saturating_add(edit_count)) + .or_insert(edit_count); + } + } + } + } + } + } + + RenameReport { + files_touched: u64::try_from(touched.len()).unwrap_or(u64::MAX), + edits_applied, + } +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "kind", rename_all = "lowercase")] +enum DiagnosticReportWire { + Full { items: Vec<DiagnosticWire> }, + Unchanged {}, +} + +#[derive(Debug, Deserialize)] +struct DiagnosticWire { + range: Range, + severity: Option<DiagnosticSeverity>, + code: Option<Value>, + message: String, +} + +fn parse_diagnostics_report( + file_path: &SourceFilePath, + value: Value, +) -> EngineResult<DiagnosticsReport> { + let parsed = serde_json::from_value::<DiagnosticReportWire>(value).map_err(|error| { + EngineError::InvalidPayload { + method: "textDocument/diagnostic", + message: error.to_string(), + } + })?; + match parsed { + DiagnosticReportWire::Unchanged {} => Ok(DiagnosticsReport { + diagnostics: Vec::new(), + }), + DiagnosticReportWire::Full { items } => { + let diagnostics = items + .into_iter() + .map(|item| { + let range = range_to_source_range(file_path, item.range)?; + let level = match item.severity.unwrap_or(DiagnosticSeverity::INFORMATION) { + DiagnosticSeverity::ERROR => DiagnosticLevel::Error, + DiagnosticSeverity::WARNING => DiagnosticLevel::Warning, + DiagnosticSeverity::INFORMATION => DiagnosticLevel::Information, + DiagnosticSeverity::HINT => DiagnosticLevel::Hint, + _ => DiagnosticLevel::Information, + }; + let code = item.code.map(|value| match value { + Value::String(message) => message, + Value::Number(number) => number.to_string(), + other => other.to_string(), + }); + Ok(DiagnosticEntry { + range, + level, + code, + message: item.message, + }) + }) + .collect::<Result<Vec<_>, InvariantViolation>>()?; + Ok(DiagnosticsReport { diagnostics }) + } + } +} diff --git a/crates/ra-mcp-engine/tests/engine_recovery.rs b/crates/ra-mcp-engine/tests/engine_recovery.rs new file mode 100644 index 0000000..a7f2db8 --- /dev/null +++ b/crates/ra-mcp-engine/tests/engine_recovery.rs @@ -0,0 +1,353 @@ +//! Integration tests for engine restart and transport recovery. + +use lsp_types as _; +use ra_mcp_domain::{ + lifecycle::LifecycleSnapshot, + types::{ + OneIndexedColumn, OneIndexedLine, SourceFilePath, SourcePoint, SourcePosition, + WorkspaceRoot, + }, +}; +use ra_mcp_engine::{BackoffPolicy, Engine, EngineConfig, EngineError}; +use serde as _; +use serde_json::{self, json}; +use serial_test::serial; +use std::{error::Error, fs, path::PathBuf, time::Duration}; +use tempfile::TempDir; +use thiserror as _; +use tracing as _; +use url as _; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn stable_fake_server_handles_core_requests() -> Result<(), Box<dyn Error>> { + let fixture = make_fixture()?; + let config = make_engine_config(&fixture, vec!["--mode".into(), "stable".into()])?; + let engine = Engine::new(config); + let position = fixture.position()?; + + let hover = engine.hover(position.clone()).await?; + assert_eq!(hover.rendered.as_deref(), Some("hover::ok")); + + let definitions = engine.definition(position.clone()).await?; + assert_eq!(definitions.len(), 1); + assert_eq!(definitions[0].line().get(), 3); + assert_eq!(definitions[0].column().get(), 4); + + let references = engine.references(position.clone()).await?; + assert_eq!(references.len(), 1); + + let rename = engine + .rename_symbol(position.clone(), "renamed".to_owned()) + .await?; + assert!(rename.files_touched >= 1); + assert!(rename.edits_applied >= 1); + + let diagnostics = engine.diagnostics(fixture.source_file_path()?).await?; + assert_eq!(diagnostics.diagnostics.len(), 1); + + let snapshot = engine.lifecycle_snapshot().await; + assert!(matches!(snapshot, LifecycleSnapshot::Ready { .. })); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn stable_fake_server_reports_success_telemetry() -> Result<(), Box<dyn Error>> { + let fixture = make_fixture()?; + let config = make_engine_config(&fixture, vec!["--mode".into(), "stable".into()])?; + let engine = Engine::new(config); + let position = fixture.position()?; + + let _hover = engine.hover(position.clone()).await?; + let _definition = engine.definition(position.clone()).await?; + let _references = engine.references(position.clone()).await?; + let _diagnostics = engine.diagnostics(fixture.source_file_path()?).await?; + + let telemetry = engine.telemetry_snapshot().await; + assert_eq!(telemetry.totals.request_count, 4); + assert_eq!(telemetry.totals.success_count, 4); + assert_eq!(telemetry.totals.response_error_count, 0); + assert_eq!(telemetry.totals.transport_fault_count, 0); + assert_eq!(telemetry.totals.retry_count, 0); + assert_eq!(telemetry.restart_count, 0); + assert!(telemetry.last_fault.is_none()); + assert_eq!(telemetry.consecutive_failures, 0); + + assert_method_counts( + telemetry.methods.as_slice(), + "textDocument/hover", + MethodExpectation::new(1, 1, 0, 0, 0), + ); + assert_method_counts( + telemetry.methods.as_slice(), + "textDocument/definition", + MethodExpectation::new(1, 1, 0, 0, 0), + ); + assert_method_counts( + telemetry.methods.as_slice(), + "textDocument/references", + MethodExpectation::new(1, 1, 0, 0, 0), + ); + assert_method_counts( + telemetry.methods.as_slice(), + "textDocument/diagnostic", + MethodExpectation::new(1, 1, 0, 0, 0), + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn diagnostics_retry_server_cancelled_response() -> Result<(), Box<dyn Error>> { + let fixture = make_fixture()?; + let config = make_engine_config( + &fixture, + vec![ + "--mode".into(), + "stable".into(), + "--diagnostic-cancel-count".into(), + "1".into(), + ], + )?; + let engine = Engine::new(config); + + let diagnostics = engine.diagnostics(fixture.source_file_path()?).await?; + assert_eq!(diagnostics.diagnostics.len(), 1); + + let telemetry = engine.telemetry_snapshot().await; + assert_eq!(telemetry.totals.request_count, 2); + assert_eq!(telemetry.totals.success_count, 1); + assert_eq!(telemetry.totals.response_error_count, 1); + assert_eq!(telemetry.totals.transport_fault_count, 0); + assert_eq!(telemetry.totals.retry_count, 1); + assert_eq!(telemetry.restart_count, 0); + assert_eq!(telemetry.consecutive_failures, 0); + assert_method_counts( + telemetry.methods.as_slice(), + "textDocument/diagnostic", + MethodExpectation::new(2, 1, 1, 0, 1), + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn engine_recovers_after_first_hover_crash() -> Result<(), Box<dyn Error>> { + let fixture = make_fixture()?; + let marker = fixture.path().join("crash-marker"); + let args = vec![ + "--mode".into(), + "crash_on_first_hover".into(), + "--crash-marker".into(), + marker.display().to_string(), + ]; + let config = make_engine_config(&fixture, args)?; + let engine = Engine::new(config); + + let hover = engine.hover(fixture.position()?).await?; + assert_eq!(hover.rendered.as_deref(), Some("hover::ok")); + assert!(marker.exists()); + + let snapshot = engine.lifecycle_snapshot().await; + let generation = if let LifecycleSnapshot::Ready { generation } = snapshot { + generation.get() + } else { + return Err("expected ready lifecycle snapshot after successful recovery".into()); + }; + assert!(generation >= 2); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn crash_recovery_records_transport_fault_retry_and_restart() -> Result<(), Box<dyn Error>> { + let fixture = make_fixture()?; + let marker = fixture.path().join("crash-marker"); + let args = vec![ + "--mode".into(), + "crash_on_first_hover".into(), + "--crash-marker".into(), + marker.display().to_string(), + ]; + let config = make_engine_config(&fixture, args)?; + let engine = Engine::new(config); + + let hover = engine.hover(fixture.position()?).await?; + assert_eq!(hover.rendered.as_deref(), Some("hover::ok")); + + let telemetry = engine.telemetry_snapshot().await; + assert_eq!(telemetry.totals.request_count, 2); + assert_eq!(telemetry.totals.success_count, 1); + assert_eq!(telemetry.totals.response_error_count, 0); + assert_eq!(telemetry.totals.transport_fault_count, 1); + assert_eq!(telemetry.totals.retry_count, 1); + assert_eq!(telemetry.restart_count, 1); + assert_eq!(telemetry.consecutive_failures, 0); + assert!(telemetry.last_fault.is_some()); + assert_method_counts( + telemetry.methods.as_slice(), + "textDocument/hover", + MethodExpectation::new(2, 1, 0, 1, 1), + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn response_error_requests_are_telemetered() -> Result<(), Box<dyn Error>> { + let fixture = make_fixture()?; + let config = make_engine_config(&fixture, vec!["--mode".into(), "stable".into()])?; + let engine = Engine::new(config); + + let invalid = engine + .raw_lsp_request("textDocument/notReal", json!({})) + .await; + match invalid { + Err(EngineError::LspResponse { .. }) => {} + other => return Err(format!("expected LSP response error, got {other:?}").into()), + } + + let telemetry = engine.telemetry_snapshot().await; + assert_eq!(telemetry.totals.request_count, 1); + assert_eq!(telemetry.totals.success_count, 0); + assert_eq!(telemetry.totals.response_error_count, 1); + assert_eq!(telemetry.totals.transport_fault_count, 0); + assert_eq!(telemetry.totals.retry_count, 0); + assert_eq!(telemetry.restart_count, 0); + assert_method_counts( + telemetry.methods.as_slice(), + "textDocument/notReal", + MethodExpectation::new(1, 0, 1, 0, 0), + ); + + Ok(()) +} + +#[derive(Debug, Clone, Copy)] +struct MethodExpectation { + request_count: u64, + success_count: u64, + response_error_count: u64, + transport_fault_count: u64, + retry_count: u64, +} + +impl MethodExpectation { + const fn new( + request_count: u64, + success_count: u64, + response_error_count: u64, + transport_fault_count: u64, + retry_count: u64, + ) -> Self { + Self { + request_count, + success_count, + response_error_count, + transport_fault_count, + retry_count, + } + } +} + +fn assert_method_counts( + methods: &[ra_mcp_engine::MethodTelemetrySnapshot], + method: &str, + expected: MethodExpectation, +) { + let maybe_entry = methods.iter().find(|entry| entry.method == method); + assert!( + maybe_entry.is_some(), + "expected telemetry entry for method `{method}`", + ); + let entry = if let Some(value) = maybe_entry { + value + } else { + return; + }; + assert_eq!(entry.request_count, expected.request_count); + assert_eq!(entry.success_count, expected.success_count); + assert_eq!(entry.response_error_count, expected.response_error_count); + assert_eq!(entry.transport_fault_count, expected.transport_fault_count); + assert_eq!(entry.retry_count, expected.retry_count); +} + +struct Fixture { + temp_dir: TempDir, +} + +impl Fixture { + fn path(&self) -> &std::path::Path { + self.temp_dir.path() + } + + fn source_file_path(&self) -> Result<SourceFilePath, Box<dyn Error>> { + let path = self.path().join("src").join("lib.rs"); + SourceFilePath::try_new(path).map_err(|error| error.to_string().into()) + } + + fn position(&self) -> Result<SourcePosition, Box<dyn Error>> { + let line = OneIndexedLine::try_new(1).map_err(|error| error.to_string())?; + let column = OneIndexedColumn::try_new(1).map_err(|error| error.to_string())?; + Ok(SourcePosition::new( + self.source_file_path()?, + SourcePoint::new(line, column), + )) + } +} + +fn make_fixture() -> Result<Fixture, Box<dyn Error>> { + let temp_dir = tempfile::tempdir()?; + let src_dir = temp_dir.path().join("src"); + fs::create_dir_all(&src_dir)?; + fs::write( + temp_dir.path().join("Cargo.toml"), + "[package]\nname = \"fixture\"\nversion = \"0.0.0\"\nedition = \"2024\"\n", + )?; + fs::write(src_dir.join("lib.rs"), "pub fn touch() -> i32 { 1 }\n")?; + Ok(Fixture { temp_dir }) +} + +fn make_engine_config( + fixture: &Fixture, + args: Vec<String>, +) -> Result<EngineConfig, Box<dyn Error>> { + let workspace_root = + WorkspaceRoot::try_new(fixture.path().to_path_buf()).map_err(|error| error.to_string())?; + let binary = fake_rust_analyzer_binary()?; + let backoff = BackoffPolicy::try_new(Duration::from_millis(5), Duration::from_millis(20)) + .map_err(|error| error.to_string())?; + EngineConfig::try_new( + workspace_root, + binary, + args, + Vec::new(), + Duration::from_secs(2), + Duration::from_secs(2), + backoff, + ) + .map_err(|error| error.to_string().into()) +} + +fn fake_rust_analyzer_binary() -> Result<PathBuf, Box<dyn Error>> { + if let Ok(path) = std::env::var("CARGO_BIN_EXE_fake-rust-analyzer") { + return Ok(PathBuf::from(path)); + } + if let Ok(path) = std::env::var("CARGO_BIN_EXE_fake_rust_analyzer") { + return Ok(PathBuf::from(path)); + } + let current = std::env::current_exe()?; + let deps_dir = current + .parent() + .ok_or_else(|| "failed to resolve test binary parent".to_owned())?; + let debug_dir = deps_dir + .parent() + .ok_or_else(|| "failed to resolve target debug directory".to_owned())?; + Ok(debug_dir.join("fake-rust-analyzer")) +} |