Skip to content

Commit

Permalink
Add agent versions to session command
Browse files Browse the repository at this point in the history
  • Loading branch information
zakiali committed Dec 27, 2024
1 parent 10343f4 commit a46fed8
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 64 deletions.
8 changes: 1 addition & 7 deletions crates/goose-cli/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub struct GooseAgent {
provider_usage: Mutex<Vec<ProviderUsage>>,
}

#[allow(dead_code)]
impl GooseAgent {
pub fn new(provider: Box<dyn Provider>) -> Self {
Self {
Expand Down Expand Up @@ -39,11 +40,4 @@ impl Agent for GooseAgent {
fn get_provider_usage(&self) -> &Mutex<Vec<ProviderUsage>> {
&self.provider_usage
}
// async fn reply(&self, messages: &[Message]) -> Result<BoxStream<'_, Result<Message>>> {
// self.reply(messages).await
// }

// async fn usage(&self) -> Result<Vec<ProviderUsage>> {
// self.usage().await
// }
}
39 changes: 15 additions & 24 deletions crates/goose-cli/src/commands/agent_version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,25 @@ use goose::agents::AgentFactory;
use std::fmt::Write;

#[derive(Args)]
pub struct AgentCommand {
/// List available agent versions
#[arg(short, long)]
list: bool,
}
pub struct AgentCommand {}

impl AgentCommand {
pub fn run(&self) -> Result<()> {
if self.list {
let mut output = String::new();
writeln!(output, "Available agent versions:")?;

let versions = AgentFactory::available_versions();
let default_version = AgentFactory::default_version();

for version in versions {
if version == default_version {
writeln!(output, "* {} (default)", version)?;
} else {
writeln!(output, " {}", version)?;
}
let mut output = String::new();
writeln!(output, "Available agent versions:")?;

let versions = AgentFactory::available_versions();
let default_version = AgentFactory::default_version();

for version in versions {
if version == default_version {
writeln!(output, "* {} (default)", version)?;
} else {
writeln!(output, " {}", version)?;
}

print!("{}", output);
} else {
// When no flags are provided, show the default version
println!("Default version: {}", AgentFactory::default_version());
}

print!("{}", output);
Ok(())
}
}
}
2 changes: 1 addition & 1 deletion crates/goose-cli/src/commands/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod agent_version;
pub mod configure;
pub mod session;
pub mod version;
pub mod agent_version;
11 changes: 6 additions & 5 deletions crates/goose-cli/src/commands/session.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use console::style;
use goose::agents::AgentFactory;
use goose::providers::factory;
use rand::{distributions::Alphanumeric, Rng};
use std::path::{Path, PathBuf};
use std::process;

use crate::agents::agent::GooseAgent;
use crate::profile::{get_provider_config, load_profiles, Profile};
use crate::prompt::rustyline::RustylinePrompt;
use crate::prompt::Prompt;
Expand All @@ -13,6 +13,7 @@ use crate::session::{ensure_session_dir, get_most_recent_session, Session};
pub fn build_session<'a>(
session: Option<String>,
profile: Option<String>,
agent_version: Option<String>,
resume: bool,
) -> Box<Session<'a>> {
let session_dir = ensure_session_dir().expect("Failed to create session directory");
Expand Down Expand Up @@ -45,7 +46,7 @@ pub fn build_session<'a>(

// TODO: Odd to be prepping the provider rather than having that done in the agent?
let provider = factory::get_provider(provider_config).unwrap();
let agent = Box::new(GooseAgent::new(provider));
let agent = AgentFactory::create(agent_version.as_deref().unwrap_or("base"), provider).unwrap();
let prompt = match std::env::var("GOOSE_INPUT") {
Ok(val) => match val.as_str() {
"rustyline" => Box::new(RustylinePrompt::new()) as Box<dyn Prompt>,
Expand Down Expand Up @@ -173,7 +174,7 @@ mod tests {
#[should_panic(expected = "Cannot resume session: file")]
fn test_resume_nonexistent_session_panics() {
run_with_tmp_dir(|| {
build_session(Some("nonexistent-session".to_string()), None, true);
build_session(Some("nonexistent-session".to_string()), None, None, true);
})
}

Expand All @@ -190,7 +191,7 @@ mod tests {
fs::write(&file2_path, "{}")?;

// Test resuming without a session name
let session = build_session(None, None, true);
let session = build_session(None, None, None, true);
assert_eq!(session.session_file().as_path(), file2_path.as_path());

Ok(())
Expand All @@ -201,7 +202,7 @@ mod tests {
#[should_panic(expected = "No session files found")]
fn test_resume_most_recent_session_no_files() {
run_with_tmp_dir(|| {
build_session(None, None, true);
build_session(None, None, None, true);
});
}
}
84 changes: 57 additions & 27 deletions crates/goose-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ use goose::agents::AgentFactory;

mod agents;
mod commands;
mod log_usage;
mod profile;
mod prompt;
mod session;
mod systems;
mod log_usage;

use commands::agent_version::AgentCommand;
use commands::configure::handle_configure;
use commands::session::build_session;
use commands::version::print_version;
use commands::agent_version::AgentCommand;
use profile::has_no_profiles;
use std::io::{self, Read};

Expand All @@ -28,10 +28,6 @@ struct Cli {
#[arg(short = 'v', long = "version")]
version: bool,

/// Agent version to use (e.g., 'base', 'v1')
#[arg(short = 'a', long = "agent", default_value_t = String::from("base"))]
agent: String,

#[command(subcommand)]
command: Option<Command>,
}
Expand Down Expand Up @@ -99,6 +95,15 @@ enum Command {
)]
profile: Option<String>,

/// Agent version to use (e.g., 'base', 'v1')
#[arg(
short,
long,
help = "Agent version to use (e.g., 'base', 'v1'), defaults to 'base'",
long_help = "Specify which agent version to use for this session."
)]
agent: Option<String>,

/// Resume a previous session
#[arg(
short,
Expand Down Expand Up @@ -152,6 +157,15 @@ enum Command {
)]
name: Option<String>,

/// Agent version to use (e.g., 'base', 'v1')
#[arg(
short,
long,
help = "Agent version to use (e.g., 'base', 'v1')",
long_help = "Specify which agent version to use for this session."
)]
agent: Option<String>,

/// Resume a previous run
#[arg(
short,
Expand All @@ -164,7 +178,7 @@ enum Command {
},

/// List available agent versions
Agent(AgentCommand),
Agents(AgentCommand),
}

#[derive(Subcommand)]
Expand Down Expand Up @@ -206,20 +220,6 @@ async fn main() -> Result<()> {
return Ok(());
}

// Validate agent version
if !AgentFactory::available_versions().contains(&cli.agent.as_str()) {
eprintln!("Error: Invalid agent version '{}'", cli.agent);
eprintln!("Available versions:");
for version in AgentFactory::available_versions() {
if version == AgentFactory::default_version() {
eprintln!("* {} (default)", version);
} else {
eprintln!(" {}", version);
}
}
std::process::exit(1);
}

match cli.command {
Some(Command::Configure {
profile_name,
Expand All @@ -242,10 +242,25 @@ async fn main() -> Result<()> {
Some(Command::Session {
name,
profile,
agent,
resume,
}) => {
let mut session = build_session(name, profile, resume);
session.agent_version = cli.agent;
if let Some(agent_version) = agent.clone() {
if !AgentFactory::available_versions().contains(&agent_version.as_str()) {
eprintln!("Error: Invalid agent version '{}'", agent_version);
eprintln!("Available versions:");
for version in AgentFactory::available_versions() {
if version == AgentFactory::default_version() {
eprintln!("* {} (default)", version);
} else {
eprintln!(" {}", version);
}
}
std::process::exit(1);
}
}

let mut session = build_session(name, profile, agent, resume);
let _ = session.start().await;
return Ok(());
}
Expand All @@ -254,8 +269,24 @@ async fn main() -> Result<()> {
input_text,
profile,
name,
agent,
resume,
}) => {
if let Some(agent_version) = agent.clone() {
if !AgentFactory::available_versions().contains(&agent_version.as_str()) {
eprintln!("Error: Invalid agent version '{}'", agent_version);
eprintln!("Available versions:");
for version in AgentFactory::available_versions() {
if version == AgentFactory::default_version() {
eprintln!("* {} (default)", version);
} else {
eprintln!(" {}", version);
}
}
std::process::exit(1);
}
}

let contents = if let Some(file_name) = instructions {
let file_path = std::path::Path::new(&file_name);
std::fs::read_to_string(file_path).expect("Failed to read the instruction file")
Expand All @@ -268,12 +299,11 @@ async fn main() -> Result<()> {
.expect("Failed to read from stdin");
stdin
};
let mut session = build_session(name, profile, resume);
session.agent_version = cli.agent;
let mut session = build_session(name, profile, agent, resume);
let _ = session.headless_start(contents.clone()).await;
return Ok(());
}
Some(Command::Agent(cmd)) => {
Some(Command::Agents(cmd)) => {
cmd.run()?;
return Ok(());
}
Expand All @@ -285,4 +315,4 @@ async fn main() -> Result<()> {
}
}
Ok(())
}
}
1 change: 1 addition & 0 deletions crates/goose-cli/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ pub struct Session<'a> {
messages: Vec<Message>,
}

#[allow(dead_code)]
impl<'a> Session<'a> {
pub fn new(
agent: Box<dyn Agent>,
Expand Down

0 comments on commit a46fed8

Please sign in to comment.