diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index b9b98c10d..94a47cd11 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -1,15 +1,17 @@ use chrono::{DateTime, TimeZone, Utc}; +use mcp_client::McpService; use rust_decimal_macros::dec; use std::collections::HashMap; use std::sync::Arc; use std::sync::LazyLock; +use std::time::Duration; use tokio::sync::Mutex; use tracing::{debug, instrument}; use super::system::{SystemConfig, SystemError, SystemInfo, SystemResult}; use crate::prompt_template::load_prompt_file; use crate::providers::base::{Provider, ProviderUsage}; -use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient}; +use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; use mcp_client::transport::{SseTransport, StdioTransport, Transport}; use mcp_core::{Content, Tool, ToolCall, ToolError, ToolResult}; @@ -20,7 +22,7 @@ static DEFAULT_TIMESTAMP: LazyLock> = /// Manages MCP clients and their interactions pub struct Capabilities { - clients: HashMap>>, + clients: HashMap>>>, instructions: HashMap, provider: Box, provider_usage: Mutex>, @@ -87,10 +89,12 @@ impl Capabilities { /// Add a new MCP system based on the provided client type // TODO IMPORTANT need to ensure this times out if the system command is broken! pub async fn add_system(&mut self, config: SystemConfig) -> SystemResult<()> { - let mut client: McpClient = match config { + let mut client: Box = match config { SystemConfig::Sse { ref uri, ref envs } => { let transport = SseTransport::new(uri, envs.get_env()); - McpClient::new(transport.start().await?) + let handle = transport.start().await?; + let service = McpService::with_timeout(handle, Duration::from_secs(10)); + Box::new(McpClient::new(service)) } SystemConfig::Stdio { ref cmd, @@ -98,7 +102,9 @@ impl Capabilities { ref envs, } => { let transport = StdioTransport::new(cmd, args.to_vec(), envs.get_env()); - McpClient::new(transport.start().await?) + let handle = transport.start().await?; + let service = McpService::with_timeout(handle, Duration::from_secs(10)); + Box::new(McpClient::new(service)) } }; @@ -271,7 +277,10 @@ impl Capabilities { } /// Find and return a reference to the appropriate client for a tool call - fn get_client_for_tool(&self, prefixed_name: &str) -> Option>> { + fn get_client_for_tool( + &self, + prefixed_name: &str, + ) -> Option>>> { prefixed_name .split_once("__") .and_then(|(client_name, _)| self.clients.get(client_name)) diff --git a/crates/goose/src/agents/system.rs b/crates/goose/src/agents/system.rs index b933e5714..10b19f247 100644 --- a/crates/goose/src/agents/system.rs +++ b/crates/goose/src/agents/system.rs @@ -9,7 +9,7 @@ use thiserror::Error; pub enum SystemError { #[error("Failed to start the MCP server from configuration `{0}` within 60 seconds")] Initialization(SystemConfig), - #[error("Failed a client call to an MCP server")] + #[error("Failed a client call to an MCP server: {0}")] Client(#[from] ClientError), #[error("Transport error: {0}")] Transport(#[from] mcp_client::transport::Error), diff --git a/crates/mcp-client/examples/clients.rs b/crates/mcp-client/examples/clients.rs index 588599fd8..4913b9521 100644 --- a/crates/mcp-client/examples/clients.rs +++ b/crates/mcp-client/examples/clients.rs @@ -1,6 +1,7 @@ use mcp_client::{ - client::{ClientCapabilities, ClientInfo, McpClient}, + client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}, transport::{SseTransport, StdioTransport, Transport}, + McpService, }; use rand::Rng; use rand::SeedableRng; @@ -17,13 +18,24 @@ async fn main() -> Result<(), Box> { ) .init(); - // Create two separate clients with stdio transport - let client1 = create_stdio_client("client1", "1.0.0").await?; - let client2 = create_stdio_client("client2", "1.0.0").await?; - let client3 = create_sse_client("client3", "1.0.0").await?; + let transport1 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); + let handle1 = transport1.start().await?; + let service1 = McpService::with_timeout(handle1, Duration::from_secs(30)); + let client1 = McpClient::new(service1); + + let transport2 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); + let handle2 = transport2.start().await?; + let service2 = McpService::with_timeout(handle2, Duration::from_secs(30)); + let client2 = McpClient::new(service2); + + let transport3 = SseTransport::new("http://localhost:8000/sse", HashMap::new()); + let handle3 = transport3.start().await?; + let service3 = McpService::with_timeout(handle3, Duration::from_secs(10)); + let client3 = McpClient::new(service3); // Initialize both clients - let mut clients = vec![client1, client2, client3]; + let mut clients: Vec> = + vec![Box::new(client1), Box::new(client2), Box::new(client3)]; // Initialize all clients for (i, client) in clients.iter_mut().enumerate() { @@ -117,19 +129,3 @@ async fn main() -> Result<(), Box> { Ok(()) } - -async fn create_stdio_client( - _name: &str, - _version: &str, -) -> Result> { - let transport = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); - Ok(McpClient::new(transport.start().await?)) -} - -async fn create_sse_client( - _name: &str, - _version: &str, -) -> Result> { - let transport = SseTransport::new("http://localhost:8000/sse", HashMap::new()); - Ok(McpClient::new(transport.start().await?)) -} diff --git a/crates/mcp-client/examples/sse.rs b/crates/mcp-client/examples/sse.rs index 673ad921c..8b93b372c 100644 --- a/crates/mcp-client/examples/sse.rs +++ b/crates/mcp-client/examples/sse.rs @@ -1,6 +1,7 @@ use anyhow::Result; -use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient}; +use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; use mcp_client::transport::{SseTransport, Transport}; +use mcp_client::McpService; use std::collections::HashMap; use std::time::Duration; use tracing_subscriber::EnvFilter; @@ -22,8 +23,11 @@ async fn main() -> Result<()> { // Start transport let handle = transport.start().await?; + // Create the service with timeout middleware + let service = McpService::with_timeout(handle, Duration::from_secs(3)); + // Create client - let mut client = McpClient::new(handle); + let mut client = McpClient::new(service); println!("Client created\n"); // Initialize diff --git a/crates/mcp-client/examples/stdio.rs b/crates/mcp-client/examples/stdio.rs index 56781b6e2..e43f036cc 100644 --- a/crates/mcp-client/examples/stdio.rs +++ b/crates/mcp-client/examples/stdio.rs @@ -1,8 +1,11 @@ use std::collections::HashMap; use anyhow::Result; -use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClient}; -use mcp_client::transport::{StdioTransport, Transport}; +use mcp_client::{ + ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait, McpService, + StdioTransport, Transport, +}; +use std::time::Duration; use tracing_subscriber::EnvFilter; #[tokio::main] @@ -22,8 +25,11 @@ async fn main() -> Result<(), ClientError> { // 2) Start the transport to get a handle let transport_handle = transport.start().await?; - // 3) Create the client - let mut client = McpClient::new(transport_handle); + // 3) Create the service with timeout middleware + let service = McpService::with_timeout(transport_handle, Duration::from_secs(10)); + + // 4) Create the client with the middleware-wrapped service + let mut client = McpClient::new(service); // Initialize let server_info = client diff --git a/crates/mcp-client/examples/stdio_integration.rs b/crates/mcp-client/examples/stdio_integration.rs index ef0824932..9acd2086d 100644 --- a/crates/mcp-client/examples/stdio_integration.rs +++ b/crates/mcp-client/examples/stdio_integration.rs @@ -1,10 +1,13 @@ -use std::collections::HashMap; - // This example shows how to use the mcp-client crate to interact with a server that has a simple counter tool. // The server is started by running `cargo run -p mcp-server` in the root of the mcp-server crate. use anyhow::Result; -use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClient}; +use mcp_client::client::{ + ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait, +}; use mcp_client::transport::{StdioTransport, Transport}; +use mcp_client::McpService; +use std::collections::HashMap; +use std::time::Duration; use tracing_subscriber::EnvFilter; #[tokio::main] @@ -31,8 +34,11 @@ async fn main() -> Result<(), ClientError> { // Start the transport to get a handle let transport_handle = transport.start().await.unwrap(); + // Create the service with timeout middleware + let service = McpService::with_timeout(transport_handle, Duration::from_secs(10)); + // Create client - let mut client = McpClient::new(transport_handle); + let mut client = McpClient::new(service); // Initialize let server_info = client diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 56eed1ad5..1eef8290c 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -1,6 +1,3 @@ -use std::sync::atomic::{AtomicU64, Ordering}; - -use crate::transport::TransportHandle; use mcp_core::protocol::{ CallToolResult, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult, ReadResourceResult, @@ -8,6 +5,7 @@ use mcp_core::protocol::{ }; use serde::{Deserialize, Serialize}; use serde_json::Value; +use std::sync::atomic::{AtomicU64, Ordering}; use thiserror::Error; use tokio::sync::Mutex; use tower::{Service, ServiceExt}; // for Service::ready() @@ -24,14 +22,23 @@ pub enum Error { #[error("Serialization error: {0}")] Serialization(#[from] serde_json::Error), - #[error("Unexpected response from server")] - UnexpectedResponse, + #[error("Unexpected response from server: {0}")] + UnexpectedResponse(String), #[error("Not initialized")] NotInitialized, #[error("Timeout or service not ready")] NotReady, + + #[error("Box error: {0}")] + BoxError(Box), +} + +impl From> for Error { + fn from(err: Box) -> Self { + Error::BoxError(err) + } } #[derive(Serialize, Deserialize)] @@ -54,20 +61,49 @@ pub struct InitializeParams { pub client_info: ClientInfo, } +#[async_trait::async_trait] +pub trait McpClientTrait: Send + Sync { + async fn initialize( + &mut self, + info: ClientInfo, + capabilities: ClientCapabilities, + ) -> Result; + + async fn list_resources( + &self, + next_cursor: Option, + ) -> Result; + + async fn read_resource(&self, uri: &str) -> Result; + + async fn list_tools(&self, next_cursor: Option) -> Result; + + async fn call_tool(&self, name: &str, arguments: Value) -> Result; +} + /// The MCP client is the interface for MCP operations. -pub struct McpClient { - service: Mutex, +pub struct McpClient +where + S: Service + Clone + Send + Sync + 'static, + S::Error: Into, + S::Future: Send, +{ + service: Mutex, next_id: AtomicU64, server_capabilities: Option, } -impl McpClient { - pub fn new(transport_handle: TransportHandle) -> Self { - // Takes TransportHandle directly +impl McpClient +where + S: Service + Clone + Send + Sync + 'static, + S::Error: Into, + S::Future: Send, +{ + pub fn new(service: S) -> Self { Self { - service: Mutex::new(transport_handle), + service: Mutex::new(service), next_id: AtomicU64::new(1), - server_capabilities: None, // set during initialization + server_capabilities: None, } } @@ -87,7 +123,7 @@ impl McpClient { params: Some(params), }); - let response_msg = service.call(request).await?; + let response_msg = service.call(request).await.map_err(Into::into)?; match response_msg { JsonRpcMessage::Response(JsonRpcResponse { @@ -95,7 +131,9 @@ impl McpClient { }) => { // Verify id matches if id != Some(self.next_id.load(Ordering::SeqCst) - 1) { - return Err(Error::UnexpectedResponse); + return Err(Error::UnexpectedResponse( + "id mismatch for JsonRpcResponse".to_string(), + )); } if let Some(err) = error { Err(Error::RpcError { @@ -105,12 +143,14 @@ impl McpClient { } else if let Some(r) = result { Ok(serde_json::from_value(r)?) } else { - Err(Error::UnexpectedResponse) + Err(Error::UnexpectedResponse("missing result".to_string())) } } JsonRpcMessage::Error(JsonRpcError { id, error, .. }) => { if id != Some(self.next_id.load(Ordering::SeqCst) - 1) { - return Err(Error::UnexpectedResponse); + return Err(Error::UnexpectedResponse( + "id mismatch for JsonRpcError".to_string(), + )); } Err(Error::RpcError { code: error.code, @@ -119,7 +159,9 @@ impl McpClient { } _ => { // Requests/notifications not expected as a response - Err(Error::UnexpectedResponse) + Err(Error::UnexpectedResponse( + "unexpected message type".to_string(), + )) } } } @@ -135,11 +177,24 @@ impl McpClient { params: Some(params), }); - service.call(notification).await?; + service.call(notification).await.map_err(Into::into)?; Ok(()) } - pub async fn initialize( + // Check if the client has completed initialization + fn completed_initialization(&self) -> bool { + self.server_capabilities.is_some() + } +} + +#[async_trait::async_trait] +impl McpClientTrait for McpClient +where + S: Service + Clone + Send + Sync + 'static, + S::Error: Into, + S::Future: Send, +{ + async fn initialize( &mut self, info: ClientInfo, capabilities: ClientCapabilities, @@ -161,11 +216,7 @@ impl McpClient { Ok(result) } - fn completed_initialization(&self) -> bool { - self.server_capabilities.is_some() - } - - pub async fn list_resources( + async fn list_resources( &self, next_cursor: Option, ) -> Result { @@ -193,7 +244,7 @@ impl McpClient { self.send_request("resources/list", payload).await } - pub async fn read_resource(&self, uri: &str) -> Result { + async fn read_resource(&self, uri: &str) -> Result { if !self.completed_initialization() { return Err(Error::NotInitialized); } @@ -215,7 +266,7 @@ impl McpClient { self.send_request("resources/read", params).await } - pub async fn list_tools(&self, next_cursor: Option) -> Result { + async fn list_tools(&self, next_cursor: Option) -> Result { if !self.completed_initialization() { return Err(Error::NotInitialized); } @@ -234,7 +285,7 @@ impl McpClient { self.send_request("tools/list", payload).await } - pub async fn call_tool(&self, name: &str, arguments: Value) -> Result { + async fn call_tool(&self, name: &str, arguments: Value) -> Result { if !self.completed_initialization() { return Err(Error::NotInitialized); } diff --git a/crates/mcp-client/src/lib.rs b/crates/mcp-client/src/lib.rs index ca0e0203e..985d89d16 100644 --- a/crates/mcp-client/src/lib.rs +++ b/crates/mcp-client/src/lib.rs @@ -1,2 +1,7 @@ pub mod client; +pub mod service; pub mod transport; + +pub use client::{ClientCapabilities, ClientInfo, Error, McpClient, McpClientTrait}; +pub use service::McpService; +pub use transport::{SseTransport, StdioTransport, Transport, TransportHandle}; diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs new file mode 100644 index 000000000..3e6b7afc5 --- /dev/null +++ b/crates/mcp-client/src/service.rs @@ -0,0 +1,59 @@ +use futures::future::BoxFuture; +use mcp_core::protocol::JsonRpcMessage; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tower::{timeout::Timeout, Service, ServiceBuilder}; + +use crate::transport::{Error, TransportHandle}; + +/// A wrapper service that implements Tower's Service trait for MCP transport +#[derive(Clone)] +pub struct McpService { + inner: Arc, +} + +impl McpService { + pub fn new(transport: T) -> Self { + Self { + inner: Arc::new(transport), + } + } +} + +impl Service for McpService +where + T: TransportHandle + Send + Sync + 'static, +{ + type Response = JsonRpcMessage; + type Error = Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + // Most transports are always ready, but this could be customized if needed + Poll::Ready(Ok(())) + } + + fn call(&mut self, request: JsonRpcMessage) -> Self::Future { + let transport = self.inner.clone(); + Box::pin(async move { transport.send(request).await }) + } +} + +// Add a convenience constructor for creating a service with timeout +impl McpService +where + T: TransportHandle, +{ + pub fn with_timeout(transport: T, timeout: std::time::Duration) -> Timeout> { + ServiceBuilder::new() + .timeout(timeout) + .service(McpService::new(transport)) + } +} + +// Implement From for our Error type +impl From for Error { + fn from(_: tower::timeout::error::Elapsed) -> Self { + Error::Timeout + } +} diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs index a10c1c9f0..aaf8a7f58 100644 --- a/crates/mcp-client/src/transport/mod.rs +++ b/crates/mcp-client/src/transport/mod.rs @@ -1,15 +1,8 @@ -use std::{ - collections::HashMap, - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - use async_trait::async_trait; use mcp_core::protocol::JsonRpcMessage; +use std::collections::HashMap; use thiserror::Error; use tokio::sync::{mpsc, oneshot, RwLock}; -use tower::Service; /// A generic error type for transport operations. #[derive(Debug, Error)] @@ -46,6 +39,15 @@ pub enum Error { #[error("Unexpected transport error: {0}")] Other(String), + + #[error("Box error: {0}")] + BoxError(Box), +} + +impl From> for Error { + fn from(err: Box) -> Self { + Error::BoxError(err) + } } /// A message that can be sent through the transport @@ -59,63 +61,46 @@ pub struct TransportMessage { /// A generic asynchronous transport trait with channel-based communication #[async_trait] -pub trait Transport: Send + Sync + 'static { +pub trait Transport { + type Handle: TransportHandle; + /// Start the transport and establish the underlying connection. /// Returns the transport handle for sending messages. - async fn start(&self) -> Result; + async fn start(&self) -> Result; /// Close the transport and free any resources. async fn close(&self) -> Result<(), Error>; } -#[derive(Clone)] -pub struct TransportHandle { - sender: mpsc::Sender, +#[async_trait] +pub trait TransportHandle: Send + Sync + Clone + 'static { + async fn send(&self, message: JsonRpcMessage) -> Result; } -impl TransportHandle { - pub async fn send(&self, message: JsonRpcMessage) -> Result { - match message { - JsonRpcMessage::Request(request) => { - let (respond_to, response) = oneshot::channel(); - let msg = TransportMessage { - message: JsonRpcMessage::Request(request), - response_tx: Some(respond_to), - }; - self.sender - .send(msg) - .await - .map_err(|_| Error::ChannelClosed)?; - Ok(response.await.map_err(|_| Error::ChannelClosed)??) - } - JsonRpcMessage::Notification(notification) => { - let msg = TransportMessage { - message: JsonRpcMessage::Notification(notification), - response_tx: None, - }; - self.sender - .send(msg) - .await - .map_err(|_| Error::ChannelClosed)?; - Ok(JsonRpcMessage::Nil) // Explicitly return None for notifications - } - _ => Err(Error::Other("Unsupported message type".to_string())), +// Helper function that contains the common send implementation +pub async fn send_message( + sender: &mpsc::Sender, + message: JsonRpcMessage, +) -> Result { + match message { + JsonRpcMessage::Request(request) => { + let (respond_to, response) = oneshot::channel(); + let msg = TransportMessage { + message: JsonRpcMessage::Request(request), + response_tx: Some(respond_to), + }; + sender.send(msg).await.map_err(|_| Error::ChannelClosed)?; + Ok(response.await.map_err(|_| Error::ChannelClosed)??) } - } -} - -impl Service for TransportHandle { - type Response = JsonRpcMessage; - type Error = Error; // Using Transport's Error directly - type Future = Pin> + Send>>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, message: JsonRpcMessage) -> Self::Future { - let this = self.clone(); - Box::pin(async move { this.send(message).await }) + JsonRpcMessage::Notification(notification) => { + let msg = TransportMessage { + message: JsonRpcMessage::Notification(notification), + response_tx: None, + }; + sender.send(msg).await.map_err(|_| Error::ChannelClosed)?; + Ok(JsonRpcMessage::Nil) + } + _ => Err(Error::Other("Unsupported message type".to_string())), } } diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs index 70f40a585..357bc0738 100644 --- a/crates/mcp-client/src/transport/sse.rs +++ b/crates/mcp-client/src/transport/sse.rs @@ -10,7 +10,7 @@ use tokio::sync::{mpsc, RwLock}; use tokio::time::{timeout, Duration}; use tracing::warn; -use super::{Transport, TransportHandle}; +use super::{send_message, Transport, TransportHandle}; // Timeout for the endpoint discovery const ENDPOINT_TIMEOUT_SECS: u64 = 5; @@ -203,6 +203,18 @@ impl SseActor { } } +#[derive(Clone)] +pub struct SseTransportHandle { + sender: mpsc::Sender, +} + +#[async_trait::async_trait] +impl TransportHandle for SseTransportHandle { + async fn send(&self, message: JsonRpcMessage) -> Result { + send_message(&self.sender, message).await + } +} + #[derive(Clone)] pub struct SseTransport { sse_url: String, @@ -240,7 +252,9 @@ impl SseTransport { #[async_trait] impl Transport for SseTransport { - async fn start(&self) -> Result { + type Handle = SseTransportHandle; + + async fn start(&self) -> Result { // Set environment variables for (key, value) in &self.env { std::env::set_var(key, value); @@ -270,7 +284,7 @@ impl Transport for SseTransport { ) .await { - Ok(_) => Ok(TransportHandle { sender: tx }), + Ok(_) => Ok(SseTransportHandle { sender: tx }), Err(e) => Err(Error::SseConnection(e.to_string())), } } diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index e32ad38da..032e2d7d5 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -7,7 +7,7 @@ use mcp_core::protocol::JsonRpcMessage; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::sync::mpsc; -use super::{Error, PendingRequests, Transport, TransportHandle, TransportMessage}; +use super::{send_message, Error, PendingRequests, Transport, TransportHandle, TransportMessage}; /// A `StdioTransport` uses a child process's stdin/stdout as a communication channel. /// @@ -101,6 +101,18 @@ impl StdioActor { } } +#[derive(Clone)] +pub struct StdioTransportHandle { + sender: mpsc::Sender, +} + +#[async_trait::async_trait] +impl TransportHandle for StdioTransportHandle { + async fn send(&self, message: JsonRpcMessage) -> Result { + send_message(&self.sender, message).await + } +} + pub struct StdioTransport { command: String, args: Vec, @@ -149,7 +161,9 @@ impl StdioTransport { #[async_trait] impl Transport for StdioTransport { - async fn start(&self) -> Result { + type Handle = StdioTransportHandle; + + async fn start(&self) -> Result { let (process, stdin, stdout) = self.spawn_process().await?; let (message_tx, message_rx) = mpsc::channel(32); @@ -163,7 +177,7 @@ impl Transport for StdioTransport { tokio::spawn(actor.run()); - let handle = TransportHandle { sender: message_tx }; + let handle = StdioTransportHandle { sender: message_tx }; Ok(handle) }