From a46fed8af15b25ba49b607e6be08ed9e889c0413 Mon Sep 17 00:00:00 2001 From: Zaki Ali Date: Fri, 27 Dec 2024 13:31:12 -0800 Subject: [PATCH] Add agent versions to session command --- crates/goose-cli/src/agents/agent.rs | 8 +- .../goose-cli/src/commands/agent_version.rs | 39 ++++----- crates/goose-cli/src/commands/mod.rs | 2 +- crates/goose-cli/src/commands/session.rs | 11 +-- crates/goose-cli/src/main.rs | 84 +++++++++++++------ crates/goose-cli/src/session.rs | 1 + 6 files changed, 81 insertions(+), 64 deletions(-) diff --git a/crates/goose-cli/src/agents/agent.rs b/crates/goose-cli/src/agents/agent.rs index 5c3b51e45..85730ab10 100644 --- a/crates/goose-cli/src/agents/agent.rs +++ b/crates/goose-cli/src/agents/agent.rs @@ -12,6 +12,7 @@ pub struct GooseAgent { provider_usage: Mutex>, } +#[allow(dead_code)] impl GooseAgent { pub fn new(provider: Box) -> Self { Self { @@ -39,11 +40,4 @@ impl Agent for GooseAgent { fn get_provider_usage(&self) -> &Mutex> { &self.provider_usage } - // async fn reply(&self, messages: &[Message]) -> Result>> { - // self.reply(messages).await - // } - - // async fn usage(&self) -> Result> { - // self.usage().await - // } } diff --git a/crates/goose-cli/src/commands/agent_version.rs b/crates/goose-cli/src/commands/agent_version.rs index 7b3c35fb0..2f1628367 100644 --- a/crates/goose-cli/src/commands/agent_version.rs +++ b/crates/goose-cli/src/commands/agent_version.rs @@ -4,34 +4,25 @@ use goose::agents::AgentFactory; use std::fmt::Write; #[derive(Args)] -pub struct AgentCommand { - /// List available agent versions - #[arg(short, long)] - list: bool, -} +pub struct AgentCommand {} impl AgentCommand { pub fn run(&self) -> Result<()> { - if self.list { - let mut output = String::new(); - writeln!(output, "Available agent versions:")?; - - let versions = AgentFactory::available_versions(); - let default_version = AgentFactory::default_version(); - - for version in versions { - if version == default_version { - writeln!(output, "* {} (default)", version)?; - } else { - writeln!(output, " {}", version)?; - } + let mut output = String::new(); + writeln!(output, "Available agent versions:")?; + + let versions = AgentFactory::available_versions(); + let default_version = AgentFactory::default_version(); + + for version in versions { + if version == default_version { + writeln!(output, "* {} (default)", version)?; + } else { + writeln!(output, " {}", version)?; } - - print!("{}", output); - } else { - // When no flags are provided, show the default version - println!("Default version: {}", AgentFactory::default_version()); } + + print!("{}", output); Ok(()) } -} \ No newline at end of file +} diff --git a/crates/goose-cli/src/commands/mod.rs b/crates/goose-cli/src/commands/mod.rs index d4e08c617..b84916a20 100644 --- a/crates/goose-cli/src/commands/mod.rs +++ b/crates/goose-cli/src/commands/mod.rs @@ -1,4 +1,4 @@ +pub mod agent_version; pub mod configure; pub mod session; pub mod version; -pub mod agent_version; \ No newline at end of file diff --git a/crates/goose-cli/src/commands/session.rs b/crates/goose-cli/src/commands/session.rs index 2d4f4eee9..1014c1d75 100644 --- a/crates/goose-cli/src/commands/session.rs +++ b/crates/goose-cli/src/commands/session.rs @@ -1,10 +1,10 @@ use console::style; +use goose::agents::AgentFactory; use goose::providers::factory; use rand::{distributions::Alphanumeric, Rng}; use std::path::{Path, PathBuf}; use std::process; -use crate::agents::agent::GooseAgent; use crate::profile::{get_provider_config, load_profiles, Profile}; use crate::prompt::rustyline::RustylinePrompt; use crate::prompt::Prompt; @@ -13,6 +13,7 @@ use crate::session::{ensure_session_dir, get_most_recent_session, Session}; pub fn build_session<'a>( session: Option, profile: Option, + agent_version: Option, resume: bool, ) -> Box> { let session_dir = ensure_session_dir().expect("Failed to create session directory"); @@ -45,7 +46,7 @@ pub fn build_session<'a>( // TODO: Odd to be prepping the provider rather than having that done in the agent? let provider = factory::get_provider(provider_config).unwrap(); - let agent = Box::new(GooseAgent::new(provider)); + let agent = AgentFactory::create(agent_version.as_deref().unwrap_or("base"), provider).unwrap(); let prompt = match std::env::var("GOOSE_INPUT") { Ok(val) => match val.as_str() { "rustyline" => Box::new(RustylinePrompt::new()) as Box, @@ -173,7 +174,7 @@ mod tests { #[should_panic(expected = "Cannot resume session: file")] fn test_resume_nonexistent_session_panics() { run_with_tmp_dir(|| { - build_session(Some("nonexistent-session".to_string()), None, true); + build_session(Some("nonexistent-session".to_string()), None, None, true); }) } @@ -190,7 +191,7 @@ mod tests { fs::write(&file2_path, "{}")?; // Test resuming without a session name - let session = build_session(None, None, true); + let session = build_session(None, None, None, true); assert_eq!(session.session_file().as_path(), file2_path.as_path()); Ok(()) @@ -201,7 +202,7 @@ mod tests { #[should_panic(expected = "No session files found")] fn test_resume_most_recent_session_no_files() { run_with_tmp_dir(|| { - build_session(None, None, true); + build_session(None, None, None, true); }); } } diff --git a/crates/goose-cli/src/main.rs b/crates/goose-cli/src/main.rs index acda7c4b3..31f7e8b17 100644 --- a/crates/goose-cli/src/main.rs +++ b/crates/goose-cli/src/main.rs @@ -4,16 +4,16 @@ use goose::agents::AgentFactory; mod agents; mod commands; +mod log_usage; mod profile; mod prompt; mod session; mod systems; -mod log_usage; +use commands::agent_version::AgentCommand; use commands::configure::handle_configure; use commands::session::build_session; use commands::version::print_version; -use commands::agent_version::AgentCommand; use profile::has_no_profiles; use std::io::{self, Read}; @@ -28,10 +28,6 @@ struct Cli { #[arg(short = 'v', long = "version")] version: bool, - /// Agent version to use (e.g., 'base', 'v1') - #[arg(short = 'a', long = "agent", default_value_t = String::from("base"))] - agent: String, - #[command(subcommand)] command: Option, } @@ -99,6 +95,15 @@ enum Command { )] profile: Option, + /// Agent version to use (e.g., 'base', 'v1') + #[arg( + short, + long, + help = "Agent version to use (e.g., 'base', 'v1'), defaults to 'base'", + long_help = "Specify which agent version to use for this session." + )] + agent: Option, + /// Resume a previous session #[arg( short, @@ -152,6 +157,15 @@ enum Command { )] name: Option, + /// Agent version to use (e.g., 'base', 'v1') + #[arg( + short, + long, + help = "Agent version to use (e.g., 'base', 'v1')", + long_help = "Specify which agent version to use for this session." + )] + agent: Option, + /// Resume a previous run #[arg( short, @@ -164,7 +178,7 @@ enum Command { }, /// List available agent versions - Agent(AgentCommand), + Agents(AgentCommand), } #[derive(Subcommand)] @@ -206,20 +220,6 @@ async fn main() -> Result<()> { return Ok(()); } - // Validate agent version - if !AgentFactory::available_versions().contains(&cli.agent.as_str()) { - eprintln!("Error: Invalid agent version '{}'", cli.agent); - eprintln!("Available versions:"); - for version in AgentFactory::available_versions() { - if version == AgentFactory::default_version() { - eprintln!("* {} (default)", version); - } else { - eprintln!(" {}", version); - } - } - std::process::exit(1); - } - match cli.command { Some(Command::Configure { profile_name, @@ -242,10 +242,25 @@ async fn main() -> Result<()> { Some(Command::Session { name, profile, + agent, resume, }) => { - let mut session = build_session(name, profile, resume); - session.agent_version = cli.agent; + if let Some(agent_version) = agent.clone() { + if !AgentFactory::available_versions().contains(&agent_version.as_str()) { + eprintln!("Error: Invalid agent version '{}'", agent_version); + eprintln!("Available versions:"); + for version in AgentFactory::available_versions() { + if version == AgentFactory::default_version() { + eprintln!("* {} (default)", version); + } else { + eprintln!(" {}", version); + } + } + std::process::exit(1); + } + } + + let mut session = build_session(name, profile, agent, resume); let _ = session.start().await; return Ok(()); } @@ -254,8 +269,24 @@ async fn main() -> Result<()> { input_text, profile, name, + agent, resume, }) => { + if let Some(agent_version) = agent.clone() { + if !AgentFactory::available_versions().contains(&agent_version.as_str()) { + eprintln!("Error: Invalid agent version '{}'", agent_version); + eprintln!("Available versions:"); + for version in AgentFactory::available_versions() { + if version == AgentFactory::default_version() { + eprintln!("* {} (default)", version); + } else { + eprintln!(" {}", version); + } + } + std::process::exit(1); + } + } + let contents = if let Some(file_name) = instructions { let file_path = std::path::Path::new(&file_name); std::fs::read_to_string(file_path).expect("Failed to read the instruction file") @@ -268,12 +299,11 @@ async fn main() -> Result<()> { .expect("Failed to read from stdin"); stdin }; - let mut session = build_session(name, profile, resume); - session.agent_version = cli.agent; + let mut session = build_session(name, profile, agent, resume); let _ = session.headless_start(contents.clone()).await; return Ok(()); } - Some(Command::Agent(cmd)) => { + Some(Command::Agents(cmd)) => { cmd.run()?; return Ok(()); } @@ -285,4 +315,4 @@ async fn main() -> Result<()> { } } Ok(()) -} \ No newline at end of file +} diff --git a/crates/goose-cli/src/session.rs b/crates/goose-cli/src/session.rs index 3483ec5e5..8454021a7 100644 --- a/crates/goose-cli/src/session.rs +++ b/crates/goose-cli/src/session.rs @@ -102,6 +102,7 @@ pub struct Session<'a> { messages: Vec, } +#[allow(dead_code)] impl<'a> Session<'a> { pub fn new( agent: Box,