From 57db4dc94dbf571ac8a393f61549def5afaa0209 Mon Sep 17 00:00:00 2001 From: main Date: Tue, 24 Mar 2026 19:26:58 -0400 Subject: Predeclare and stream consult session ids --- crates/phone-opus/src/mcp/service.rs | 379 +++++++++++++++++++++++++++++------ 1 file changed, 322 insertions(+), 57 deletions(-) (limited to 'crates/phone-opus/src/mcp/service.rs') diff --git a/crates/phone-opus/src/mcp/service.rs b/crates/phone-opus/src/mcp/service.rs index c5c2d66..39cc825 100644 --- a/crates/phone-opus/src/mcp/service.rs +++ b/crates/phone-opus/src/mcp/service.rs @@ -1,10 +1,11 @@ use std::collections::{BTreeMap, BTreeSet}; use std::fs; -use std::io::{self, BufRead, Write}; +use std::io::{self, BufRead, Read, Write}; #[cfg(unix)] use std::os::unix::fs::symlink; use std::path::{Path, PathBuf}; use std::process::{Command, Stdio}; +use std::thread; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use dirs::{home_dir, state_dir}; @@ -145,7 +146,7 @@ struct ConsultRequest { cwd: WorkingDirectory, context_key: ConsultContextKey, fresh_context: bool, - session: Option, + session_plan: ConsultSessionPlan, } impl ConsultRequest { @@ -154,34 +155,70 @@ impl ConsultRequest { let cwd = WorkingDirectory::resolve(args.cwd)?; let context_key = ConsultContextKey::from_cwd(&cwd); let fresh_context = args.fresh_context.unwrap_or(false); + let session_plan = if fresh_context { + ConsultSessionPlan::fresh() + } else { + load_consult_context(&context_key) + .map_err(|source| ConsultRequestError::ContextIndex { source })? + .and_then(ConsultSessionPlan::from_stored) + .unwrap_or_else(ConsultSessionPlan::fresh) + }; Ok(Self { prompt, cwd, - session: if fresh_context { - None - } else { - load_consult_context(&context_key) - .map_err(|source| ConsultRequestError::ContextIndex { source })? - }, context_key, fresh_context, + session_plan, }) } fn context_mode(&self) -> &'static str { - if self.session.is_some() { - "reused" - } else { - "fresh" - } + self.session_plan.context_mode() } fn reused_session_id(&self) -> Option { - self.session.as_ref().map(SessionHandle::display) + self.session_plan.reused_session_id() + } + + fn planned_session_id(&self) -> String { + self.session_plan.planned_session().display() + } + + fn launch_resume_session(&self) -> Option { + self.session_plan + .resume_session() + .map(SessionHandle::display) + } + + fn launch_session_id(&self) -> Option { + match self.session_plan { + ConsultSessionPlan::Start { .. } => Some(self.planned_session_id()), + ConsultSessionPlan::Resume(_) => None, + } } fn remember_context(&self, session_id: Option<&str>) -> io::Result<()> { - remember_consult_context(&self.context_key, session_id) + confirm_consult_context( + &self.context_key, + session_id, + self.session_plan.planned_session(), + ) + } + + fn remember_planned_context(&self) -> io::Result<()> { + if let ConsultSessionPlan::Start { session, .. } = &self.session_plan { + remember_planned_consult_context(&self.context_key, session) + } else { + Ok(()) + } + } + + fn current_context_session_id(&self) -> Option { + load_consult_context(&self.context_key) + .ok() + .flatten() + .and_then(ConsultSessionPlan::from_stored) + .map(|plan| plan.planned_session().display()) } #[allow(dead_code, reason = "background submission is parked but not exposed")] @@ -261,6 +298,10 @@ impl WorkingDirectory { struct SessionHandle(Uuid); impl SessionHandle { + fn fresh() -> Self { + Self(Uuid::new_v4()) + } + fn parse(raw: &str) -> Option { Uuid::parse_str(raw).ok().map(Self) } @@ -286,32 +327,110 @@ impl ConsultContextKey { #[derive(Debug, Clone, Deserialize, Serialize)] struct StoredConsultContext { session_id: String, + #[serde(default = "default_consult_context_state")] + state: StoredConsultContextState, updated_unix_ms: u64, } +#[derive(Debug, Clone, Copy, Default, Deserialize, Eq, PartialEq, Serialize)] +#[serde(rename_all = "snake_case")] +enum StoredConsultContextState { + Planned, + #[default] + Confirmed, +} + +const fn default_consult_context_state() -> StoredConsultContextState { + StoredConsultContextState::Confirmed +} + #[derive(Debug, Default, Deserialize, Serialize)] struct ConsultContextIndex { by_cwd: BTreeMap, } impl ConsultContextIndex { - fn session_for(&self, key: &ConsultContextKey) -> Option { - self.by_cwd - .get(key.as_str()) - .and_then(|entry| SessionHandle::parse(entry.session_id.as_str())) + fn context_for(&self, key: &ConsultContextKey) -> Option { + self.by_cwd.get(key.as_str()).cloned() } - fn remember(&mut self, key: &ConsultContextKey, session: &SessionHandle) { + fn remember( + &mut self, + key: &ConsultContextKey, + session: &SessionHandle, + state: StoredConsultContextState, + ) { let _ = self.by_cwd.insert( key.as_str().to_owned(), StoredConsultContext { session_id: session.display(), + state, updated_unix_ms: unix_ms_now(), }, ); } } +#[derive(Debug, Clone)] +enum ConsultSessionPlan { + Start { + session: SessionHandle, + reused: bool, + }, + Resume(SessionHandle), +} + +impl ConsultSessionPlan { + fn fresh() -> Self { + Self::Start { + session: SessionHandle::fresh(), + reused: false, + } + } + + fn from_stored(context: StoredConsultContext) -> Option { + let session = SessionHandle::parse(context.session_id.as_str())?; + Some(match context.state { + StoredConsultContextState::Planned => Self::Start { + session, + reused: true, + }, + StoredConsultContextState::Confirmed => Self::Resume(session), + }) + } + + fn planned_session(&self) -> &SessionHandle { + match self { + Self::Start { session, .. } | Self::Resume(session) => session, + } + } + + fn resume_session(&self) -> Option<&SessionHandle> { + match self { + Self::Resume(session) => Some(session), + Self::Start { .. } => None, + } + } + + fn context_mode(&self) -> &'static str { + match self { + Self::Start { reused: false, .. } => "fresh", + Self::Start { reused: true, .. } | Self::Resume(_) => "reused", + } + } + + fn reused_session_id(&self) -> Option { + match self { + Self::Start { + reused: true, + session, + } + | Self::Resume(session) => Some(session.display()), + Self::Start { reused: false, .. } => None, + } + } +} + #[derive(Debug, Clone, Deserialize, Eq, PartialEq, Serialize)] struct BackgroundConsultRequest { prompt: String, @@ -565,18 +684,27 @@ struct ClaudeJsonEnvelope { uuid: Option, } +#[derive(Debug, Default)] +struct ClaudeStreamCapture { + stdout: String, + observed_session_id: Option, + final_envelope: Option, +} + #[derive(Debug)] struct ConsultResponse { cwd: WorkingDirectory, result: String, persisted_output_path: PersistedConsultPath, context_mode: &'static str, + planned_session_id: String, reused_session_id: Option, duration_ms: u64, duration_api_ms: Option, num_turns: u64, stop_reason: Option, session_id: Option, + observed_session_id: Option, total_cost_usd: Option, usage: Option, model_usage: Option, @@ -870,10 +998,15 @@ fn consult_fault_context(request: &ConsultRequest, error: &ConsultInvocationErro | ConsultInvocationError::Downstream(detail) => Some(detail.as_str()), }; let reused_session_id = request.reused_session_id(); - let downstream_session_id = detail.and_then(downstream_session_id); - let resume_session_id = downstream_session_id + let planned_session_id = request.planned_session_id(); + let observed_session_id = detail + .and_then(downstream_session_id) + .clone() + .or_else(|| request.current_context_session_id()); + let resume_session_id = observed_session_id .clone() - .or_else(|| reused_session_id.clone()); + .or_else(|| reused_session_id.clone()) + .or_else(|| Some(planned_session_id.clone())); let quota_reset_hint = detail.and_then(quota_reset_hint); let quota_limited = quota_reset_hint.is_some(); let retry_hint = consult_retry_hint(quota_limited, resume_session_id.as_deref()); @@ -881,8 +1014,9 @@ fn consult_fault_context(request: &ConsultRequest, error: &ConsultInvocationErro consult: Some(ConsultFaultContext { cwd: request.cwd.display(), context_mode: request.context_mode().to_owned(), + planned_session_id, reused_session_id, - downstream_session_id, + observed_session_id, resume_session_id, quota_limited, quota_reset_hint, @@ -1147,7 +1281,7 @@ fn wait_for_background_consult( } let remaining_ms = request.config.timeout_ms.saturating_sub(waited_ms); let sleep_ms = remaining_ms.min(request.config.poll_interval_ms); - std::thread::sleep(Duration::from_millis(sleep_ms)); + thread::sleep(Duration::from_millis(sleep_ms)); } } @@ -1332,7 +1466,7 @@ fn invoke_claude(request: &ConsultRequest) -> Result Result(&output.stdout) { - Ok(envelope) => envelope, - Err(_error) if !output.status.success() => { + request + .remember_planned_context() + .map_err(ConsultInvocationError::Spawn)?; + let stdout = child + .stdout + .take() + .ok_or_else(|| ConsultInvocationError::Spawn(io::Error::other("missing Claude stdout")))?; + let stderr = child + .stderr + .take() + .ok_or_else(|| ConsultInvocationError::Spawn(io::Error::other("missing Claude stderr")))?; + let stderr_reader = thread::spawn(move || -> io::Result { + let mut stderr = stderr; + let mut buffer = String::new(); + let _ = stderr.read_to_string(&mut buffer)?; + Ok(buffer) + }); + let capture = capture_claude_stream(stdout, request)?; + let status = child.wait().map_err(ConsultInvocationError::Spawn)?; + let stderr = stderr_reader + .join() + .map_err(|_| { + ConsultInvocationError::Spawn(io::Error::other("Claude stderr reader panicked")) + })? + .map_err(ConsultInvocationError::Spawn)? + .trim() + .to_owned(); + let stdout = capture.stdout.trim().to_owned(); + let envelope = match capture.final_envelope { + Some(envelope) => envelope, + None if !status.success() => { return Err(ConsultInvocationError::Downstream(downstream_message( - output.status.code(), + status.code(), &stdout, &stderr, ))); } - Err(error) => { + None => { return Err(ConsultInvocationError::InvalidJson(format!( - "{error}; stdout={stdout}; stderr={stderr}" + "missing Claude result envelope; stdout={stdout}; stderr={stderr}" ))); } }; @@ -1375,34 +1540,46 @@ fn invoke_claude(request: &ConsultRequest) -> Result Result Result { + let mut capture = ClaudeStreamCapture::default(); + let reader = io::BufReader::new(stdout); + for line in reader.lines() { + let line = line.map_err(ConsultInvocationError::Spawn)?; + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + capture.stdout.push_str(trimmed); + capture.stdout.push('\n'); + let Ok(value) = serde_json::from_str::(trimmed) else { + continue; + }; + if capture.observed_session_id.is_none() + && let Some(session_id) = stream_session_id(&value) + { + request + .remember_context(Some(session_id.as_str())) + .map_err(ConsultInvocationError::Spawn)?; + capture.observed_session_id = Some(session_id); + } + if value.get("type").and_then(Value::as_str) == Some("result") { + let envelope = + serde_json::from_value::(value).map_err(|error| { + ConsultInvocationError::InvalidJson(format!("{error}; stream_line={trimmed}")) + })?; + if let Some(session_id) = envelope.session_id.as_deref() { + request + .remember_context(Some(session_id)) + .map_err(ConsultInvocationError::Spawn)?; + capture.observed_session_id = Some(session_id.to_owned()); + } + capture.final_envelope = Some(envelope); + } + } + Ok(capture) +} + +fn stream_session_id(value: &Value) -> Option { + match value { + Value::Object(object) => object + .iter() + .find_map(|(key, value)| { + ((key == "session_id") || (key == "sessionId")) + .then_some(value) + .and_then(Value::as_str) + .and_then(SessionHandle::parse) + .map(|session| session.display()) + }) + .or_else(|| object.values().find_map(stream_session_id)), + Value::Array(array) => array.iter().find_map(stream_session_id), + _ => None, + } +} + fn downstream_message(status_code: Option, stdout: &str, stderr: &str) -> String { if !stderr.is_empty() { return stderr.to_owned(); @@ -1460,16 +1696,29 @@ fn load_consult_context_index() -> io::Result { } } -fn load_consult_context(key: &ConsultContextKey) -> io::Result> { - Ok(load_consult_context_index()?.session_for(key)) +fn load_consult_context(key: &ConsultContextKey) -> io::Result> { + Ok(load_consult_context_index()?.context_for(key)) } -fn remember_consult_context(key: &ConsultContextKey, session_id: Option<&str>) -> io::Result<()> { - let Some(session_id) = session_id.and_then(SessionHandle::parse) else { - return Ok(()); - }; +fn remember_planned_consult_context( + key: &ConsultContextKey, + session_id: &SessionHandle, +) -> io::Result<()> { let mut index = load_consult_context_index()?; - index.remember(key, &session_id); + index.remember(key, session_id, StoredConsultContextState::Planned); + write_json_file(consult_context_index_path()?.as_path(), &index) +} + +fn confirm_consult_context( + key: &ConsultContextKey, + observed_session_id: Option<&str>, + fallback_session_id: &SessionHandle, +) -> io::Result<()> { + let session_id = observed_session_id + .and_then(SessionHandle::parse) + .unwrap_or_else(|| fallback_session_id.clone()); + let mut index = load_consult_context_index()?; + index.remember(key, &session_id, StoredConsultContextState::Confirmed); write_json_file(consult_context_index_path()?.as_path(), &index) } @@ -1537,8 +1786,10 @@ fn persist_consult_output( request: &ConsultRequest, result: &str, envelope: &ClaudeJsonEnvelope, + session_id: Option<&str>, + observed_session_id: Option<&str>, ) -> io::Result { - let path = PersistedConsultPath::new(request, envelope.session_id.as_deref())?; + let path = PersistedConsultPath::new(request, session_id)?; let saved_at = OffsetDateTime::now_utc() .format(&Rfc3339) .map_err(|error| io::Error::other(error.to_string()))?; @@ -1551,6 +1802,7 @@ fn persist_consult_output( "prompt_prefix": CLAUDE_CONSULT_PREFIX, "effective_prompt": request.prompt.rendered(), "context_mode": request.context_mode(), + "planned_session_id": request.planned_session_id(), "reused_session_id": request.reused_session_id(), "response": result, "model": model_name(envelope.model_usage.as_ref()), @@ -1558,7 +1810,8 @@ fn persist_consult_output( "duration_api_ms": envelope.duration_api_ms, "num_turns": envelope.num_turns.unwrap_or(0), "stop_reason": envelope.stop_reason, - "session_id": envelope.session_id, + "session_id": session_id, + "observed_session_id": observed_session_id, "total_cost_usd": envelope.total_cost_usd, "usage": envelope.usage, "model_usage": envelope.model_usage, @@ -1650,6 +1903,7 @@ fn consult_output( "cwd": response.cwd.display(), "persisted_output_path": response.persisted_output_path.display(), "context_mode": response.context_mode, + "planned_session_id": response.planned_session_id, "reused_session_id": response.reused_session_id, "prompt_prefix_injected": true, "model": response.model_name(), @@ -1657,6 +1911,7 @@ fn consult_output( "num_turns": response.num_turns, "stop_reason": response.stop_reason, "session_id": response.session_id, + "observed_session_id": response.observed_session_id, "total_cost_usd": response.total_cost_usd, "permission_denial_count": response.permission_denials.len(), }); @@ -1668,12 +1923,14 @@ fn consult_output( "prompt_prefix": CLAUDE_CONSULT_PREFIX, "effective_prompt": request.prompt.rendered(), "context_mode": response.context_mode, + "planned_session_id": response.planned_session_id, "reused_session_id": response.reused_session_id, "duration_ms": response.duration_ms, "duration_api_ms": response.duration_api_ms, "num_turns": response.num_turns, "stop_reason": response.stop_reason, "session_id": response.session_id, + "observed_session_id": response.observed_session_id, "total_cost_usd": response.total_cost_usd, "usage": response.usage, "model_usage": response.model_usage, @@ -1712,9 +1969,13 @@ fn concise_text(_request: &ConsultRequest, response: &ConsultResponse) -> String let mut lines = vec![status.join(" ")]; lines.push(format!("cwd: {}", response.cwd.display())); + lines.push(format!("planned_session: {}", response.planned_session_id)); if let Some(session_id) = response.reused_session_id.as_deref() { lines.push(format!("reused_session: {session_id}")); } + if let Some(session_id) = response.observed_session_id.as_deref() { + lines.push(format!("observed_session: {session_id}")); + } if let Some(session_id) = response.session_id.as_deref() { lines.push(format!("session: {session_id}")); } @@ -1740,11 +2001,15 @@ fn full_text(_request: &ConsultRequest, response: &ConsultResponse) -> String { response.context_mode, response.num_turns ), format!("cwd: {}", response.cwd.display()), + format!("planned_session: {}", response.planned_session_id), format!("duration: {}", render_duration_ms(response.duration_ms)), ]; if let Some(session_id) = response.reused_session_id.as_deref() { lines.push(format!("reused_session: {session_id}")); } + if let Some(session_id) = response.observed_session_id.as_deref() { + lines.push(format!("observed_session: {session_id}")); + } if let Some(duration_api_ms) = response.duration_api_ms { lines.push(format!( "api_duration: {}", -- cgit v1.2.3