diff options
| author | main <main@swarm.moe> | 2026-03-23 19:16:10 -0400 |
|---|---|---|
| committer | main <main@swarm.moe> | 2026-03-23 19:16:10 -0400 |
| commit | 00949559a8a4757e1198e1ea582ebfcf7268fec4 (patch) | |
| tree | 7bb8821c1d597eb386998a6f0b4cba1b9e24de41 /crates/phone-opus/src/mcp/service.rs | |
| parent | dd2b64ed08b8ac55d6aaeb54d635b33b51eea790 (diff) | |
| download | phone_opus-00949559a8a4757e1198e1ea582ebfcf7268fec4.zip | |
Add blocking wait for background consult jobs
Diffstat (limited to 'crates/phone-opus/src/mcp/service.rs')
| -rw-r--r-- | crates/phone-opus/src/mcp/service.rs | 173 |
1 files changed, 145 insertions, 28 deletions
diff --git a/crates/phone-opus/src/mcp/service.rs b/crates/phone-opus/src/mcp/service.rs index ff25433..773a516 100644 --- a/crates/phone-opus/src/mcp/service.rs +++ b/crates/phone-opus/src/mcp/service.rs @@ -3,12 +3,12 @@ use std::fs; use std::io::{self, BufRead, Write}; use std::path::{Path, PathBuf}; use std::process::{Command, Stdio}; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use dirs::{home_dir, state_dir}; use libmcp::{Generation, SurfaceKind}; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; +use serde_json::{Map, Value, json}; use thiserror::Error; use time::{OffsetDateTime, format_description::well_known::Rfc3339}; use users::get_current_uid; @@ -132,11 +132,22 @@ struct ConsultJobArgs { job_id: String, } +#[derive(Debug, Deserialize)] +struct ConsultWaitArgs { + job_id: String, + timeout_ms: Option<u64>, + poll_interval_ms: Option<u64>, +} + #[derive(Debug, Default, Deserialize)] struct ConsultJobsArgs { limit: Option<u64>, } +const DEFAULT_CONSULT_WAIT_TIMEOUT_MS: u64 = 30 * 60 * 1_000; +const DEFAULT_CONSULT_WAIT_POLL_INTERVAL_MS: u64 = 1_000; +const MIN_CONSULT_WAIT_POLL_INTERVAL_MS: u64 = 10; + #[derive(Debug, Clone)] struct ConsultRequest { prompt: PromptText, @@ -406,6 +417,50 @@ struct BackgroundConsultJobRecord { failure: Option<BackgroundConsultFailure>, } +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +struct ConsultWaitConfig { + timeout_ms: u64, + poll_interval_ms: u64, +} + +impl ConsultWaitConfig { + fn parse(timeout_ms: Option<u64>, poll_interval_ms: Option<u64>) -> Self { + Self { + timeout_ms: timeout_ms.unwrap_or(DEFAULT_CONSULT_WAIT_TIMEOUT_MS), + poll_interval_ms: poll_interval_ms + .unwrap_or(DEFAULT_CONSULT_WAIT_POLL_INTERVAL_MS) + .max(MIN_CONSULT_WAIT_POLL_INTERVAL_MS), + } + } +} + +#[derive(Debug, Clone)] +struct BackgroundConsultWaitRequest { + job_id: BackgroundConsultJobId, + config: ConsultWaitConfig, +} + +impl BackgroundConsultWaitRequest { + fn parse(args: ConsultWaitArgs) -> Result<Self, ConsultRequestError> { + Ok(Self { + job_id: BackgroundConsultJobId::parse(args.job_id)?, + config: ConsultWaitConfig::parse(args.timeout_ms, args.poll_interval_ms), + }) + } +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +struct BackgroundConsultWaitMetadata { + waited_ms: u64, + timed_out: bool, +} + +#[derive(Debug, Clone)] +struct BackgroundConsultWaitOutcome { + record: BackgroundConsultJobRecord, + metadata: BackgroundConsultWaitMetadata, +} + impl BackgroundConsultJobRecord { fn new(job_id: BackgroundConsultJobId, request: BackgroundConsultRequest) -> Self { let now = unix_ms_now(); @@ -770,7 +825,28 @@ pub(crate) fn consult_job_tool_output( let record = load_background_consult_job(&job_id).map_err(|error| { FaultRecord::downstream(generation, stage, operation, error.to_string()) })?; - background_job_tool_output(&record, generation, stage, operation) + background_job_tool_output(&record, None, generation, stage, operation) +} + +pub(crate) fn consult_wait_tool_output( + arguments: Value, + generation: Generation, + stage: FaultStage, + operation: &str, +) -> Result<ToolOutput, FaultRecord> { + let args = deserialize::<ConsultWaitArgs>(arguments, operation, generation)?; + let request = BackgroundConsultWaitRequest::parse(args) + .map_err(|error| invalid_consult_request(generation, operation, error))?; + let outcome = wait_for_background_consult(&request).map_err(|error| { + FaultRecord::downstream(generation, stage, operation, error.to_string()) + })?; + background_job_tool_output( + &outcome.record, + Some(outcome.metadata), + generation, + stage, + operation, + ) } pub(crate) fn consult_jobs_tool_output( @@ -860,7 +936,7 @@ fn submit_background_consult( "requested_session_id": request.requested_session_id(), "session_mode": request.session_mode(), "prompt_prefix_injected": true, - "follow_up_tools": ["consult_job", "consult_jobs"], + "follow_up_tools": ["consult_wait", "consult_job", "consult_jobs"], }); let full = json!({ "mode": request.mode().as_str(), @@ -873,7 +949,7 @@ fn submit_background_consult( "prompt": request.prompt.as_str(), "effective_prompt": request.prompt.rendered(), "cwd": request.cwd.display(), - "follow_up_tools": ["consult_job", "consult_jobs"], + "follow_up_tools": ["consult_wait", "consult_job", "consult_jobs"], }); fallback_detailed_tool_output( &concise, @@ -935,44 +1011,81 @@ fn background_failure(error: ConsultInvocationError) -> BackgroundConsultFailure } } +fn wait_for_background_consult( + request: &BackgroundConsultWaitRequest, +) -> io::Result<BackgroundConsultWaitOutcome> { + let started_at = Instant::now(); + loop { + let record = load_background_consult_job(&request.job_id)?; + let waited_ms = elapsed_duration_ms(started_at.elapsed()); + if record.status.done() { + return Ok(BackgroundConsultWaitOutcome { + record, + metadata: BackgroundConsultWaitMetadata { + waited_ms, + timed_out: false, + }, + }); + } + if waited_ms >= request.config.timeout_ms { + return Ok(BackgroundConsultWaitOutcome { + record, + metadata: BackgroundConsultWaitMetadata { + waited_ms, + timed_out: true, + }, + }); + } + let remaining_ms = request.config.timeout_ms.saturating_sub(waited_ms); + let sleep_ms = remaining_ms.min(request.config.poll_interval_ms); + std::thread::sleep(Duration::from_millis(sleep_ms)); + } +} + fn background_job_tool_output( record: &BackgroundConsultJobRecord, + wait: Option<BackgroundConsultWaitMetadata>, generation: Generation, stage: FaultStage, operation: &str, ) -> Result<ToolOutput, FaultRecord> { - let concise = json!({ - "job_id": record.job_id.display(), - "status": record.status, - "done": record.status.done(), - "succeeded": record.status.success(), - "failed": record.status.failed(), - "created_unix_ms": record.created_unix_ms, - "updated_unix_ms": record.updated_unix_ms, - "started_unix_ms": record.started_unix_ms, - "finished_unix_ms": record.finished_unix_ms, - "runner_pid": record.runner_pid, - "cwd": record.request.cwd, - "requested_session_id": record.request.session_id, - "prompt_prefix_injected": record.prompt_prefix_injected, - "result": record.result.as_ref().map(|result| json!({ + let mut concise = match record.summary() { + Value::Object(object) => object, + _ => Map::new(), + }; + if let Some(wait) = wait { + let _ = concise.insert("waited_ms".to_owned(), json!(wait.waited_ms)); + let _ = concise.insert("timed_out".to_owned(), json!(wait.timed_out)); + } + let _ = concise.insert( + "result".to_owned(), + json!(record.result.as_ref().map(|result| json!({ "response": result.response, "persisted_output_path": result.persisted_output_path, "duration_ms": result.duration_ms, "num_turns": result.num_turns, "session_id": result.session_id, "model": result.model_name(), - })), - "failure": record.failure, - }); - let full = json!({ - "job": record, - }); + }))), + ); + let _ = concise.insert("failure".to_owned(), json!(record.failure)); + let mut full = Map::from_iter([("job".to_owned(), json!(record))]); + if let Some(wait) = wait { + let _ = full.insert("waited_ms".to_owned(), json!(wait.waited_ms)); + let _ = full.insert("timed_out".to_owned(), json!(wait.timed_out)); + } let mut lines = vec![format!( "job={} status={:?}", record.job_id.display(), record.status )]; + if let Some(wait) = wait { + lines.push(format!( + "waited={} timed_out={}", + render_duration_ms(wait.waited_ms), + wait.timed_out + )); + } if let Some(result) = record.result.as_ref() { lines.push(format!( "result ready model={} turns={} duration={}", @@ -988,8 +1101,8 @@ fn background_job_tool_output( lines.push(format!("failure={} {}", failure.class, failure.detail)); } fallback_detailed_tool_output( - &concise, - &full, + &Value::Object(concise), + &Value::Object(full), lines.join("\n"), None, SurfaceKind::Read, @@ -1552,6 +1665,10 @@ fn model_name(model_usage: Option<&Value>) -> Option<String> { models.keys().next().cloned() } +fn elapsed_duration_ms(duration: Duration) -> u64 { + duration.as_millis().try_into().unwrap_or(u64::MAX) +} + fn render_duration_ms(duration_ms: u64) -> String { if duration_ms < 1_000 { return format!("{duration_ms}ms"); |