Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basis for named agents #525

Merged
merged 11 commits into from
Jan 1, 2025
Merged
28 changes: 0 additions & 28 deletions crates/goose-cli/src/agents/agent.rs

This file was deleted.

54 changes: 45 additions & 9 deletions crates/goose-cli/src/agents/mock_agent.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,59 @@
use std::vec;

use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use goose::{message::Message, providers::base::ProviderUsage, systems::System};
use goose::providers::mock::MockProvider;
use goose::{
agents::Agent,
errors::AgentResult,
message::Message,
providers::base::{Provider, ProviderUsage},
systems::System,
};
use serde_json::Value;
use tokio::sync::Mutex;

use crate::agents::agent::Agent;
pub struct MockAgent {
systems: Vec<Box<dyn System>>,
provider: Box<dyn Provider>,
provider_usage: Mutex<Vec<ProviderUsage>>,
}

pub struct MockAgent;
impl MockAgent {
pub fn new() -> Self {
Self {
systems: Vec::new(),
provider: Box::new(MockProvider::new(Vec::new())),
provider_usage: Mutex::new(Vec::new()),
}
}
}

#[async_trait]
impl Agent for MockAgent {
fn add_system(&mut self, _system: Box<dyn System>) {}
async fn add_system(&mut self, system: Box<dyn System>) -> AgentResult<()> {
self.systems.push(system);
Ok(())
}

async fn remove_system(&mut self, name: &str) -> AgentResult<()> {
self.systems.retain(|s| s.name() != name);
Ok(())
}

async fn list_systems(&self) -> AgentResult<Vec<(String, String)>> {
Ok(self.systems.iter()
.map(|s| (s.name().to_string(), s.description().to_string()))
.collect())
}

async fn passthrough(&self, _system: &str, _request: Value) -> AgentResult<Value> {
Ok(Value::Null)
}

async fn reply(&self, _messages: &[Message]) -> Result<BoxStream<'_, Result<Message>>> {
async fn reply(&self, _messages: &[Message]) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
Ok(Box::pin(futures::stream::empty()))
}

async fn usage(&self) -> Result<Vec<ProviderUsage>> {
async fn usage(&self) -> AgentResult<Vec<ProviderUsage>> {
Ok(vec![ProviderUsage::new(
"mock".to_string(),
Default::default(),
Expand Down
2 changes: 0 additions & 2 deletions crates/goose-cli/src/agents/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
pub mod agent;

#[cfg(test)]
pub mod mock_agent;
28 changes: 28 additions & 0 deletions crates/goose-cli/src/commands/agent_version.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use anyhow::Result;
use clap::Args;
use goose::agents::AgentFactory;
use std::fmt::Write;

#[derive(Args)]
pub struct AgentCommand {}

impl AgentCommand {
pub fn run(&self) -> Result<()> {
let mut output = String::new();
writeln!(output, "Available agent versions:")?;

let versions = AgentFactory::available_versions();
let default_version = AgentFactory::default_version();

for version in versions {
if version == default_version {
writeln!(output, "* {} (default)", version)?;
} else {
writeln!(output, " {}", version)?;
}
}

print!("{}", output);
Ok(())
}
}
2 changes: 1 addition & 1 deletion crates/goose-cli/src/commands/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod agent_version;
pub mod configure;
pub mod session;
pub mod version;
pub mod expected_config;
11 changes: 6 additions & 5 deletions crates/goose-cli/src/commands/session.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use console::style;
use goose::agent::Agent;
use goose::agents::AgentFactory;
use goose::providers::factory;
use rand::{distributions::Alphanumeric, Rng};
use std::path::{Path, PathBuf};
Expand All @@ -13,6 +13,7 @@ use crate::session::{ensure_session_dir, get_most_recent_session, Session};
pub fn build_session<'a>(
session: Option<String>,
profile: Option<String>,
agent_version: Option<String>,
resume: bool,
) -> Box<Session<'a>> {
let session_dir = ensure_session_dir().expect("Failed to create session directory");
Expand Down Expand Up @@ -45,7 +46,7 @@ pub fn build_session<'a>(

// TODO: Odd to be prepping the provider rather than having that done in the agent?
let provider = factory::get_provider(provider_config).unwrap();
let agent = Box::new(Agent::new(provider));
let agent = AgentFactory::create(agent_version.as_deref().unwrap_or("default"), provider).unwrap();
let prompt = match std::env::var("GOOSE_INPUT") {
Ok(val) => match val.as_str() {
"rustyline" => Box::new(RustylinePrompt::new()) as Box<dyn Prompt>,
Expand Down Expand Up @@ -173,7 +174,7 @@ mod tests {
#[should_panic(expected = "Cannot resume session: file")]
fn test_resume_nonexistent_session_panics() {
run_with_tmp_dir(|| {
build_session(Some("nonexistent-session".to_string()), None, true);
build_session(Some("nonexistent-session".to_string()), None, None, true);
})
}

Expand All @@ -190,7 +191,7 @@ mod tests {
fs::write(&file2_path, "{}")?;

// Test resuming without a session name
let session = build_session(None, None, true);
let session = build_session(None, None, None, true);
assert_eq!(session.session_file().as_path(), file2_path.as_path());

Ok(())
Expand All @@ -201,7 +202,7 @@ mod tests {
#[should_panic(expected = "No session files found")]
fn test_resume_most_recent_session_no_files() {
run_with_tmp_dir(|| {
build_session(None, None, true);
build_session(None, None, None, true);
});
}
}
82 changes: 68 additions & 14 deletions crates/goose-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
mod commands {
pub mod configure;
pub mod session;
pub mod version;
}
pub mod agents;
use anyhow::Result;
use clap::{Parser, Subcommand};
use goose::agents::AgentFactory;

mod agents;
mod commands;
mod log_usage;
mod profile;
mod prompt;
pub mod session;

mod session;
mod systems;

use anyhow::Result;
use clap::{Parser, Subcommand};
use commands::agent_version::AgentCommand;
use commands::configure::handle_configure;
use commands::session::build_session;
use commands::version::print_version;
use profile::has_no_profiles;
use std::io::{self, Read};

mod log_usage;

#[cfg(test)]
mod test_helpers;

Expand Down Expand Up @@ -98,6 +95,15 @@ enum Command {
)]
profile: Option<String>,

/// Agent version to use (e.g., 'default', 'v1')
#[arg(
short,
long,
help = "Agent version to use (e.g., 'default', 'v1'), defaults to 'default'",
long_help = "Specify which agent version to use for this session."
)]
agent: Option<String>,

/// Resume a previous session
#[arg(
short,
Expand Down Expand Up @@ -151,6 +157,15 @@ enum Command {
)]
name: Option<String>,

/// Agent version to use (e.g., 'default', 'v1')
#[arg(
short,
long,
help = "Agent version to use (e.g., 'default', 'v1')",
long_help = "Specify which agent version to use for this session."
)]
agent: Option<String>,

/// Resume a previous run
#[arg(
short,
Expand All @@ -161,6 +176,9 @@ enum Command {
)]
resume: bool,
},

/// List available agent versions
Agents(AgentCommand),
}

#[derive(Subcommand)]
Expand Down Expand Up @@ -224,9 +242,25 @@ async fn main() -> Result<()> {
Some(Command::Session {
name,
profile,
agent,
resume,
}) => {
let mut session = build_session(name, profile, resume);
if let Some(agent_version) = agent.clone() {
if !AgentFactory::available_versions().contains(&agent_version.as_str()) {
eprintln!("Error: Invalid agent version '{}'", agent_version);
eprintln!("Available versions:");
for version in AgentFactory::available_versions() {
if version == AgentFactory::default_version() {
eprintln!("* {} (default)", version);
} else {
eprintln!(" {}", version);
}
}
std::process::exit(1);
}
}

let mut session = build_session(name, profile, agent, resume);
let _ = session.start().await;
return Ok(());
}
Expand All @@ -235,8 +269,24 @@ async fn main() -> Result<()> {
input_text,
profile,
name,
agent,
resume,
}) => {
if let Some(agent_version) = agent.clone() {
if !AgentFactory::available_versions().contains(&agent_version.as_str()) {
eprintln!("Error: Invalid agent version '{}'", agent_version);
eprintln!("Available versions:");
for version in AgentFactory::available_versions() {
if version == AgentFactory::default_version() {
eprintln!("* {} (default)", version);
} else {
eprintln!(" {}", version);
}
}
std::process::exit(1);
}
}

let contents = if let Some(file_name) = instructions {
let file_path = std::path::Path::new(&file_name);
std::fs::read_to_string(file_path).expect("Failed to read the instruction file")
Expand All @@ -249,10 +299,14 @@ async fn main() -> Result<()> {
.expect("Failed to read from stdin");
stdin
};
let mut session = build_session(name, profile, resume);
let mut session = build_session(name, profile, agent, resume);
let _ = session.headless_start(contents.clone()).await;
return Ok(());
}
Some(Command::Agents(cmd)) => {
cmd.run()?;
return Ok(());
}
None => {
println!("No command provided - Run 'goose help' to see available commands.");
if has_no_profiles().unwrap_or(false) {
Expand Down
Loading