Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timeout middleware for clients #572

Merged
merged 13 commits into from
Jan 13, 2025
21 changes: 15 additions & 6 deletions crates/goose/src/agents/capabilities.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
use chrono::{DateTime, TimeZone, Utc};
use mcp_client::McpService;
use rust_decimal_macros::dec;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::LazyLock;
use std::time::Duration;
use tokio::sync::Mutex;
use tracing::{debug, instrument};

use super::system::{SystemConfig, SystemError, SystemInfo, SystemResult};
use crate::prompt_template::load_prompt_file;
use crate::providers::base::{Provider, ProviderUsage};
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient};
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
use mcp_client::transport::{SseTransport, StdioTransport, Transport};
use mcp_core::{Content, Tool, ToolCall, ToolError, ToolResult};

Expand All @@ -20,7 +22,7 @@ static DEFAULT_TIMESTAMP: LazyLock<DateTime<Utc>> =

/// Manages MCP clients and their interactions
pub struct Capabilities {
clients: HashMap<String, Arc<Mutex<McpClient>>>,
clients: HashMap<String, Arc<Mutex<Box<dyn McpClientTrait>>>>,
instructions: HashMap<String, String>,
provider: Box<dyn Provider>,
provider_usage: Mutex<Vec<ProviderUsage>>,
Expand Down Expand Up @@ -87,18 +89,22 @@ impl Capabilities {
/// Add a new MCP system based on the provided client type
// TODO IMPORTANT need to ensure this times out if the system command is broken!
pub async fn add_system(&mut self, config: SystemConfig) -> SystemResult<()> {
let mut client: McpClient = match config {
let mut client: Box<dyn McpClientTrait> = match config {
SystemConfig::Sse { ref uri, ref envs } => {
let transport = SseTransport::new(uri, envs.get_env());
McpClient::new(transport.start().await?)
let handle = transport.start().await?;
let service = McpService::with_timeout(handle, Duration::from_secs(10));
Box::new(McpClient::new(service))
}
SystemConfig::Stdio {
ref cmd,
ref args,
ref envs,
} => {
let transport = StdioTransport::new(cmd, args.to_vec(), envs.get_env());
McpClient::new(transport.start().await?)
let handle = transport.start().await?;
let service = McpService::with_timeout(handle, Duration::from_secs(10));
Box::new(McpClient::new(service))
}
};

Expand Down Expand Up @@ -271,7 +277,10 @@ impl Capabilities {
}

/// Find and return a reference to the appropriate client for a tool call
fn get_client_for_tool(&self, prefixed_name: &str) -> Option<Arc<Mutex<McpClient>>> {
fn get_client_for_tool(
&self,
prefixed_name: &str,
) -> Option<Arc<Mutex<Box<dyn McpClientTrait>>>> {
prefixed_name
.split_once("__")
.and_then(|(client_name, _)| self.clients.get(client_name))
Expand Down
2 changes: 1 addition & 1 deletion crates/goose/src/agents/system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use thiserror::Error;
pub enum SystemError {
#[error("Failed to start the MCP server from configuration `{0}` within 60 seconds")]
Initialization(SystemConfig),
#[error("Failed a client call to an MCP server")]
#[error("Failed a client call to an MCP server: {0}")]
Client(#[from] ClientError),
#[error("Transport error: {0}")]
Transport(#[from] mcp_client::transport::Error),
Expand Down
40 changes: 18 additions & 22 deletions crates/mcp-client/examples/clients.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use mcp_client::{
client::{ClientCapabilities, ClientInfo, McpClient},
client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait},
transport::{SseTransport, StdioTransport, Transport},
McpService,
};
use rand::Rng;
use rand::SeedableRng;
Expand All @@ -17,13 +18,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
)
.init();

// Create two separate clients with stdio transport
let client1 = create_stdio_client("client1", "1.0.0").await?;
let client2 = create_stdio_client("client2", "1.0.0").await?;
let client3 = create_sse_client("client3", "1.0.0").await?;
let transport1 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
let handle1 = transport1.start().await?;
let service1 = McpService::with_timeout(handle1, Duration::from_secs(30));
let client1 = McpClient::new(service1);

let transport2 = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
let handle2 = transport2.start().await?;
let service2 = McpService::with_timeout(handle2, Duration::from_secs(30));
let client2 = McpClient::new(service2);

let transport3 = SseTransport::new("http://localhost:8000/sse", HashMap::new());
let handle3 = transport3.start().await?;
let service3 = McpService::with_timeout(handle3, Duration::from_secs(10));
let client3 = McpClient::new(service3);

// Initialize both clients
let mut clients = vec![client1, client2, client3];
let mut clients: Vec<Box<dyn McpClientTrait>> =
vec![Box::new(client1), Box::new(client2), Box::new(client3)];

// Initialize all clients
for (i, client) in clients.iter_mut().enumerate() {
Expand Down Expand Up @@ -117,19 +129,3 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

Ok(())
}

async fn create_stdio_client(
_name: &str,
_version: &str,
) -> Result<McpClient, Box<dyn std::error::Error>> {
let transport = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new());
Ok(McpClient::new(transport.start().await?))
}

async fn create_sse_client(
_name: &str,
_version: &str,
) -> Result<McpClient, Box<dyn std::error::Error>> {
let transport = SseTransport::new("http://localhost:8000/sse", HashMap::new());
Ok(McpClient::new(transport.start().await?))
}
8 changes: 6 additions & 2 deletions crates/mcp-client/examples/sse.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use anyhow::Result;
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient};
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
use mcp_client::transport::{SseTransport, Transport};
use mcp_client::McpService;
use std::collections::HashMap;
use std::time::Duration;
use tracing_subscriber::EnvFilter;
Expand All @@ -22,8 +23,11 @@ async fn main() -> Result<()> {
// Start transport
let handle = transport.start().await?;

// Create the service with timeout middleware
let service = McpService::with_timeout(handle, Duration::from_secs(3));

// Create client
let mut client = McpClient::new(handle);
let mut client = McpClient::new(service);
println!("Client created\n");

// Initialize
Expand Down
14 changes: 10 additions & 4 deletions crates/mcp-client/examples/stdio.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::collections::HashMap;

use anyhow::Result;
use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClient};
use mcp_client::transport::{StdioTransport, Transport};
use mcp_client::{
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait, McpService,
StdioTransport, Transport,
};
use std::time::Duration;
use tracing_subscriber::EnvFilter;

#[tokio::main]
Expand All @@ -22,8 +25,11 @@ async fn main() -> Result<(), ClientError> {
// 2) Start the transport to get a handle
let transport_handle = transport.start().await?;

// 3) Create the client
let mut client = McpClient::new(transport_handle);
// 3) Create the service with timeout middleware
let service = McpService::with_timeout(transport_handle, Duration::from_secs(10));

// 4) Create the client with the middleware-wrapped service
let mut client = McpClient::new(service);

// Initialize
let server_info = client
Expand Down
14 changes: 10 additions & 4 deletions crates/mcp-client/examples/stdio_integration.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use std::collections::HashMap;

// This example shows how to use the mcp-client crate to interact with a server that has a simple counter tool.
// The server is started by running `cargo run -p mcp-server` in the root of the mcp-server crate.
use anyhow::Result;
use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClient};
use mcp_client::client::{
ClientCapabilities, ClientInfo, Error as ClientError, McpClient, McpClientTrait,
};
use mcp_client::transport::{StdioTransport, Transport};
use mcp_client::McpService;
use std::collections::HashMap;
use std::time::Duration;
use tracing_subscriber::EnvFilter;

#[tokio::main]
Expand All @@ -31,8 +34,11 @@ async fn main() -> Result<(), ClientError> {
// Start the transport to get a handle
let transport_handle = transport.start().await.unwrap();

// Create the service with timeout middleware
let service = McpService::with_timeout(transport_handle, Duration::from_secs(10));

// Create client
let mut client = McpClient::new(transport_handle);
let mut client = McpClient::new(service);

// Initialize
let server_info = client
Expand Down
Loading
Loading