diff options
| author | main <main@swarm.moe> | 2026-03-19 15:49:41 -0400 |
|---|---|---|
| committer | main <main@swarm.moe> | 2026-03-19 15:49:41 -0400 |
| commit | fa1bd32800b65aab31ea732dd240261b4047522c (patch) | |
| tree | 2fd08af6f36b8beb3c7c941990becc1a0a091d62 /crates/ra-mcp-engine/src/supervisor.rs | |
| download | adequate-rust-mcp-1.0.0.zip | |
Release adequate-rust-mcp 1.0.0v1.0.0
Diffstat (limited to 'crates/ra-mcp-engine/src/supervisor.rs')
| -rw-r--r-- | crates/ra-mcp-engine/src/supervisor.rs | 1257 |
1 files changed, 1257 insertions, 0 deletions
diff --git a/crates/ra-mcp-engine/src/supervisor.rs b/crates/ra-mcp-engine/src/supervisor.rs new file mode 100644 index 0000000..f0c7ea6 --- /dev/null +++ b/crates/ra-mcp-engine/src/supervisor.rs @@ -0,0 +1,1257 @@ +use crate::{ + config::EngineConfig, + error::{EngineError, EngineResult}, + lsp_transport::{WorkerHandle, WorkerRequestError, spawn_worker}, +}; +use lsp_types::{ + DiagnosticSeverity, GotoDefinitionResponse, Hover, HoverContents, Location, LocationLink, + MarkedString, Position, Range, Uri, WorkspaceEdit, +}; +use ra_mcp_domain::{ + fault::{Fault, RecoveryDirective}, + lifecycle::{DynamicLifecycle, LifecycleSnapshot}, + types::{ + InvariantViolation, OneIndexedColumn, OneIndexedLine, SourceFilePath, SourceLocation, + SourcePoint, SourcePosition, SourceRange, + }, +}; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use serde_json::Value; +use std::{ + cmp::min, + collections::HashMap, + fs, + sync::Arc, + time::{Duration, Instant, SystemTime}, +}; +use tokio::{sync::Mutex, time::sleep}; +use tracing::{debug, warn}; +use url::Url; + +/// Hover response payload. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct HoverPayload { + /// Rendered markdown/text content, if available. + pub rendered: Option<String>, + /// Symbol range, if rust-analyzer provided one. + pub range: Option<SourceRange>, +} + +/// Diagnostic severity level. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum DiagnosticLevel { + /// Error severity. + Error, + /// Warning severity. + Warning, + /// Informational severity. + Information, + /// Hint severity. + Hint, +} + +/// One diagnostic record. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct DiagnosticEntry { + /// Affected range. + pub range: SourceRange, + /// Severity. + pub level: DiagnosticLevel, + /// Optional diagnostic code. + pub code: Option<String>, + /// User-facing diagnostic message. + pub message: String, +} + +/// Diagnostics report for a single file. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct DiagnosticsReport { + /// Entries returned by rust-analyzer. + pub diagnostics: Vec<DiagnosticEntry>, +} + +/// Summary of rename operation impact. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct RenameReport { + /// Number of files touched by the edit. + pub files_touched: u64, + /// Number of text edits in total. + pub edits_applied: u64, +} + +/// Aggregate runtime telemetry snapshot for engine behavior. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct TelemetrySnapshot { + /// Process uptime in milliseconds. + pub uptime_ms: u64, + /// Current lifecycle snapshot. + pub lifecycle: LifecycleSnapshot, + /// Number of consecutive failures currently tracked by supervisor. + pub consecutive_failures: u32, + /// Number of worker restarts performed. + pub restart_count: u64, + /// Global counters across all requests. + pub totals: TelemetryTotals, + /// Per-method counters and latency aggregates. + pub methods: Vec<MethodTelemetrySnapshot>, + /// Last fault that triggered worker restart, if any. + pub last_fault: Option<Fault>, +} + +/// Total request/fault counters. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct TelemetryTotals { + /// Total request attempts issued to rust-analyzer. + pub request_count: u64, + /// Successful request attempts. + pub success_count: u64, + /// LSP response error attempts. + pub response_error_count: u64, + /// Transport/protocol fault attempts. + pub transport_fault_count: u64, + /// Retry attempts performed. + pub retry_count: u64, +} + +/// Per-method telemetry aggregate. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct MethodTelemetrySnapshot { + /// LSP method name. + pub method: String, + /// Total request attempts for this method. + pub request_count: u64, + /// Successful attempts. + pub success_count: u64, + /// LSP response error attempts. + pub response_error_count: u64, + /// Transport/protocol fault attempts. + pub transport_fault_count: u64, + /// Retry attempts for this method. + pub retry_count: u64, + /// Last observed attempt latency in milliseconds. + pub last_latency_ms: Option<u64>, + /// Maximum observed attempt latency in milliseconds. + pub max_latency_ms: u64, + /// Average attempt latency in milliseconds. + pub avg_latency_ms: u64, + /// Last error detail for this method, if any. + pub last_error: Option<String>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum RequestMethod { + Hover, + Definition, + References, + Rename, + DocumentDiagnostic, + Raw(&'static str), +} + +impl RequestMethod { + const fn as_lsp_method(self) -> &'static str { + match self { + Self::Hover => "textDocument/hover", + Self::Definition => "textDocument/definition", + Self::References => "textDocument/references", + Self::Rename => "textDocument/rename", + Self::DocumentDiagnostic => "textDocument/diagnostic", + Self::Raw(method) => method, + } + } + + fn retry_delay(self, payload: &crate::lsp_transport::RpcErrorPayload) -> Option<Duration> { + if self.supports_transient_response_retry() + && is_transient_response_error(payload.code, payload.message.as_str()) + { + return Some(self.transient_response_retry_delay()); + } + let retryable_method = matches!( + self.as_lsp_method(), + "textDocument/rename" + | "textDocument/prepareRename" + | "textDocument/definition" + | "textDocument/references" + ); + if !retryable_method + || payload.code != -32602 + || !payload.message.contains("No references found at position") + { + return None; + } + match self.as_lsp_method() { + "textDocument/rename" | "textDocument/prepareRename" => { + Some(Duration::from_millis(1500)) + } + _ => Some(Duration::from_millis(250)), + } + } + + const fn supports_transient_response_retry(self) -> bool { + matches!( + self, + Self::Hover + | Self::Definition + | Self::References + | Self::Rename + | Self::DocumentDiagnostic + ) + } + + fn transient_response_retry_delay(self) -> Duration { + match self { + Self::DocumentDiagnostic => Duration::from_millis(250), + Self::Rename => Duration::from_millis(350), + Self::Hover | Self::Definition | Self::References => Duration::from_millis(150), + Self::Raw(_) => Duration::from_millis(0), + } + } +} + +fn is_transient_response_error(code: i64, message: &str) -> bool { + let normalized = message.to_ascii_lowercase(); + code == -32801 + || code == -32802 + || normalized.contains("content modified") + || normalized.contains("document changed") + || normalized.contains("server cancelled") + || normalized.contains("request cancelled") + || normalized.contains("request canceled") +} + +#[derive(Debug, Clone, Serialize)] +struct TextDocumentIdentifierWire { + uri: String, +} + +#[derive(Debug, Clone, Copy, Serialize)] +struct PositionWire { + line: u32, + character: u32, +} + +impl From<SourcePoint> for PositionWire { + fn from(value: SourcePoint) -> Self { + Self { + line: value.line().to_zero_indexed(), + character: value.column().to_zero_indexed(), + } + } +} + +#[derive(Debug, Clone, Serialize)] +struct TextDocumentPositionParamsWire { + #[serde(rename = "textDocument")] + text_document: TextDocumentIdentifierWire, + position: PositionWire, +} + +#[derive(Debug, Clone, Serialize)] +struct ReferencesContextWire { + #[serde(rename = "includeDeclaration")] + include_declaration: bool, +} + +#[derive(Debug, Clone, Serialize)] +struct ReferencesParamsWire { + #[serde(rename = "textDocument")] + text_document: TextDocumentIdentifierWire, + position: PositionWire, + context: ReferencesContextWire, +} + +#[derive(Debug, Clone, Serialize)] +struct RenameParamsWire { + #[serde(rename = "textDocument")] + text_document: TextDocumentIdentifierWire, + position: PositionWire, + #[serde(rename = "newName")] + new_name: String, +} + +#[derive(Debug, Clone, Serialize)] +struct DocumentDiagnosticParamsWire { + #[serde(rename = "textDocument")] + text_document: TextDocumentIdentifierWire, +} + +#[derive(Debug, Clone, Serialize)] +struct VersionedTextDocumentIdentifierWire { + uri: String, + version: i32, +} + +#[derive(Debug, Clone, Serialize)] +struct TextDocumentContentChangeEventWire { + text: String, +} + +#[derive(Debug, Clone, Serialize)] +struct DidChangeTextDocumentParamsWire { + #[serde(rename = "textDocument")] + text_document: VersionedTextDocumentIdentifierWire, + #[serde(rename = "contentChanges")] + content_changes: Vec<TextDocumentContentChangeEventWire>, +} + +#[derive(Debug, Clone, Serialize)] +struct TextDocumentItemWire { + uri: String, + #[serde(rename = "languageId")] + language_id: &'static str, + version: i32, + text: String, +} + +#[derive(Debug, Clone, Serialize)] +struct DidOpenTextDocumentParamsWire { + #[serde(rename = "textDocument")] + text_document: TextDocumentItemWire, +} + +/// Resilient engine façade. +#[derive(Clone)] +pub struct Engine { + supervisor: Arc<Mutex<Supervisor>>, +} + +struct Supervisor { + config: EngineConfig, + lifecycle: DynamicLifecycle, + worker: Option<WorkerHandle>, + consecutive_failures: u32, + open_documents: HashMap<SourceFilePath, OpenDocumentState>, + telemetry: TelemetryState, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct OpenDocumentState { + version: i32, + fingerprint: SourceFileFingerprint, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct SourceFileFingerprint { + byte_len: u64, + modified_nanos_since_epoch: u128, +} + +#[derive(Debug)] +struct TelemetryState { + started_at: Instant, + totals: TelemetryTotalsState, + methods: HashMap<&'static str, MethodTelemetryState>, + restart_count: u64, + last_fault: Option<Fault>, +} + +#[derive(Debug, Default)] +struct TelemetryTotalsState { + request_count: u64, + success_count: u64, + response_error_count: u64, + transport_fault_count: u64, + retry_count: u64, +} + +#[derive(Debug, Default)] +struct MethodTelemetryState { + request_count: u64, + success_count: u64, + response_error_count: u64, + transport_fault_count: u64, + retry_count: u64, + total_latency_ms: u128, + last_latency_ms: Option<u64>, + max_latency_ms: u64, + last_error: Option<String>, +} + +impl Engine { + /// Creates a new engine. + #[must_use] + pub fn new(config: EngineConfig) -> Self { + Self { + supervisor: Arc::new(Mutex::new(Supervisor::new(config))), + } + } + + /// Returns current lifecycle snapshot. + pub async fn lifecycle_snapshot(&self) -> LifecycleSnapshot { + let supervisor = self.supervisor.lock().await; + supervisor.snapshot() + } + + /// Returns aggregate request/fault telemetry snapshot. + pub async fn telemetry_snapshot(&self) -> TelemetrySnapshot { + let supervisor = self.supervisor.lock().await; + supervisor.telemetry_snapshot() + } + + /// Executes hover request. + pub async fn hover(&self, position: SourcePosition) -> EngineResult<HoverPayload> { + let document_hint = Some(position.file_path().clone()); + let request = text_document_position_params(&position)?; + let hover = self + .issue_typed_request::<_, Option<Hover>>(RequestMethod::Hover, &request, document_hint) + .await?; + let payload = hover + .map(|hover| -> Result<HoverPayload, EngineError> { + let range = hover + .range + .map(|range| range_to_source_range(position.file_path(), range)) + .transpose()?; + Ok(HoverPayload { + rendered: Some(render_hover_contents(hover.contents)), + range, + }) + }) + .transpose()? + .unwrap_or(HoverPayload { + rendered: None, + range: None, + }); + Ok(payload) + } + + /// Executes definition request. + pub async fn definition(&self, position: SourcePosition) -> EngineResult<Vec<SourceLocation>> { + let document_hint = Some(position.file_path().clone()); + let request = text_document_position_params(&position)?; + let parsed = self + .issue_typed_request::<_, Option<GotoDefinitionResponse>>( + RequestMethod::Definition, + &request, + document_hint, + ) + .await?; + let locations = match parsed { + None => Vec::new(), + Some(GotoDefinitionResponse::Scalar(location)) => { + vec![source_location_from_lsp_location(location)?] + } + Some(GotoDefinitionResponse::Array(locations)) => locations + .into_iter() + .map(source_location_from_lsp_location) + .collect::<Result<Vec<_>, _>>()?, + Some(GotoDefinitionResponse::Link(links)) => links + .into_iter() + .map(source_location_from_lsp_link) + .collect::<Result<Vec<_>, _>>()?, + }; + Ok(locations) + } + + /// Executes references request. + pub async fn references(&self, position: SourcePosition) -> EngineResult<Vec<SourceLocation>> { + let request = ReferencesParamsWire { + text_document: text_document_identifier(position.file_path())?, + position: PositionWire::from(position.point()), + context: ReferencesContextWire { + include_declaration: true, + }, + }; + let parsed = self + .issue_typed_request::<_, Option<Vec<Location>>>( + RequestMethod::References, + &request, + Some(position.file_path().clone()), + ) + .await?; + parsed + .unwrap_or_default() + .into_iter() + .map(source_location_from_lsp_location) + .collect::<Result<Vec<_>, _>>() + } + + /// Executes rename request. + pub async fn rename_symbol( + &self, + position: SourcePosition, + new_name: String, + ) -> EngineResult<RenameReport> { + let request = RenameParamsWire { + text_document: text_document_identifier(position.file_path())?, + position: PositionWire::from(position.point()), + new_name, + }; + let edit = self + .issue_typed_request::<_, WorkspaceEdit>( + RequestMethod::Rename, + &request, + Some(position.file_path().clone()), + ) + .await?; + Ok(summarize_workspace_edit(edit)) + } + + /// Executes document diagnostics request. + pub async fn diagnostics(&self, file_path: SourceFilePath) -> EngineResult<DiagnosticsReport> { + let request = DocumentDiagnosticParamsWire { + text_document: text_document_identifier(&file_path)?, + }; + let response = self + .issue_request( + RequestMethod::DocumentDiagnostic, + &request, + Some(file_path.clone()), + ) + .await?; + parse_diagnostics_report(&file_path, response) + } + + /// Executes an arbitrary typed LSP request and returns raw JSON payload. + pub async fn raw_lsp_request( + &self, + method: &'static str, + params: Value, + ) -> EngineResult<Value> { + let document_hint = source_file_path_hint_from_request_params(¶ms)?; + self.issue_request(RequestMethod::Raw(method), ¶ms, document_hint) + .await + } + + async fn issue_typed_request<P, R>( + &self, + method: RequestMethod, + params: &P, + document_hint: Option<SourceFilePath>, + ) -> EngineResult<R> + where + P: Serialize, + R: DeserializeOwned, + { + let response = self.issue_request(method, params, document_hint).await?; + serde_json::from_value::<R>(response).map_err(|error| EngineError::InvalidPayload { + method: method.as_lsp_method(), + message: error.to_string(), + }) + } + + async fn issue_request<P>( + &self, + method: RequestMethod, + params: &P, + document_hint: Option<SourceFilePath>, + ) -> EngineResult<Value> + where + P: Serialize, + { + let max_attempts = 2_u8; + let mut attempt = 0_u8; + while attempt < max_attempts { + attempt = attempt.saturating_add(1); + let (worker, request_timeout) = { + let mut supervisor = self.supervisor.lock().await; + let worker = supervisor.ensure_worker().await?; + if let Some(file_path) = document_hint.as_ref() { + supervisor.synchronize_document(&worker, file_path).await?; + } + (worker, supervisor.request_timeout()) + }; + + let attempt_started_at = Instant::now(); + let result = worker + .send_request(method.as_lsp_method(), params, request_timeout) + .await; + let latency = attempt_started_at.elapsed(); + match result { + Ok(value) => { + let mut supervisor = self.supervisor.lock().await; + supervisor.record_success(method.as_lsp_method(), latency); + return Ok(value); + } + Err(WorkerRequestError::Response(payload)) => { + let retry_delay = (attempt < max_attempts) + .then(|| method.retry_delay(&payload)) + .flatten(); + let should_retry = retry_delay.is_some(); + { + let mut supervisor = self.supervisor.lock().await; + supervisor.record_response_error( + method.as_lsp_method(), + latency, + payload.code, + format_lsp_response_error_detail(&payload), + should_retry, + ); + } + + if let Some(retry_delay) = retry_delay { + debug!( + attempt, + method = method.as_lsp_method(), + code = payload.code, + delay_ms = retry_delay.as_millis(), + "retrying request after transient lsp response error" + ); + sleep(retry_delay).await; + continue; + } + return Err(EngineError::from(payload)); + } + Err(WorkerRequestError::Fault(fault)) => { + let directive = fault.directive(); + let will_retry = matches!( + directive, + RecoveryDirective::RetryInPlace | RecoveryDirective::RestartAndReplay + ) && attempt < max_attempts; + { + let mut supervisor = self.supervisor.lock().await; + supervisor.record_transport_fault( + method.as_lsp_method(), + latency, + fault.detail.message.clone(), + will_retry, + ); + } + + match directive { + RecoveryDirective::RetryInPlace => { + debug!( + attempt, + method = method.as_lsp_method(), + "retrying request in-place after fault" + ); + if attempt >= max_attempts { + return Err(EngineError::Fault(fault)); + } + } + RecoveryDirective::RestartAndReplay => { + let mut supervisor = self.supervisor.lock().await; + supervisor.record_fault(fault.clone()).await?; + if attempt >= max_attempts { + return Err(EngineError::Fault(fault)); + } + debug!( + attempt, + method = method.as_lsp_method(), + "restarting worker and replaying request" + ); + } + RecoveryDirective::AbortRequest => { + let mut supervisor = self.supervisor.lock().await; + supervisor.record_fault(fault.clone()).await?; + return Err(EngineError::Fault(fault)); + } + } + } + } + } + Err(EngineError::Fault(Fault::new( + self.lifecycle_generation().await, + ra_mcp_domain::fault::FaultClass::Resource, + ra_mcp_domain::fault::FaultCode::RequestTimedOut, + ra_mcp_domain::fault::FaultDetail::new(format!( + "exhausted retries for method {}", + method.as_lsp_method() + )), + ))) + } + + async fn lifecycle_generation(&self) -> ra_mcp_domain::types::Generation { + let supervisor = self.supervisor.lock().await; + supervisor.generation() + } +} + +impl TelemetryState { + fn new() -> Self { + Self { + started_at: Instant::now(), + totals: TelemetryTotalsState::default(), + methods: HashMap::new(), + restart_count: 0, + last_fault: None, + } + } + + fn record_success(&mut self, method: &'static str, latency: Duration) { + self.totals.request_count = self.totals.request_count.saturating_add(1); + self.totals.success_count = self.totals.success_count.saturating_add(1); + let entry = self.methods.entry(method).or_default(); + entry.request_count = entry.request_count.saturating_add(1); + entry.success_count = entry.success_count.saturating_add(1); + entry.record_latency(latency); + entry.last_error = None; + } + + fn record_response_error( + &mut self, + method: &'static str, + latency: Duration, + detail: String, + retry_performed: bool, + ) { + self.totals.request_count = self.totals.request_count.saturating_add(1); + self.totals.response_error_count = self.totals.response_error_count.saturating_add(1); + if retry_performed { + self.totals.retry_count = self.totals.retry_count.saturating_add(1); + } + + let entry = self.methods.entry(method).or_default(); + entry.request_count = entry.request_count.saturating_add(1); + entry.response_error_count = entry.response_error_count.saturating_add(1); + if retry_performed { + entry.retry_count = entry.retry_count.saturating_add(1); + } + entry.record_latency(latency); + entry.last_error = Some(detail); + } + + fn record_transport_fault( + &mut self, + method: &'static str, + latency: Duration, + detail: String, + retry_performed: bool, + ) { + self.totals.request_count = self.totals.request_count.saturating_add(1); + self.totals.transport_fault_count = self.totals.transport_fault_count.saturating_add(1); + if retry_performed { + self.totals.retry_count = self.totals.retry_count.saturating_add(1); + } + + let entry = self.methods.entry(method).or_default(); + entry.request_count = entry.request_count.saturating_add(1); + entry.transport_fault_count = entry.transport_fault_count.saturating_add(1); + if retry_performed { + entry.retry_count = entry.retry_count.saturating_add(1); + } + entry.record_latency(latency); + entry.last_error = Some(detail); + } + + fn record_restart(&mut self, fault: Fault) { + self.restart_count = self.restart_count.saturating_add(1); + self.last_fault = Some(fault); + } + + fn snapshot( + &self, + lifecycle: LifecycleSnapshot, + consecutive_failures: u32, + ) -> TelemetrySnapshot { + let mut methods = self + .methods + .iter() + .map(|(method, entry)| MethodTelemetrySnapshot { + method: (*method).to_owned(), + request_count: entry.request_count, + success_count: entry.success_count, + response_error_count: entry.response_error_count, + transport_fault_count: entry.transport_fault_count, + retry_count: entry.retry_count, + last_latency_ms: entry.last_latency_ms, + max_latency_ms: entry.max_latency_ms, + avg_latency_ms: entry.average_latency_ms(), + last_error: entry.last_error.clone(), + }) + .collect::<Vec<_>>(); + methods.sort_by(|left, right| left.method.cmp(&right.method)); + + let uptime_ms = duration_millis_u64(self.started_at.elapsed()); + TelemetrySnapshot { + uptime_ms, + lifecycle, + consecutive_failures, + restart_count: self.restart_count, + totals: TelemetryTotals { + request_count: self.totals.request_count, + success_count: self.totals.success_count, + response_error_count: self.totals.response_error_count, + transport_fault_count: self.totals.transport_fault_count, + retry_count: self.totals.retry_count, + }, + methods, + last_fault: self.last_fault.clone(), + } + } +} + +impl MethodTelemetryState { + fn record_latency(&mut self, latency: Duration) { + let latency_ms = duration_millis_u64(latency); + self.last_latency_ms = Some(latency_ms); + self.max_latency_ms = self.max_latency_ms.max(latency_ms); + self.total_latency_ms = self.total_latency_ms.saturating_add(latency_ms as u128); + } + + fn average_latency_ms(&self) -> u64 { + if self.request_count == 0 { + return 0; + } + let avg = self.total_latency_ms / u128::from(self.request_count); + if avg > u128::from(u64::MAX) { + u64::MAX + } else { + avg as u64 + } + } +} + +fn duration_millis_u64(duration: Duration) -> u64 { + let millis = duration.as_millis(); + if millis > u128::from(u64::MAX) { + u64::MAX + } else { + millis as u64 + } +} + +impl Supervisor { + fn new(config: EngineConfig) -> Self { + Self { + config, + lifecycle: DynamicLifecycle::cold(), + worker: None, + consecutive_failures: 0, + open_documents: HashMap::new(), + telemetry: TelemetryState::new(), + } + } + + fn request_timeout(&self) -> Duration { + self.config.request_timeout + } + + async fn synchronize_document( + &mut self, + worker: &WorkerHandle, + file_path: &SourceFilePath, + ) -> EngineResult<()> { + let fingerprint = capture_source_file_fingerprint(file_path)?; + if let Some(existing) = self.open_documents.get_mut(file_path) { + if existing.fingerprint == fingerprint { + return Ok(()); + } + let text = fs::read_to_string(file_path.as_path())?; + let next_version = existing.version.saturating_add(1); + let params = DidChangeTextDocumentParamsWire { + text_document: VersionedTextDocumentIdentifierWire { + uri: file_uri_string_from_source_path(file_path)?, + version: next_version, + }, + content_changes: vec![TextDocumentContentChangeEventWire { text }], + }; + worker + .send_notification("textDocument/didChange", ¶ms) + .await + .map_err(EngineError::from)?; + existing.version = next_version; + existing.fingerprint = fingerprint; + return Ok(()); + } + + let text = fs::read_to_string(file_path.as_path())?; + let params = DidOpenTextDocumentParamsWire { + text_document: TextDocumentItemWire { + uri: file_uri_string_from_source_path(file_path)?, + language_id: "rust", + version: 1, + text, + }, + }; + worker + .send_notification("textDocument/didOpen", ¶ms) + .await + .map_err(EngineError::from)?; + let _previous = self.open_documents.insert( + file_path.clone(), + OpenDocumentState { + version: 1, + fingerprint, + }, + ); + Ok(()) + } + + fn snapshot(&self) -> LifecycleSnapshot { + self.lifecycle.snapshot() + } + + fn telemetry_snapshot(&self) -> TelemetrySnapshot { + let lifecycle = self.snapshot(); + self.telemetry + .snapshot(lifecycle, self.consecutive_failures) + } + + fn generation(&self) -> ra_mcp_domain::types::Generation { + let snapshot = self.snapshot(); + match snapshot { + LifecycleSnapshot::Cold { generation } + | LifecycleSnapshot::Starting { generation } + | LifecycleSnapshot::Ready { generation } + | LifecycleSnapshot::Recovering { generation, .. } => generation, + } + } + + async fn ensure_worker(&mut self) -> EngineResult<WorkerHandle> { + if let Some(worker) = self.worker.clone() { + if let Some(fault) = worker.terminal_fault() { + warn!( + generation = fault.generation.get(), + "worker marked terminal, recycling" + ); + self.record_fault(fault).await?; + } else { + return Ok(worker); + } + } + self.spawn_worker().await + } + + async fn spawn_worker(&mut self) -> EngineResult<WorkerHandle> { + self.lifecycle = self.lifecycle.clone().begin_startup()?; + let generation = self.generation(); + let started = spawn_worker(&self.config, generation).await; + match started { + Ok(worker) => { + self.lifecycle = self.lifecycle.clone().complete_startup()?; + self.worker = Some(worker.clone()); + self.consecutive_failures = 0; + self.open_documents.clear(); + Ok(worker) + } + Err(fault) => { + self.record_fault(fault.clone()).await?; + Err(EngineError::Fault(fault)) + } + } + } + + async fn record_fault(&mut self, fault: Fault) -> EngineResult<()> { + self.lifecycle = fracture_or_force_recovery(self.lifecycle.clone(), fault.clone())?; + self.consecutive_failures = self.consecutive_failures.saturating_add(1); + self.telemetry.record_restart(fault.clone()); + + if let Some(worker) = self.worker.take() { + worker.terminate().await; + } + self.open_documents.clear(); + + let delay = self.next_backoff_delay(); + debug!( + failures = self.consecutive_failures, + delay_ms = delay.as_millis(), + "applying restart backoff delay" + ); + sleep(delay).await; + Ok(()) + } + + fn record_success(&mut self, method: &'static str, latency: Duration) { + self.consecutive_failures = 0; + self.telemetry.record_success(method, latency); + } + + fn record_response_error( + &mut self, + method: &'static str, + latency: Duration, + code: i64, + message: String, + retry_performed: bool, + ) { + let detail = format!("code={code} message={message}"); + self.telemetry + .record_response_error(method, latency, detail, retry_performed); + } + + fn record_transport_fault( + &mut self, + method: &'static str, + latency: Duration, + detail: String, + retry_performed: bool, + ) { + self.telemetry + .record_transport_fault(method, latency, detail, retry_performed); + } + + fn next_backoff_delay(&self) -> Duration { + let exponent = self.consecutive_failures.saturating_sub(1); + let multiplier = if exponent >= 31 { + u32::MAX + } else { + 1_u32 << exponent + }; + let scaled = self.config.backoff_policy.floor.saturating_mul(multiplier); + min(scaled, self.config.backoff_policy.ceiling) + } +} + +fn fracture_or_force_recovery( + lifecycle: DynamicLifecycle, + fault: Fault, +) -> EngineResult<DynamicLifecycle> { + match lifecycle.clone().fracture(fault.clone()) { + Ok(next) => Ok(next), + Err(_error) => { + let started = lifecycle.begin_startup()?; + started.fracture(fault).map_err(EngineError::from) + } + } +} + +fn text_document_identifier( + file_path: &SourceFilePath, +) -> EngineResult<TextDocumentIdentifierWire> { + Ok(TextDocumentIdentifierWire { + uri: file_uri_string_from_source_path(file_path)?, + }) +} + +fn text_document_position_params( + position: &SourcePosition, +) -> EngineResult<TextDocumentPositionParamsWire> { + Ok(TextDocumentPositionParamsWire { + text_document: text_document_identifier(position.file_path())?, + position: PositionWire::from(position.point()), + }) +} + +fn format_lsp_response_error_detail(payload: &crate::lsp_transport::RpcErrorPayload) -> String { + let crate::lsp_transport::RpcErrorPayload { + code, + message, + data, + } = payload; + match data { + Some(data) => format!("code={code} message={message} data={data}"), + None => format!("code={code} message={message}"), + } +} + +fn file_uri_string_from_source_path(file_path: &SourceFilePath) -> EngineResult<String> { + let file_url = + Url::from_file_path(file_path.as_path()).map_err(|()| EngineError::InvalidFileUrl)?; + Ok(file_url.to_string()) +} + +fn source_file_path_hint_from_request_params( + params: &Value, +) -> EngineResult<Option<SourceFilePath>> { + let maybe_uri = params + .get("textDocument") + .and_then(Value::as_object) + .and_then(|document| document.get("uri")) + .and_then(Value::as_str); + let Some(uri) = maybe_uri else { + return Ok(None); + }; + let file_path = source_file_path_from_file_uri_str(uri)?; + Ok(Some(file_path)) +} + +fn source_file_path_from_file_uri_str(uri: &str) -> EngineResult<SourceFilePath> { + let file_url = Url::parse(uri).map_err(|_error| EngineError::InvalidFileUrl)?; + let file_path = file_url + .to_file_path() + .map_err(|()| EngineError::InvalidFileUrl)?; + SourceFilePath::try_new(file_path).map_err(EngineError::from) +} + +fn capture_source_file_fingerprint( + file_path: &SourceFilePath, +) -> EngineResult<SourceFileFingerprint> { + let metadata = fs::metadata(file_path.as_path())?; + let modified = metadata.modified().unwrap_or(SystemTime::UNIX_EPOCH); + let modified_nanos_since_epoch = modified + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap_or(Duration::ZERO) + .as_nanos(); + Ok(SourceFileFingerprint { + byte_len: metadata.len(), + modified_nanos_since_epoch, + }) +} + +fn source_location_from_lsp_link(link: LocationLink) -> EngineResult<SourceLocation> { + let uri = link.target_uri; + let range = link.target_selection_range; + source_location_from_uri_and_position(uri, range.start) +} + +fn source_location_from_lsp_location(location: Location) -> EngineResult<SourceLocation> { + source_location_from_uri_and_position(location.uri, location.range.start) +} + +fn source_location_from_uri_and_position( + uri: Uri, + position: Position, +) -> EngineResult<SourceLocation> { + let file_url = Url::parse(uri.as_str()).map_err(|_error| EngineError::InvalidFileUrl)?; + let path = file_url + .to_file_path() + .map_err(|()| EngineError::InvalidFileUrl)?; + let file_path = SourceFilePath::try_new(path)?; + let point = SourcePoint::new( + OneIndexedLine::try_new(u64::from(position.line).saturating_add(1))?, + OneIndexedColumn::try_new(u64::from(position.character).saturating_add(1))?, + ); + Ok(SourceLocation::new(file_path, point)) +} + +fn range_to_source_range( + file_path: &SourceFilePath, + range: Range, +) -> Result<SourceRange, InvariantViolation> { + let start = SourcePoint::new( + OneIndexedLine::try_new(u64::from(range.start.line).saturating_add(1))?, + OneIndexedColumn::try_new(u64::from(range.start.character).saturating_add(1))?, + ); + let end = SourcePoint::new( + OneIndexedLine::try_new(u64::from(range.end.line).saturating_add(1))?, + OneIndexedColumn::try_new(u64::from(range.end.character).saturating_add(1))?, + ); + SourceRange::try_new(file_path.clone(), start, end) +} + +fn render_hover_contents(contents: HoverContents) -> String { + match contents { + HoverContents::Scalar(marked_string) => marked_string_to_string(marked_string), + HoverContents::Array(items) => items + .into_iter() + .map(marked_string_to_string) + .collect::<Vec<_>>() + .join("\n"), + HoverContents::Markup(markup) => markup.value, + } +} + +fn marked_string_to_string(marked: MarkedString) -> String { + match marked { + MarkedString::String(value) => value, + MarkedString::LanguageString(language_string) => { + format!( + "```{}\n{}\n```", + language_string.language, language_string.value + ) + } + } +} + +fn summarize_workspace_edit(edit: WorkspaceEdit) -> RenameReport { + let mut touched = HashMap::<String, u64>::new(); + let mut edits_applied = 0_u64; + + if let Some(changes) = edit.changes { + for (uri, edits) in changes { + let edit_count = u64::try_from(edits.len()).unwrap_or(u64::MAX); + let _previous = touched.insert(uri.as_str().to_owned(), edit_count); + edits_applied = edits_applied.saturating_add(edit_count); + } + } + + if let Some(document_changes) = edit.document_changes { + match document_changes { + lsp_types::DocumentChanges::Edits(edits) => { + for document_edit in edits { + let uri = document_edit.text_document.uri; + let edit_count = u64::try_from(document_edit.edits.len()).unwrap_or(u64::MAX); + let _entry = touched + .entry(uri.as_str().to_owned()) + .and_modify(|count| *count = count.saturating_add(edit_count)) + .or_insert(edit_count); + edits_applied = edits_applied.saturating_add(edit_count); + } + } + lsp_types::DocumentChanges::Operations(operations) => { + edits_applied = edits_applied + .saturating_add(u64::try_from(operations.len()).unwrap_or(u64::MAX)); + for operation in operations { + match operation { + lsp_types::DocumentChangeOperation::Op(operation) => match operation { + lsp_types::ResourceOp::Create(create) => { + let _entry = + touched.entry(create.uri.as_str().to_owned()).or_insert(0); + } + lsp_types::ResourceOp::Rename(rename) => { + let _entry = touched + .entry(rename.new_uri.as_str().to_owned()) + .or_insert(0); + } + lsp_types::ResourceOp::Delete(delete) => { + let _entry = + touched.entry(delete.uri.as_str().to_owned()).or_insert(0); + } + }, + lsp_types::DocumentChangeOperation::Edit(edit) => { + let edit_count = u64::try_from(edit.edits.len()).unwrap_or(u64::MAX); + let _entry = touched + .entry(edit.text_document.uri.as_str().to_owned()) + .and_modify(|count| *count = count.saturating_add(edit_count)) + .or_insert(edit_count); + } + } + } + } + } + } + + RenameReport { + files_touched: u64::try_from(touched.len()).unwrap_or(u64::MAX), + edits_applied, + } +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "kind", rename_all = "lowercase")] +enum DiagnosticReportWire { + Full { items: Vec<DiagnosticWire> }, + Unchanged {}, +} + +#[derive(Debug, Deserialize)] +struct DiagnosticWire { + range: Range, + severity: Option<DiagnosticSeverity>, + code: Option<Value>, + message: String, +} + +fn parse_diagnostics_report( + file_path: &SourceFilePath, + value: Value, +) -> EngineResult<DiagnosticsReport> { + let parsed = serde_json::from_value::<DiagnosticReportWire>(value).map_err(|error| { + EngineError::InvalidPayload { + method: "textDocument/diagnostic", + message: error.to_string(), + } + })?; + match parsed { + DiagnosticReportWire::Unchanged {} => Ok(DiagnosticsReport { + diagnostics: Vec::new(), + }), + DiagnosticReportWire::Full { items } => { + let diagnostics = items + .into_iter() + .map(|item| { + let range = range_to_source_range(file_path, item.range)?; + let level = match item.severity.unwrap_or(DiagnosticSeverity::INFORMATION) { + DiagnosticSeverity::ERROR => DiagnosticLevel::Error, + DiagnosticSeverity::WARNING => DiagnosticLevel::Warning, + DiagnosticSeverity::INFORMATION => DiagnosticLevel::Information, + DiagnosticSeverity::HINT => DiagnosticLevel::Hint, + _ => DiagnosticLevel::Information, + }; + let code = item.code.map(|value| match value { + Value::String(message) => message, + Value::Number(number) => number.to_string(), + other => other.to_string(), + }); + Ok(DiagnosticEntry { + range, + level, + code, + message: item.message, + }) + }) + .collect::<Result<Vec<_>, InvariantViolation>>()?; + Ok(DiagnosticsReport { diagnostics }) + } + } +} |