Skip to content

Commit

Permalink
Add timeout middleware for clients (#572)
Browse files Browse the repository at this point in the history
  • Loading branch information
salman1993 authored Jan 13, 2025
1 parent ea6ec0d commit 669df02
Show file tree
Hide file tree
Showing 12 changed files with 276 additions and 127 deletions.
21 changes: 15 additions & 6 deletions crates/goose/src/agents/capabilities.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
use chrono::{DateTime, TimeZone, Utc};
use mcp_client::McpService;
use rust_decimal_macros::dec;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::LazyLock;
use std::time::Duration;
use tokio::sync::Mutex;
use tracing::{debug, instrument};

use super::system::{SystemConfig, SystemError, SystemInfo, SystemResult};
use crate::prompt_template::load_prompt_file;
use crate::providers::base::{Provider, ProviderUsage};
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient};
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
use mcp_client::transport::{SseTransport, StdioTransport, Transport};
use mcp_core::{Content, Tool, ToolCall, ToolError, ToolResult};

Expand All @@ -20,7 +22,7 @@ static DEFAULT_TIMESTAMP: LazyLock<DateTime<Utc>> =

/// Manages MCP clients and their interactions
pub struct Capabilities {
clients: HashMap<String, Arc<Mutex<McpClient>>>,
clients: HashMap<String, Arc<Mutex<Box<dyn McpClientTrait>>>>,
instructions: HashMap<String, String>,
provider: Box<dyn Provider>,
provider_usage: Mutex<Vec<ProviderUsage>>,
Expand Down Expand Up @@ -87,18 +89,22 @@ impl Capabilities {
/// Add a new MCP system based on the provided client type
// TODO IMPORTANT need to ensure this times out if the system command is broken!
pub async fn add_system(&mut self, config: SystemConfig) -> SystemResult<()> {
let mut client: McpClient = match config {
let mut client: Box<dyn McpClientTrait> = match config {
SystemConfig::Sse { ref uri, ref envs } => {
let transport = SseTransport::new(uri, envs.get_env());
McpClient::new(transport.start().await?)
let handle = transport.start().await?;
let service = McpService::with_timeout(handle, Duration::from_secs(10));
Box::new(McpClient::new(service))
}
SystemConfig::Stdio {
ref cmd,
ref args,
ref envs,
} => {
let transport = StdioTransport::new(cmd, args.to_vec(), envs.get_env());
McpClient::new(transport.start().await?)
let handle = transport.start().await?;
let service = McpService::with_timeout(handle, Duration::from_secs(10));
Box::new(McpClient::new(service))
}
};

Expand Down Expand Up @@ -271,7 +277,10 @@ impl Capabilities {
}

/// Find and return a reference to the appropriate client for a tool call
fn get_client_for_tool(&self, prefixed_name: &str) -> Option<Arc<Mutex<McpClient>>> {
fn get_client_for_tool(
&self,
prefixed_name: &str,
) -> Option<Arc<Mutex<Box<dyn McpClientTrait>>>> {
prefixed_name
.split_once("__")
.and_then(|(client_name, _)| self.clients.get(client_name))
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/src/agents/system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use thiserror::Error;
pub enum SystemError {
#[error("Failed to start the MCP server from configuration `{0}` within 60 seconds")]
Initialization(SystemConfig),
#[error("Failed a client call to an MCP server")]
#[error("Failed a client call to an MCP server: {0}")]
Client(#[from] ClientError),
#[error("Transport error: {0}")]
Transport(#[from] mcp_client::transport::Error),
Expand Down
40 changes: 18 additions & 22 deletions crates/mcp-client/examples/clients.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use mcp_client::{
client::{ClientCapabilities, ClientInfo, McpClient},
client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait},
transport::{SseTransport, StdioTransport, Transport},
McpService,
};
use rand::Rng;
use rand::SeedableRng;
Expand All @@ -17,13 +18,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
)
.init();

// Create two separate clients with stdio transport
let client1 = create_stdio_client("client1", "1.0.0").await?;
let client2 = create_stdio_client("client2", "1.0.0").await?;
let client3 = create_sse_client("client3", "1.0.0").await?;
let transport1 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
let handle1 = transport1.start().await?;
let service1 = McpService::with_timeout(handle1, Duration::from_secs(30));
let client1 = McpClient::new(service1);

let transport2 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
let handle2 = transport2.start().await?;
let service2 = McpService::with_timeout(handle2, Duration::from_secs(30));
let client2 = McpClient::new(service2);

let transport3 = SseTransport::new("http://localhost:8000/sse", HashMap::new());
let handle3 = transport3.start().await?;
let service3 = McpService::with_timeout(handle3, Duration::from_secs(10));
let client3 = McpClient::new(service3);

// Initialize both clients
let mut clients = vec![client1, client2, client3];
let mut clients: Vec<Box<dyn McpClientTrait>> =
vec![Box::new(client1), Box::new(client2), Box::new(client3)];

// Initialize all clients
for (i, client) in clients.iter_mut().enumerate() {
Expand Down Expand Up @@ -117,19 +129,3 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

Ok(())
}

async fn create_stdio_client(
_name: &str,
_version: &str,
) -> Result<McpClient, Box<dyn std::error::Error>> {
let transport = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
Ok(McpClient::new(transport.start().await?))
}

async fn create_sse_client(
_name: &str,
_version: &str,
) -> Result<McpClient, Box<dyn std::error::Error>> {
let transport = SseTransport::new("http://localhost:8000/sse", HashMap::new());
Ok(McpClient::new(transport.start().await?))
}
8 changes: 6 additions & 2 deletions crates/mcp-client/examples/sse.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use anyhow::Result;
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient};
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
use mcp_client::transport::{SseTransport, Transport};
use mcp_client::McpService;
use std::collections::HashMap;
use std::time::Duration;
use tracing_subscriber::EnvFilter;
Expand All @@ -22,8 +23,11 @@ async fn main() -> Result<()> {
// Start transport
let handle = transport.start().await?;

// Create the service with timeout middleware
let service = McpService::with_timeout(handle, Duration::from_secs(3));

// Create client
let mut client = McpClient::new(handle);
let mut client = McpClient::new(service);
println!("Client created\n");

// Initialize
Expand Down
14 changes: 10 additions & 4 deletions crates/mcp-client/examples/stdio.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::collections::HashMap;

use anyhow::Result;
use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClient};
use mcp_client::transport::{StdioTransport, Transport};
use mcp_client::{
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait, McpService,
StdioTransport, Transport,
};
use std::time::Duration;
use tracing_subscriber::EnvFilter;

#[tokio::main]
Expand All @@ -22,8 +25,11 @@ async fn main() -> Result<(), ClientError> {
// 2) Start the transport to get a handle
let transport_handle = transport.start().await?;

// 3) Create the client
let mut client = McpClient::new(transport_handle);
// 3) Create the service with timeout middleware
let service = McpService::with_timeout(transport_handle, Duration::from_secs(10));

// 4) Create the client with the middleware-wrapped service
let mut client = McpClient::new(service);

// Initialize
let server_info = client
Expand Down
14 changes: 10 additions & 4 deletions crates/mcp-client/examples/stdio_integration.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use std::collections::HashMap;

// This example shows how to use the mcp-client crate to interact with a server that has a simple counter tool.
// The server is started by running `cargo run -p mcp-server` in the root of the mcp-server crate.
use anyhow::Result;
use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClient};
use mcp_client::client::{
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait,
};
use mcp_client::transport::{StdioTransport, Transport};
use mcp_client::McpService;
use std::collections::HashMap;
use std::time::Duration;
use tracing_subscriber::EnvFilter;

#[tokio::main]
Expand All @@ -31,8 +34,11 @@ async fn main() -> Result<(), ClientError> {
// Start the transport to get a handle
let transport_handle = transport.start().await.unwrap();

// Create the service with timeout middleware
let service = McpService::with_timeout(transport_handle, Duration::from_secs(10));

// Create client
let mut client = McpClient::new(transport_handle);
let mut client = McpClient::new(service);

// Initialize
let server_info = client
Expand Down
Loading

0 comments on commit 669df02

Please sign in to comment.