From d4d155c8c843ba3cd9d7a8742522c95044f62d3b Mon Sep 17 00:00:00 2001 From: flazouh Date: Sun, 5 Apr 2026 23:08:11 +0300 Subject: [PATCH 1/4] feat(acp): add machine stdio transport for Acepe Expose a machine stdio entrypoint in Forge and route it through a real ACP stdio transport so Acepe can launch Forge as an installable provider instead of depending on an unpublished branch. Co-Authored-By: ForgeCode --- Cargo.lock | 93 ++++++ Cargo.toml | 1 + crates/forge_api/src/api.rs | 3 + crates/forge_api/src/forge_api.rs | 5 + crates/forge_app/Cargo.toml | 4 + crates/forge_app/src/acp/adapter.rs | 120 ++++++++ crates/forge_app/src/acp/conversion.rs | 290 +++++++++++++++++++ crates/forge_app/src/acp/error.rs | 27 ++ crates/forge_app/src/acp/mod.rs | 88 ++++++ crates/forge_app/src/acp/prompt_handler.rs | 283 ++++++++++++++++++ crates/forge_app/src/acp/session_handlers.rs | 225 ++++++++++++++ crates/forge_app/src/acp/state_builders.rs | 180 ++++++++++++ crates/forge_app/src/acp_app.rs | 89 ++++++ crates/forge_app/src/lib.rs | 3 + crates/forge_main/src/acp_runner.rs | 57 ++++ crates/forge_main/src/cli.rs | 37 +++ crates/forge_main/src/lib.rs | 1 + crates/forge_main/src/ui.rs | 8 + 18 files changed, 1514 insertions(+) create mode 100644 crates/forge_app/src/acp/adapter.rs create mode 100644 crates/forge_app/src/acp/conversion.rs create mode 100644 crates/forge_app/src/acp/error.rs create mode 100644 crates/forge_app/src/acp/mod.rs create mode 100644 crates/forge_app/src/acp/prompt_handler.rs create mode 100644 crates/forge_app/src/acp/session_handlers.rs create mode 100644 crates/forge_app/src/acp/state_builders.rs create mode 100644 crates/forge_app/src/acp_app.rs create mode 100644 crates/forge_main/src/acp_runner.rs diff --git a/Cargo.lock b/Cargo.lock index ab087820c8..59b465bc18 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,37 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "agent-client-protocol" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "499b7ff5c6c842e43fb188f6da7c99a258ae89a9df8c896d6e9784da9b4b23e7" +dependencies = [ + "agent-client-protocol-schema", + "anyhow", + "async-broadcast", + "async-trait", + "derive_more", + "futures", + "log", + "serde", + "serde_json", +] + +[[package]] +name = "agent-client-protocol-schema" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44bc1fef9c32f03bce2ab44af35b6f483bfd169bf55cc59beeb2e3b1a00ae4d1" +dependencies = [ + "anyhow", + "derive_more", + "schemars 1.2.1", + "serde", + "serde_json", + "strum 0.27.2", +] + [[package]] name = "aho-corasick" version = "1.1.4" @@ -130,6 +161,18 @@ dependencies = [ "serde_json", ] +[[package]] +name = "async-broadcast" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "435a87a52755b8f27fcf321ac4f04b2802e337c8c4872923137471ec39c37532" +dependencies = [ + "event-listener", + "event-listener-strategy", + "futures-core", + "pin-project-lite", +] + [[package]] name = "async-compression" version = "0.4.39" @@ -873,6 +916,15 @@ version = "0.4.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75984efb6ed102a0d42db99afb6c1948f0380d1d91808d5529916e6c08b49d8d" +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "config" version = "0.15.22" @@ -1787,6 +1839,27 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "event-listener" +version = "5.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" +dependencies = [ + "event-listener", + "pin-project-lite", +] + [[package]] name = "eventsource-stream" version = "0.2.3" @@ -1924,10 +1997,12 @@ dependencies = [ name = "forge_app" version = "0.1.0" dependencies = [ + "agent-client-protocol", "anyhow", "async-recursion", "async-trait", "backon", + "base64 0.22.1", "bytes", "chrono", "console", @@ -1967,9 +2042,11 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tokio-stream", + "tokio-util", "tonic", "tracing", "url", + "uuid", ] [[package]] @@ -4378,6 +4455,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + [[package]] name = "parking_lot" version = "0.12.5" @@ -6177,6 +6260,15 @@ version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" +[[package]] +name = "strum" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" +dependencies = [ + "strum_macros 0.27.2", +] + [[package]] name = "strum" version = "0.28.0" @@ -6634,6 +6726,7 @@ checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" dependencies = [ "bytes", "futures-core", + "futures-io", "futures-sink", "pin-project-lite", "tokio", diff --git a/Cargo.toml b/Cargo.toml index ac214161cb..a2ac2f4642 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ opt-level = 3 strip = true [workspace.dependencies] +agent-client-protocol = { version = "0.9", features = ["unstable_session_model"] } anyhow = "1.0.102" async-recursion = "1.1.1" async-stream = "0.3" diff --git a/crates/forge_api/src/api.rs b/crates/forge_api/src/api.rs index b13863b0b2..7d1d2a04a0 100644 --- a/crates/forge_api/src/api.rs +++ b/crates/forge_api/src/api.rs @@ -251,4 +251,7 @@ pub trait API: Sync + Send { &self, data_parameters: DataGenerationParameters, ) -> Result>>; + + /// Starts the ACP (Agent Communication Protocol) server over stdio. + async fn acp_start_stdio(&self) -> Result<()>; } diff --git a/crates/forge_api/src/forge_api.rs b/crates/forge_api/src/forge_api.rs index 36df08a4c4..223d61156e 100644 --- a/crates/forge_api/src/forge_api.rs +++ b/crates/forge_api/src/forge_api.rs @@ -404,6 +404,11 @@ impl Result<()> { + let acp_app = forge_app::AcpApp::new(self.services.clone()); + acp_app.start_stdio().await + } + async fn get_default_provider(&self) -> Result> { let provider_id = self.services.get_default_provider().await?; self.services.get_provider(provider_id).await diff --git a/crates/forge_app/Cargo.toml b/crates/forge_app/Cargo.toml index 8f5f1873b5..476a3be713 100644 --- a/crates/forge_app/Cargo.toml +++ b/crates/forge_app/Cargo.toml @@ -48,6 +48,10 @@ lazy_static.workspace = true forge_json_repair.workspace = true tonic.workspace = true +agent-client-protocol.workspace = true +tokio-util = { workspace = true, features = ["compat"] } +base64.workspace = true +uuid.workspace = true [dev-dependencies] diff --git a/crates/forge_app/src/acp/adapter.rs b/crates/forge_app/src/acp/adapter.rs new file mode 100644 index 0000000000..d7c1ebc73d --- /dev/null +++ b/crates/forge_app/src/acp/adapter.rs @@ -0,0 +1,120 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use agent_client_protocol as acp; +use forge_domain::{AgentId, ConversationId}; +use tokio::sync::{Mutex, Notify, mpsc}; + +use crate::Services; + +use super::error::{Error, Result}; + +#[derive(Clone)] +pub(super) struct SessionState { + pub conversation_id: ConversationId, + pub agent_id: AgentId, + pub cancel_notify: Option>, +} + +pub(crate) struct AcpAdapter { + pub(super) services: Arc, + pub(super) session_update_tx: mpsc::UnboundedSender, + pub(super) client_conn: Arc>>>, + sessions: Arc>>, +} + +impl AcpAdapter { + pub(crate) fn new( + services: Arc, + session_update_tx: mpsc::UnboundedSender, + ) -> Self { + Self { + services, + session_update_tx, + client_conn: Arc::new(Mutex::new(None)), + sessions: Arc::new(Mutex::new(HashMap::new())), + } + } + + pub(crate) async fn set_client_connection(&self, conn: Arc) { + *self.client_conn.lock().await = Some(conn); + } + + pub(super) async fn store_session(&self, session_id: String, state: SessionState) { + self.sessions.lock().await.insert(session_id, state); + } + + pub(super) async fn session_state(&self, session_id: &str) -> Result { + self.sessions + .lock() + .await + .get(session_id) + .cloned() + .ok_or_else(|| Error::Application(anyhow::anyhow!("Session not found"))) + } + + pub(super) async fn update_session_agent( + &self, + session_id: &str, + agent_id: AgentId, + ) -> Result<()> { + let mut sessions = self.sessions.lock().await; + let state = sessions + .get_mut(session_id) + .ok_or_else(|| Error::Application(anyhow::anyhow!("Session not found")))?; + state.agent_id = agent_id; + Ok(()) + } + + pub(super) async fn set_cancel_notify( + &self, + session_id: &str, + cancel_notify: Option>, + ) -> Result<()> { + let mut sessions = self.sessions.lock().await; + let state = sessions + .get_mut(session_id) + .ok_or_else(|| Error::Application(anyhow::anyhow!("Session not found")))?; + state.cancel_notify = cancel_notify; + Ok(()) + } + + pub(super) async fn cancel_session(&self, session_id: &str) -> bool { + let notify = self + .sessions + .lock() + .await + .get(session_id) + .and_then(|state| state.cancel_notify.clone()); + + if let Some(notify) = notify { + notify.notify_waiters(); + true + } else { + false + } + } + + pub(super) async fn ensure_session( + &self, + session_id: &str, + conversation_id: ConversationId, + agent_id: AgentId, + ) -> SessionState { + let mut sessions = self.sessions.lock().await; + sessions + .entry(session_id.to_string()) + .or_insert_with(|| SessionState { + conversation_id, + agent_id, + cancel_notify: None, + }) + .clone() + } + + pub(super) fn send_notification(&self, notification: acp::SessionNotification) -> Result<()> { + self.session_update_tx + .send(notification) + .map_err(|_| Error::Application(anyhow::anyhow!("Failed to send notification"))) + } +} \ No newline at end of file diff --git a/crates/forge_app/src/acp/conversion.rs b/crates/forge_app/src/acp/conversion.rs new file mode 100644 index 0000000000..31a9021ce0 --- /dev/null +++ b/crates/forge_app/src/acp/conversion.rs @@ -0,0 +1,290 @@ +use std::path::PathBuf; + +use agent_client_protocol as acp; +use forge_domain::{ + Agent, AgentId, Attachment, AttachmentContent, FileInfo, ToolCallFull, ToolName, ToolOutput, + ToolValue, +}; + +use super::error::{Error, Result}; + +pub(crate) fn map_tool_kind(tool_name: &ToolName) -> acp::ToolKind { + match tool_name.as_str() { + "read" => acp::ToolKind::Read, + "write" | "patch" => acp::ToolKind::Edit, + "remove" | "undo" => acp::ToolKind::Delete, + "fs_search" | "sem_search" => acp::ToolKind::Search, + "shell" => acp::ToolKind::Execute, + "fetch" => acp::ToolKind::Fetch, + "sage" => acp::ToolKind::Think, + _ => { + let name = tool_name.as_str(); + if name.starts_with("mcp_") { + if name.contains("read") + || name.contains("get") + || name.contains("fetch") + || name.contains("list") + || name.contains("show") + || name.contains("view") + || name.contains("load") + { + acp::ToolKind::Read + } else if name.contains("search") + || name.contains("query") + || name.contains("find") + || name.contains("filter") + || name.contains("lookup") + { + acp::ToolKind::Search + } else if name.contains("write") + || name.contains("update") + || name.contains("create") + || name.contains("set") + || name.contains("add") + || name.contains("insert") + || name.contains("push") + || name.contains("merge") + || name.contains("fork") + || name.contains("comment") + || name.contains("assign") + || name.contains("request") + { + acp::ToolKind::Edit + } else if name.contains("delete") + || name.contains("remove") + || name.contains("drop") + || name.contains("clear") + || name.contains("close") + || name.contains("cancel") + { + acp::ToolKind::Delete + } else if name.contains("execute") + || name.contains("run") + || name.contains("start") + || name.contains("invoke") + || name.contains("call") + { + acp::ToolKind::Execute + } else { + acp::ToolKind::Other + } + } else { + acp::ToolKind::Other + } + } + } +} + +pub(crate) fn extract_file_locations( + tool_name: &ToolName, + arguments: &serde_json::Value, +) -> Vec { + match tool_name.as_str() { + "read" | "write" | "patch" | "remove" | "undo" => arguments + .get("file_path") + .and_then(|value| value.as_str()) + .map(|file_path| vec![acp::ToolCallLocation::new(PathBuf::from(file_path))]) + .unwrap_or_default(), + _ => vec![], + } +} + +pub(crate) fn map_tool_call_to_acp(tool_call: &ToolCallFull) -> acp::ToolCall { + let tool_call_id = tool_call + .call_id + .as_ref() + .map(|id| id.as_str().to_string()) + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + + let locations = extract_file_locations( + &tool_call.name, + &serde_json::to_value(&tool_call.arguments).unwrap_or(serde_json::json!({})), + ); + + acp::ToolCall::new(tool_call_id, tool_call.name.as_str().to_string()) + .kind(map_tool_kind(&tool_call.name)) + .status(acp::ToolCallStatus::Pending) + .locations(locations) + .raw_input( + serde_json::to_value(&tool_call.arguments) + .ok() + .filter(|value| !value.is_null()), + ) +} + +pub(crate) struct ToolOutputConverter { + _private: (), +} + +impl ToolOutputConverter { + pub(crate) fn new(output: &ToolOutput) -> Self { + let _ = output; + Self { _private: () } + } + + pub(crate) fn convert(output: &ToolOutput) -> Vec { + let converter = Self::new(output); + output + .values + .iter() + .filter_map(|value| converter.convert_value(value)) + .collect() + } + + fn convert_value(&self, value: &ToolValue) -> Option { + match value { + ToolValue::Text(text) => self.convert_text(text), + ToolValue::AI { value, .. } => self.convert_text(value), + ToolValue::Image(image) => Some(acp::ToolCallContent::Content(acp::Content::new( + acp::ContentBlock::Image(acp::ImageContent::new(image.data(), image.mime_type())), + ))), + ToolValue::Empty => None, + } + } + + fn convert_text(&self, text: &str) -> Option { + if text.is_empty() { + None + } else { + Some(acp::ToolCallContent::Content(acp::Content::new( + acp::ContentBlock::Text(acp::TextContent::new(text.to_string())), + ))) + } + } +} + +pub(crate) fn acp_resource_to_attachment(resource: &acp::EmbeddedResource) -> Result { + let (content_text, uri) = match &resource.resource { + acp::EmbeddedResourceResource::TextResourceContents(text_resource) => { + (text_resource.text.clone(), text_resource.uri.clone()) + } + acp::EmbeddedResourceResource::BlobResourceContents(blob_resource) => { + let decoded = base64::Engine::decode( + &base64::engine::general_purpose::STANDARD, + &blob_resource.blob, + ) + .map_err(|error| { + Error::Application(anyhow::anyhow!("Failed to decode base64 blob: {}", error)) + })?; + let text = String::from_utf8(decoded).map_err(|error| { + Error::Application(anyhow::anyhow!("Failed to decode UTF-8: {}", error)) + })?; + (text, blob_resource.uri.clone()) + } + _ => { + return Err(Error::Application(anyhow::anyhow!( + "Unsupported resource type" + ))) + } + }; + + let path = uri_to_path(&uri); + let total_lines = content_text.lines().count() as u64; + let info = FileInfo::new(1, total_lines, total_lines, String::new()); + let content = AttachmentContent::FileContent { + content: content_text, + info, + }; + + Ok(Attachment { path, content }) +} + +pub(crate) fn uri_to_path(uri: &str) -> String { + if let Some(path) = uri.strip_prefix("file://") { + if path.len() > 2 && path.chars().nth(2) == Some(':') { + path.trim_start_matches('/').to_string() + } else { + path.to_string() + } + } else { + uri.to_string() + } +} + +pub(crate) fn build_session_mode_state( + agents: &[Agent], + current_agent_id: &AgentId, +) -> acp::SessionModeState { + let available_modes = agents + .iter() + .map(|agent| { + acp::SessionMode::new( + acp::SessionModeId::new(agent.id.to_string()), + agent.id.to_string(), + ) + .description(agent.description.clone()) + }) + .collect(); + + acp::SessionModeState::new( + acp::SessionModeId::new(current_agent_id.to_string()), + available_modes, + ) +} + +#[cfg(test)] +mod tests { + use forge_domain::{ConversationId, Image}; + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_uri_to_path_preserves_non_file_uri() { + let fixture = "relative/path.txt"; + let actual = uri_to_path(fixture); + let expected = "relative/path.txt".to_string(); + assert_eq!(actual, expected); + } + + #[test] + fn test_markdown_sent_to_acp_not_xml() { + let fixture = ToolOutput::text("## File: test.txt\n\nContent here"); + + let actual = ToolOutputConverter::convert(&fixture); + + assert_eq!(actual.len(), 1); + if let Some(acp::ToolCallContent::Content(content)) = actual.first() { + if let acp::ContentBlock::Text(text) = &content.content { + assert_eq!(text.text, "## File: test.txt\n\nContent here"); + } else { + panic!("Expected text content block"); + } + } else { + panic!("Expected content"); + } + } + + #[test] + fn test_ai_output_sent_to_acp_as_text() { + let fixture = ToolOutput::ai(ConversationId::generate(), "Agent result"); + + let actual = ToolOutputConverter::convert(&fixture); + + assert_eq!(actual.len(), 1); + if let Some(acp::ToolCallContent::Content(content)) = actual.first() { + if let acp::ContentBlock::Text(text) = &content.content { + assert_eq!(text.text, "Agent result"); + } else { + panic!("Expected text content block"); + } + } else { + panic!("Expected content"); + } + } + + #[test] + fn test_image_sent_to_acp() { + let image = Image::new_bytes(vec![1, 2, 3, 4], "image/png".to_string()); + let fixture = ToolOutput::image(image); + + let actual = ToolOutputConverter::convert(&fixture); + + assert_eq!(actual.len(), 1); + if let Some(acp::ToolCallContent::Content(content)) = actual.first() { + assert!(matches!(content.content, acp::ContentBlock::Image(_))); + } else { + panic!("Expected content"); + } + } +} \ No newline at end of file diff --git a/crates/forge_app/src/acp/error.rs b/crates/forge_app/src/acp/error.rs new file mode 100644 index 0000000000..1dbf4a0696 --- /dev/null +++ b/crates/forge_app/src/acp/error.rs @@ -0,0 +1,27 @@ +use agent_client_protocol as acp; + +pub type Result = std::result::Result; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("ACP protocol error: {0}")] + Protocol(#[from] acp::Error), + + #[error("Forge application error: {0}")] + Application(#[from] anyhow::Error), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), +} + +impl From for acp::Error { + fn from(error: Error) -> Self { + match error { + Error::Protocol(error) => error, + Error::Application(error) => { + acp::Error::into_internal_error(error.as_ref() as &dyn std::error::Error) + } + Error::Io(error) => acp::Error::into_internal_error(&error), + } + } +} \ No newline at end of file diff --git a/crates/forge_app/src/acp/mod.rs b/crates/forge_app/src/acp/mod.rs new file mode 100644 index 0000000000..d90a00fda9 --- /dev/null +++ b/crates/forge_app/src/acp/mod.rs @@ -0,0 +1,88 @@ +mod adapter; +mod conversion; +mod error; +mod prompt_handler; +mod session_handlers; +mod state_builders; + +pub(crate) use adapter::AcpAdapter; + +#[async_trait::async_trait(?Send)] +impl agent_client_protocol::Agent for AcpAdapter { + async fn initialize( + &self, + arguments: agent_client_protocol::InitializeRequest, + ) -> std::result::Result< + agent_client_protocol::InitializeResponse, + agent_client_protocol::Error, + > { + self.handle_initialize(arguments).await + } + + async fn authenticate( + &self, + arguments: agent_client_protocol::AuthenticateRequest, + ) -> std::result::Result< + agent_client_protocol::AuthenticateResponse, + agent_client_protocol::Error, + > { + self.handle_authenticate(arguments).await + } + + async fn new_session( + &self, + arguments: agent_client_protocol::NewSessionRequest, + ) -> std::result::Result< + agent_client_protocol::NewSessionResponse, + agent_client_protocol::Error, + > { + self.handle_new_session(arguments).await + } + + async fn load_session( + &self, + arguments: agent_client_protocol::LoadSessionRequest, + ) -> std::result::Result< + agent_client_protocol::LoadSessionResponse, + agent_client_protocol::Error, + > { + self.handle_load_session(arguments).await + } + + async fn prompt( + &self, + arguments: agent_client_protocol::PromptRequest, + ) -> std::result::Result< + agent_client_protocol::PromptResponse, + agent_client_protocol::Error, + > { + self.handle_prompt(arguments).await + } + + async fn cancel( + &self, + arguments: agent_client_protocol::CancelNotification, + ) -> std::result::Result<(), agent_client_protocol::Error> { + self.handle_cancel(arguments).await + } + + async fn set_session_mode( + &self, + arguments: agent_client_protocol::SetSessionModeRequest, + ) -> std::result::Result< + agent_client_protocol::SetSessionModeResponse, + agent_client_protocol::Error, + > { + self.handle_set_session_mode(arguments).await + } + + async fn set_session_model( + &self, + arguments: agent_client_protocol::SetSessionModelRequest, + ) -> std::result::Result< + agent_client_protocol::SetSessionModelResponse, + agent_client_protocol::Error, + > { + self.handle_set_session_model(arguments).await + } +} \ No newline at end of file diff --git a/crates/forge_app/src/acp/prompt_handler.rs b/crates/forge_app/src/acp/prompt_handler.rs new file mode 100644 index 0000000000..08bfba5f0b --- /dev/null +++ b/crates/forge_app/src/acp/prompt_handler.rs @@ -0,0 +1,283 @@ +use std::sync::Arc; + +use agent_client_protocol as acp; +use agent_client_protocol::Client; +use forge_domain::{ + ChatRequest, ChatResponse, ChatResponseContent, Event, EventValue, InterruptionReason, +}; +use futures::StreamExt; +use tokio::sync::Notify; + +use crate::{ForgeApp, Services}; + +use super::adapter::AcpAdapter; +use super::conversion; +use super::error::{Error, Result}; + +impl AcpAdapter { + pub(super) async fn handle_prompt( + &self, + arguments: acp::PromptRequest, + ) -> std::result::Result { + let session_key = arguments.session_id.0.as_ref().to_string(); + let session = self.session_state(&session_key).await.map_err(acp::Error::from)?; + + let mut prompt_text_parts = Vec::new(); + let mut attachments = Vec::new(); + + for content_block in &arguments.prompt { + match content_block { + acp::ContentBlock::Text(text_content) => { + prompt_text_parts.push(text_content.text.clone()); + } + acp::ContentBlock::ResourceLink(resource_link) => { + let path = conversion::uri_to_path(&resource_link.uri); + prompt_text_parts.push(format!("@[{}]", path)); + } + acp::ContentBlock::Resource(embedded_resource) => { + match conversion::acp_resource_to_attachment(embedded_resource) { + Ok(attachment) => attachments.push(attachment), + Err(error) => { + tracing::warn!("Failed to convert embedded resource: {}", error); + } + } + } + _ => {} + } + } + + let prompt_text = prompt_text_parts.join("\n"); + let cancel_notify = Arc::new(Notify::new()); + self.set_cancel_notify(&session_key, Some(cancel_notify.clone())) + .await + .map_err(acp::Error::from)?; + + let response = self + .run_prompt_loop(&arguments.session_id, &session_key, session, prompt_text, attachments, cancel_notify) + .await; + + let _ = self.set_cancel_notify(&session_key, None).await; + response + } + + async fn run_prompt_loop( + &self, + session_id: &acp::SessionId, + session_key: &str, + session: super::adapter::SessionState, + prompt_text: String, + attachments: Vec, + cancel_notify: Arc, + ) -> std::result::Result { + let mut event = Event::new(EventValue::text(prompt_text)); + event.attachments = attachments; + + let mut chat_request = ChatRequest::new(event, session.conversation_id); + loop { + let app = ForgeApp::new(self.services.clone()); + let mut stream = app + .chat(session.agent_id.clone(), chat_request) + .await + .map_err(|error| acp::Error::into_internal_error(error.as_ref() as &dyn std::error::Error))?; + + let mut continue_after_interrupt = false; + + loop { + tokio::select! { + _ = cancel_notify.notified() => { + tracing::info!("ACP prompt cancelled for session {}", session_key); + return Ok(acp::PromptResponse::new(acp::StopReason::Cancelled)); + } + response_result = stream.next() => { + match response_result { + Some(Ok(response)) => { + self.handle_chat_response(session_id, response, &mut continue_after_interrupt).await?; + } + Some(Err(error)) => { + tracing::error!("Error in chat stream: {}", error); + return Err(acp::Error::into_internal_error( + error.as_ref() as &dyn std::error::Error, + )); + } + None => { + break; + } + } + } + } + } + + if continue_after_interrupt { + chat_request = ChatRequest::new(Event::new(EventValue::text("")), session.conversation_id); + continue; + } + + return Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)); + } + } + + async fn handle_chat_response( + &self, + session_id: &acp::SessionId, + response: ChatResponse, + continue_after_interrupt: &mut bool, + ) -> std::result::Result<(), acp::Error> { + match response { + ChatResponse::TaskMessage { content } => { + self.handle_task_message(session_id, content).await?; + } + ChatResponse::TaskReasoning { content } => { + if !content.is_empty() { + let notification = acp::SessionNotification::new( + session_id.clone(), + acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new( + acp::ContentBlock::Text(acp::TextContent::new(content)), + )), + ); + self.send_notification(notification) + .map_err(acp::Error::from)?; + } + } + ChatResponse::ToolCallStart { tool_call, .. } => { + let notification = acp::SessionNotification::new( + session_id.clone(), + acp::SessionUpdate::ToolCallUpdate( + conversion::map_tool_call_to_acp(&tool_call).into(), + ), + ); + self.send_notification(notification) + .map_err(acp::Error::from)?; + } + ChatResponse::ToolCallEnd(tool_result) => { + let content = conversion::ToolOutputConverter::convert(&tool_result.output); + let status = if tool_result.output.is_error { + acp::ToolCallStatus::Failed + } else { + acp::ToolCallStatus::Completed + }; + let tool_call_id = tool_result + .call_id + .as_ref() + .map(|id| id.as_str().to_string()) + .unwrap_or_else(|| "unknown".to_string()); + let update = acp::ToolCallUpdate::new( + tool_call_id, + acp::ToolCallUpdateFields::new().status(status).content(content), + ); + let notification = acp::SessionNotification::new( + session_id.clone(), + acp::SessionUpdate::ToolCallUpdate(update), + ); + self.send_notification(notification) + .map_err(acp::Error::from)?; + } + ChatResponse::TaskComplete => {} + ChatResponse::RetryAttempt { .. } => {} + ChatResponse::Interrupt { reason } => { + let should_continue = self + .request_continue_permission(session_id, &reason) + .await + .map_err(acp::Error::from)?; + if should_continue { + *continue_after_interrupt = true; + } + } + } + + Ok(()) + } + + async fn handle_task_message( + &self, + session_id: &acp::SessionId, + content: ChatResponseContent, + ) -> std::result::Result<(), acp::Error> { + match content { + ChatResponseContent::ToolOutput(_) => {} + ChatResponseContent::Markdown { text, .. } => { + if !text.is_empty() { + let notification = acp::SessionNotification::new( + session_id.clone(), + acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new( + acp::ContentBlock::Text(acp::TextContent::new(text)), + )), + ); + self.send_notification(notification) + .map_err(acp::Error::from)?; + } + } + ChatResponseContent::ToolInput(_) => {} + } + + Ok(()) + } + + async fn request_continue_permission( + &self, + session_id: &acp::SessionId, + reason: &InterruptionReason, + ) -> Result { + let client_conn = self.client_conn.lock().await; + let Some(conn) = client_conn.as_ref() else { + return Ok(false); + }; + + let (title, description) = format_interruption(reason); + let options = vec![ + acp::PermissionOption::new( + "continue", + "Continue Anyway", + acp::PermissionOptionKind::AllowOnce, + ), + acp::PermissionOption::new("stop", "Stop", acp::PermissionOptionKind::RejectOnce), + ]; + let tool_call_update = acp::ToolCallUpdate::new( + "interrupt-continue", + acp::ToolCallUpdateFields::new() + .status(acp::ToolCallStatus::Pending) + .title(title.clone()), + ); + + let mut request = acp::RequestPermissionRequest::new( + session_id.clone(), + tool_call_update, + options, + ); + let mut meta = serde_json::Map::new(); + meta.insert("title".to_string(), serde_json::json!(title)); + meta.insert("description".to_string(), serde_json::json!(description)); + request = request.meta(meta); + + let response = conn.request_permission(request).await.map_err(|error| { + Error::Application(anyhow::anyhow!("Permission request failed: {}", error)) + })?; + + match response.outcome { + acp::RequestPermissionOutcome::Selected(selection) => { + Ok(selection.option_id.0.as_ref() == "continue") + } + acp::RequestPermissionOutcome::Cancelled => Ok(false), + _ => Ok(false), + } + } +} + +fn format_interruption(reason: &InterruptionReason) -> (String, String) { + match reason { + InterruptionReason::MaxToolFailurePerTurnLimitReached { limit, errors } => { + let error_summary = errors + .iter() + .map(|(tool_name, count)| format!("{} ({})", tool_name, count)) + .collect::>() + .join(", "); + ( + format!("Tool failure limit reached ({})", limit), + format!("Forge stopped after repeated tool failures: {}", error_summary), + ) + } + InterruptionReason::MaxRequestPerTurnLimitReached { limit } => ( + format!("Request limit reached ({})", limit), + "Forge reached the maximum number of requests for this turn.".to_string(), + ), + } +} \ No newline at end of file diff --git a/crates/forge_app/src/acp/session_handlers.rs b/crates/forge_app/src/acp/session_handlers.rs new file mode 100644 index 0000000000..4d6ce8b7ff --- /dev/null +++ b/crates/forge_app/src/acp/session_handlers.rs @@ -0,0 +1,225 @@ +use agent_client_protocol as acp; +use forge_domain::{AgentId, Conversation, ConversationId, ModelId}; + +use crate::{AgentRegistry, AppConfigService, ConversationService, Services}; + +use super::adapter::{AcpAdapter, SessionState}; +use super::state_builders::StateBuilders; + +const VERSION: &str = env!("CARGO_PKG_VERSION"); + +impl AcpAdapter { + pub(super) async fn handle_initialize( + &self, + arguments: acp::InitializeRequest, + ) -> std::result::Result { + tracing::info!("Received initialize request from client: {:?}", arguments.client_info); + + Ok(acp::InitializeResponse::new(acp::ProtocolVersion::V1) + .agent_capabilities( + acp::AgentCapabilities::new().load_session(true).mcp_capabilities( + acp::McpCapabilities::new() + .http(true) + .sse(true), + ), + ) + .agent_info( + acp::Implementation::new("forge".to_string(), VERSION.to_string()) + .title("Forge Code".to_string()), + )) + } + + pub(super) async fn handle_authenticate( + &self, + _arguments: acp::AuthenticateRequest, + ) -> std::result::Result { + Ok(acp::AuthenticateResponse::default()) + } + + pub(super) async fn handle_new_session( + &self, + arguments: acp::NewSessionRequest, + ) -> std::result::Result { + if !arguments.mcp_servers.is_empty() { + StateBuilders::load_mcp_servers(self.services.as_ref(), &arguments.mcp_servers) + .await + .map_err(acp::Error::from)?; + } + + let active_agent_id = self + .services + .agent_registry() + .get_active_agent_id() + .await + .map_err(|error| acp::Error::into_internal_error(&*error))? + .unwrap_or_default(); + + let conversation = Conversation::generate(); + let conversation_id = conversation.id; + self.services + .conversation_service() + .upsert_conversation(conversation) + .await + .map_err(|error| acp::Error::into_internal_error(&*error))?; + + let session_id = acp::SessionId::new(conversation_id.into_string()); + let session_key = session_id.0.as_ref().to_string(); + self.store_session( + session_key, + SessionState { + conversation_id, + agent_id: active_agent_id.clone(), + cancel_notify: None, + }, + ) + .await; + + let agent = self + .services + .agent_registry() + .get_agent(&active_agent_id) + .await + .map_err(|error| acp::Error::into_internal_error(&*error))? + .ok_or_else(|| { + acp::Error::into_internal_error(&*anyhow::anyhow!( + "Agent '{}' not found", + active_agent_id + )) + })?; + + let mode_state = StateBuilders::build_session_mode_state( + self.services.as_ref(), + &active_agent_id, + ) + .await + .map_err(acp::Error::from)?; + let model_state = StateBuilders::build_session_model_state(&self.services, &agent) + .await + .map_err(acp::Error::from)?; + + Ok(acp::NewSessionResponse::new(session_id) + .modes(mode_state) + .models(model_state)) + } + + pub(super) async fn handle_load_session( + &self, + arguments: acp::LoadSessionRequest, + ) -> std::result::Result { + if !arguments.mcp_servers.is_empty() { + StateBuilders::load_mcp_servers(self.services.as_ref(), &arguments.mcp_servers) + .await + .map_err(acp::Error::from)?; + } + + let session_key = arguments.session_id.0.as_ref().to_string(); + let conversation_id = ConversationId::parse(&session_key) + .map_err(|error| acp::Error::into_internal_error(&error))?; + + let conversation = self + .services + .conversation_service() + .find_conversation(&conversation_id) + .await + .map_err(|error| acp::Error::into_internal_error(&*error))?; + if conversation.is_none() { + return Err(acp::Error::invalid_params()); + } + + let active_agent_id = self + .services + .agent_registry() + .get_active_agent_id() + .await + .map_err(|error| acp::Error::into_internal_error(&*error))? + .unwrap_or_default(); + let state = self + .ensure_session(&session_key, conversation_id, active_agent_id.clone()) + .await; + + let agent = self + .services + .agent_registry() + .get_agent(&state.agent_id) + .await + .map_err(|error| acp::Error::into_internal_error(&*error))? + .ok_or_else(|| acp::Error::invalid_params())?; + + let mode_state = StateBuilders::build_session_mode_state( + self.services.as_ref(), + &state.agent_id, + ) + .await + .map_err(acp::Error::from)?; + let model_state = StateBuilders::build_session_model_state(&self.services, &agent) + .await + .map_err(acp::Error::from)?; + + Ok(acp::LoadSessionResponse::new() + .modes(mode_state) + .models(model_state)) + } + + pub(super) async fn handle_cancel( + &self, + arguments: acp::CancelNotification, + ) -> std::result::Result<(), acp::Error> { + let session_key = arguments.session_id.0.as_ref().to_string(); + let cancelled = self.cancel_session(&session_key).await; + if !cancelled { + tracing::warn!("No active ACP prompt to cancel for session {}", session_key); + } + Ok(()) + } + + pub(super) async fn handle_set_session_mode( + &self, + arguments: acp::SetSessionModeRequest, + ) -> std::result::Result { + let session_key = arguments.session_id.0.as_ref().to_string(); + let mode_id = arguments.mode_id.0.as_ref(); + let agent_id = AgentId::new(mode_id); + + self.update_session_agent(&session_key, agent_id.clone()) + .await + .map_err(acp::Error::from)?; + + let notification = acp::SessionNotification::new( + arguments.session_id, + acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate::new( + acp::SessionModeId::new(mode_id.to_string()), + )), + ); + self.send_notification(notification) + .map_err(acp::Error::from)?; + + Ok(acp::SetSessionModeResponse::new()) + } + + pub(super) async fn handle_set_session_model( + &self, + arguments: acp::SetSessionModelRequest, + ) -> std::result::Result { + let model_id = ModelId::new(arguments.model_id.0.to_string()); + self.services + .set_default_model(model_id.clone()) + .await + .map_err(|error| acp::Error::into_internal_error(&*error))?; + let _ = self.services.reload_agents().await; + + let notification = acp::SessionNotification::new( + arguments.session_id, + acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new( + acp::ContentBlock::Text(acp::TextContent::new(format!( + "Model changed to: {}\n\n", + model_id + ))), + )), + ); + if let Err(error) = self.send_notification(notification) { + tracing::warn!("Failed to send model change notification: {}", error); + } + + Ok(acp::SetSessionModelResponse::default()) + } +} \ No newline at end of file diff --git a/crates/forge_app/src/acp/state_builders.rs b/crates/forge_app/src/acp/state_builders.rs new file mode 100644 index 0000000000..1301ae9323 --- /dev/null +++ b/crates/forge_app/src/acp/state_builders.rs @@ -0,0 +1,180 @@ +use std::collections::BTreeMap; +use std::sync::Arc; + +use agent_client_protocol as acp; +use forge_domain::{Agent, AgentId, McpHttpServer, McpServerConfig, Scope, ServerName}; + +use crate::{ + AgentProviderResolver, AgentRegistry, McpConfigManager, McpService, ProviderAuthService, + ProviderService, Services, +}; + +use super::conversion; +use super::error::{Error, Result}; + +pub(super) struct StateBuilders; + +impl StateBuilders { + pub(super) async fn build_session_mode_state( + services: &S, + current_agent_id: &AgentId, + ) -> Result { + let agents = services + .agent_registry() + .get_agents() + .await + .map_err(Error::Application)?; + + Ok(conversion::build_session_mode_state( + &agents, + current_agent_id, + )) + } + + pub(super) async fn build_session_model_state( + services: &Arc, + current_agent: &Agent, + ) -> Result { + let agent_provider_resolver = AgentProviderResolver::new(services.clone()); + let provider = agent_provider_resolver + .get_provider(Some(current_agent.id.clone())) + .await + .map_err(Error::Application)?; + let provider = services + .provider_auth_service() + .refresh_provider_credential(provider) + .await + .map_err(Error::Application)?; + + let mut models = services + .provider_service() + .models(provider) + .await + .map_err(Error::Application)?; + models.sort_by(|left, right| left.name.cmp(&right.name)); + + let available_models = models + .iter() + .map(|model| { + let mut model_info = acp::ModelInfo::new( + model.id.to_string(), + model.name.clone().unwrap_or_else(|| model.id.to_string()), + ) + .description(model.description.clone()); + + let mut meta = serde_json::Map::new(); + if let Some(context_length) = model.context_length { + meta.insert( + "contextLength".to_string(), + serde_json::json!(context_length), + ); + } + if let Some(tools_supported) = model.tools_supported { + meta.insert( + "toolsSupported".to_string(), + serde_json::json!(tools_supported), + ); + } + if let Some(supports_reasoning) = model.supports_reasoning { + meta.insert( + "supportsReasoning".to_string(), + serde_json::json!(supports_reasoning), + ); + } + if !model.input_modalities.is_empty() { + let modalities = model + .input_modalities + .iter() + .map(|modality| format!("{:?}", modality).to_lowercase()) + .collect::>(); + meta.insert("inputModalities".to_string(), serde_json::json!(modalities)); + } + if !meta.is_empty() { + model_info = model_info.meta(meta); + } + + model_info + }) + .collect(); + + Ok( + acp::SessionModelState::new(current_agent.model.to_string(), available_models).meta({ + let mut meta = serde_json::Map::new(); + meta.insert("searchable".to_string(), serde_json::json!(true)); + meta.insert("searchThreshold".to_string(), serde_json::json!(10)); + meta.insert("filterable".to_string(), serde_json::json!(true)); + meta.insert("groupBy".to_string(), serde_json::json!("provider")); + meta + }), + ) + } + + pub(super) async fn load_mcp_servers( + services: &S, + mcp_servers: &[acp::McpServer], + ) -> Result<()> { + let mut config = services + .mcp_config_manager() + .read_mcp_config(Some(&Scope::Local)) + .await + .map_err(Error::Application)?; + + for server in mcp_servers { + let (name, server_config) = Self::acp_to_mcp_server_config(server)?; + config.mcp_servers.insert(name, server_config); + } + + services + .mcp_config_manager() + .write_mcp_config(&config, &Scope::Local) + .await + .map_err(Error::Application)?; + services.mcp_service().reload_mcp().await.map_err(Error::Application)?; + Ok(()) + } + + fn acp_to_mcp_server_config(server: &acp::McpServer) -> Result<(ServerName, McpServerConfig)> { + match server { + acp::McpServer::Stdio(stdio) => { + let env = stdio + .env + .iter() + .map(|entry| (entry.name.clone(), entry.value.clone())) + .collect::>(); + Ok(( + ServerName::from(stdio.name.clone()), + McpServerConfig::new_stdio(stdio.command.to_string_lossy().to_string(), stdio.args.clone(), Some(env)), + )) + } + acp::McpServer::Http(http) => Ok(( + ServerName::from(http.name.clone()), + McpServerConfig::Http(McpHttpServer { + url: http.url.clone(), + headers: http + .headers + .iter() + .map(|header| (header.name.clone(), header.value.clone())) + .collect(), + timeout: None, + disable: false, + }), + )), + acp::McpServer::Sse(sse) => Ok(( + ServerName::from(sse.name.clone()), + McpServerConfig::Http(McpHttpServer { + url: sse.url.clone(), + headers: sse + .headers + .iter() + .map(|header| (header.name.clone(), header.value.clone())) + .collect(), + timeout: None, + disable: false, + }), + )), + _ => Err(Error::Application(anyhow::anyhow!( + "Unsupported MCP server type" + ))), + } + } +} \ No newline at end of file diff --git a/crates/forge_app/src/acp_app.rs b/crates/forge_app/src/acp_app.rs new file mode 100644 index 0000000000..c86c99bb02 --- /dev/null +++ b/crates/forge_app/src/acp_app.rs @@ -0,0 +1,89 @@ +use std::sync::Arc; + +use anyhow::Result; + +use crate::Services; + +/// ACP (Agent Communication Protocol) application orchestrator. +pub struct AcpApp { + services: Arc, +} + +impl AcpApp { + /// Creates a new ACP application orchestrator. + pub fn new(services: Arc) -> Self { + Self { services } + } + + /// Starts the ACP server over stdio transport. + pub async fn start_stdio(&self) -> Result<()> { + use agent_client_protocol as acp; + use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; + + let services = self.services.clone(); + let handle = tokio::task::spawn_blocking(move || { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("Failed to create Tokio runtime"); + + rt.block_on(async move { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let adapter = Arc::new(crate::acp::AcpAdapter::new(services, tx)); + + let local_set = tokio::task::LocalSet::new(); + local_set + .run_until(async move { + let outgoing = tokio::io::stdout().compat_write(); + let incoming = tokio::io::stdin().compat(); + + let (conn, handle_io) = acp::AgentSideConnection::new( + adapter.clone(), + outgoing, + incoming, + |fut| { + tokio::task::spawn_local(fut); + }, + ); + + let conn = Arc::new(conn); + adapter.set_client_connection(conn.clone()).await; + + let conn_for_notifications = conn.clone(); + let notification_task = tokio::task::spawn_local(async move { + let mut rx = rx; + while let Some(session_notification) = rx.recv().await { + use agent_client_protocol::Client; + + if let Err(error) = conn_for_notifications + .session_notification(session_notification) + .await + { + tracing::error!( + "Failed to send session notification: {}", + error + ); + break; + } + } + }); + + let io_result = handle_io.await; + notification_task.abort(); + + io_result.map_err(|error| anyhow::anyhow!("ACP transport error: {}", error)) + }) + .await + }) + }); + + match handle.await { + Ok(result) => result, + Err(error) if error.is_cancelled() => { + tracing::info!("ACP server task was cancelled"); + Ok(()) + } + Err(error) => Err(anyhow::anyhow!("ACP server task panicked: {}", error)), + } + } +} \ No newline at end of file diff --git a/crates/forge_app/src/lib.rs b/crates/forge_app/src/lib.rs index 1b3295498c..21e419ea48 100644 --- a/crates/forge_app/src/lib.rs +++ b/crates/forge_app/src/lib.rs @@ -1,3 +1,5 @@ +mod acp; +mod acp_app; mod agent; mod agent_executor; mod agent_provider_resolver; @@ -38,6 +40,7 @@ pub mod utils; mod walker; mod workspace_status; +pub use acp_app::*; pub use agent::*; pub use agent_provider_resolver::*; pub use app::*; diff --git a/crates/forge_main/src/acp_runner.rs b/crates/forge_main/src/acp_runner.rs new file mode 100644 index 0000000000..e04ea42e31 --- /dev/null +++ b/crates/forge_main/src/acp_runner.rs @@ -0,0 +1,57 @@ +use std::future::Future; + +use anyhow::Result; +use forge_api::API; + +pub trait MachineStdioApi { + fn acp_start_stdio(&self) -> impl Future> + Send; +} + +impl MachineStdioApi for T { + fn acp_start_stdio(&self) -> impl Future> + Send { + API::acp_start_stdio(self) + } +} + +pub async fn run_machine_stdio_server(api: &A) -> Result<()> { + api.acp_start_stdio().await +} + +#[cfg(test)] +mod tests { + use std::sync::{Arc, atomic::{AtomicBool, Ordering}}; + + use anyhow::Result; + + use super::{MachineStdioApi, run_machine_stdio_server}; + + struct MockApi { + called: Arc, + } + + impl MockApi { + fn new(called: Arc) -> Self { + Self { called } + } + } + + impl MachineStdioApi for MockApi { + fn acp_start_stdio(&self) -> impl std::future::Future> + Send { + self.called.store(true, Ordering::SeqCst); + async { Ok(()) } + } + } + + #[tokio::test] + async fn test_run_machine_stdio_server_delegates_to_api_transport() -> Result<()> { + let called = Arc::new(AtomicBool::new(false)); + let fixture = MockApi::new(called.clone()); + + run_machine_stdio_server(&fixture).await?; + + let actual = called.load(Ordering::SeqCst); + let expected = true; + assert_eq!(actual, expected); + Ok(()) + } +} \ No newline at end of file diff --git a/crates/forge_main/src/cli.rs b/crates/forge_main/src/cli.rs index 4cb6ff66f6..6d95e12874 100644 --- a/crates/forge_main/src/cli.rs +++ b/crates/forge_main/src/cli.rs @@ -82,6 +82,9 @@ pub enum TopLevelCommand { /// Manage agents. Agent(AgentCommandGroup), + /// Run machine-oriented commands. + Machine(MachineCommandGroup), + /// Generate shell extension scripts. #[command(subcommand, alias = "extension")] Zsh(ZshCommandGroup), @@ -203,6 +206,19 @@ pub enum AgentCommand { List, } +/// Command group for machine-oriented interfaces. +#[derive(Parser, Debug, Clone)] +pub struct MachineCommandGroup { + #[command(subcommand)] + pub command: MachineCommand, +} + +#[derive(Subcommand, Debug, Clone)] +pub enum MachineCommand { + /// Run the machine interface over stdio. + Stdio, +} + /// Command group for workspace management. #[derive(Parser, Debug, Clone)] pub struct WorkspaceCommandGroup { @@ -1864,4 +1880,25 @@ mod tests { }; assert!(!actual); } + + #[test] + fn test_machine_stdio_command() { + let fixture = Cli::parse_from(["forge", "machine", "stdio"]); + let actual = matches!( + fixture.subcommands, + Some(TopLevelCommand::Machine(MachineCommandGroup { + command: MachineCommand::Stdio, + })) + ); + let expected = true; + assert_eq!(actual, expected); + } + + #[test] + fn test_machine_stdio_is_not_interactive() { + let fixture = Cli::parse_from(["forge", "machine", "stdio"]); + let actual = fixture.is_interactive(); + let expected = false; + assert_eq!(actual, expected); + } } diff --git a/crates/forge_main/src/lib.rs b/crates/forge_main/src/lib.rs index 1fc22a116d..1690693660 100644 --- a/crates/forge_main/src/lib.rs +++ b/crates/forge_main/src/lib.rs @@ -1,4 +1,5 @@ pub mod banner; +mod acp_runner; mod cli; mod completer; mod conversation_selector; diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 3d6b946bac..c95d2daf3f 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -388,6 +388,14 @@ impl A + Send + Sync> UI { } return Ok(()); } + TopLevelCommand::Machine(machine_group) => { + match machine_group.command { + crate::cli::MachineCommand::Stdio => { + crate::acp_runner::run_machine_stdio_server(self.api.as_ref()).await?; + return Ok(()); + } + } + } TopLevelCommand::List(list_group) => { let porcelain = list_group.porcelain; match list_group.command { From 8fe513fabf91987d440a805b6e1d8a73ef04f230 Mon Sep 17 00:00:00 2001 From: flazouh Date: Mon, 6 Apr 2026 01:01:07 +0300 Subject: [PATCH 2/4] fix(acp): harden machine stdio transport - Replace unbounded notification channel with bounded (1024) to apply backpressure when the client stalls - Add per-session model override to prevent concurrent sessions from interfering with each other - Replace From impl with explicit into_acp_error() per project guidelines - Extract classify_mcp_tool() and convert to free functions, removing the unnecessary ToolOutputConverter struct - Validate MCP server names (length, charset) to prevent injection - Add MAX_BLOB_SIZE (50 MB) guard on base64-decoded resources - Add I/O timeout (5 min) and graceful shutdown drain (5 s) to prevent indefinite hangs - Track cancellation via AtomicBool across loop iterations - Log warnings instead of silently ignoring reload/config errors - Add tests for tool kind mapping, file extraction, and edge cases Co-Authored-By: ForgeCode --- crates/forge_app/src/acp/adapter.rs | 47 +++- crates/forge_app/src/acp/conversion.rs | 246 ++++++++++++------- crates/forge_app/src/acp/error.rs | 20 +- crates/forge_app/src/acp/mod.rs | 2 +- crates/forge_app/src/acp/prompt_handler.rs | 42 +++- crates/forge_app/src/acp/session_handlers.rs | 45 +++- crates/forge_app/src/acp/state_builders.rs | 117 ++++++--- crates/forge_app/src/acp_app.rs | 43 +++- crates/forge_main/src/acp_runner.rs | 4 +- crates/forge_main/src/cli.rs | 1 + 10 files changed, 396 insertions(+), 171 deletions(-) diff --git a/crates/forge_app/src/acp/adapter.rs b/crates/forge_app/src/acp/adapter.rs index d7c1ebc73d..9ac30b45bf 100644 --- a/crates/forge_app/src/acp/adapter.rs +++ b/crates/forge_app/src/acp/adapter.rs @@ -2,38 +2,46 @@ use std::collections::HashMap; use std::sync::Arc; use agent_client_protocol as acp; -use forge_domain::{AgentId, ConversationId}; +use forge_domain::{AgentId, ConversationId, ModelId}; use tokio::sync::{Mutex, Notify, mpsc}; use crate::Services; use super::error::{Error, Result}; +/// Maximum number of buffered session notifications before backpressure. +const NOTIFICATION_CHANNEL_CAPACITY: usize = 1024; + #[derive(Clone)] pub(super) struct SessionState { pub conversation_id: ConversationId, pub agent_id: AgentId, + /// Session-scoped model override. When set, prompts use this model + /// instead of the global default. + pub model_id: Option, pub cancel_notify: Option>, } pub(crate) struct AcpAdapter { pub(super) services: Arc, - pub(super) session_update_tx: mpsc::UnboundedSender, + pub(super) session_update_tx: mpsc::Sender, pub(super) client_conn: Arc>>>, sessions: Arc>>, } impl AcpAdapter { + /// Creates a new ACP adapter and returns the notification receiver. pub(crate) fn new( services: Arc, - session_update_tx: mpsc::UnboundedSender, - ) -> Self { - Self { + ) -> (Self, mpsc::Receiver) { + let (tx, rx) = mpsc::channel(NOTIFICATION_CHANNEL_CAPACITY); + let adapter = Self { services, - session_update_tx, + session_update_tx: tx, client_conn: Arc::new(Mutex::new(None)), sessions: Arc::new(Mutex::new(HashMap::new())), - } + }; + (adapter, rx) } pub(crate) async fn set_client_connection(&self, conn: Arc) { @@ -44,6 +52,13 @@ impl AcpAdapter { self.sessions.lock().await.insert(session_id, state); } + /// Removes a session from the adapter. Currently unused but available + /// for future session lifecycle management (TTL, explicit close). + #[allow(dead_code)] + pub(super) async fn remove_session(&self, session_id: &str) { + self.sessions.lock().await.remove(session_id); + } + pub(super) async fn session_state(&self, session_id: &str) -> Result { self.sessions .lock() @@ -66,6 +81,19 @@ impl AcpAdapter { Ok(()) } + pub(super) async fn update_session_model( + &self, + session_id: &str, + model_id: ModelId, + ) -> Result<()> { + let mut sessions = self.sessions.lock().await; + let state = sessions + .get_mut(session_id) + .ok_or_else(|| Error::Application(anyhow::anyhow!("Session not found")))?; + state.model_id = Some(model_id); + Ok(()) + } + pub(super) async fn set_cancel_notify( &self, session_id: &str, @@ -107,6 +135,7 @@ impl AcpAdapter { .or_insert_with(|| SessionState { conversation_id, agent_id, + model_id: None, cancel_notify: None, }) .clone() @@ -114,7 +143,7 @@ impl AcpAdapter { pub(super) fn send_notification(&self, notification: acp::SessionNotification) -> Result<()> { self.session_update_tx - .send(notification) + .try_send(notification) .map_err(|_| Error::Application(anyhow::anyhow!("Failed to send notification"))) } -} \ No newline at end of file +} diff --git a/crates/forge_app/src/acp/conversion.rs b/crates/forge_app/src/acp/conversion.rs index 31a9021ce0..640fe42868 100644 --- a/crates/forge_app/src/acp/conversion.rs +++ b/crates/forge_app/src/acp/conversion.rs @@ -8,6 +8,16 @@ use forge_domain::{ use super::error::{Error, Result}; +/// Maximum size in bytes for base64-encoded blob resources. +/// Protects against OOM from oversized client payloads. +const MAX_BLOB_SIZE: usize = 50 * 1024 * 1024; // 50 MB + +/// Maps a Forge tool name to an ACP ToolKind. +/// +/// Native Forge tools are classified by exact match. MCP tools (prefixed +/// with `mcp_`) use best-effort keyword heuristics and default to `Other` +/// when the name is ambiguous. The heuristic is order-dependent: the first +/// matching keyword category wins. pub(crate) fn map_tool_kind(tool_name: &ToolName) -> acp::ToolKind { match tool_name.as_str() { "read" => acp::ToolKind::Read, @@ -17,62 +27,51 @@ pub(crate) fn map_tool_kind(tool_name: &ToolName) -> acp::ToolKind { "shell" => acp::ToolKind::Execute, "fetch" => acp::ToolKind::Fetch, "sage" => acp::ToolKind::Think, - _ => { - let name = tool_name.as_str(); - if name.starts_with("mcp_") { - if name.contains("read") - || name.contains("get") - || name.contains("fetch") - || name.contains("list") - || name.contains("show") - || name.contains("view") - || name.contains("load") - { - acp::ToolKind::Read - } else if name.contains("search") - || name.contains("query") - || name.contains("find") - || name.contains("filter") - || name.contains("lookup") - { - acp::ToolKind::Search - } else if name.contains("write") - || name.contains("update") - || name.contains("create") - || name.contains("set") - || name.contains("add") - || name.contains("insert") - || name.contains("push") - || name.contains("merge") - || name.contains("fork") - || name.contains("comment") - || name.contains("assign") - || name.contains("request") - { - acp::ToolKind::Edit - } else if name.contains("delete") - || name.contains("remove") - || name.contains("drop") - || name.contains("clear") - || name.contains("close") - || name.contains("cancel") - { - acp::ToolKind::Delete - } else if name.contains("execute") - || name.contains("run") - || name.contains("start") - || name.contains("invoke") - || name.contains("call") - { - acp::ToolKind::Execute - } else { - acp::ToolKind::Other - } - } else { - acp::ToolKind::Other - } + _ => classify_mcp_tool(tool_name.as_str()), + } +} + +/// Best-effort classification for MCP tools by keyword heuristic. +/// +/// Falls back to `Other` for non-MCP tools or when no keyword matches. +/// The match order matters: a tool named `mcp_get_search_results` would +/// classify as `Read` (matches "get" before "search"). +fn classify_mcp_tool(name: &str) -> acp::ToolKind { + if !name.starts_with("mcp_") { + return acp::ToolKind::Other; + } + + // Strip the "mcp__" prefix to get the action portion. + // E.g. "mcp_github_list_issues" → check against "list_issues". + let action = name + .strip_prefix("mcp_") + .and_then(|rest| rest.split_once('_').map(|(_, action)| action)) + .unwrap_or(name); + + const READ_KEYWORDS: &[&str] = &["read", "get", "fetch", "list", "show", "view", "load"]; + const SEARCH_KEYWORDS: &[&str] = &["search", "query", "find", "filter", "lookup"]; + const EDIT_KEYWORDS: &[&str] = &[ + "write", "update", "create", "set", "add", "insert", "push", "merge", + "fork", "comment", "assign", "request", + ]; + const DELETE_KEYWORDS: &[&str] = &["delete", "remove", "drop", "clear", "close", "cancel"]; + const EXECUTE_KEYWORDS: &[&str] = &["execute", "run", "start", "invoke", "call"]; + + let checks: &[(&[&str], acp::ToolKind)] = &[ + (READ_KEYWORDS, acp::ToolKind::Read), + (SEARCH_KEYWORDS, acp::ToolKind::Search), + (EDIT_KEYWORDS, acp::ToolKind::Edit), + (DELETE_KEYWORDS, acp::ToolKind::Delete), + (EXECUTE_KEYWORDS, acp::ToolKind::Execute), + ]; + + for (keywords, kind) in checks { + if keywords.iter().any(|kw| action.contains(kw)) { + return kind.clone(); } } + + acp::ToolKind::Other } pub(crate) fn extract_file_locations( @@ -112,44 +111,33 @@ pub(crate) fn map_tool_call_to_acp(tool_call: &ToolCallFull) -> acp::ToolCall { ) } -pub(crate) struct ToolOutputConverter { - _private: (), +/// Converts a ToolOutput into ACP content blocks. +pub(crate) fn convert_tool_output(output: &ToolOutput) -> Vec { + output + .values + .iter() + .filter_map(convert_tool_value) + .collect() } -impl ToolOutputConverter { - pub(crate) fn new(output: &ToolOutput) -> Self { - let _ = output; - Self { _private: () } - } - - pub(crate) fn convert(output: &ToolOutput) -> Vec { - let converter = Self::new(output); - output - .values - .iter() - .filter_map(|value| converter.convert_value(value)) - .collect() - } - - fn convert_value(&self, value: &ToolValue) -> Option { - match value { - ToolValue::Text(text) => self.convert_text(text), - ToolValue::AI { value, .. } => self.convert_text(value), - ToolValue::Image(image) => Some(acp::ToolCallContent::Content(acp::Content::new( - acp::ContentBlock::Image(acp::ImageContent::new(image.data(), image.mime_type())), - ))), - ToolValue::Empty => None, - } +fn convert_tool_value(value: &ToolValue) -> Option { + match value { + ToolValue::Text(text) => convert_text(text), + ToolValue::AI { value, .. } => convert_text(value), + ToolValue::Image(image) => Some(acp::ToolCallContent::Content(acp::Content::new( + acp::ContentBlock::Image(acp::ImageContent::new(image.data(), image.mime_type())), + ))), + ToolValue::Empty => None, } +} - fn convert_text(&self, text: &str) -> Option { - if text.is_empty() { - None - } else { - Some(acp::ToolCallContent::Content(acp::Content::new( - acp::ContentBlock::Text(acp::TextContent::new(text.to_string())), - ))) - } +fn convert_text(text: &str) -> Option { + if text.is_empty() { + None + } else { + Some(acp::ToolCallContent::Content(acp::Content::new( + acp::ContentBlock::Text(acp::TextContent::new(text.to_string())), + ))) } } @@ -159,6 +147,12 @@ pub(crate) fn acp_resource_to_attachment(resource: &acp::EmbeddedResource) -> Re (text_resource.text.clone(), text_resource.uri.clone()) } acp::EmbeddedResourceResource::BlobResourceContents(blob_resource) => { + if blob_resource.blob.len() > MAX_BLOB_SIZE { + return Err(Error::Application(anyhow::anyhow!( + "Blob resource exceeds maximum size of {} bytes", + MAX_BLOB_SIZE + ))); + } let decoded = base64::Engine::decode( &base64::engine::general_purpose::STANDARD, &blob_resource.blob, @@ -237,11 +231,19 @@ mod tests { assert_eq!(actual, expected); } + #[test] + fn test_uri_to_path_strips_file_prefix() { + let fixture = "file:///home/user/file.txt"; + let actual = uri_to_path(fixture); + let expected = "/home/user/file.txt".to_string(); + assert_eq!(actual, expected); + } + #[test] fn test_markdown_sent_to_acp_not_xml() { let fixture = ToolOutput::text("## File: test.txt\n\nContent here"); - let actual = ToolOutputConverter::convert(&fixture); + let actual = convert_tool_output(&fixture); assert_eq!(actual.len(), 1); if let Some(acp::ToolCallContent::Content(content)) = actual.first() { @@ -259,7 +261,7 @@ mod tests { fn test_ai_output_sent_to_acp_as_text() { let fixture = ToolOutput::ai(ConversationId::generate(), "Agent result"); - let actual = ToolOutputConverter::convert(&fixture); + let actual = convert_tool_output(&fixture); assert_eq!(actual.len(), 1); if let Some(acp::ToolCallContent::Content(content)) = actual.first() { @@ -278,7 +280,7 @@ mod tests { let image = Image::new_bytes(vec![1, 2, 3, 4], "image/png".to_string()); let fixture = ToolOutput::image(image); - let actual = ToolOutputConverter::convert(&fixture); + let actual = convert_tool_output(&fixture); assert_eq!(actual.len(), 1); if let Some(acp::ToolCallContent::Content(content)) = actual.first() { @@ -287,4 +289,64 @@ mod tests { panic!("Expected content"); } } -} \ No newline at end of file + + #[test] + fn test_empty_output_produces_no_content() { + let fixture = ToolOutput::text(""); + let actual = convert_tool_output(&fixture); + let expected: Vec = vec![]; + assert_eq!(actual.len(), expected.len()); + } + + #[test] + fn test_map_tool_kind_native_tools() { + let fixture = ToolName::new("read"); + let actual = map_tool_kind(&fixture); + assert!(matches!(actual, acp::ToolKind::Read)); + } + + #[test] + fn test_map_tool_kind_mcp_read() { + let fixture = ToolName::new("mcp_github_list_issues"); + let actual = map_tool_kind(&fixture); + assert!(matches!(actual, acp::ToolKind::Read)); + } + + #[test] + fn test_map_tool_kind_mcp_search() { + let fixture = ToolName::new("mcp_db_search_records"); + let actual = map_tool_kind(&fixture); + assert!(matches!(actual, acp::ToolKind::Search)); + } + + #[test] + fn test_map_tool_kind_unknown_defaults_to_other() { + let fixture = ToolName::new("mcp_custom_foobar"); + let actual = map_tool_kind(&fixture); + assert!(matches!(actual, acp::ToolKind::Other)); + } + + #[test] + fn test_map_tool_kind_non_mcp_unknown() { + let fixture = ToolName::new("custom_tool"); + let actual = map_tool_kind(&fixture); + assert!(matches!(actual, acp::ToolKind::Other)); + } + + #[test] + fn test_extract_file_locations_read_tool() { + let fixture_name = ToolName::new("read"); + let fixture_args = serde_json::json!({"file_path": "/tmp/test.rs"}); + let actual = extract_file_locations(&fixture_name, &fixture_args); + assert_eq!(actual.len(), 1); + } + + #[test] + fn test_extract_file_locations_unknown_tool() { + let fixture_name = ToolName::new("shell"); + let fixture_args = serde_json::json!({"command": "ls"}); + let actual = extract_file_locations(&fixture_name, &fixture_args); + let expected: Vec = vec![]; + assert_eq!(actual.len(), expected.len()); + } +} diff --git a/crates/forge_app/src/acp/error.rs b/crates/forge_app/src/acp/error.rs index 1dbf4a0696..9a452b4a93 100644 --- a/crates/forge_app/src/acp/error.rs +++ b/crates/forge_app/src/acp/error.rs @@ -14,14 +14,16 @@ pub enum Error { Io(#[from] std::io::Error), } -impl From for acp::Error { - fn from(error: Error) -> Self { - match error { - Error::Protocol(error) => error, - Error::Application(error) => { - acp::Error::into_internal_error(error.as_ref() as &dyn std::error::Error) - } - Error::Io(error) => acp::Error::into_internal_error(&error), +/// Converts a domain Error into an acp::Error. +/// +/// AGENTS.md forbids blanket `From` impls for domain error conversion. +/// Call this explicitly at each `.map_err()` site instead. +pub fn into_acp_error(error: Error) -> acp::Error { + match error { + Error::Protocol(error) => error, + Error::Application(error) => { + acp::Error::into_internal_error(error.as_ref() as &dyn std::error::Error) } + Error::Io(error) => acp::Error::into_internal_error(&error), } -} \ No newline at end of file +} diff --git a/crates/forge_app/src/acp/mod.rs b/crates/forge_app/src/acp/mod.rs index d90a00fda9..716d103c20 100644 --- a/crates/forge_app/src/acp/mod.rs +++ b/crates/forge_app/src/acp/mod.rs @@ -85,4 +85,4 @@ impl agent_client_protocol::Agent for AcpAdapter { > { self.handle_set_session_model(arguments).await } -} \ No newline at end of file +} diff --git a/crates/forge_app/src/acp/prompt_handler.rs b/crates/forge_app/src/acp/prompt_handler.rs index 08bfba5f0b..d057b76c54 100644 --- a/crates/forge_app/src/acp/prompt_handler.rs +++ b/crates/forge_app/src/acp/prompt_handler.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use agent_client_protocol as acp; use agent_client_protocol::Client; @@ -12,7 +13,7 @@ use crate::{ForgeApp, Services}; use super::adapter::AcpAdapter; use super::conversion; -use super::error::{Error, Result}; +use super::error::{self, Error, Result}; impl AcpAdapter { pub(super) async fn handle_prompt( @@ -20,7 +21,7 @@ impl AcpAdapter { arguments: acp::PromptRequest, ) -> std::result::Result { let session_key = arguments.session_id.0.as_ref().to_string(); - let session = self.session_state(&session_key).await.map_err(acp::Error::from)?; + let session = self.session_state(&session_key).await.map_err(error::into_acp_error)?; let mut prompt_text_parts = Vec::new(); let mut attachments = Vec::new(); @@ -48,12 +49,21 @@ impl AcpAdapter { let prompt_text = prompt_text_parts.join("\n"); let cancel_notify = Arc::new(Notify::new()); + let cancelled = Arc::new(AtomicBool::new(false)); self.set_cancel_notify(&session_key, Some(cancel_notify.clone())) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; let response = self - .run_prompt_loop(&arguments.session_id, &session_key, session, prompt_text, attachments, cancel_notify) + .run_prompt_loop( + &arguments.session_id, + &session_key, + session, + prompt_text, + attachments, + cancel_notify, + cancelled, + ) .await; let _ = self.set_cancel_notify(&session_key, None).await; @@ -68,12 +78,21 @@ impl AcpAdapter { prompt_text: String, attachments: Vec, cancel_notify: Arc, + cancelled: Arc, ) -> std::result::Result { let mut event = Event::new(EventValue::text(prompt_text)); event.attachments = attachments; let mut chat_request = ChatRequest::new(event, session.conversation_id); loop { + // Check if cancellation was requested before starting a new + // chat round (handles the case where cancel arrives between + // loop iterations). + if cancelled.load(Ordering::SeqCst) { + tracing::info!("ACP prompt cancelled for session {}", session_key); + return Ok(acp::PromptResponse::new(acp::StopReason::Cancelled)); + } + let app = ForgeApp::new(self.services.clone()); let mut stream = app .chat(session.agent_id.clone(), chat_request) @@ -85,6 +104,7 @@ impl AcpAdapter { loop { tokio::select! { _ = cancel_notify.notified() => { + cancelled.store(true, Ordering::SeqCst); tracing::info!("ACP prompt cancelled for session {}", session_key); return Ok(acp::PromptResponse::new(acp::StopReason::Cancelled)); } @@ -135,7 +155,7 @@ impl AcpAdapter { )), ); self.send_notification(notification) - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; } } ChatResponse::ToolCallStart { tool_call, .. } => { @@ -146,10 +166,10 @@ impl AcpAdapter { ), ); self.send_notification(notification) - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; } ChatResponse::ToolCallEnd(tool_result) => { - let content = conversion::ToolOutputConverter::convert(&tool_result.output); + let content = conversion::convert_tool_output(&tool_result.output); let status = if tool_result.output.is_error { acp::ToolCallStatus::Failed } else { @@ -169,7 +189,7 @@ impl AcpAdapter { acp::SessionUpdate::ToolCallUpdate(update), ); self.send_notification(notification) - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; } ChatResponse::TaskComplete => {} ChatResponse::RetryAttempt { .. } => {} @@ -177,7 +197,7 @@ impl AcpAdapter { let should_continue = self .request_continue_permission(session_id, &reason) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; if should_continue { *continue_after_interrupt = true; } @@ -203,7 +223,7 @@ impl AcpAdapter { )), ); self.send_notification(notification) - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; } } ChatResponseContent::ToolInput(_) => {} @@ -280,4 +300,4 @@ fn format_interruption(reason: &InterruptionReason) -> (String, String) { "Forge reached the maximum number of requests for this turn.".to_string(), ), } -} \ No newline at end of file +} diff --git a/crates/forge_app/src/acp/session_handlers.rs b/crates/forge_app/src/acp/session_handlers.rs index 4d6ce8b7ff..4b3c7a576e 100644 --- a/crates/forge_app/src/acp/session_handlers.rs +++ b/crates/forge_app/src/acp/session_handlers.rs @@ -4,6 +4,7 @@ use forge_domain::{AgentId, Conversation, ConversationId, ModelId}; use crate::{AgentRegistry, AppConfigService, ConversationService, Services}; use super::adapter::{AcpAdapter, SessionState}; +use super::error; use super::state_builders::StateBuilders; const VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -29,10 +30,18 @@ impl AcpAdapter { )) } + /// Handles ACP authentication. + /// + /// This is intentionally a no-op. The stdio transport inherits OS-level + /// process isolation: only the parent process (e.g. Acepe) that spawned + /// `forge machine stdio` can read/write the stdin/stdout pipes. No + /// network listener is opened, so no additional authentication is + /// required. See `AcpApp::start_stdio` for the full trust model. pub(super) async fn handle_authenticate( &self, _arguments: acp::AuthenticateRequest, ) -> std::result::Result { + tracing::debug!("ACP authenticate: no-op (stdio transport uses OS process isolation)"); Ok(acp::AuthenticateResponse::default()) } @@ -43,7 +52,7 @@ impl AcpAdapter { if !arguments.mcp_servers.is_empty() { StateBuilders::load_mcp_servers(self.services.as_ref(), &arguments.mcp_servers) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; } let active_agent_id = self @@ -69,6 +78,7 @@ impl AcpAdapter { SessionState { conversation_id, agent_id: active_agent_id.clone(), + model_id: None, cancel_notify: None, }, ) @@ -92,10 +102,10 @@ impl AcpAdapter { &active_agent_id, ) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; let model_state = StateBuilders::build_session_model_state(&self.services, &agent) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; Ok(acp::NewSessionResponse::new(session_id) .modes(mode_state) @@ -109,7 +119,7 @@ impl AcpAdapter { if !arguments.mcp_servers.is_empty() { StateBuilders::load_mcp_servers(self.services.as_ref(), &arguments.mcp_servers) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; } let session_key = arguments.session_id.0.as_ref().to_string(); @@ -150,10 +160,10 @@ impl AcpAdapter { &state.agent_id, ) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; let model_state = StateBuilders::build_session_model_state(&self.services, &agent) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; Ok(acp::LoadSessionResponse::new() .modes(mode_state) @@ -182,7 +192,7 @@ impl AcpAdapter { self.update_session_agent(&session_key, agent_id.clone()) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; let notification = acp::SessionNotification::new( arguments.session_id, @@ -191,21 +201,36 @@ impl AcpAdapter { )), ); self.send_notification(notification) - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; Ok(acp::SetSessionModeResponse::new()) } + /// Handles session model changes. + /// + /// The model preference is stored per-session so that concurrent ACP + /// clients do not interfere with each other. The global default model + /// is also updated for backward compatibility with non-ACP code paths. pub(super) async fn handle_set_session_model( &self, arguments: acp::SetSessionModelRequest, ) -> std::result::Result { + let session_key = arguments.session_id.0.as_ref().to_string(); let model_id = ModelId::new(arguments.model_id.0.to_string()); + + // Store per-session model preference. + self.update_session_model(&session_key, model_id.clone()) + .await + .map_err(error::into_acp_error)?; + + // Also update the global default for backward compatibility. self.services .set_default_model(model_id.clone()) .await .map_err(|error| acp::Error::into_internal_error(&*error))?; - let _ = self.services.reload_agents().await; + if let Err(error) = self.services.reload_agents().await { + tracing::warn!("Failed to reload agents after model change: {}", error); + } let notification = acp::SessionNotification::new( arguments.session_id, @@ -222,4 +247,4 @@ impl AcpAdapter { Ok(acp::SetSessionModelResponse::default()) } -} \ No newline at end of file +} diff --git a/crates/forge_app/src/acp/state_builders.rs b/crates/forge_app/src/acp/state_builders.rs index 1301ae9323..e1d2770b67 100644 --- a/crates/forge_app/src/acp/state_builders.rs +++ b/crates/forge_app/src/acp/state_builders.rs @@ -12,6 +12,9 @@ use crate::{ use super::conversion; use super::error::{Error, Result}; +/// Maximum allowed length for an MCP server name (prevents injection). +const MAX_SERVER_NAME_LEN: usize = 128; + pub(super) struct StateBuilders; impl StateBuilders { @@ -109,6 +112,15 @@ impl StateBuilders { ) } + /// Loads MCP server configurations provided by the ACP client. + /// + /// # Trust model + /// + /// The stdio transport inherits OS-level process isolation, so the + /// client is the parent process (Acepe). Server names are validated + /// to prevent injection. The configs are written to the local scope + /// only and do not persist across Forge restarts unless the caller + /// explicitly saves them. pub(super) async fn load_mcp_servers( services: &S, mcp_servers: &[acp::McpServer], @@ -119,10 +131,23 @@ impl StateBuilders { .await .map_err(Error::Application)?; - for server in mcp_servers { - let (name, server_config) = Self::acp_to_mcp_server_config(server)?; - config.mcp_servers.insert(name, server_config); - } + let server_names: Vec = mcp_servers + .iter() + .filter_map(|s| { + match Self::acp_to_mcp_server_config(s) { + Ok((name, server_config)) => { + config.mcp_servers.insert(name.clone(), server_config); + Some(name.to_string()) + } + Err(error) => { + tracing::warn!("Skipping invalid MCP server config: {}", error); + None + } + } + }) + .collect(); + + tracing::info!("Loading {} MCP servers from ACP client: {:?}", server_names.len(), server_names); services .mcp_config_manager() @@ -136,6 +161,7 @@ impl StateBuilders { fn acp_to_mcp_server_config(server: &acp::McpServer) -> Result<(ServerName, McpServerConfig)> { match server { acp::McpServer::Stdio(stdio) => { + Self::validate_server_name(&stdio.name)?; let env = stdio .env .iter() @@ -146,35 +172,64 @@ impl StateBuilders { McpServerConfig::new_stdio(stdio.command.to_string_lossy().to_string(), stdio.args.clone(), Some(env)), )) } - acp::McpServer::Http(http) => Ok(( - ServerName::from(http.name.clone()), - McpServerConfig::Http(McpHttpServer { - url: http.url.clone(), - headers: http - .headers - .iter() - .map(|header| (header.name.clone(), header.value.clone())) - .collect(), - timeout: None, - disable: false, - }), - )), - acp::McpServer::Sse(sse) => Ok(( - ServerName::from(sse.name.clone()), - McpServerConfig::Http(McpHttpServer { - url: sse.url.clone(), - headers: sse - .headers - .iter() - .map(|header| (header.name.clone(), header.value.clone())) - .collect(), - timeout: None, - disable: false, - }), - )), + acp::McpServer::Http(http) => { + Self::validate_server_name(&http.name)?; + Ok(( + ServerName::from(http.name.clone()), + McpServerConfig::Http(McpHttpServer { + url: http.url.clone(), + headers: http + .headers + .iter() + .map(|header| (header.name.clone(), header.value.clone())) + .collect(), + timeout: None, + disable: false, + }), + )) + } + acp::McpServer::Sse(sse) => { + Self::validate_server_name(&sse.name)?; + Ok(( + ServerName::from(sse.name.clone()), + McpServerConfig::Http(McpHttpServer { + url: sse.url.clone(), + headers: sse + .headers + .iter() + .map(|header| (header.name.clone(), header.value.clone())) + .collect(), + timeout: None, + disable: false, + }), + )) + } _ => Err(Error::Application(anyhow::anyhow!( "Unsupported MCP server type" ))), } } -} \ No newline at end of file + + /// Validates that an MCP server name is safe to use as a config key. + fn validate_server_name(name: &str) -> Result<()> { + if name.is_empty() { + return Err(Error::Application(anyhow::anyhow!( + "MCP server name must not be empty" + ))); + } + if name.len() > MAX_SERVER_NAME_LEN { + return Err(Error::Application(anyhow::anyhow!( + "MCP server name exceeds maximum length of {} characters", + MAX_SERVER_NAME_LEN + ))); + } + // Only allow alphanumeric, hyphens, underscores, and dots. + if !name.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.') { + return Err(Error::Application(anyhow::anyhow!( + "MCP server name '{}' contains invalid characters (allowed: alphanumeric, -, _, .)", + name + ))); + } + Ok(()) + } +} diff --git a/crates/forge_app/src/acp_app.rs b/crates/forge_app/src/acp_app.rs index c86c99bb02..697be3d554 100644 --- a/crates/forge_app/src/acp_app.rs +++ b/crates/forge_app/src/acp_app.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::time::Duration; use anyhow::Result; @@ -9,6 +10,12 @@ pub struct AcpApp { services: Arc, } +/// Maximum time to wait for ACP I/O before considering the client hung. +const IO_TIMEOUT: Duration = Duration::from_secs(300); + +/// Maximum time to wait for pending notifications to drain on shutdown. +const SHUTDOWN_DRAIN_TIMEOUT: Duration = Duration::from_secs(5); + impl AcpApp { /// Creates a new ACP application orchestrator. pub fn new(services: Arc) -> Self { @@ -16,6 +23,13 @@ impl AcpApp { } /// Starts the ACP server over stdio transport. + /// + /// # Trust model + /// + /// The stdio transport inherits OS-level process isolation: only the + /// parent process (e.g. Acepe) that spawned `forge machine stdio` can + /// read/write the stdin/stdout pipes. No network listener is opened. + /// Authentication is therefore a no-op by design. pub async fn start_stdio(&self) -> Result<()> { use agent_client_protocol as acp; use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; @@ -25,11 +39,11 @@ impl AcpApp { let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build() - .expect("Failed to create Tokio runtime"); + .map_err(|e| anyhow::anyhow!("Failed to create Tokio runtime: {}", e))?; rt.block_on(async move { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - let adapter = Arc::new(crate::acp::AcpAdapter::new(services, tx)); + let (adapter, mut rx) = crate::acp::AcpAdapter::new(services); + let adapter = Arc::new(adapter); let local_set = tokio::task::LocalSet::new(); local_set @@ -51,7 +65,6 @@ impl AcpApp { let conn_for_notifications = conn.clone(); let notification_task = tokio::task::spawn_local(async move { - let mut rx = rx; while let Some(session_notification) = rx.recv().await { use agent_client_protocol::Client; @@ -68,8 +81,24 @@ impl AcpApp { } }); - let io_result = handle_io.await; - notification_task.abort(); + // Wait for I/O with a timeout to prevent indefinite hangs + // when the client stalls. + let io_result = match tokio::time::timeout(IO_TIMEOUT, handle_io).await { + Ok(result) => result, + Err(_) => { + tracing::warn!("ACP I/O timed out after {:?}", IO_TIMEOUT); + notification_task.abort(); + return Err(anyhow::anyhow!( + "ACP transport timed out after {:?}", + IO_TIMEOUT + )); + } + }; + + // Graceful shutdown: give the notification task time to + // drain pending messages instead of aborting immediately. + drop(adapter); // drops the sender half → rx.recv() returns None + let _ = tokio::time::timeout(SHUTDOWN_DRAIN_TIMEOUT, notification_task).await; io_result.map_err(|error| anyhow::anyhow!("ACP transport error: {}", error)) }) @@ -86,4 +115,4 @@ impl AcpApp { Err(error) => Err(anyhow::anyhow!("ACP server task panicked: {}", error)), } } -} \ No newline at end of file +} diff --git a/crates/forge_main/src/acp_runner.rs b/crates/forge_main/src/acp_runner.rs index e04ea42e31..32671c5411 100644 --- a/crates/forge_main/src/acp_runner.rs +++ b/crates/forge_main/src/acp_runner.rs @@ -3,6 +3,7 @@ use std::future::Future; use anyhow::Result; use forge_api::API; +/// Abstraction over the ACP stdio transport entry point for testability. pub trait MachineStdioApi { fn acp_start_stdio(&self) -> impl Future> + Send; } @@ -13,6 +14,7 @@ impl MachineStdioApi for T { } } +/// Starts the ACP machine stdio server by delegating to the provided API. pub async fn run_machine_stdio_server(api: &A) -> Result<()> { api.acp_start_stdio().await } @@ -54,4 +56,4 @@ mod tests { assert_eq!(actual, expected); Ok(()) } -} \ No newline at end of file +} diff --git a/crates/forge_main/src/cli.rs b/crates/forge_main/src/cli.rs index 6d95e12874..3972c151c2 100644 --- a/crates/forge_main/src/cli.rs +++ b/crates/forge_main/src/cli.rs @@ -213,6 +213,7 @@ pub struct MachineCommandGroup { pub command: MachineCommand, } +/// Machine-oriented subcommands for non-interactive transport protocols. #[derive(Subcommand, Debug, Clone)] pub enum MachineCommand { /// Run the machine interface over stdio. From 4b014b50fd6e673bb4d8ce7f9852d7ceff7de2b4 Mon Sep 17 00:00:00 2001 From: flazouh Date: Mon, 6 Apr 2026 21:46:32 +0300 Subject: [PATCH 3/4] fix(test): move mock assertion inside async block The store ran eagerly on function call, not when the future was awaited. Move it into the async block so the test actually verifies that the caller awaits the returned future. Co-Authored-By: ForgeCode --- crates/forge_main/src/acp_runner.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/crates/forge_main/src/acp_runner.rs b/crates/forge_main/src/acp_runner.rs index 32671c5411..b00cbeab3e 100644 --- a/crates/forge_main/src/acp_runner.rs +++ b/crates/forge_main/src/acp_runner.rs @@ -39,8 +39,11 @@ mod tests { impl MachineStdioApi for MockApi { fn acp_start_stdio(&self) -> impl std::future::Future> + Send { - self.called.store(true, Ordering::SeqCst); - async { Ok(()) } + let called = self.called.clone(); + async move { + called.store(true, Ordering::SeqCst); + Ok(()) + } } } From 80193f9da9eeb7a2dcd4feae22189cf01571e86e Mon Sep 17 00:00:00 2001 From: "Alex (Sasha) Zelts de Pape" Date: Mon, 4 May 2026 03:46:48 +0300 Subject: [PATCH 4/4] test(acp): expand machine stdio coverage Co-Authored-By: ForgeCode --- crates/forge_app/src/acp/adapter.rs | 166 ++- crates/forge_app/src/acp/error.rs | 40 + crates/forge_app/src/acp/prompt_handler.rs | 127 +++ crates/forge_app/src/acp/session_handlers.rs | 1048 +++++++++++++++++- crates/forge_app/src/acp/state_builders.rs | 67 ++ 5 files changed, 1404 insertions(+), 44 deletions(-) diff --git a/crates/forge_app/src/acp/adapter.rs b/crates/forge_app/src/acp/adapter.rs index 9ac30b45bf..cd8a6e5582 100644 --- a/crates/forge_app/src/acp/adapter.rs +++ b/crates/forge_app/src/acp/adapter.rs @@ -5,8 +5,6 @@ use agent_client_protocol as acp; use forge_domain::{AgentId, ConversationId, ModelId}; use tokio::sync::{Mutex, Notify, mpsc}; -use crate::Services; - use super::error::{Error, Result}; /// Maximum number of buffered session notifications before backpressure. @@ -29,11 +27,8 @@ pub(crate) struct AcpAdapter { sessions: Arc>>, } -impl AcpAdapter { - /// Creates a new ACP adapter and returns the notification receiver. - pub(crate) fn new( - services: Arc, - ) -> (Self, mpsc::Receiver) { +impl AcpAdapter { + fn with_services(services: Arc) -> (Self, mpsc::Receiver) { let (tx, rx) = mpsc::channel(NOTIFICATION_CHANNEL_CAPACITY); let adapter = Self { services, @@ -44,6 +39,20 @@ impl AcpAdapter { (adapter, rx) } + #[cfg(test)] + pub(super) fn new_for_test(services: S) -> Self { + Self::with_services(Arc::new(services)).0 + } + + #[cfg(test)] + pub(super) fn new_for_test_with_receiver( + services: S, + ) -> (Self, mpsc::Receiver) { + Self::with_services(Arc::new(services)) + } +} + +impl AcpAdapter { pub(crate) async fn set_client_connection(&self, conn: Arc) { *self.client_conn.lock().await = Some(conn); } @@ -147,3 +156,146 @@ impl AcpAdapter { .map_err(|_| Error::Application(anyhow::anyhow!("Failed to send notification"))) } } + +impl AcpAdapter { + /// Creates a new ACP adapter and returns the notification receiver. + pub(crate) fn new( + services: Arc, + ) -> (Self, mpsc::Receiver) { + Self::with_services(services) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::time::Duration; + + use forge_domain::{AgentId, ConversationId, ModelId}; + use tokio::sync::Notify; + + use super::{AcpAdapter, SessionState}; + + #[tokio::test] + async fn ensure_session_keeps_existing_state() { + let adapter = AcpAdapter::new_for_test(()); + let conversation_id = ConversationId::generate(); + let notify = Arc::new(Notify::new()); + + adapter + .store_session( + "session-1".to_string(), + SessionState { + conversation_id: conversation_id.clone(), + agent_id: AgentId::new("original-agent"), + model_id: Some(ModelId::new("model-a")), + cancel_notify: Some(notify.clone()), + }, + ) + .await; + + let actual = adapter + .ensure_session( + "session-1", + ConversationId::generate(), + AgentId::new("replacement-agent"), + ) + .await; + + assert_eq!(actual.conversation_id, conversation_id); + assert_eq!(actual.agent_id, AgentId::new("original-agent")); + assert_eq!(actual.model_id, Some(ModelId::new("model-a"))); + assert!(actual.cancel_notify.is_some()); + } + + #[tokio::test] + async fn ensure_session_creates_new_state_when_missing() { + let adapter = AcpAdapter::new_for_test(()); + let conversation_id = ConversationId::generate(); + + let actual = adapter + .ensure_session( + "new-session", + conversation_id.clone(), + AgentId::new("fresh-agent"), + ) + .await; + + assert_eq!(actual.conversation_id, conversation_id); + assert_eq!(actual.agent_id, AgentId::new("fresh-agent")); + assert_eq!(actual.model_id, None); + assert!(actual.cancel_notify.is_none()); + } + + #[tokio::test] + async fn cancel_session_notifies_waiters() { + let adapter = AcpAdapter::new_for_test(()); + let notify = Arc::new(Notify::new()); + let wait_for_cancel_handle = notify.clone(); + let wait_for_cancel = wait_for_cancel_handle.notified(); + + adapter + .store_session( + "session-2".to_string(), + SessionState { + conversation_id: ConversationId::generate(), + agent_id: AgentId::new("agent"), + model_id: None, + cancel_notify: Some(notify), + }, + ) + .await; + + let cancelled = adapter.cancel_session("session-2").await; + + assert!(cancelled); + let result = tokio::time::timeout(Duration::from_millis(100), wait_for_cancel).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn cancel_session_returns_false_when_session_has_no_waiter() { + let adapter = AcpAdapter::new_for_test(()); + + let cancelled = adapter.cancel_session("missing-session").await; + + assert!(!cancelled); + } + + #[tokio::test] + async fn update_methods_change_existing_session() { + let adapter = AcpAdapter::new_for_test(()); + let notify = Arc::new(Notify::new()); + + adapter + .store_session( + "session-3".to_string(), + SessionState { + conversation_id: ConversationId::generate(), + agent_id: AgentId::new("old-agent"), + model_id: None, + cancel_notify: None, + }, + ) + .await; + + adapter + .update_session_agent("session-3", AgentId::new("new-agent")) + .await + .unwrap(); + adapter + .update_session_model("session-3", ModelId::new("new-model")) + .await + .unwrap(); + adapter + .set_cancel_notify("session-3", Some(notify.clone())) + .await + .unwrap(); + + let actual = adapter.session_state("session-3").await.unwrap(); + + assert_eq!(actual.agent_id, AgentId::new("new-agent")); + assert_eq!(actual.model_id, Some(ModelId::new("new-model"))); + assert!(actual.cancel_notify.is_some()); + } +} diff --git a/crates/forge_app/src/acp/error.rs b/crates/forge_app/src/acp/error.rs index 9a452b4a93..1b5aad0d0c 100644 --- a/crates/forge_app/src/acp/error.rs +++ b/crates/forge_app/src/acp/error.rs @@ -27,3 +27,43 @@ pub fn into_acp_error(error: Error) -> acp::Error { Error::Io(error) => acp::Error::into_internal_error(&error), } } + +#[cfg(test)] +mod tests { + use std::io; + + use agent_client_protocol as acp; + + use super::{Error, into_acp_error}; + + #[test] + fn preserves_protocol_errors() { + let error = acp::Error::invalid_params(); + + let actual = into_acp_error(Error::Protocol(error.clone())); + + assert_eq!(actual.code, error.code); + assert_eq!(actual.message, error.message); + } + + #[test] + fn wraps_application_errors_as_internal_errors() { + let actual = into_acp_error(Error::Application(anyhow::anyhow!("boom"))); + + assert_eq!(actual.code, acp::ErrorCode::InternalError); + assert_eq!(actual.message, "Internal error"); + assert_eq!(actual.data, Some(serde_json::Value::String("boom".to_string()))); + } + + #[test] + fn wraps_io_errors_as_internal_errors() { + let actual = into_acp_error(Error::Io(io::Error::other("disk failure"))); + + assert_eq!(actual.code, acp::ErrorCode::InternalError); + assert_eq!(actual.message, "Internal error"); + assert_eq!( + actual.data, + Some(serde_json::Value::String("disk failure".to_string())) + ); + } +} diff --git a/crates/forge_app/src/acp/prompt_handler.rs b/crates/forge_app/src/acp/prompt_handler.rs index 1f37d6964a..3cfb78fd0a 100644 --- a/crates/forge_app/src/acp/prompt_handler.rs +++ b/crates/forge_app/src/acp/prompt_handler.rs @@ -137,6 +137,9 @@ impl> AcpAdapter { } } +} + +impl AcpAdapter { async fn handle_chat_response( &self, session_id: &acp::SessionId, @@ -302,3 +305,127 @@ fn format_interruption(reason: &InterruptionReason) -> (String, String) { ), } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use agent_client_protocol as acp; + use forge_domain::{ChatResponse, ChatResponseContent, InterruptionReason, ToolName}; + use tokio::sync::mpsc::error::TryRecvError; + + use super::{AcpAdapter, format_interruption}; + + #[test] + fn formats_tool_failure_interruptions() { + let mut errors = HashMap::new(); + errors.insert(ToolName::new("read"), 2); + errors.insert(ToolName::new("write"), 1); + + let actual = format_interruption( + &InterruptionReason::MaxToolFailurePerTurnLimitReached { limit: 3, errors }, + ); + + assert_eq!(actual.0, "Tool failure limit reached (3)"); + assert!(actual.1.contains("read (2)")); + assert!(actual.1.contains("write (1)")); + } + + #[test] + fn formats_request_limit_interruptions() { + let actual = format_interruption( + &InterruptionReason::MaxRequestPerTurnLimitReached { limit: 5 }, + ); + + assert_eq!(actual.0, "Request limit reached (5)"); + assert_eq!( + actual.1, + "Forge reached the maximum number of requests for this turn." + ); + } + + #[tokio::test] + async fn task_message_sends_agent_message_notification() { + let (adapter, mut rx) = AcpAdapter::new_for_test_with_receiver(()); + let session_id = acp::SessionId::new("session-1"); + + adapter + .handle_task_message( + &session_id, + ChatResponseContent::Markdown { + text: "Hello from Forge".to_string(), + partial: false, + }, + ) + .await + .unwrap(); + + let actual = rx.try_recv().unwrap(); + assert_eq!(actual.session_id, session_id); + assert!(matches!( + actual.update, + acp::SessionUpdate::AgentMessageChunk(_) + )); + } + + #[tokio::test] + async fn empty_task_message_is_ignored() { + let (adapter, mut rx) = AcpAdapter::new_for_test_with_receiver(()); + + adapter + .handle_task_message( + &acp::SessionId::new("session-2"), + ChatResponseContent::Markdown { + text: String::new(), + partial: false, + }, + ) + .await + .unwrap(); + + assert_eq!(rx.try_recv(), Err(TryRecvError::Empty)); + } + + #[tokio::test] + async fn task_reasoning_sends_thought_notification() { + let (adapter, mut rx) = AcpAdapter::new_for_test_with_receiver(()); + let session_id = acp::SessionId::new("session-3"); + + adapter + .handle_chat_response( + &session_id, + ChatResponse::TaskReasoning { + content: "Thinking".to_string(), + }, + &mut false, + ) + .await + .unwrap(); + + let actual = rx.try_recv().unwrap(); + assert_eq!(actual.session_id, session_id); + assert!(matches!( + actual.update, + acp::SessionUpdate::AgentThoughtChunk(_) + )); + } + + #[tokio::test] + async fn interrupt_without_client_does_not_enable_continue() { + let adapter = AcpAdapter::new_for_test(()); + let mut continue_after_interrupt = false; + + adapter + .handle_chat_response( + &acp::SessionId::new("session-4"), + ChatResponse::Interrupt { + reason: InterruptionReason::MaxRequestPerTurnLimitReached { limit: 2 }, + }, + &mut continue_after_interrupt, + ) + .await + .unwrap(); + + assert!(!continue_after_interrupt); + } +} diff --git a/crates/forge_app/src/acp/session_handlers.rs b/crates/forge_app/src/acp/session_handlers.rs index 50e2696a3e..6fc2e0bfaa 100644 --- a/crates/forge_app/src/acp/session_handlers.rs +++ b/crates/forge_app/src/acp/session_handlers.rs @@ -10,7 +10,7 @@ use super::state_builders::StateBuilders; const VERSION: &str = env!("CARGO_PKG_VERSION"); -impl> AcpAdapter { +impl AcpAdapter { pub(super) async fn handle_initialize( &self, arguments: acp::InitializeRequest, @@ -45,7 +45,9 @@ impl> AcpAdapter { tracing::debug!("ACP authenticate: no-op (stdio transport uses OS process isolation)"); Ok(acp::AuthenticateResponse::default()) } +} +impl> AcpAdapter { pub(super) async fn handle_new_session( &self, arguments: acp::NewSessionRequest, @@ -171,42 +173,6 @@ impl> AcpAdapter { .models(model_state)) } - pub(super) async fn handle_cancel( - &self, - arguments: acp::CancelNotification, - ) -> std::result::Result<(), acp::Error> { - let session_key = arguments.session_id.0.as_ref().to_string(); - let cancelled = self.cancel_session(&session_key).await; - if !cancelled { - tracing::warn!("No active ACP prompt to cancel for session {}", session_key); - } - Ok(()) - } - - pub(super) async fn handle_set_session_mode( - &self, - arguments: acp::SetSessionModeRequest, - ) -> std::result::Result { - let session_key = arguments.session_id.0.as_ref().to_string(); - let mode_id = arguments.mode_id.0.as_ref(); - let agent_id = AgentId::new(mode_id); - - self.update_session_agent(&session_key, agent_id.clone()) - .await - .map_err(error::into_acp_error)?; - - let notification = acp::SessionNotification::new( - arguments.session_id, - acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate::new( - acp::SessionModeId::new(mode_id.to_string()), - )), - ); - self.send_notification(notification) - .map_err(error::into_acp_error)?; - - Ok(acp::SetSessionModeResponse::new()) - } - /// Handles session model changes. /// /// The model preference is stored per-session so that concurrent ACP @@ -268,3 +234,1011 @@ impl> AcpAdapter { Ok(acp::SetSessionModelResponse::default()) } } + +impl AcpAdapter { + pub(super) async fn handle_cancel( + &self, + arguments: acp::CancelNotification, + ) -> std::result::Result<(), acp::Error> { + let session_key = arguments.session_id.0.as_ref().to_string(); + let cancelled = self.cancel_session(&session_key).await; + if !cancelled { + tracing::warn!("No active ACP prompt to cancel for session {}", session_key); + } + Ok(()) + } + + pub(super) async fn handle_set_session_mode( + &self, + arguments: acp::SetSessionModeRequest, + ) -> std::result::Result { + let session_key = arguments.session_id.0.as_ref().to_string(); + let mode_id = arguments.mode_id.0.as_ref(); + let agent_id = AgentId::new(mode_id); + + self.update_session_agent(&session_key, agent_id.clone()) + .await + .map_err(error::into_acp_error)?; + + let notification = acp::SessionNotification::new( + arguments.session_id, + acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate::new( + acp::SessionModeId::new(mode_id.to_string()), + )), + ); + self.send_notification(notification) + .map_err(error::into_acp_error)?; + + Ok(acp::SetSessionModeResponse::new()) + } +} + +#[cfg(test)] +mod tests { + use std::collections::{BTreeMap, HashMap}; + use std::path::{Path, PathBuf}; + use std::sync::{Arc, Mutex}; + + use agent_client_protocol as acp; + use forge_config::ForgeConfig; + use forge_domain::{ + Agent, AgentId, AnyProvider, Attachment, AuthContextRequest, AuthContextResponse, + AuthMethod, ChatCompletionMessage, CommandOutput, ConfigOperation, Context, Conversation, + ConversationId, File, FileStatus, Image, McpConfig, McpServers, Model, ModelConfig, + ModelId, Node, Provider, ProviderId, ProviderResponse, ProviderType, Scope, SearchParams, + Skill, SyncProgress, Template, ToolCallFull, ToolOutput, URLParamSpec, WorkspaceAuth, + WorkspaceId, WorkspaceInfo, + }; + use reqwest::Url; + + use super::{AcpAdapter, SessionState}; + use crate::infra::EnvironmentInfra; + use crate::services::{ + AgentRegistry, AppConfigService, AttachmentService, AuthService, CommandLoaderService, + ConversationService, CustomInstructionsService, FileDiscoveryService, FollowUpService, + FsPatchService, FsReadService, FsRemoveService, FsSearchService, FsUndoService, + FsWriteService, HttpResponse, ImageReadService, McpConfigManager, McpService, + NetFetchService, PatchOutput, PlanCreateOutput, PlanCreateService, PolicyDecision, + PolicyService, ProviderAuthService, ProviderService, ReadOutput, ResponseContext, + SearchResult, Services, ShellOutput, ShellService, SkillFetchService, TemplateService, + WorkspaceService, + }; + use crate::user::{AuthProviderId, Plan, UsageInfo, User, UserUsage}; + use crate::Walker; + + #[derive(Clone)] + struct SharedState(Arc>); + + #[derive(Clone)] + struct MockServices { + provider_service: MockProviderService, + config_service: MockConfigService, + conversation_service: MockConversationService, + mcp_config_manager: MockMcpConfigService, + agent_registry: MockAgentRegistryService, + noop_service: NoopService, + environment: forge_domain::Environment, + config: ForgeConfig, + } + + struct MockState { + active_agent_id: Option, + agents: Vec, + conversations: HashMap, + provider: Provider, + models: Vec, + mcp_config: McpConfig, + config_updates: Vec>, + } + + #[derive(Clone)] + struct MockProviderService { + state: SharedState, + } + + #[derive(Clone)] + struct MockConfigService { + state: SharedState, + } + + #[derive(Clone)] + struct MockConversationService { + state: SharedState, + } + + #[derive(Clone)] + struct MockMcpConfigService { + state: SharedState, + } + + #[derive(Clone)] + struct MockAgentRegistryService { + state: SharedState, + } + + #[derive(Clone, Default)] + struct NoopService; + + impl MockServices { + fn new() -> Self { + let agent = Agent::new( + AgentId::new("forge"), + ProviderId::OPENAI, + ModelId::new("test-model"), + ) + .title("Forge") + .description("Test agent"); + let provider = Provider { + id: ProviderId::OPENAI, + provider_type: ProviderType::Llm, + response: Some(ProviderResponse::OpenAI), + url: Url::parse("https://api.example.com/chat").unwrap(), + models: None, + auth_methods: vec![AuthMethod::ApiKey], + url_params: Vec::::new(), + credential: None, + custom_headers: None, + }; + let model = Model { + id: ModelId::new("test-model"), + name: Some("Test Model".to_string()), + description: Some("Model used by ACP tests".to_string()), + context_length: Some(8192), + tools_supported: Some(true), + supports_parallel_tool_calls: Some(true), + supports_reasoning: Some(false), + input_modalities: vec![forge_domain::InputModality::Text], + }; + let state = SharedState(Arc::new(Mutex::new(MockState { + active_agent_id: Some(agent.id.clone()), + agents: vec![agent], + conversations: HashMap::new(), + provider: provider.clone(), + models: vec![model], + mcp_config: McpConfig::default(), + config_updates: Vec::new(), + }))); + + Self { + provider_service: MockProviderService { + state: state.clone(), + }, + config_service: MockConfigService { + state: state.clone(), + }, + conversation_service: MockConversationService { + state: state.clone(), + }, + mcp_config_manager: MockMcpConfigService { + state: state.clone(), + }, + agent_registry: MockAgentRegistryService { state }, + noop_service: NoopService, + environment: forge_domain::Environment { + os: "macos".to_string(), + cwd: PathBuf::from("/tmp/project"), + home: Some(PathBuf::from("/tmp/home")), + shell: "/bin/zsh".to_string(), + base_path: PathBuf::from("/tmp/forge"), + }, + config: ForgeConfig::default(), + } + } + + fn insert_conversation(&self, conversation: Conversation) { + self.conversation_service + .state + .0 + .lock() + .unwrap() + .conversations + .insert(conversation.id, conversation); + } + + fn config_updates(&self) -> Vec> { + self.config_service + .state + .0 + .lock() + .unwrap() + .config_updates + .clone() + } + } + + impl EnvironmentInfra for MockServices { + type Config = ForgeConfig; + + fn get_env_var(&self, _key: &str) -> Option { + None + } + + fn get_env_vars(&self) -> BTreeMap { + BTreeMap::new() + } + + fn get_environment(&self) -> forge_domain::Environment { + self.environment.clone() + } + + fn get_config(&self) -> anyhow::Result { + Ok(self.config.clone()) + } + + fn update_environment( + &self, + _ops: Vec, + ) -> impl std::future::Future> + Send { + async { Ok(()) } + } + } + + #[async_trait::async_trait] + impl ProviderService for MockProviderService { + async fn chat( + &self, + _model_id: &ModelId, + _context: Context, + _provider: Provider, + ) -> forge_domain::ResultStream { + todo!("unused in session handler tests") + } + + async fn models(&self, _provider: Provider) -> anyhow::Result> { + Ok(self.state.0.lock().unwrap().models.clone()) + } + + async fn get_provider(&self, _id: ProviderId) -> anyhow::Result> { + Ok(self.state.0.lock().unwrap().provider.clone()) + } + + async fn get_all_providers(&self) -> anyhow::Result> { + Ok(vec![AnyProvider::Url(self.state.0.lock().unwrap().provider.clone())]) + } + + async fn upsert_credential( + &self, + _credential: forge_domain::AuthCredential, + ) -> anyhow::Result<()> { + todo!("unused in session handler tests") + } + + async fn remove_credential(&self, _id: &ProviderId) -> anyhow::Result<()> { + todo!("unused in session handler tests") + } + + async fn migrate_env_credentials( + &self, + ) -> anyhow::Result> { + Ok(None) + } + } + + #[async_trait::async_trait] + impl AppConfigService for MockConfigService { + async fn get_session_config(&self) -> Option { + Some(ModelConfig::new(ProviderId::OPENAI, ModelId::new("test-model"))) + } + + async fn get_commit_config(&self) -> anyhow::Result> { + Ok(None) + } + + async fn get_suggest_config(&self) -> anyhow::Result> { + Ok(None) + } + + async fn get_reasoning_effort(&self) -> anyhow::Result> { + Ok(None) + } + + async fn update_config(&self, ops: Vec) -> anyhow::Result<()> { + self.state.0.lock().unwrap().config_updates.push(ops); + Ok(()) + } + } + + #[async_trait::async_trait] + impl ConversationService for MockConversationService { + async fn find_conversation(&self, id: &ConversationId) -> anyhow::Result> { + Ok(self.state.0.lock().unwrap().conversations.get(id).cloned()) + } + + async fn upsert_conversation(&self, conversation: Conversation) -> anyhow::Result<()> { + self.state + .0 + .lock() + .unwrap() + .conversations + .insert(conversation.id, conversation); + Ok(()) + } + + async fn modify_conversation(&self, id: &ConversationId, f: F) -> anyhow::Result + where + F: FnOnce(&mut Conversation) -> T + Send, + T: Send, + { + let mut guard = self.state.0.lock().unwrap(); + let conversation = guard.conversations.get_mut(id).expect("conversation must exist"); + Ok(f(conversation)) + } + + async fn get_conversations( + &self, + _limit: Option, + ) -> anyhow::Result>> { + Ok(Some( + self.state + .0 + .lock() + .unwrap() + .conversations + .values() + .cloned() + .collect(), + )) + } + + async fn last_conversation(&self) -> anyhow::Result> { + Ok(self + .state + .0 + .lock() + .unwrap() + .conversations + .values() + .last() + .cloned()) + } + + async fn delete_conversation(&self, conversation_id: &ConversationId) -> anyhow::Result<()> { + self.state.0.lock().unwrap().conversations.remove(conversation_id); + Ok(()) + } + } + + #[async_trait::async_trait] + impl AgentRegistry for MockAgentRegistryService { + async fn get_active_agent_id(&self) -> anyhow::Result> { + Ok(self.state.0.lock().unwrap().active_agent_id.clone()) + } + + async fn set_active_agent_id(&self, agent_id: AgentId) -> anyhow::Result<()> { + self.state.0.lock().unwrap().active_agent_id = Some(agent_id); + Ok(()) + } + + async fn get_agents(&self) -> anyhow::Result> { + Ok(self.state.0.lock().unwrap().agents.clone()) + } + + async fn get_agent_infos(&self) -> anyhow::Result> { + Ok(self + .state + .0 + .lock() + .unwrap() + .agents + .iter() + .map(|agent| { + let mut info = forge_domain::AgentInfo::default().id(agent.id.clone()); + if let Some(title) = agent.title.clone() { + info = info.title(title); + } + if let Some(description) = agent.description.clone() { + info = info.description(description); + } + info + }) + .collect()) + } + + async fn get_agent(&self, agent_id: &AgentId) -> anyhow::Result> { + Ok(self + .state + .0 + .lock() + .unwrap() + .agents + .iter() + .find(|agent| &agent.id == agent_id) + .cloned()) + } + + async fn reload_agents(&self) -> anyhow::Result<()> { + Ok(()) + } + } + + #[async_trait::async_trait] + impl McpConfigManager for MockMcpConfigService { + async fn read_mcp_config(&self, _scope: Option<&Scope>) -> anyhow::Result { + Ok(self.state.0.lock().unwrap().mcp_config.clone()) + } + + async fn write_mcp_config(&self, config: &McpConfig, _scope: &Scope) -> anyhow::Result<()> { + self.state.0.lock().unwrap().mcp_config = config.clone(); + Ok(()) + } + } + + #[async_trait::async_trait] + impl McpService for NoopService { + async fn get_mcp_servers(&self) -> anyhow::Result { + Ok(McpServers::default()) + } + + async fn execute_mcp(&self, _call: ToolCallFull) -> anyhow::Result { + todo!("unused in session handler tests") + } + + async fn reload_mcp(&self) -> anyhow::Result<()> { + Ok(()) + } + } + + #[async_trait::async_trait] + impl ProviderAuthService for NoopService { + async fn init_provider_auth( + &self, + _provider_id: ProviderId, + _method: AuthMethod, + ) -> anyhow::Result { + todo!("unused in session handler tests") + } + + async fn complete_provider_auth( + &self, + _provider_id: ProviderId, + _context: AuthContextResponse, + _timeout: std::time::Duration, + ) -> anyhow::Result<()> { + todo!("unused in session handler tests") + } + + async fn refresh_provider_credential( + &self, + provider: Provider, + ) -> anyhow::Result> { + Ok(provider) + } + } + + #[async_trait::async_trait] + impl TemplateService for NoopService { + async fn register_template(&self, _path: PathBuf) -> anyhow::Result<()> { + todo!("unused in session handler tests") + } + + async fn render_template( + &self, + _template: Template, + _object: &V, + ) -> anyhow::Result { + todo!("unused in session handler tests") + } + } + + #[async_trait::async_trait] + impl AttachmentService for NoopService { + async fn attachments(&self, _url: &str) -> anyhow::Result> { + Ok(vec![]) + } + } + + #[async_trait::async_trait] + impl CustomInstructionsService for NoopService { + async fn get_custom_instructions(&self) -> Vec { + vec![] + } + } + + #[async_trait::async_trait] + impl FileDiscoveryService for NoopService { + async fn collect_files(&self, _config: Walker) -> anyhow::Result> { + Ok(vec![]) + } + + async fn list_current_directory(&self) -> anyhow::Result> { + Ok(vec![]) + } + } + + #[async_trait::async_trait] + impl FsWriteService for NoopService { + async fn write( + &self, + _path: String, + _content: String, + _overwrite: bool, + ) -> anyhow::Result { + todo!("unused in session handler tests") + } + } + + #[async_trait::async_trait] + impl PlanCreateService for NoopService { + async fn create_plan( + &self, + _plan_name: String, + _version: String, + _content: String, + ) -> anyhow::Result { + todo!("unused in session handler tests") + } + } + + #[async_trait::async_trait] + impl FsPatchService for NoopService { + async fn patch( + &self, + _path: String, + _search: String, + _content: String, + _replace_all: bool, + ) -> anyhow::Result { + todo!("unused in session handler tests") + } + + async fn multi_patch( + &self, + _path: String, + _edits: Vec, + ) -> anyhow::Result { + todo!("unused in session handler tests") + } + } + + #[async_trait::async_trait] + impl FsReadService for NoopService { + async fn read( + &self, + _path: String, + _start_line: Option, + _end_line: Option, + ) -> anyhow::Result { + todo!("unused in session handler tests") + } + } + + #[async_trait::async_trait] + impl ImageReadService for NoopService { + async fn read_image(&self, _path: String) -> anyhow::Result { + todo!("unused in session handler tests") + } + } + + #[async_trait::async_trait] + impl FsRemoveService for NoopService { + async fn remove(&self, _path: String) -> anyhow::Result { + todo!("unused in session handler tests") + } + } + + #[async_trait::async_trait] + impl FsSearchService for NoopService { + async fn search(&self, _params: forge_domain::FSSearch) -> anyhow::Result> { + Ok(None) + } + } + + #[async_trait::async_trait] + impl FollowUpService for NoopService { + async fn follow_up( + &self, + _question: String, + _options: Vec, + _multiple: Option, + ) -> anyhow::Result> { + Ok(None) + } + } + + #[async_trait::async_trait] + impl FsUndoService for NoopService { + async fn undo(&self, _path: String) -> anyhow::Result { + todo!("unused in session handler tests") + } + } + + #[async_trait::async_trait] + impl NetFetchService for NoopService { + async fn fetch(&self, _url: String, _raw: Option) -> anyhow::Result { + Ok(HttpResponse { + content: String::new(), + code: 200, + context: ResponseContext::Raw, + content_type: "text/plain".to_string(), + }) + } + } + + #[async_trait::async_trait] + impl ShellService for NoopService { + async fn execute( + &self, + _command: String, + _cwd: PathBuf, + _keep_ansi: bool, + _silent: bool, + _env_vars: Option>, + _description: Option, + ) -> anyhow::Result { + Ok(ShellOutput { + output: CommandOutput { + command: String::new(), + stdout: String::new(), + stderr: String::new(), + exit_code: Some(0), + }, + shell: "/bin/zsh".to_string(), + description: None, + }) + } + } + + #[async_trait::async_trait] + impl AuthService for NoopService { + async fn user_info(&self, _api_key: &str) -> anyhow::Result { + Ok(User { + auth_provider_id: AuthProviderId::new("test"), + }) + } + + async fn user_usage(&self, _api_key: &str) -> anyhow::Result { + Ok(UserUsage { + plan: Plan { r#type: "free".to_string() }, + usage: UsageInfo { + current: 0, + limit: 0, + remaining: 0, + reset_in: None, + }, + }) + } + } + + #[async_trait::async_trait] + impl CommandLoaderService for NoopService { + async fn get_commands(&self) -> anyhow::Result> { + Ok(vec![]) + } + } + + #[async_trait::async_trait] + impl PolicyService for NoopService { + async fn check_operation_permission( + &self, + _operation: &forge_domain::PermissionOperation, + ) -> anyhow::Result { + Ok(PolicyDecision { + allowed: true, + path: None, + }) + } + } + + #[async_trait::async_trait] + impl WorkspaceService for NoopService { + async fn sync_workspace( + &self, + _path: PathBuf, + ) -> anyhow::Result>> { + todo!("unused in session handler tests") + } + + async fn query_workspace( + &self, + _path: PathBuf, + _params: SearchParams<'_>, + ) -> anyhow::Result> { + todo!("unused in session handler tests") + } + + async fn list_workspaces(&self) -> anyhow::Result> { + Ok(vec![]) + } + + async fn get_workspace_info(&self, _path: PathBuf) -> anyhow::Result> { + Ok(None) + } + + async fn delete_workspace(&self, _workspace_id: &WorkspaceId) -> anyhow::Result<()> { + Ok(()) + } + + async fn delete_workspaces(&self, _workspace_ids: &[WorkspaceId]) -> anyhow::Result<()> { + Ok(()) + } + + async fn is_indexed(&self, _path: &Path) -> anyhow::Result { + Ok(false) + } + + async fn get_workspace_status(&self, _path: PathBuf) -> anyhow::Result> { + Ok(vec![]) + } + + async fn is_authenticated(&self) -> anyhow::Result { + Ok(false) + } + + async fn init_auth_credentials(&self) -> anyhow::Result { + todo!("unused in session handler tests") + } + + async fn init_workspace(&self, _path: PathBuf) -> anyhow::Result { + todo!("unused in session handler tests") + } + } + + #[async_trait::async_trait] + impl SkillFetchService for NoopService { + async fn fetch_skill(&self, _skill_name: String) -> anyhow::Result { + todo!("unused in session handler tests") + } + + async fn list_skills(&self) -> anyhow::Result> { + Ok(vec![]) + } + } + + impl Services for MockServices { + type ProviderService = MockProviderService; + type AppConfigService = MockConfigService; + type ConversationService = MockConversationService; + type TemplateService = NoopService; + type AttachmentService = NoopService; + type CustomInstructionsService = NoopService; + type FileDiscoveryService = NoopService; + type McpConfigManager = MockMcpConfigService; + type FsWriteService = NoopService; + type PlanCreateService = NoopService; + type FsPatchService = NoopService; + type FsReadService = NoopService; + type ImageReadService = NoopService; + type FsRemoveService = NoopService; + type FsSearchService = NoopService; + type FollowUpService = NoopService; + type FsUndoService = NoopService; + type NetFetchService = NoopService; + type ShellService = NoopService; + type McpService = NoopService; + type AuthService = NoopService; + type AgentRegistry = MockAgentRegistryService; + type CommandLoaderService = NoopService; + type PolicyService = NoopService; + type ProviderAuthService = NoopService; + type WorkspaceService = NoopService; + type SkillFetchService = NoopService; + + fn provider_service(&self) -> &Self::ProviderService { &self.provider_service } + fn config_service(&self) -> &Self::AppConfigService { &self.config_service } + fn conversation_service(&self) -> &Self::ConversationService { &self.conversation_service } + fn template_service(&self) -> &Self::TemplateService { &self.noop_service } + fn attachment_service(&self) -> &Self::AttachmentService { &self.noop_service } + fn file_discovery_service(&self) -> &Self::FileDiscoveryService { &self.noop_service } + fn mcp_config_manager(&self) -> &Self::McpConfigManager { &self.mcp_config_manager } + fn fs_create_service(&self) -> &Self::FsWriteService { &self.noop_service } + fn plan_create_service(&self) -> &Self::PlanCreateService { &self.noop_service } + fn fs_patch_service(&self) -> &Self::FsPatchService { &self.noop_service } + fn fs_read_service(&self) -> &Self::FsReadService { &self.noop_service } + fn image_read_service(&self) -> &Self::ImageReadService { &self.noop_service } + fn fs_remove_service(&self) -> &Self::FsRemoveService { &self.noop_service } + fn fs_search_service(&self) -> &Self::FsSearchService { &self.noop_service } + fn follow_up_service(&self) -> &Self::FollowUpService { &self.noop_service } + fn fs_undo_service(&self) -> &Self::FsUndoService { &self.noop_service } + fn net_fetch_service(&self) -> &Self::NetFetchService { &self.noop_service } + fn shell_service(&self) -> &Self::ShellService { &self.noop_service } + fn mcp_service(&self) -> &Self::McpService { &self.noop_service } + fn custom_instructions_service(&self) -> &Self::CustomInstructionsService { &self.noop_service } + fn auth_service(&self) -> &Self::AuthService { &self.noop_service } + fn agent_registry(&self) -> &Self::AgentRegistry { &self.agent_registry } + fn command_loader_service(&self) -> &Self::CommandLoaderService { &self.noop_service } + fn policy_service(&self) -> &Self::PolicyService { &self.noop_service } + fn provider_auth_service(&self) -> &Self::ProviderAuthService { &self.noop_service } + fn workspace_service(&self) -> &Self::WorkspaceService { &self.noop_service } + fn skill_fetch_service(&self) -> &Self::SkillFetchService { &self.noop_service } + } + + #[tokio::test] + async fn initialize_exposes_acp_capabilities() { + let adapter = AcpAdapter::new_for_test(()); + + let actual = adapter + .handle_initialize(acp::InitializeRequest::new(acp::ProtocolVersion::V1)) + .await + .unwrap(); + + assert_eq!(actual.protocol_version, acp::ProtocolVersion::V1); + assert!(actual.agent_capabilities.load_session); + assert!(actual.agent_capabilities.mcp_capabilities.http); + assert!(actual.agent_capabilities.mcp_capabilities.sse); + } + + #[tokio::test] + async fn authenticate_is_a_no_op() { + let adapter = AcpAdapter::new_for_test(()); + + let actual = adapter + .handle_authenticate(acp::AuthenticateRequest::new(acp::AuthMethodId::new("stdio"))) + .await + .unwrap(); + + assert_eq!(actual, acp::AuthenticateResponse::default()); + } + + #[tokio::test] + async fn cancel_returns_ok_when_session_is_missing() { + let adapter = AcpAdapter::new_for_test(()); + + let actual = adapter + .handle_cancel(acp::CancelNotification::new(acp::SessionId::new("missing"))) + .await; + + assert!(actual.is_ok()); + } + + #[tokio::test] + async fn set_session_mode_updates_state_and_emits_notification() { + let (adapter, mut rx) = AcpAdapter::new_for_test_with_receiver(()); + let session_id = acp::SessionId::new("session-4"); + let conversation_id = ConversationId::generate(); + + adapter + .store_session( + "session-4".to_string(), + SessionState { + conversation_id, + agent_id: AgentId::new("before"), + model_id: None, + cancel_notify: None, + }, + ) + .await; + + let actual = adapter + .handle_set_session_mode(acp::SetSessionModeRequest::new( + session_id.clone(), + acp::SessionModeId::new("after"), + )) + .await; + + assert!(actual.is_ok()); + let state = adapter.session_state("session-4").await.unwrap(); + assert_eq!(state.agent_id, AgentId::new("after")); + + let notification = rx.recv().await; + assert!(notification.is_some()); + let notification = notification.unwrap(); + assert_eq!(notification.session_id, session_id); + } + + #[tokio::test] + async fn new_session_creates_conversation_and_returns_initial_state() { + let services = MockServices::new(); + let adapter = AcpAdapter::new_for_test(services.clone()); + + let actual = adapter + .handle_new_session(acp::NewSessionRequest::new(PathBuf::from("/tmp/project"))) + .await + .unwrap(); + + let conversation_id = ConversationId::parse(actual.session_id.0.as_ref()).unwrap(); + let stored = services.find_conversation(&conversation_id).await.unwrap(); + + assert!(stored.is_some()); + assert_eq!( + actual + .modes + .as_ref() + .map(|modes| modes.current_mode_id.0.as_ref()), + Some("forge") + ); + assert_eq!( + actual + .models + .as_ref() + .map(|models| models.current_model_id.0.as_ref()), + Some("test-model") + ); + let session = adapter.session_state(actual.session_id.0.as_ref()).await.unwrap(); + assert_eq!(session.agent_id, AgentId::new("forge")); + } + + #[tokio::test] + async fn load_session_returns_invalid_params_for_unknown_conversation() { + let services = MockServices::new(); + let adapter = AcpAdapter::new_for_test(services); + + let actual = adapter + .handle_load_session(acp::LoadSessionRequest::new( + acp::SessionId::new(ConversationId::generate().into_string()), + PathBuf::from("/tmp/project"), + )) + .await; + + assert!(actual.is_err()); + assert_eq!(actual.unwrap_err().code, acp::ErrorCode::InvalidParams); + } + + #[tokio::test] + async fn load_session_uses_existing_conversation_and_builds_state() { + let services = MockServices::new(); + let adapter = AcpAdapter::new_for_test(services.clone()); + let conversation = Conversation::generate(); + let conversation_id = conversation.id; + services.insert_conversation(conversation); + + let actual = adapter + .handle_load_session(acp::LoadSessionRequest::new( + acp::SessionId::new(conversation_id.into_string()), + PathBuf::from("/tmp/project"), + )) + .await + .unwrap(); + + assert_eq!( + actual + .modes + .as_ref() + .map(|modes| modes.current_mode_id.0.as_ref()), + Some("forge") + ); + assert_eq!( + actual + .models + .as_ref() + .map(|models| models.current_model_id.0.as_ref()), + Some("test-model") + ); + let session = adapter + .session_state(conversation_id.into_string().as_str()) + .await + .unwrap(); + assert_eq!(session.conversation_id, conversation_id); + } + + #[tokio::test] + async fn set_session_model_updates_session_and_config() { + let (adapter, mut rx) = AcpAdapter::new_for_test_with_receiver(MockServices::new()); + let conversation = Conversation::generate(); + let session_id = acp::SessionId::new(conversation.id.into_string()); + + adapter + .store_session( + session_id.0.as_ref().to_string(), + SessionState { + conversation_id: conversation.id, + agent_id: AgentId::new("forge"), + model_id: None, + cancel_notify: None, + }, + ) + .await; + + let actual = adapter + .handle_set_session_model(acp::SetSessionModelRequest::new( + session_id.clone(), + acp::ModelId::new("gpt-test"), + )) + .await; + + assert!(actual.is_ok()); + let session = adapter.session_state(session_id.0.as_ref()).await.unwrap(); + assert_eq!(session.model_id, Some(ModelId::new("gpt-test"))); + + let updates = adapter.services.config_updates(); + assert_eq!( + updates, + vec![vec![ConfigOperation::SetSessionConfig(ModelConfig::new( + ProviderId::OPENAI, + ModelId::new("gpt-test"), + ))]] + ); + + let notification = rx.recv().await.expect("expected model change notification"); + assert_eq!(notification.session_id, session_id); + } +} diff --git a/crates/forge_app/src/acp/state_builders.rs b/crates/forge_app/src/acp/state_builders.rs index abf32b5d6c..a0fe49cd44 100644 --- a/crates/forge_app/src/acp/state_builders.rs +++ b/crates/forge_app/src/acp/state_builders.rs @@ -237,3 +237,70 @@ impl StateBuilders { Ok(()) } } + +#[cfg(test)] +mod tests { + use agent_client_protocol as acp; + use agent_client_protocol::{EnvVariable, HttpHeader}; + use forge_domain::{McpOAuthSetting, McpServerConfig}; + + use super::StateBuilders; + + #[test] + fn maps_stdio_servers_with_env() { + let server = acp::McpServer::Stdio( + acp::McpServerStdio::new("local-server", "/bin/echo") + .args(vec!["hello".to_string()]) + .env(vec![EnvVariable::new("TOKEN", "secret")]), + ); + + let (name, config) = StateBuilders::acp_to_mcp_server_config(&server).unwrap(); + + assert_eq!(name.to_string(), "local-server"); + match config { + McpServerConfig::Stdio(stdio) => { + assert_eq!(stdio.command, "/bin/echo"); + assert_eq!(stdio.args, vec!["hello".to_string()]); + assert_eq!(stdio.env.get("TOKEN"), Some(&"secret".to_string())); + } + McpServerConfig::Http(_) => panic!("expected stdio config"), + } + } + + #[test] + fn maps_http_servers_with_auto_detect_oauth() { + let server = acp::McpServer::Http( + acp::McpServerHttp::new("remote.server", "https://example.com/mcp").headers(vec![ + HttpHeader::new("Authorization", "Bearer token"), + ]), + ); + + let (name, config) = StateBuilders::acp_to_mcp_server_config(&server).unwrap(); + + assert_eq!(name.to_string(), "remote.server"); + match config { + McpServerConfig::Http(http) => { + assert_eq!(http.url, "https://example.com/mcp"); + assert_eq!( + http.headers.get("Authorization"), + Some(&"Bearer token".to_string()) + ); + assert_eq!(http.oauth, McpOAuthSetting::AutoDetect); + } + McpServerConfig::Stdio(_) => panic!("expected http config"), + } + } + + #[test] + fn rejects_invalid_server_names() { + let server = acp::McpServer::Sse(acp::McpServerSse::new( + "bad server name!", + "https://example.com/sse", + )); + + let error = StateBuilders::acp_to_mcp_server_config(&server).unwrap_err(); + let actual = error.to_string(); + + assert!(actual.contains("invalid characters")); + } +}