diff --git a/crates/goose-cli/src/agents/agent.rs b/crates/goose-cli/src/agents/agent.rs deleted file mode 100644 index eb0833489..000000000 --- a/crates/goose-cli/src/agents/agent.rs +++ /dev/null @@ -1,28 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; -use futures::stream::BoxStream; -use goose::{ - agent::Agent as GooseAgent, message::Message, providers::base::ProviderUsage, systems::System, -}; - -#[async_trait] -pub trait Agent { - fn add_system(&mut self, system: Box); - async fn reply(&self, messages: &[Message]) -> Result>>; - async fn usage(&self) -> Result>; -} - -#[async_trait] -impl Agent for GooseAgent { - fn add_system(&mut self, system: Box) { - self.add_system(system); - } - - async fn reply(&self, messages: &[Message]) -> Result>> { - self.reply(messages).await - } - - async fn usage(&self) -> Result> { - self.usage().await - } -} diff --git a/crates/goose-cli/src/agents/mock_agent.rs b/crates/goose-cli/src/agents/mock_agent.rs index 542431ac3..090feabd8 100644 --- a/crates/goose-cli/src/agents/mock_agent.rs +++ b/crates/goose-cli/src/agents/mock_agent.rs @@ -1,23 +1,59 @@ -use std::vec; - -use anyhow::Result; use async_trait::async_trait; use futures::stream::BoxStream; -use goose::{message::Message, providers::base::ProviderUsage, systems::System}; +use goose::providers::mock::MockProvider; +use goose::{ + agents::Agent, + errors::AgentResult, + message::Message, + providers::base::{Provider, ProviderUsage}, + systems::System, +}; +use serde_json::Value; +use tokio::sync::Mutex; -use crate::agents::agent::Agent; +pub struct MockAgent { + systems: Vec>, + provider: Box, + provider_usage: Mutex>, +} -pub struct MockAgent; +impl MockAgent { + pub fn new() -> Self { + Self { + systems: Vec::new(), + provider: Box::new(MockProvider::new(Vec::new())), + provider_usage: Mutex::new(Vec::new()), + } + } +} #[async_trait] impl Agent for MockAgent { - fn add_system(&mut self, _system: Box) {} + async fn add_system(&mut self, system: Box) -> AgentResult<()> { + self.systems.push(system); + Ok(()) + } + + async fn remove_system(&mut self, name: &str) -> AgentResult<()> { + self.systems.retain(|s| s.name() != name); + Ok(()) + } + + async fn list_systems(&self) -> AgentResult> { + Ok(self.systems.iter() + .map(|s| (s.name().to_string(), s.description().to_string())) + .collect()) + } + + async fn passthrough(&self, _system: &str, _request: Value) -> AgentResult { + Ok(Value::Null) + } - async fn reply(&self, _messages: &[Message]) -> Result>> { + async fn reply(&self, _messages: &[Message]) -> anyhow::Result>> { Ok(Box::pin(futures::stream::empty())) } - async fn usage(&self) -> Result> { + async fn usage(&self) -> AgentResult> { Ok(vec![ProviderUsage::new( "mock".to_string(), Default::default(), diff --git a/crates/goose-cli/src/agents/mod.rs b/crates/goose-cli/src/agents/mod.rs index 14b5f00cd..a1a102c66 100644 --- a/crates/goose-cli/src/agents/mod.rs +++ b/crates/goose-cli/src/agents/mod.rs @@ -1,4 +1,2 @@ -pub mod agent; - #[cfg(test)] pub mod mock_agent; diff --git a/crates/goose-cli/src/commands/agent_version.rs b/crates/goose-cli/src/commands/agent_version.rs new file mode 100644 index 000000000..2f1628367 --- /dev/null +++ b/crates/goose-cli/src/commands/agent_version.rs @@ -0,0 +1,28 @@ +use anyhow::Result; +use clap::Args; +use goose::agents::AgentFactory; +use std::fmt::Write; + +#[derive(Args)] +pub struct AgentCommand {} + +impl AgentCommand { + pub fn run(&self) -> Result<()> { + let mut output = String::new(); + writeln!(output, "Available agent versions:")?; + + let versions = AgentFactory::available_versions(); + let default_version = AgentFactory::default_version(); + + for version in versions { + if version == default_version { + writeln!(output, "* {} (default)", version)?; + } else { + writeln!(output, " {}", version)?; + } + } + + print!("{}", output); + Ok(()) + } +} diff --git a/crates/goose-cli/src/commands/mod.rs b/crates/goose-cli/src/commands/mod.rs index 9420b16f7..b84916a20 100644 --- a/crates/goose-cli/src/commands/mod.rs +++ b/crates/goose-cli/src/commands/mod.rs @@ -1,4 +1,4 @@ +pub mod agent_version; pub mod configure; pub mod session; pub mod version; -pub mod expected_config; \ No newline at end of file diff --git a/crates/goose-cli/src/commands/session.rs b/crates/goose-cli/src/commands/session.rs index 0e23dac49..88ed4abee 100644 --- a/crates/goose-cli/src/commands/session.rs +++ b/crates/goose-cli/src/commands/session.rs @@ -1,5 +1,5 @@ use console::style; -use goose::agent::Agent; +use goose::agents::AgentFactory; use goose::providers::factory; use rand::{distributions::Alphanumeric, Rng}; use std::path::{Path, PathBuf}; @@ -13,6 +13,7 @@ use crate::session::{ensure_session_dir, get_most_recent_session, Session}; pub fn build_session<'a>( session: Option, profile: Option, + agent_version: Option, resume: bool, ) -> Box> { let session_dir = ensure_session_dir().expect("Failed to create session directory"); @@ -45,7 +46,7 @@ pub fn build_session<'a>( // TODO: Odd to be prepping the provider rather than having that done in the agent? let provider = factory::get_provider(provider_config).unwrap(); - let agent = Box::new(Agent::new(provider)); + let agent = AgentFactory::create(agent_version.as_deref().unwrap_or("default"), provider).unwrap(); let prompt = match std::env::var("GOOSE_INPUT") { Ok(val) => match val.as_str() { "rustyline" => Box::new(RustylinePrompt::new()) as Box, @@ -173,7 +174,7 @@ mod tests { #[should_panic(expected = "Cannot resume session: file")] fn test_resume_nonexistent_session_panics() { run_with_tmp_dir(|| { - build_session(Some("nonexistent-session".to_string()), None, true); + build_session(Some("nonexistent-session".to_string()), None, None, true); }) } @@ -190,7 +191,7 @@ mod tests { fs::write(&file2_path, "{}")?; // Test resuming without a session name - let session = build_session(None, None, true); + let session = build_session(None, None, None, true); assert_eq!(session.session_file().as_path(), file2_path.as_path()); Ok(()) @@ -201,7 +202,7 @@ mod tests { #[should_panic(expected = "No session files found")] fn test_resume_most_recent_session_no_files() { run_with_tmp_dir(|| { - build_session(None, None, true); + build_session(None, None, None, true); }); } } diff --git a/crates/goose-cli/src/main.rs b/crates/goose-cli/src/main.rs index e9b1fa6b2..3b3cd852c 100644 --- a/crates/goose-cli/src/main.rs +++ b/crates/goose-cli/src/main.rs @@ -1,25 +1,22 @@ -mod commands { - pub mod configure; - pub mod session; - pub mod version; -} -pub mod agents; +use anyhow::Result; +use clap::{Parser, Subcommand}; +use goose::agents::AgentFactory; + +mod agents; +mod commands; +mod log_usage; mod profile; mod prompt; -pub mod session; - +mod session; mod systems; -use anyhow::Result; -use clap::{Parser, Subcommand}; +use commands::agent_version::AgentCommand; use commands::configure::handle_configure; use commands::session::build_session; use commands::version::print_version; use profile::has_no_profiles; use std::io::{self, Read}; -mod log_usage; - #[cfg(test)] mod test_helpers; @@ -98,6 +95,15 @@ enum Command { )] profile: Option, + /// Agent version to use (e.g., 'default', 'v1') + #[arg( + short, + long, + help = "Agent version to use (e.g., 'default', 'v1'), defaults to 'default'", + long_help = "Specify which agent version to use for this session." + )] + agent: Option, + /// Resume a previous session #[arg( short, @@ -151,6 +157,15 @@ enum Command { )] name: Option, + /// Agent version to use (e.g., 'default', 'v1') + #[arg( + short, + long, + help = "Agent version to use (e.g., 'default', 'v1')", + long_help = "Specify which agent version to use for this session." + )] + agent: Option, + /// Resume a previous run #[arg( short, @@ -161,6 +176,9 @@ enum Command { )] resume: bool, }, + + /// List available agent versions + Agents(AgentCommand), } #[derive(Subcommand)] @@ -224,9 +242,25 @@ async fn main() -> Result<()> { Some(Command::Session { name, profile, + agent, resume, }) => { - let mut session = build_session(name, profile, resume); + if let Some(agent_version) = agent.clone() { + if !AgentFactory::available_versions().contains(&agent_version.as_str()) { + eprintln!("Error: Invalid agent version '{}'", agent_version); + eprintln!("Available versions:"); + for version in AgentFactory::available_versions() { + if version == AgentFactory::default_version() { + eprintln!("* {} (default)", version); + } else { + eprintln!(" {}", version); + } + } + std::process::exit(1); + } + } + + let mut session = build_session(name, profile, agent, resume); let _ = session.start().await; return Ok(()); } @@ -235,8 +269,24 @@ async fn main() -> Result<()> { input_text, profile, name, + agent, resume, }) => { + if let Some(agent_version) = agent.clone() { + if !AgentFactory::available_versions().contains(&agent_version.as_str()) { + eprintln!("Error: Invalid agent version '{}'", agent_version); + eprintln!("Available versions:"); + for version in AgentFactory::available_versions() { + if version == AgentFactory::default_version() { + eprintln!("* {} (default)", version); + } else { + eprintln!(" {}", version); + } + } + std::process::exit(1); + } + } + let contents = if let Some(file_name) = instructions { let file_path = std::path::Path::new(&file_name); std::fs::read_to_string(file_path).expect("Failed to read the instruction file") @@ -249,10 +299,14 @@ async fn main() -> Result<()> { .expect("Failed to read from stdin"); stdin }; - let mut session = build_session(name, profile, resume); + let mut session = build_session(name, profile, agent, resume); let _ = session.headless_start(contents.clone()).await; return Ok(()); } + Some(Command::Agents(cmd)) => { + cmd.run()?; + return Ok(()); + } None => { println!("No command provided - Run 'goose help' to see available commands."); if has_no_profiles().unwrap_or(false) { diff --git a/crates/goose-cli/src/session.rs b/crates/goose-cli/src/session.rs index 0d0270d3e..96d652390 100644 --- a/crates/goose-cli/src/session.rs +++ b/crates/goose-cli/src/session.rs @@ -6,9 +6,11 @@ use std::fs::{self, File}; use std::io::{self, BufRead, Write}; use std::path::PathBuf; -use crate::agents::agent::Agent; +// use crate::agents::agent::Agent; use crate::log_usage::log_usage; use crate::prompt::{InputType, Prompt}; +use goose::agents::Agent; +use goose::errors::AgentResult; use goose::developer::DeveloperSystem; use goose::message::{Message, MessageContent}; use goose::systems::goose_hints::GooseHintsSystem; @@ -101,6 +103,7 @@ pub struct Session<'a> { messages: Vec, } +#[allow(dead_code)] impl<'a> Session<'a> { pub fn new( agent: Box, @@ -132,7 +135,7 @@ impl<'a> Session<'a> { } pub async fn start(&mut self) -> Result<(), Box> { - self.setup_session(); + self.setup_session().await?; self.prompt.goose_ready(); loop { @@ -160,7 +163,7 @@ impl<'a> Session<'a> { &mut self, initial_message: String, ) -> Result<(), Box> { - self.setup_session(); + self.setup_session().await?; self.messages .push(Message::user().with_text(initial_message.as_str())); @@ -311,11 +314,12 @@ We've removed the conversation up to the most recent user message } } - fn setup_session(&mut self) { + async fn setup_session(&mut self) -> AgentResult<()> { let system = Box::new(DeveloperSystem::new()); - self.agent.add_system(system); + self.agent.add_system(system).await?; let goosehints_system = Box::new(GooseHintsSystem::new()); - self.agent.add_system(goosehints_system); + self.agent.add_system(goosehints_system).await?; + Ok(()) } async fn close_session(&mut self) { @@ -327,6 +331,7 @@ We've removed the conversation up to the most recent user message .as_str(), )); self.prompt.close(); + match self.agent.usage().await { Ok(usage) => log_usage(self.session_file.to_string_lossy().to_string(), usage), Err(e) => eprintln!("Failed to collect total provider usage: {}", e), @@ -361,14 +366,14 @@ mod tests { // Helper function to create a test session fn create_test_session() -> Session<'static> { let temp_file = NamedTempFile::new().unwrap(); - let agent = Box::new(MockAgent {}); + let agent = Box::new(MockAgent::new()); let prompt = Box::new(MockPrompt::new()); Session::new(agent, prompt, temp_file.path().to_path_buf()) } fn create_test_session_with_prompt<'a>(prompt: Box) -> Session<'a> { let temp_file = NamedTempFile::new().unwrap(); - let agent = Box::new(MockAgent {}); + let agent = Box::new(MockAgent::new()); Session::new(agent, prompt, temp_file.path().to_path_buf()) } diff --git a/crates/goose-server/src/configuration.rs b/crates/goose-server/src/configuration.rs index bbece1276..7d402cec5 100644 --- a/crates/goose-server/src/configuration.rs +++ b/crates/goose-server/src/configuration.rs @@ -219,6 +219,12 @@ pub struct Settings { #[serde(default)] pub server: ServerSettings, pub provider: ProviderSettings, + #[serde(default = "default_agent_version")] + pub agent_version: Option, +} + +fn default_agent_version() -> Option { + None // Will use AgentFactory::default_version() when None } impl Settings { diff --git a/crates/goose-server/src/main.rs b/crates/goose-server/src/main.rs index d575596a7..4c0eb2b82 100644 --- a/crates/goose-server/src/main.rs +++ b/crates/goose-server/src/main.rs @@ -19,7 +19,11 @@ async fn main() -> anyhow::Result<()> { std::env::var("GOOSE_SERVER__SECRET_KEY").unwrap_or_else(|_| "test".to_string()); // Create app state - let state = state::AppState::new(settings.provider.into_config(), secret_key.clone())?; + let state = state::AppState::new( + settings.provider.into_config(), + secret_key.clone(), + settings.agent_version, + ).await?; // Create router with CORS support let cors = CorsLayer::new() diff --git a/crates/goose-server/src/routes/agent.rs b/crates/goose-server/src/routes/agent.rs new file mode 100644 index 000000000..4b78a7319 --- /dev/null +++ b/crates/goose-server/src/routes/agent.rs @@ -0,0 +1,28 @@ +use crate::state::AppState; +use axum::{extract::State, routing::get, Json, Router}; +use goose::agents::AgentFactory; +use serde::Serialize; + +#[derive(Serialize)] +struct VersionsResponse { + current_version: String, + available_versions: Vec, + default_version: String, +} + +async fn get_versions(State(state): State) -> Json { + let versions = AgentFactory::available_versions(); + let default_version = AgentFactory::default_version().to_string(); + + Json(VersionsResponse { + current_version: state.agent_version.clone(), + available_versions: versions.iter().map(|v| v.to_string()).collect(), + default_version, + }) +} + +pub fn routes(state: AppState) -> Router { + Router::new() + .route("/api/agent/versions", get(get_versions)) + .with_state(state) +} diff --git a/crates/goose-server/src/routes/mod.rs b/crates/goose-server/src/routes/mod.rs index 2d798a0da..b430136d0 100644 --- a/crates/goose-server/src/routes/mod.rs +++ b/crates/goose-server/src/routes/mod.rs @@ -1,9 +1,12 @@ // Export route modules +pub mod agent; pub mod reply; use axum::Router; // Function to configure all routes pub fn configure(state: crate::state::AppState) -> Router { - Router::new().merge(reply::routes(state)) + Router::new() + .merge(reply::routes(state.clone())) + .merge(agent::routes(state)) } diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index c6bd48964..b5269aa8b 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -390,7 +390,7 @@ pub fn routes(state: AppState) -> Router { mod tests { use super::*; use goose::{ - agent::Agent, + agents::DefaultAgent as Agent, providers::{ base::{Provider, ProviderUsage, Usage}, configs::{ModelConfig, OpenAiProviderConfig}, @@ -422,7 +422,7 @@ mod tests { &self.model_config } - fn get_usage(&self, data: &Value) -> anyhow::Result { + fn get_usage(&self, _data: &Value) -> anyhow::Result { Ok(Usage::new(None, None, None)) } } @@ -518,13 +518,14 @@ mod tests { }); let agent = Agent::new(mock_provider); let state = AppState { - agent: Arc::new(Mutex::new(agent)), + agent: Arc::new(Mutex::new(Box::new(agent))), provider_config: ProviderConfig::OpenAi(OpenAiProviderConfig { host: "https://api.openai.com".to_string(), api_key: "test-key".to_string(), model: ModelConfig::new("test-model".to_string()), }), secret_key: "test-secret".to_string(), + agent_version: "test-version".to_string(), }; // Build router diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index fa430c316..f62f96d07 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -1,7 +1,8 @@ use anyhow::Result; use goose::providers::configs::GroqProviderConfig; use goose::{ - agent::Agent, + agents::Agent, + agents::AgentFactory, developer::DeveloperSystem, memory::MemorySystem, providers::{configs::ProviderConfig, factory}, @@ -13,30 +14,45 @@ use tokio::sync::Mutex; /// Shared application state pub struct AppState { pub provider_config: ProviderConfig, - pub agent: Arc>, + pub agent: Arc>>, pub secret_key: String, + pub agent_version: String, } impl AppState { - pub fn new(provider_config: ProviderConfig, secret_key: String) -> Result { + pub async fn new( + provider_config: ProviderConfig, + secret_key: String, + agent_version: Option, + ) -> Result { let provider = factory::get_provider(provider_config.clone())?; - let mut agent = Agent::new(provider); - agent.add_system(Box::new(DeveloperSystem::new())); + let mut agent = AgentFactory::create( + agent_version + .clone() + .unwrap_or(AgentFactory::default_version().to_string()) + .as_str(), + provider, + )?; + + agent.add_system(Box::new(DeveloperSystem::new())).await?; // Add memory system only if GOOSE_SERVER__MEMORY is set to "true" if let Ok(memory_enabled) = env::var("GOOSE_SERVER__MEMORY") { if memory_enabled.to_lowercase() == "true" { - agent.add_system(Box::new(MemorySystem::new())); + agent.add_system(Box::new(MemorySystem::new())).await?; } } let goosehints_system = Box::new(GooseHintsSystem::new()); - agent.add_system(goosehints_system); + agent.add_system(goosehints_system).await?; Ok(Self { provider_config, agent: Arc::new(Mutex::new(agent)), secret_key, + agent_version: agent_version + .clone() + .unwrap_or(AgentFactory::default_version().to_string()), }) } } @@ -89,6 +105,7 @@ impl Clone for AppState { }, agent: self.agent.clone(), secret_key: self.secret_key.clone(), + agent_version: self.agent_version.clone(), } } } diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 47131df0d..b3b6e7fd3 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -67,6 +67,8 @@ keyring = { version = "3.6.1", features = [ shellexpand = "3.1.0" rust_decimal = "1.36.0" rust_decimal_macros = "1.36.0" +ctor = "0.2.7" +paste = "1.0" [dev-dependencies] sysinfo = "0.32.1" diff --git a/crates/goose/src/agent.rs b/crates/goose/src/agent.rs deleted file mode 100644 index 5c02bf3be..000000000 --- a/crates/goose/src/agent.rs +++ /dev/null @@ -1,781 +0,0 @@ -use anyhow::Result; -use async_stream; -use futures::stream::BoxStream; -use rust_decimal_macros::dec; -use serde_json::json; -use std::collections::HashMap; -use tokio::sync::Mutex; - -use crate::errors::{AgentError, AgentResult}; -use crate::message::{Message, ToolRequest}; -use crate::prompt_template::load_prompt_file; -use crate::providers::base::{Provider, ProviderUsage}; -use crate::systems::System; -use crate::token_counter::TokenCounter; -use mcp_core::{Content, Resource, Tool, ToolCall}; -use serde::Serialize; - -// used to sort resources by priority within error margin -const PRIORITY_EPSILON: f32 = 0.001; - -#[derive(Clone, Debug, Serialize)] -struct SystemInfo { - name: String, - description: String, - instructions: String, -} - -impl SystemInfo { - fn new(name: &str, description: &str, instructions: &str) -> Self { - Self { - name: name.to_string(), - description: description.to_string(), - instructions: instructions.to_string(), - } - } -} - -#[derive(Clone, Debug, Serialize)] -struct SystemStatus { - name: String, - status: String, -} - -impl SystemStatus { - fn new(name: &str, status: String) -> Self { - Self { - name: name.to_string(), - status, - } - } -} - -/// Agent integrates a foundational LLM with the systems it needs to pilot -pub struct Agent { - systems: Vec>, - provider: Box, - provider_usage: Mutex>, -} - -#[allow(dead_code)] -impl Agent { - /// Create a new Agent with the specified provider - pub fn new(provider: Box) -> Self { - Self { - systems: Vec::new(), - provider, - provider_usage: Mutex::new(Vec::new()), - } - } - - /// Add a system to the agent - pub fn add_system(&mut self, system: Box) { - self.systems.push(system); - } - - /// Get the context limit from the provider's configuration - fn get_context_limit(&self) -> usize { - self.provider.get_model_config().context_limit() - } - - /// Get all tools from all systems with proper system prefixing - fn get_prefixed_tools(&self) -> Vec { - let mut tools = Vec::new(); - for system in &self.systems { - for tool in system.tools() { - tools.push(Tool::new( - format!("{}__{}", system.name(), tool.name), - &tool.description, - tool.input_schema.clone(), - )); - } - } - tools - } - - /// Find the appropriate system for a tool call based on the prefixed name - fn get_system_for_tool(&self, prefixed_name: &str) -> Option<&dyn System> { - let parts: Vec<&str> = prefixed_name.split("__").collect(); - if parts.len() != 2 { - return None; - } - let system_name = parts[0]; - self.systems - .iter() - .find(|sys| sys.name() == system_name) - .map(|v| &**v) - } - - /// Dispatch a single tool call to the appropriate system - async fn dispatch_tool_call( - &self, - tool_call: AgentResult, - ) -> AgentResult> { - let call = tool_call?; - let system = self - .get_system_for_tool(&call.name) - .ok_or_else(|| AgentError::ToolNotFound(call.name.clone()))?; - - let tool_name = call - .name - .split("__") - .nth(1) - .ok_or_else(|| AgentError::InvalidToolName(call.name.clone()))?; - let system_tool_call = ToolCall::new(tool_name, call.arguments); - - system.call(system_tool_call).await - } - - fn get_system_prompt(&self) -> AgentResult { - let mut context = HashMap::new(); - let systems_info: Vec = self - .systems - .iter() - .map(|system| { - SystemInfo::new(system.name(), system.description(), system.instructions()) - }) - .collect(); - - context.insert("systems", systems_info); - load_prompt_file("system.md", &context).map_err(|e| AgentError::Internal(e.to_string())) - } - - async fn get_systems_resources( - &self, - ) -> AgentResult>> { - let mut system_resource_content: HashMap> = - HashMap::new(); - for system in &self.systems { - let system_status = system - .status() - .await - .map_err(|e| AgentError::Internal(e.to_string()))?; - - let mut resource_content: HashMap = HashMap::new(); - for resource in system_status { - if let Ok(content) = system.read_resource(&resource.uri).await { - resource_content.insert(resource.uri.to_string(), (resource, content)); - } - } - system_resource_content.insert(system.name().to_string(), resource_content); - } - Ok(system_resource_content) - } - - /// Setup the next inference by budgeting the context window as well as we can - async fn prepare_inference( - &self, - system_prompt: &str, - tools: &[Tool], - messages: &[Message], - pending: &Vec, - target_limit: usize, - ) -> AgentResult> { - // Prepares the inference by managing context window and token budget. - // This function: - // 1. Retrieves and formats system resources and status - // 2. Trims content if total tokens exceed the model's context limit - // 3. Adds pending messages if any. Pending messages are messages that have been added - // to the conversation but not yet responded to. - // 4. Adds two messages to the conversation: - // - A tool request message for status - // - A tool response message containing the (potentially trimmed) status - // - // Returns the updated message history with status information appended. - // - // Arguments: - // * `system_prompt` - The system prompt to include - // * `tools` - Available tools for the agent - // * `messages` - Current conversation history - // - // Returns: - // * `AgentResult>` - Updated message history with status appended - - let token_counter = TokenCounter::new(); - let resource_content = self.get_systems_resources().await?; - - // Flatten all resource content into a vector of strings - let mut resources = Vec::new(); - for system_resources in resource_content.values() { - for (_, content) in system_resources.values() { - resources.push(content.clone()); - } - } - - let approx_count = token_counter.count_everything( - system_prompt, - messages, - tools, - &resources, - Some(&self.provider.get_model_config().model_name), - ); - - let mut status_content: Vec = Vec::new(); - - if approx_count > target_limit { - println!("[WARNING] Token budget exceeded. Current count: {} \n Difference: {} tokens over buget. Removing context", approx_count, approx_count - target_limit); - - // Get token counts for each resourcee - let mut system_token_counts = HashMap::new(); - - // Iterate through each system and its resources - for (system_name, resources) in &resource_content { - let mut resource_counts = HashMap::new(); - for (uri, (_resource, content)) in resources { - let token_count = token_counter - .count_tokens(&content, Some(&self.provider.get_model_config().model_name)) - as u32; - resource_counts.insert(uri.clone(), token_count); - } - system_token_counts.insert(system_name.clone(), resource_counts); - } - // Sort resources by priority and timestamp and trim to fit context limit - let mut all_resources: Vec<(String, String, Resource, u32)> = Vec::new(); - for (system_name, resources) in &resource_content { - for (uri, (resource, _)) in resources { - if let Some(token_count) = system_token_counts - .get(system_name) - .and_then(|counts| counts.get(uri)) - { - all_resources.push(( - system_name.clone(), - uri.clone(), - resource.clone(), - *token_count, - )); - } - } - } - - // Sort by priority (high to low) and timestamp (newest to oldest) - // since priority is float, we need to sort by priority within error margin - PRIORITY_EPSILON - all_resources.sort_by(|a, b| { - // Compare priorities with epsilon - // Compare priorities with Option handling - default to 0.0 if None - let a_priority = a.2.priority().unwrap_or(0.0); - let b_priority = b.2.priority().unwrap_or(0.0); - if (b_priority - a_priority).abs() < PRIORITY_EPSILON { - // Priorities are "equal" within epsilon, use timestamp as tiebreaker - b.2.timestamp().cmp(&a.2.timestamp()) - } else { - // Priorities are different enough, use priority ordering - b.2.priority() - .partial_cmp(&a.2.priority()) - .unwrap_or(std::cmp::Ordering::Equal) - } - }); - - // Remove resources until we're under target limit - let mut current_tokens = approx_count; - - while current_tokens > target_limit && !all_resources.is_empty() { - if let Some((system_name, uri, _, token_count)) = all_resources.pop() { - if let Some(system_counts) = system_token_counts.get_mut(&system_name) { - system_counts.remove(&uri); - current_tokens -= token_count as usize; - } - } - } - // Create status messages only from resources that remain after token trimming - for (system_name, uri, _, _) in &all_resources { - if let Some(system_resources) = resource_content.get(system_name) { - if let Some((resource, content)) = system_resources.get(uri) { - status_content.push(format!("{}\n```\n{}\n```\n", resource.name, content)); - } - } - } - } else { - // Create status messages from all resources when no trimming needed - for resources in resource_content.values() { - for (resource, content) in resources.values() { - status_content.push(format!("{}\n```\n{}\n```\n", resource.name, content)); - } - } - } - - // Join remaining status content and create status message - let status_str = status_content.join("\n"); - let mut context = HashMap::new(); - let systems_status = vec![SystemStatus::new("system", status_str)]; - context.insert("systems", &systems_status); - - // Load and format the status template with only remaining resources - let status = load_prompt_file("status.md", &context) - .map_err(|e| AgentError::Internal(e.to_string()))?; - - // Create a new messages vector with our changes - let mut new_messages = messages.to_vec(); - - // Add pending messages - for msg in pending { - new_messages.push(msg.clone()); - } - - // Finally add the status messages - let message_use = - Message::assistant().with_tool_request("000", Ok(ToolCall::new("status", json!({})))); - - let message_result = - Message::user().with_tool_response("000", Ok(vec![Content::text(status)])); - - new_messages.push(message_use); - new_messages.push(message_result); - - Ok(new_messages) - } - - /// Create a stream that yields each message as it's generated by the agent. - /// This includes both the assistant's responses and any tool responses. - pub async fn reply(&self, messages: &[Message]) -> Result>> { - let mut messages = messages.to_vec(); - let tools = self.get_prefixed_tools(); - let system_prompt = self.get_system_prompt()?; - let estimated_limit = self.provider.get_model_config().get_estimated_limit(); - - // Update conversation history for the start of the reply - messages = self - .prepare_inference( - &system_prompt, - &tools, - &messages, - &Vec::new(), - estimated_limit, - ) - .await?; - - Ok(Box::pin(async_stream::try_stream! { - loop { - // Get completion from provider - let (response, usage) = self.provider.complete( - &system_prompt, - &messages, - &tools, - ).await?; - self.provider_usage.lock().await.push(usage); - - // The assistant's response is added in rewrite_messages_on_tool_response - // Yield the assistant's response - yield response.clone(); - - // Not sure why this is needed, but this ensures that the above message is yielded - // before the following potentially long-running commands start processing - tokio::task::yield_now().await; - - // First collect any tool requests - let tool_requests: Vec<&ToolRequest> = response.content - .iter() - .filter_map(|content| content.as_tool_request()) - .collect(); - - if tool_requests.is_empty() { - // No more tool calls, end the reply loop - break; - } - - // Then dispatch each in parallel - let futures: Vec<_> = tool_requests - .iter() - .map(|request| self.dispatch_tool_call(request.tool_call.clone())) - .collect(); - - // Process all the futures in parallel but wait until all are finished - let outputs = futures::future::join_all(futures).await; - - // Create a message with the responses - let mut message_tool_response = Message::user(); - // Now combine these into MessageContent::ToolResponse using the original ID - for (request, output) in tool_requests.iter().zip(outputs.into_iter()) { - message_tool_response = message_tool_response.with_tool_response( - request.id.clone(), - output, - ); - } - - yield message_tool_response.clone(); - - // Now we have to remove the previous status tooluse and toolresponse - // before we add pending messages, then the status msgs back again - messages.pop(); - messages.pop(); - - let pending = vec![response, message_tool_response]; - messages = self.prepare_inference(&system_prompt, &tools, &messages, &pending, estimated_limit).await?; - } - })) - } - - pub async fn usage(&self) -> Result> { - let provider_usage = self.provider_usage.lock().await.clone(); - - let mut usage_map: HashMap = HashMap::new(); - provider_usage.iter().for_each(|usage| { - usage_map - .entry(usage.model.clone()) - .and_modify(|e| { - e.usage.input_tokens = Some( - e.usage.input_tokens.unwrap_or(0) + usage.usage.input_tokens.unwrap_or(0), - ); - e.usage.output_tokens = Some( - e.usage.output_tokens.unwrap_or(0) + usage.usage.output_tokens.unwrap_or(0), - ); - e.usage.total_tokens = Some( - e.usage.total_tokens.unwrap_or(0) + usage.usage.total_tokens.unwrap_or(0), - ); - if e.cost.is_none() || usage.cost.is_none() { - e.cost = None; // Pricing is not available for all models - } else { - e.cost = Some(e.cost.unwrap_or(dec!(0)) + usage.cost.unwrap_or(dec!(0))); - } - }) - .or_insert_with(|| usage.clone()); - }); - Ok(usage_map.into_values().collect()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::message::MessageContent; - use crate::providers::configs::ModelConfig; - use crate::providers::mock::MockProvider; - use async_trait::async_trait; - use chrono::Utc; - use futures::TryStreamExt; - use mcp_core::resource::Resource; - use mcp_core::Annotations; - use serde_json::json; - - // Mock system for testing - struct MockSystem { - name: String, - tools: Vec, - resources: Vec, - resource_content: HashMap, - } - - impl MockSystem { - fn new(name: &str) -> Self { - Self { - name: name.to_string(), - tools: vec![Tool::new( - "echo", - "Echoes back the input", - json!({"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}), - )], - resources: Vec::new(), - resource_content: HashMap::new(), - } - } - - fn add_resource(&mut self, name: &str, content: &str, priority: f32) { - let uri = format!("file://{}", name); - let resource = Resource { - name: name.to_string(), - uri: uri.clone(), - annotations: Some(Annotations::for_resource(priority, Utc::now())), - description: Some("A mock resource".to_string()), - mime_type: "text/plain".to_string(), - }; - self.resources.push(resource); - self.resource_content.insert(uri, content.to_string()); - } - } - - #[async_trait] - impl System for MockSystem { - fn name(&self) -> &str { - &self.name - } - - fn description(&self) -> &str { - "A mock system for testing" - } - - fn instructions(&self) -> &str { - "Mock system instructions" - } - - fn tools(&self) -> &[Tool] { - &self.tools - } - - async fn status(&self) -> anyhow::Result> { - Ok(self.resources.clone()) - } - - async fn call(&self, tool_call: ToolCall) -> AgentResult> { - match tool_call.name.as_str() { - "echo" => Ok(vec![Content::text( - tool_call.arguments["message"].as_str().unwrap_or(""), - )]), - _ => Err(AgentError::ToolNotFound(tool_call.name)), - } - } - - async fn read_resource(&self, uri: &str) -> AgentResult { - self.resource_content.get(uri).cloned().ok_or_else(|| { - AgentError::InvalidParameters(format!("Resource {} could not be found", uri)) - }) - } - } - - #[tokio::test] - async fn test_simple_response() -> Result<()> { - let response = Message::assistant().with_text("Hello!"); - let provider = MockProvider::new(vec![response.clone()]); - let agent = Agent::new(Box::new(provider)); - - let initial_message = Message::user().with_text("Hi"); - let initial_messages = vec![initial_message]; - - let mut stream = agent.reply(&initial_messages).await?; - let mut messages = Vec::new(); - while let Some(msg) = stream.try_next().await? { - messages.push(msg); - } - - assert_eq!(messages.len(), 1); - assert_eq!(messages[0], response); - Ok(()) - } - - #[tokio::test] - async fn test_usage_rollup() -> Result<()> { - let response = Message::assistant().with_text("Hello!"); - let provider = MockProvider::new(vec![response.clone()]); - let agent = Agent::new(Box::new(provider)); - - let initial_message = Message::user().with_text("Hi"); - let initial_messages = vec![initial_message]; - - let mut stream = agent.reply(&initial_messages).await?; - while stream.try_next().await?.is_some() {} - - // Second message - let mut stream = agent.reply(&initial_messages).await?; - while stream.try_next().await?.is_some() {} - - let usage = agent.usage().await?; - assert_eq!(usage.len(), 1); // 2 messages rolled up to one usage per model - assert_eq!(usage[0].usage.input_tokens, Some(2)); - assert_eq!(usage[0].usage.output_tokens, Some(2)); - assert_eq!(usage[0].usage.total_tokens, Some(4)); - assert_eq!(usage[0].model, "mock"); - assert_eq!(usage[0].cost, Some(dec!(2))); - Ok(()) - } - - #[tokio::test] - async fn test_tool_call() -> Result<()> { - let mut agent = Agent::new(Box::new(MockProvider::new(vec![ - Message::assistant().with_tool_request( - "1", - Ok(ToolCall::new("test_echo", json!({"message": "test"}))), - ), - Message::assistant().with_text("Done!"), - ]))); - - agent.add_system(Box::new(MockSystem::new("test"))); - - let initial_message = Message::user().with_text("Echo test"); - let initial_messages = vec![initial_message]; - - let mut stream = agent.reply(&initial_messages).await?; - let mut messages = Vec::new(); - while let Some(msg) = stream.try_next().await? { - messages.push(msg); - } - - // Should have three messages: tool request, response, and model text - assert_eq!(messages.len(), 3); - assert!(messages[0] - .content - .iter() - .any(|c| matches!(c, MessageContent::ToolRequest(_)))); - assert_eq!(messages[2].content[0], MessageContent::text("Done!")); - Ok(()) - } - - #[tokio::test] - async fn test_invalid_tool() -> Result<()> { - let mut agent = Agent::new(Box::new(MockProvider::new(vec![ - Message::assistant() - .with_tool_request("1", Ok(ToolCall::new("invalid_tool", json!({})))), - Message::assistant().with_text("Error occurred"), - ]))); - - agent.add_system(Box::new(MockSystem::new("test"))); - - let initial_message = Message::user().with_text("Invalid tool"); - let initial_messages = vec![initial_message]; - - let mut stream = agent.reply(&initial_messages).await?; - let mut messages = Vec::new(); - while let Some(msg) = stream.try_next().await? { - messages.push(msg); - } - - // Should have three messages: failed tool request, fail response, and model text - assert_eq!(messages.len(), 3); - assert!(messages[0] - .content - .iter() - .any(|c| matches!(c, MessageContent::ToolRequest(_)))); - assert_eq!( - messages[2].content[0], - MessageContent::text("Error occurred") - ); - Ok(()) - } - - #[tokio::test] - async fn test_multiple_tool_calls() -> Result<()> { - let mut agent = Agent::new(Box::new(MockProvider::new(vec![ - Message::assistant() - .with_tool_request( - "1", - Ok(ToolCall::new("test_echo", json!({"message": "first"}))), - ) - .with_tool_request( - "2", - Ok(ToolCall::new("test_echo", json!({"message": "second"}))), - ), - Message::assistant().with_text("All done!"), - ]))); - - agent.add_system(Box::new(MockSystem::new("test"))); - - let initial_message = Message::user().with_text("Multiple calls"); - let initial_messages = vec![initial_message]; - - let mut stream = agent.reply(&initial_messages).await?; - let mut messages = Vec::new(); - while let Some(msg) = stream.try_next().await? { - messages.push(msg); - } - - // Should have three messages: tool requests, responses, and model text - assert_eq!(messages.len(), 3); - assert!(messages[0] - .content - .iter() - .any(|c| matches!(c, MessageContent::ToolRequest(_)))); - assert_eq!(messages[2].content[0], MessageContent::text("All done!")); - Ok(()) - } - - #[tokio::test] - async fn test_prepare_inference_trims_resources_when_budget_exceeded() -> Result<()> { - // Create a mock provider - let provider = MockProvider::new(vec![]); - let mut agent = Agent::new(Box::new(provider)); - - // Create a mock system with two resources - let mut system = MockSystem::new("test"); - - // Add two resources with different priorities - let string_10toks = "hello ".repeat(10); - system.add_resource("high_priority", &string_10toks, 0.8); - system.add_resource("low_priority", &string_10toks, 0.1); - - agent.add_system(Box::new(system)); - - // Set up test parameters - // 18 tokens with system + user msg in chat format - let system_prompt = "This is a system prompt"; - let messages = vec![Message::user().with_text("Hi there")]; - let tools = vec![]; - let pending = vec![]; - - // Approx count is 40, so target limit of 35 will force trimming - let target_limit = 35; - - // Call prepare_inference - let result = agent - .prepare_inference(system_prompt, &tools, &messages, &pending, target_limit) - .await?; - - // Get the last message which should be the tool response containing status - let status_message = result.last().unwrap(); - let status_content = status_message - .content - .first() - .and_then(|content| content.as_tool_response_text()) - .unwrap_or_default(); - - // Verify that only the high priority resource is included in the status - assert!(status_content.contains("high_priority")); - assert!(!status_content.contains("low_priority")); - - // Now test with a target limit that allows both resources (no trimming) - let target_limit = 100; - - // Call prepare_inference - let result = agent - .prepare_inference(system_prompt, &tools, &messages, &pending, target_limit) - .await?; - - // Get the last message which should be the tool response containing status - let status_message = result.last().unwrap(); - let status_content = status_message - .content - .first() - .and_then(|content| content.as_tool_response_text()) - .unwrap_or_default(); - - // Verify that only the high priority resource is included in the status - assert!(status_content.contains("high_priority")); - assert!(status_content.contains("low_priority")); - Ok(()) - } - - #[tokio::test] - async fn test_context_trimming_with_custom_model_config() -> Result<()> { - let provider = MockProvider::with_config( - vec![], - ModelConfig::new("test_model".to_string()).with_context_limit(Some(20)), - ); - let mut agent = Agent::new(Box::new(provider)); - - // Create a mock system with a resource that will exceed the context limit - let mut system = MockSystem::new("test"); - - // Add a resource that will exceed our tiny context limit - let hello_1_tokens = "hello ".repeat(1); // 1 tokens - let goodbye_10_tokens = "goodbye ".repeat(10); // 10 tokens - system.add_resource("test_resource_removed", &goodbye_10_tokens, 0.1); - system.add_resource("test_resource_expected", &hello_1_tokens, 0.5); - - agent.add_system(Box::new(system)); - - // Set up test parameters - // 18 tokens with system + user msg in chat format - let system_prompt = "This is a system prompt"; - let messages = vec![Message::user().with_text("Hi there")]; - let tools = vec![]; - let pending = vec![]; - - // Use the context limit from the model config - let target_limit = agent.get_context_limit(); - assert_eq!(target_limit, 20, "Context limit should be 20"); - - // Call prepare_inference - let result = agent - .prepare_inference(system_prompt, &tools, &messages, &pending, target_limit) - .await?; - - // Get the last message which should be the tool response containing status - let status_message = result.last().unwrap(); - let status_content = status_message - .content - .first() - .and_then(|content| content.as_tool_response_text()) - .unwrap_or_default(); - - // verify that "hello" is within the response, should be just under 20 tokens with "hello" - assert!(status_content.contains("hello")); - - Ok(()) - } -} diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs new file mode 100644 index 000000000..5da9167d8 --- /dev/null +++ b/crates/goose/src/agents/agent.rs @@ -0,0 +1,31 @@ +use anyhow::Result; +use async_trait::async_trait; +use futures::stream::BoxStream; +use serde_json::Value; + +use crate::errors::AgentResult; +use crate::message::Message; +use crate::providers::base::ProviderUsage; +use crate::systems::System; + +/// Core trait defining the behavior of an Agent +#[async_trait] +pub trait Agent: Send + Sync { + /// Create a stream that yields each message as it's generated by the agent + async fn reply(&self, messages: &[Message]) -> Result>>; + + /// Add a system to the agent + async fn add_system(&mut self, system: Box) -> AgentResult<()>; + + /// Remove a system by name + async fn remove_system(&mut self, name: &str) -> AgentResult<()>; + + /// List all systems and their status + async fn list_systems(&self) -> AgentResult>; + + /// Pass through a JSON-RPC request to a specific system + async fn passthrough(&self, _system: &str, _request: Value) -> AgentResult; + + /// Get the total usage of the agent + async fn usage(&self) -> AgentResult>; +} diff --git a/crates/goose/src/agents/default.rs b/crates/goose/src/agents/default.rs new file mode 100644 index 000000000..f5f75d612 --- /dev/null +++ b/crates/goose/src/agents/default.rs @@ -0,0 +1,505 @@ +use async_trait::async_trait; +use futures::stream::BoxStream; +use serde::Serialize; +use serde_json::json; +use tokio::sync::Mutex; +use std::collections::HashMap; + +use super::{Agent, MCPManager}; +use crate::errors::{AgentError, AgentResult}; +use crate::message::{Message, ToolRequest}; +use crate::providers::base::Provider; +use crate::register_agent; +use crate::systems::System; +use crate::token_counter::TokenCounter; +use mcp_core::{Content, Resource, Tool, ToolCall}; +use crate::prompt_template::load_prompt_file; +use crate::providers::base::ProviderUsage; +use serde_json::Value; +// used to sort resources by priority within error margin +const PRIORITY_EPSILON: f32 = 0.001; + +#[derive(Clone, Debug, Serialize)] +struct SystemStatus { + name: String, + status: String, +} + +impl SystemStatus { + fn new(name: &str, status: String) -> Self { + Self { + name: name.to_string(), + status, + } + } +} + +/// Default implementation of an Agent +pub struct DefaultAgent { + mcp_manager: Mutex, +} + +impl DefaultAgent { + pub fn new(provider: Box) -> Self { + Self { + mcp_manager: Mutex::new(MCPManager::new(provider)), + } + } + + /// Setup the next inference by budgeting the context window + async fn prepare_inference( + &self, + system_prompt: &str, + tools: &[Tool], + messages: &[Message], + pending: &[Message], + target_limit: usize, + model_name: &str, + resource_content: &HashMap>, + ) -> AgentResult> { + let token_counter = TokenCounter::new(); + + // Flatten all resource content into a vector of strings + let mut resources = Vec::new(); + for system_resources in resource_content.values() { + for (_, content) in system_resources.values() { + resources.push(content.clone()); + } + } + + let approx_count = token_counter.count_everything( + system_prompt, + messages, + tools, + &resources, + Some(model_name), + ); + + let mut status_content: Vec = Vec::new(); + + if approx_count > target_limit { + println!("[WARNING] Token budget exceeded. Current count: {} \n Difference: {} tokens over buget. Removing context", approx_count, approx_count - target_limit); + + // Get token counts for each resource + let mut system_token_counts = HashMap::new(); + + // Iterate through each system and its resources + for (system_name, resources) in resource_content { + let mut resource_counts = HashMap::new(); + for (uri, (_resource, content)) in resources { + let token_count = token_counter.count_tokens(&content, Some(model_name)) as u32; + resource_counts.insert(uri.clone(), token_count); + } + system_token_counts.insert(system_name.clone(), resource_counts); + } + + // Sort resources by priority and timestamp and trim to fit context limit + let mut all_resources: Vec<(String, String, Resource, u32)> = Vec::new(); + for (system_name, resources) in resource_content { + for (uri, (resource, _)) in resources { + if let Some(token_count) = system_token_counts + .get(system_name) + .and_then(|counts| counts.get(uri)) + { + all_resources.push(( + system_name.clone(), + uri.clone(), + resource.clone(), + *token_count, + )); + } + } + } + + // Sort by priority (high to low) and timestamp (newest to oldest) + all_resources.sort_by(|a, b| { + let a_priority = a.2.priority().unwrap_or(0.0); + let b_priority = b.2.priority().unwrap_or(0.0); + if (b_priority - a_priority).abs() < PRIORITY_EPSILON { + b.2.timestamp().cmp(&a.2.timestamp()) + } else { + b.2.priority() + .partial_cmp(&a.2.priority()) + .unwrap_or(std::cmp::Ordering::Equal) + } + }); + + // Remove resources until we're under target limit + let mut current_tokens = approx_count; + + while current_tokens > target_limit && !all_resources.is_empty() { + if let Some((system_name, uri, _, token_count)) = all_resources.pop() { + if let Some(system_counts) = system_token_counts.get_mut(&system_name) { + system_counts.remove(&uri); + current_tokens -= token_count as usize; + } + } + } + + // Create status messages only from resources that remain after token trimming + for (system_name, uri, _, _) in &all_resources { + if let Some(system_resources) = resource_content.get(system_name) { + if let Some((resource, content)) = system_resources.get(uri) { + status_content.push(format!("{}\n```\n{}\n```\n", resource.name, content)); + } + } + } + } else { + // Create status messages from all resources when no trimming needed + for resources in resource_content.values() { + for (resource, content) in resources.values() { + status_content.push(format!("{}\n```\n{}\n```\n", resource.name, content)); + } + } + } + + // Join remaining status content and create status message + let status_str = status_content.join("\n"); + let mut context = HashMap::new(); + let systems_status = vec![SystemStatus::new("system", status_str)]; + context.insert("systems", &systems_status); + + // Load and format the status template with only remaining resources + let status = load_prompt_file("status.md", &context) + .map_err(|e| AgentError::Internal(e.to_string()))?; + + // Create a new messages vector with our changes + let mut new_messages = messages.to_vec(); + + // Add pending messages + for msg in pending { + new_messages.push(msg.clone()); + } + + // Finally add the status messages + let message_use = + Message::assistant().with_tool_request("000", Ok(ToolCall::new("status", json!({})))); + + let message_result = + Message::user().with_tool_response("000", Ok(vec![Content::text(status)])); + + new_messages.push(message_use); + new_messages.push(message_result); + + Ok(new_messages) + } +} + +#[async_trait] +impl Agent for DefaultAgent { + async fn add_system(&mut self, system: Box) -> AgentResult<()> { + let mut manager = self.mcp_manager.lock().await; + manager.add_system(system); + Ok(()) + } + + async fn remove_system(&mut self, name: &str) -> AgentResult<()> { + let mut manager = self.mcp_manager.lock().await; + manager.remove_system(name) + } + + async fn list_systems(&self) -> AgentResult> { + let manager = self.mcp_manager.lock().await; + manager.list_systems().await + } + + async fn passthrough(&self, _system: &str, _request: Value) -> AgentResult { + Ok(Value::Null) + } + + async fn reply(&self, messages: &[Message]) -> anyhow::Result>> { + let manager = self.mcp_manager.lock().await; + let tools = manager.get_prefixed_tools(); + let system_prompt = manager.get_system_prompt()?; + let estimated_limit = manager.provider().get_model_config().get_estimated_limit(); + + // Update conversation history for the start of the reply + let mut messages = self.prepare_inference( + &system_prompt, + &tools, + messages, + &Vec::new(), + estimated_limit, + &manager.provider().get_model_config().model_name, + &manager.get_systems_resources().await?, + ).await?; + + Ok(Box::pin(async_stream::try_stream! { + loop { + // Get completion from provider + let (response, usage) = manager.provider().complete( + &system_prompt, + &messages, + &tools, + ).await?; + manager.record_usage(usage).await; + + // Yield the assistant's response + yield response.clone(); + + tokio::task::yield_now().await; + + // First collect any tool requests + let tool_requests: Vec<&ToolRequest> = response.content + .iter() + .filter_map(|content| content.as_tool_request()) + .collect(); + + if tool_requests.is_empty() { + break; + } + + // Then dispatch each in parallel + let futures: Vec<_> = tool_requests + .iter() + .map(|request| manager.dispatch_tool_call(request.tool_call.clone())) + .collect(); + + // Process all the futures in parallel but wait until all are finished + let outputs = futures::future::join_all(futures).await; + + // Create a message with the responses + let mut message_tool_response = Message::user(); + // Now combine these into MessageContent::ToolResponse using the original ID + for (request, output) in tool_requests.iter().zip(outputs.into_iter()) { + message_tool_response = message_tool_response.with_tool_response( + request.id.clone(), + output, + ); + } + + yield message_tool_response.clone(); + + // Now we have to remove the previous status tooluse and toolresponse + // before we add pending messages, then the status msgs back again + messages.pop(); + messages.pop(); + + let pending = vec![response, message_tool_response]; + messages = self.prepare_inference(&system_prompt, &tools, &messages, &pending, estimated_limit, &manager.provider().get_model_config().model_name, &manager.get_systems_resources().await?).await?; + } + })) + } + + async fn usage(&self) -> AgentResult> { + let manager = self.mcp_manager.lock().await; + manager.get_usage().await.map_err(|e| AgentError::Internal(e.to_string())) + } +} + +register_agent!("default", DefaultAgent); + +#[cfg(test)] +mod tests { + use super::*; + use crate::message::{Message, MessageContent}; + use crate::providers::configs::ModelConfig; + use crate::providers::mock::MockProvider; + use async_trait::async_trait; + use chrono::Utc; + use futures::TryStreamExt; + use mcp_core::resource::Resource; + use mcp_core::{Annotations, Content, Tool, ToolCall}; + use serde_json::json; + use std::collections::HashMap; + + // Mock system for testing + struct MockSystem { + name: String, + tools: Vec, + resources: Vec, + resource_content: HashMap, + } + + impl MockSystem { + fn new(name: &str) -> Self { + Self { + name: name.to_string(), + tools: vec![Tool::new( + "echo", + "Echoes back the input", + json!({"type": "object", "properties": {"message": {"type": "string"}}, "required": ["message"]}), + )], + resources: Vec::new(), + resource_content: HashMap::new(), + } + } + + fn add_resource(&mut self, name: &str, content: &str, priority: f32) { + let uri = format!("file://{}", name); + let resource = Resource { + name: name.to_string(), + uri: uri.clone(), + annotations: Some(Annotations::for_resource(priority, Utc::now())), + description: Some("A mock resource".to_string()), + mime_type: "text/plain".to_string(), + }; + self.resources.push(resource); + self.resource_content.insert(uri, content.to_string()); + } + } + + #[async_trait] + impl System for MockSystem { + fn name(&self) -> &str { + &self.name + } + + fn description(&self) -> &str { + "A mock system for testing" + } + + fn instructions(&self) -> &str { + "Mock system instructions" + } + + fn tools(&self) -> &[Tool] { + &self.tools + } + + async fn status(&self) -> anyhow::Result> { + Ok(self.resources.clone()) + } + + async fn call(&self, tool_call: ToolCall) -> AgentResult> { + match tool_call.name.as_str() { + "echo" => Ok(vec![Content::text( + tool_call.arguments["message"].as_str().unwrap_or(""), + )]), + _ => Err(AgentError::ToolNotFound(tool_call.name)), + } + } + + async fn read_resource(&self, uri: &str) -> AgentResult { + self.resource_content.get(uri).cloned().ok_or_else(|| { + AgentError::InvalidParameters(format!("Resource {} could not be found", uri)) + }) + } + } + + #[tokio::test(flavor = "current_thread")] + async fn test_simple_response() -> anyhow::Result<()> { + let response = Message::assistant().with_text("Hello!"); + let provider = MockProvider::new(vec![response.clone()]); + let mut agent = DefaultAgent::new(Box::new(provider)); + + // Add a system to test system management + agent.add_system(Box::new(MockSystem::new("test"))).await?; + + let initial_message = Message::user().with_text("Hi"); + let initial_messages = vec![initial_message]; + + let mut stream = agent.reply(&initial_messages).await?; + let mut messages = Vec::new(); + while let Some(msg) = stream.try_next().await? { + messages.push(msg); + } + + assert_eq!(messages.len(), 1); + assert_eq!(messages[0], response); + Ok(()) + } + + #[tokio::test(flavor = "current_thread")] + async fn test_system_management() -> anyhow::Result<()> { + let provider = MockProvider::new(vec![]); + let mut agent = DefaultAgent::new(Box::new(provider)); + + // Add a system + agent.add_system(Box::new(MockSystem::new("test1"))).await?; + agent.add_system(Box::new(MockSystem::new("test2"))).await?; + + // List systems + let systems = agent.list_systems().await?; + assert_eq!(systems.len(), 2); + assert!(systems.iter().any(|(name, _)| name == "test1")); + assert!(systems.iter().any(|(name, _)| name == "test2")); + + // Remove a system + agent.remove_system("test1").await?; + let systems = agent.list_systems().await?; + assert_eq!(systems.len(), 1); + assert_eq!(systems[0].0, "test2"); + + Ok(()) + } + + #[tokio::test] + async fn test_tool_call() -> anyhow::Result<()> { + let mut agent = DefaultAgent::new(Box::new(MockProvider::new(vec![ + Message::assistant().with_tool_request( + "1", + Ok(ToolCall::new("test_echo", json!({"message": "test"}))), + ), + Message::assistant().with_text("Done!"), + ]))); + + agent.add_system(Box::new(MockSystem::new("test"))).await?; + + let initial_message = Message::user().with_text("Echo test"); + let initial_messages = vec![initial_message]; + + let mut stream = agent.reply(&initial_messages).await?; + let mut messages = Vec::new(); + while let Some(msg) = stream.try_next().await? { + messages.push(msg); + } + + // Should have three messages: tool request, response, and model text + assert_eq!(messages.len(), 3); + assert!(messages[0] + .content + .iter() + .any(|c| matches!(c, MessageContent::ToolRequest(_)))); + assert_eq!(messages[2].content[0], MessageContent::text("Done!")); + Ok(()) + } + + #[tokio::test] + async fn test_prepare_inference_trims_resources() -> anyhow::Result<()> { + let provider = MockProvider::with_config( + vec![], + ModelConfig::new("test_model".to_string()).with_context_limit(Some(20)), + ); + let mut agent = DefaultAgent::new(Box::new(provider)); + + // Create a mock system with resources + let mut system = MockSystem::new("test"); + let hello_1_tokens = "hello ".repeat(1); // 1 tokens + let goodbye_10_tokens = "goodbye ".repeat(10); // 10 tokens + system.add_resource("test_resource_removed", &goodbye_10_tokens, 0.1); + system.add_resource("test_resource_expected", &hello_1_tokens, 0.5); + + agent.add_system(Box::new(system)).await?; + + // Set up test parameters + let manager = agent.mcp_manager.lock().await; + + let system_prompt = "This is a system prompt"; + let messages = vec![Message::user().with_text("Hi there")]; + let pending = vec![]; + let tools = vec![]; + let target_limit = manager.provider().get_model_config().context_limit(); + + assert_eq!(target_limit, 20, "Context limit should be 20"); + // Test prepare_inference + let result = agent + .prepare_inference(&system_prompt, &tools, &messages, &pending, target_limit, &manager.provider().get_model_config().model_name, &manager.get_systems_resources().await?) + .await?; + + // Get the last message which should be the tool response containing status + let status_message = result.last().unwrap(); + let status_content = status_message + .content + .first() + .and_then(|content| content.as_tool_response_text()) + .unwrap_or_default(); + + + // Verify that "hello" is within the response, should be just under 20 tokens with "hello" + assert!(status_content.contains("hello")); + assert!(!status_content.contains("goodbye")); + + Ok(()) + } +} \ No newline at end of file diff --git a/crates/goose/src/agents/factory.rs b/crates/goose/src/agents/factory.rs new file mode 100644 index 000000000..a91b095ab --- /dev/null +++ b/crates/goose/src/agents/factory.rs @@ -0,0 +1,191 @@ +use std::collections::HashMap; +use std::sync::{OnceLock, RwLock}; + +use super::Agent; +use crate::errors::AgentError; +use crate::providers::base::Provider; + +type AgentConstructor = Box) -> Box + Send + Sync>; + +// Use std::sync::RwLock for interior mutability +static AGENT_REGISTRY: OnceLock>> = OnceLock::new(); + +/// Initialize the registry if it hasn't been initialized +fn registry() -> &'static RwLock> { + AGENT_REGISTRY.get_or_init(|| RwLock::new(HashMap::new())) +} + +/// Register a new agent version +pub fn register_agent( + version: &'static str, + constructor: impl Fn(Box) -> Box + Send + Sync + 'static, +) { + let registry = registry(); + if let Ok(mut map) = registry.write() { + map.insert(version, Box::new(constructor)); + } +} + +pub struct AgentFactory; + +impl AgentFactory { + /// Create a new agent instance of the specified version + pub fn create( + version: &str, + provider: Box, + ) -> Result, AgentError> { + let registry = registry(); + if let Ok(map) = registry.read() { + if let Some(constructor) = map.get(version) { + Ok(constructor(provider)) + } else { + Err(AgentError::VersionNotFound(version.to_string())) + } + } else { + Err(AgentError::Internal( + "Failed to access agent registry".to_string(), + )) + } + } + + /// Get a list of all available agent versions + pub fn available_versions() -> Vec<&'static str> { + registry() + .read() + .map(|map| map.keys().copied().collect()) + .unwrap_or_default() + } + + /// Get the default version name + pub fn default_version() -> &'static str { + "default" + } +} + +/// Macro to help with agent registration +#[macro_export] +macro_rules! register_agent { + ($version:expr, $agent_type:ty) => { + paste::paste! { + #[ctor::ctor] + #[allow(non_snake_case)] + fn [<__register_agent_ $version>]() { + $crate::agents::factory::register_agent($version, |provider| { + Box::new(<$agent_type>::new(provider)) + }); + } + } + }; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::message::Message; + use crate::providers::mock::MockProvider; + use crate::providers::base::ProviderUsage; + use crate::errors::AgentResult; + use crate::systems::System; + use async_trait::async_trait; + use futures::stream::BoxStream; + use serde_json::Value; + use tokio::sync::Mutex; + + // Test agent implementation + struct TestAgent { + mcp_manager: Mutex, + } + + impl TestAgent { + fn new(provider: Box) -> Self { + Self { + mcp_manager: Mutex::new(super::super::MCPManager::new(provider)), + } + } + } + + #[async_trait] + impl Agent for TestAgent { + async fn add_system(&mut self, system: Box) -> AgentResult<()> { + let mut manager = self.mcp_manager.lock().await; + manager.add_system(system); + Ok(()) + } + + async fn remove_system(&mut self, name: &str) -> AgentResult<()> { + let mut manager = self.mcp_manager.lock().await; + manager.remove_system(name) + } + + async fn list_systems(&self) -> AgentResult> { + let manager = self.mcp_manager.lock().await; + manager.list_systems().await + } + + async fn passthrough(&self, _system: &str, _request: Value) -> AgentResult { + Ok(Value::Null) + } + + async fn reply(&self, _messages: &[Message]) -> anyhow::Result>> { + Ok(Box::pin(futures::stream::empty())) + } + + async fn usage(&self) -> AgentResult> { + Ok(vec![]) + } + } + + #[test] + fn test_register_and_create_agent() { + register_agent!("test_create", TestAgent); + + // Create a mock provider + let provider = Box::new(MockProvider::new(vec![])); + + // Create an agent instance + let result = AgentFactory::create("test_create", provider); + assert!(result.is_ok()); + } + + #[test] + fn test_version_not_found() { + // Try to create an agent with a non-existent version + let provider = Box::new(MockProvider::new(vec![])); + let result = AgentFactory::create("nonexistent", provider); + + assert!(matches!(result, Err(AgentError::VersionNotFound(_)))); + if let Err(AgentError::VersionNotFound(version)) = result { + assert_eq!(version, "nonexistent"); + } + } + + #[test] + fn test_available_versions() { + register_agent!("test_available_1", TestAgent); + register_agent!("test_available_2", TestAgent); + + // Get available versions + let versions = AgentFactory::available_versions(); + + assert!(versions.contains(&"test_available_1")); + assert!(versions.contains(&"test_available_2")); + } + + #[test] + fn test_default_version() { + assert_eq!(AgentFactory::default_version(), "base"); + } + + #[test] + fn test_multiple_registrations() { + register_agent!("test_duplicate", TestAgent); + register_agent!("test_duplicate_other", TestAgent); + + // Create an agent instance + let provider = Box::new(MockProvider::new(vec![])); + let result = AgentFactory::create("test_duplicate", provider); + + // Should still work, last registration wins + assert!(result.is_ok()); + } +} \ No newline at end of file diff --git a/crates/goose/src/agents/mcp_manager.rs b/crates/goose/src/agents/mcp_manager.rs new file mode 100644 index 000000000..c2dab215f --- /dev/null +++ b/crates/goose/src/agents/mcp_manager.rs @@ -0,0 +1,196 @@ +use std::collections::HashMap; +use tokio::sync::Mutex; +use rust_decimal_macros::dec; + +use crate::errors::{AgentError, AgentResult}; +use crate::prompt_template::load_prompt_file; +use crate::systems::System; +use crate::providers::base::{Provider, ProviderUsage}; +use mcp_core::{Content, Resource, Tool, ToolCall}; +use serde::Serialize; + +#[derive(Clone, Debug, Serialize)] +struct SystemInfo { + name: String, + description: String, + instructions: String, +} + +impl SystemInfo { + fn new(name: &str, description: &str, instructions: &str) -> Self { + Self { + name: name.to_string(), + description: description.to_string(), + instructions: instructions.to_string(), + } + } +} + +/// Manages MCP systems and their interactions +pub struct MCPManager { + systems: Vec>, + provider: Box, + provider_usage: Mutex>, +} + +impl MCPManager { + pub fn new(provider: Box) -> Self { + Self { + systems: Vec::new(), + provider, + provider_usage: Mutex::new(Vec::new()), + } + } + + /// Get a reference to the provider + pub fn provider(&self) -> &Box { + &self.provider + } + + /// Record provider usage + pub async fn record_usage(&self, usage: ProviderUsage) { + self.provider_usage.lock().await.push(usage); + } + + /// Get aggregated usage statistics + pub async fn get_usage(&self) -> anyhow::Result> { + let provider_usage = self.provider_usage.lock().await.clone(); + let mut usage_map: HashMap = HashMap::new(); + + provider_usage.iter().for_each(|usage| { + usage_map + .entry(usage.model.clone()) + .and_modify(|e| { + e.usage.input_tokens = Some( + e.usage.input_tokens.unwrap_or(0) + usage.usage.input_tokens.unwrap_or(0), + ); + e.usage.output_tokens = Some( + e.usage.output_tokens.unwrap_or(0) + usage.usage.output_tokens.unwrap_or(0), + ); + e.usage.total_tokens = Some( + e.usage.total_tokens.unwrap_or(0) + usage.usage.total_tokens.unwrap_or(0), + ); + if e.cost.is_none() || usage.cost.is_none() { + e.cost = None; // Pricing is not available for all models + } else { + e.cost = Some(e.cost.unwrap_or(dec!(0)) + usage.cost.unwrap_or(dec!(0))); + } + }) + .or_insert_with(|| usage.clone()); + }); + Ok(usage_map.into_values().collect()) + } + + /// Add a system to the manager + pub fn add_system(&mut self, system: Box) { + self.systems.push(system); + } + + /// Remove a system by name + pub fn remove_system(&mut self, name: &str) -> AgentResult<()> { + if let Some(pos) = self.systems.iter().position(|sys| sys.name() == name) { + self.systems.remove(pos); + Ok(()) + } else { + Err(AgentError::SystemNotFound(name.to_string())) + } + } + + /// List all systems and their status + pub async fn list_systems(&self) -> AgentResult> { + let mut statuses = Vec::new(); + for system in &self.systems { + let status = system + .status() + .await + .map_err(|e| AgentError::Internal(e.to_string()))?; + statuses.push((system.name().to_string(), format!("{:?}", status))); + } + Ok(statuses) + } + + /// Get all tools from all systems with proper system prefixing + pub fn get_prefixed_tools(&self) -> Vec { + let mut tools = Vec::new(); + for system in &self.systems { + for tool in system.tools() { + tools.push(Tool::new( + format!("{}__{}", system.name(), tool.name), + &tool.description, + tool.input_schema.clone(), + )); + } + } + tools + } + + /// Get system resources and their contents + pub async fn get_systems_resources( + &self, + ) -> AgentResult>> { + let mut system_resource_content = HashMap::new(); + for system in &self.systems { + let system_status = system + .status() + .await + .map_err(|e| AgentError::Internal(e.to_string()))?; + + let mut resource_content = HashMap::new(); + for resource in system_status { + if let Ok(content) = system.read_resource(&resource.uri).await { + resource_content.insert(resource.uri.to_string(), (resource, content)); + } + } + system_resource_content.insert(system.name().to_string(), resource_content); + } + Ok(system_resource_content) + } + + /// Get the system prompt + pub fn get_system_prompt(&self) -> AgentResult { + let mut context = HashMap::new(); + let systems_info: Vec = self + .systems + .iter() + .map(|system| { + SystemInfo::new(system.name(), system.description(), system.instructions()) + }) + .collect(); + + context.insert("systems", systems_info); + load_prompt_file("system.md", &context).map_err(|e| AgentError::Internal(e.to_string())) + } + + /// Find the appropriate system for a tool call based on the prefixed name + pub fn get_system_for_tool(&self, prefixed_name: &str) -> Option<&dyn System> { + let parts: Vec<&str> = prefixed_name.split("__").collect(); + if parts.len() != 2 { + return None; + } + let system_name = parts[0]; + self.systems + .iter() + .find(|sys| sys.name() == system_name) + .map(|v| &**v) + } + + /// Dispatch a single tool call to the appropriate system + pub async fn dispatch_tool_call( + &self, + tool_call: AgentResult, + ) -> AgentResult> { + let call = tool_call?; + let system = self + .get_system_for_tool(&call.name) + .ok_or_else(|| AgentError::ToolNotFound(call.name.clone()))?; + + let tool_name = call + .name + .split("__") + .nth(1) + .ok_or_else(|| AgentError::InvalidToolName(call.name.clone()))?; + let system_tool_call = ToolCall::new(tool_name, call.arguments); + + system.call(system_tool_call).await + } +} \ No newline at end of file diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs new file mode 100644 index 000000000..5905bd186 --- /dev/null +++ b/crates/goose/src/agents/mod.rs @@ -0,0 +1,9 @@ +mod agent; +mod default; +mod factory; +mod mcp_manager; + +pub use agent::Agent; +pub use default::DefaultAgent; +pub use factory::{register_agent, AgentFactory}; +pub use mcp_manager::MCPManager; \ No newline at end of file diff --git a/crates/goose/src/errors.rs b/crates/goose/src/errors.rs index 2d2497bdf..045ef8fa7 100644 --- a/crates/goose/src/errors.rs +++ b/crates/goose/src/errors.rs @@ -18,6 +18,12 @@ pub enum AgentError { #[error("Invalid tool name: {0}")] InvalidToolName(String), + + #[error("System not found: {0}")] + SystemNotFound(String), + + #[error("Agent version not found: {0}")] + VersionNotFound(String), } -pub type AgentResult = Result; +pub type AgentResult = Result; \ No newline at end of file diff --git a/crates/goose/src/lib.rs b/crates/goose/src/lib.rs index 394e6450c..cede26b24 100644 --- a/crates/goose/src/lib.rs +++ b/crates/goose/src/lib.rs @@ -1,4 +1,4 @@ -pub mod agent; +pub mod agents; pub mod developer; pub mod errors; pub mod key_manager; diff --git a/crates/goose/src/providers.rs b/crates/goose/src/providers.rs index 6d564eeb8..badc74098 100644 --- a/crates/goose/src/providers.rs +++ b/crates/goose/src/providers.rs @@ -3,6 +3,7 @@ pub mod base; pub mod configs; pub mod databricks; pub mod factory; +pub mod mock; pub mod model_pricing; pub mod oauth; pub mod ollama; @@ -13,7 +14,5 @@ pub mod utils; pub mod google; pub mod groq; -#[cfg(test)] -pub mod mock; #[cfg(test)] pub mod mock_server; diff --git a/crates/goose/src/providers/mock.rs b/crates/goose/src/providers/mock.rs index fa84a63af..54aed6ad2 100644 --- a/crates/goose/src/providers/mock.rs +++ b/crates/goose/src/providers/mock.rs @@ -21,7 +21,7 @@ impl MockProvider { pub fn new(responses: Vec) -> Self { Self { responses: Arc::new(Mutex::new(responses)), - model_config: ModelConfig::new("mock-model".to_string()), + model_config: ModelConfig::new("mock".to_string()), } } @@ -62,7 +62,7 @@ impl Provider for MockProvider { } } - fn get_usage(&self, data: &Value) -> Result { + fn get_usage(&self, _data: &Value) -> Result { Ok(Usage::new(None, None, None)) } }