diff options
Diffstat (limited to 'crates/phone-opus/src/mcp/service.rs')
| -rw-r--r-- | crates/phone-opus/src/mcp/service.rs | 550 |
1 files changed, 550 insertions, 0 deletions
diff --git a/crates/phone-opus/src/mcp/service.rs b/crates/phone-opus/src/mcp/service.rs new file mode 100644 index 0000000..2472887 --- /dev/null +++ b/crates/phone-opus/src/mcp/service.rs @@ -0,0 +1,550 @@ +use std::collections::BTreeMap; +use std::io::{self, BufRead, Write}; +use std::path::{Path, PathBuf}; +use std::process::Command; + +use libmcp::{Generation, SurfaceKind}; +use serde::Deserialize; +use serde_json::{Value, json}; +use thiserror::Error; + +use crate::mcp::fault::{FaultRecord, FaultStage}; +use crate::mcp::output::{ + ToolOutput, fallback_detailed_tool_output, split_presentation, tool_success, +}; +use crate::mcp::protocol::{CLAUDE_BIN_ENV, CLAUDE_TOOLSET, EMPTY_MCP_CONFIG}; + +pub(crate) fn run_worker(generation: u64) -> Result<(), Box<dyn std::error::Error>> { + let generation = generation_from_wire(generation); + let stdin = io::stdin(); + let mut stdout = io::stdout().lock(); + let mut service = WorkerService::new(generation); + + for line in stdin.lock().lines() { + let line = line?; + if line.trim().is_empty() { + continue; + } + let request = serde_json::from_str::<crate::mcp::protocol::WorkerRequest>(&line)?; + let response = match request { + crate::mcp::protocol::WorkerRequest::Execute { id, operation } => { + let outcome = match service.execute(operation) { + Ok(result) => crate::mcp::protocol::WorkerOutcome::Success { result }, + Err(fault) => crate::mcp::protocol::WorkerOutcome::Fault { fault }, + }; + crate::mcp::protocol::WorkerResponse { id, outcome } + } + }; + serde_json::to_writer(&mut stdout, &response)?; + stdout.write_all(b"\n")?; + stdout.flush()?; + } + + Ok(()) +} + +struct WorkerService { + generation: Generation, +} + +impl WorkerService { + fn new(generation: Generation) -> Self { + Self { generation } + } + + fn execute( + &mut self, + operation: crate::mcp::protocol::WorkerOperation, + ) -> Result<Value, FaultRecord> { + match operation { + crate::mcp::protocol::WorkerOperation::CallTool { name, arguments } => { + self.call_tool(&name, arguments) + } + } + } + + fn call_tool(&mut self, name: &str, arguments: Value) -> Result<Value, FaultRecord> { + let operation = format!("tools/call:{name}"); + let (presentation, arguments) = + split_presentation(arguments, &operation, self.generation, FaultStage::Worker)?; + let output = match name { + "consult" => { + let args = deserialize::<ConsultArgs>(arguments, &operation, self.generation)?; + let request = ConsultRequest::parse(args) + .map_err(|error| invalid_consult_request(self.generation, &operation, error))?; + let response = invoke_claude(&request) + .map_err(|error| consult_fault(self.generation, &operation, error))?; + consult_output(&request, &response, self.generation, &operation)? + } + other => { + return Err(FaultRecord::invalid_input( + self.generation, + FaultStage::Worker, + &operation, + format!("unknown worker tool `{other}`"), + )); + } + }; + tool_success( + output, + presentation, + self.generation, + FaultStage::Worker, + &operation, + ) + } +} + +#[derive(Debug, Deserialize)] +struct ConsultArgs { + prompt: String, + cwd: Option<String>, + max_turns: Option<u64>, +} + +#[derive(Debug, Clone)] +struct ConsultRequest { + prompt: PromptText, + cwd: WorkingDirectory, + max_turns: Option<TurnLimit>, +} + +impl ConsultRequest { + fn parse(args: ConsultArgs) -> Result<Self, ConsultRequestError> { + Ok(Self { + prompt: PromptText::parse(args.prompt)?, + cwd: WorkingDirectory::resolve(args.cwd)?, + max_turns: args.max_turns.map(TurnLimit::parse).transpose()?, + }) + } +} + +#[derive(Debug, Clone)] +struct PromptText(String); + +impl PromptText { + fn parse(raw: String) -> Result<Self, ConsultRequestError> { + if raw.trim().is_empty() { + return Err(ConsultRequestError::EmptyPrompt); + } + Ok(Self(raw)) + } + + fn as_str(&self) -> &str { + self.0.as_str() + } +} + +#[derive(Debug, Clone)] +struct WorkingDirectory(PathBuf); + +impl WorkingDirectory { + fn resolve(raw: Option<String>) -> Result<Self, ConsultRequestError> { + let base = + std::env::current_dir().map_err(|source| ConsultRequestError::CurrentDir { source })?; + let requested = raw.map_or_else(|| base.clone(), PathBuf::from); + let candidate = if requested.is_absolute() { + requested + } else { + base.join(requested) + }; + let canonical = + candidate + .canonicalize() + .map_err(|source| ConsultRequestError::Canonicalize { + path: candidate.display().to_string(), + source, + })?; + if !canonical.is_dir() { + return Err(ConsultRequestError::NotDirectory( + canonical.display().to_string(), + )); + } + Ok(Self(canonical)) + } + + fn as_path(&self) -> &Path { + self.0.as_path() + } + + fn display(&self) -> String { + self.0.display().to_string() + } +} + +#[derive(Debug, Clone, Copy)] +struct TurnLimit(u64); + +impl TurnLimit { + fn parse(raw: u64) -> Result<Self, ConsultRequestError> { + if raw == 0 { + return Err(ConsultRequestError::InvalidTurnLimit); + } + Ok(Self(raw)) + } + + fn get(self) -> u64 { + self.0 + } +} + +#[derive(Debug, Error)] +enum ConsultRequestError { + #[error("prompt must not be empty")] + EmptyPrompt, + #[error("failed to resolve the current working directory: {source}")] + CurrentDir { source: io::Error }, + #[error("failed to resolve working directory `{path}`: {source}")] + Canonicalize { path: String, source: io::Error }, + #[error("working directory `{0}` is not a directory")] + NotDirectory(String), + #[error("max_turns must be greater than zero")] + InvalidTurnLimit, +} + +#[derive(Debug, Error)] +enum ConsultInvocationError { + #[error("failed to spawn Claude Code: {0}")] + Spawn(#[source] io::Error), + #[error("Claude Code returned non-JSON output: {0}")] + InvalidJson(String), + #[error("{0}")] + Downstream(String), +} + +#[derive(Debug, Deserialize)] +struct ClaudeJsonEnvelope { + #[serde(rename = "type")] + envelope_type: String, + subtype: Option<String>, + is_error: bool, + duration_ms: Option<u64>, + duration_api_ms: Option<u64>, + num_turns: Option<u64>, + result: Option<String>, + stop_reason: Option<String>, + session_id: Option<String>, + total_cost_usd: Option<f64>, + usage: Option<Value>, + #[serde(rename = "modelUsage")] + model_usage: Option<Value>, + #[serde(default)] + permission_denials: Vec<Value>, + fast_mode_state: Option<String>, + uuid: Option<String>, +} + +#[derive(Debug)] +struct ConsultResponse { + cwd: WorkingDirectory, + result: String, + duration_ms: u64, + duration_api_ms: Option<u64>, + num_turns: u64, + stop_reason: Option<String>, + session_id: Option<String>, + total_cost_usd: Option<f64>, + usage: Option<Value>, + model_usage: Option<Value>, + permission_denials: Vec<Value>, + fast_mode_state: Option<String>, + uuid: Option<String>, +} + +impl ConsultResponse { + fn model_name(&self) -> Option<String> { + let Value::Object(models) = self.model_usage.as_ref()? else { + return None; + }; + models.keys().next().cloned() + } +} + +fn deserialize<T: for<'de> Deserialize<'de>>( + value: Value, + operation: &str, + generation: Generation, +) -> Result<T, FaultRecord> { + serde_json::from_value(value).map_err(|error| { + FaultRecord::invalid_input( + generation, + FaultStage::Protocol, + operation, + format!("invalid params: {error}"), + ) + }) +} + +fn invalid_consult_request( + generation: Generation, + operation: &str, + error: ConsultRequestError, +) -> FaultRecord { + FaultRecord::invalid_input(generation, FaultStage::Worker, operation, error.to_string()) +} + +fn consult_fault( + generation: Generation, + operation: &str, + error: ConsultInvocationError, +) -> FaultRecord { + match error { + ConsultInvocationError::Spawn(source) => FaultRecord::process( + generation, + FaultStage::Claude, + operation, + source.to_string(), + ), + ConsultInvocationError::InvalidJson(detail) + | ConsultInvocationError::Downstream(detail) => { + FaultRecord::downstream(generation, FaultStage::Claude, operation, detail) + } + } +} + +fn invoke_claude(request: &ConsultRequest) -> Result<ConsultResponse, ConsultInvocationError> { + let mut command = Command::new(claude_binary()); + let _ = command + .arg("-p") + .arg("--output-format") + .arg("json") + .arg("--strict-mcp-config") + .arg("--mcp-config") + .arg(EMPTY_MCP_CONFIG) + .arg("--disable-slash-commands") + .arg("--no-chrome") + .arg("--tools") + .arg(CLAUDE_TOOLSET) + .arg("--permission-mode") + .arg("dontAsk"); + if let Some(max_turns) = request.max_turns { + let _ = command.arg("--max-turns").arg(max_turns.get().to_string()); + } + let output = command + .current_dir(request.cwd.as_path()) + .arg(request.prompt.as_str()) + .output() + .map_err(ConsultInvocationError::Spawn)?; + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_owned(); + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_owned(); + let envelope = match serde_json::from_slice::<ClaudeJsonEnvelope>(&output.stdout) { + Ok(envelope) => envelope, + Err(_error) if !output.status.success() => { + return Err(ConsultInvocationError::Downstream(downstream_message( + output.status.code(), + &stdout, + &stderr, + ))); + } + Err(error) => { + return Err(ConsultInvocationError::InvalidJson(format!( + "{error}; stdout={stdout}; stderr={stderr}" + ))); + } + }; + if envelope.envelope_type != "result" { + return Err(ConsultInvocationError::Downstream(format!( + "unexpected Claude envelope type `{}`", + envelope.envelope_type + ))); + } + if !output.status.success() + || envelope.is_error + || envelope.subtype.as_deref() != Some("success") + { + return Err(ConsultInvocationError::Downstream( + envelope + .result + .filter(|value| !value.trim().is_empty()) + .unwrap_or_else(|| downstream_message(output.status.code(), &stdout, &stderr)), + )); + } + Ok(ConsultResponse { + cwd: request.cwd.clone(), + result: envelope.result.unwrap_or_default(), + duration_ms: envelope.duration_ms.unwrap_or(0), + duration_api_ms: envelope.duration_api_ms, + num_turns: envelope.num_turns.unwrap_or(0), + stop_reason: envelope.stop_reason, + session_id: envelope.session_id, + total_cost_usd: envelope.total_cost_usd, + usage: envelope.usage, + model_usage: envelope.model_usage, + permission_denials: envelope.permission_denials, + fast_mode_state: envelope.fast_mode_state, + uuid: envelope.uuid, + }) +} + +fn downstream_message(status_code: Option<i32>, stdout: &str, stderr: &str) -> String { + if !stderr.is_empty() { + return stderr.to_owned(); + } + if !stdout.is_empty() { + return stdout.to_owned(); + } + format!("Claude Code exited with status {status_code:?}") +} + +fn claude_binary() -> PathBuf { + std::env::var_os(CLAUDE_BIN_ENV) + .map(PathBuf::from) + .unwrap_or_else(|| PathBuf::from("claude")) +} + +fn consult_output( + request: &ConsultRequest, + response: &ConsultResponse, + generation: Generation, + operation: &str, +) -> Result<ToolOutput, FaultRecord> { + let concise = json!({ + "response": response.result, + "cwd": response.cwd.display(), + "model": response.model_name(), + "duration_ms": response.duration_ms, + "num_turns": response.num_turns, + "stop_reason": response.stop_reason, + "session_id": response.session_id, + "total_cost_usd": response.total_cost_usd, + "permission_denial_count": response.permission_denials.len(), + }); + let full = json!({ + "response": response.result, + "cwd": response.cwd.display(), + "prompt": request.prompt.as_str(), + "max_turns": request.max_turns.map(TurnLimit::get), + "duration_ms": response.duration_ms, + "duration_api_ms": response.duration_api_ms, + "num_turns": response.num_turns, + "stop_reason": response.stop_reason, + "session_id": response.session_id, + "total_cost_usd": response.total_cost_usd, + "usage": response.usage, + "model_usage": response.model_usage, + "permission_denials": response.permission_denials, + "fast_mode_state": response.fast_mode_state, + "uuid": response.uuid, + }); + fallback_detailed_tool_output( + &concise, + &full, + concise_text(response), + Some(full_text(response)), + SurfaceKind::Read, + generation, + FaultStage::Worker, + operation, + ) +} + +fn concise_text(response: &ConsultResponse) -> String { + let mut status = vec![ + "consult ok".to_owned(), + format!("turns={}", response.num_turns), + format!("duration={}", render_duration_ms(response.duration_ms)), + ]; + if let Some(model) = response.model_name() { + status.push(format!("model={model}")); + } + if let Some(stop_reason) = response.stop_reason.as_deref() { + status.push(format!("stop={stop_reason}")); + } + if let Some(cost) = response.total_cost_usd { + status.push(format!("cost=${cost:.6}")); + } + + let mut lines = vec![status.join(" ")]; + lines.push(format!("cwd: {}", response.cwd.display())); + if let Some(session_id) = response.session_id.as_deref() { + lines.push(format!("session: {session_id}")); + } + if !response.permission_denials.is_empty() { + lines.push(format!( + "permission_denials: {}", + response.permission_denials.len() + )); + } + lines.push("response:".to_owned()); + lines.push(response.result.clone()); + lines.join("\n") +} + +fn full_text(response: &ConsultResponse) -> String { + let mut lines = vec![ + format!("consult ok turns={}", response.num_turns), + format!("cwd: {}", response.cwd.display()), + format!("duration: {}", render_duration_ms(response.duration_ms)), + ]; + if let Some(duration_api_ms) = response.duration_api_ms { + lines.push(format!( + "api_duration: {}", + render_duration_ms(duration_api_ms) + )); + } + if let Some(model) = response.model_name() { + lines.push(format!("model: {model}")); + } + if let Some(stop_reason) = response.stop_reason.as_deref() { + lines.push(format!("stop: {stop_reason}")); + } + if let Some(session_id) = response.session_id.as_deref() { + lines.push(format!("session: {session_id}")); + } + if let Some(cost) = response.total_cost_usd { + lines.push(format!("cost_usd: {cost:.6}")); + } + lines.push(format!( + "permission_denials: {}", + response.permission_denials.len() + )); + if let Some(fast_mode_state) = response.fast_mode_state.as_deref() { + lines.push(format!("fast_mode: {fast_mode_state}")); + } + if let Some(uuid) = response.uuid.as_deref() { + lines.push(format!("uuid: {uuid}")); + } + if let Some(usage) = usage_summary(response.usage.as_ref()) { + lines.push(format!("usage: {usage}")); + } + lines.push("response:".to_owned()); + lines.push(response.result.clone()); + lines.join("\n") +} + +fn usage_summary(usage: Option<&Value>) -> Option<String> { + let Value::Object(usage) = usage? else { + return None; + }; + let summary = usage + .iter() + .filter_map(|(key, value)| match value { + Value::Number(number) => Some((key.clone(), number.to_string())), + Value::String(text) if !text.is_empty() => Some((key.clone(), text.clone())), + _ => None, + }) + .collect::<BTreeMap<_, _>>(); + (!summary.is_empty()).then(|| { + summary + .into_iter() + .map(|(key, value)| format!("{key}={value}")) + .collect::<Vec<_>>() + .join(" ") + }) +} + +fn render_duration_ms(duration_ms: u64) -> String { + if duration_ms < 1_000 { + return format!("{duration_ms}ms"); + } + let seconds = duration_ms as f64 / 1_000.0; + format!("{seconds:.3}s") +} + +fn generation_from_wire(raw: u64) -> Generation { + let mut generation = Generation::genesis(); + for _ in 1..raw { + generation = generation.next(); + } + generation +} |