Skip to content

Commit

Permalink
Basis for named agents (#525)
Browse files Browse the repository at this point in the history
  • Loading branch information
zakiali authored and baxen committed Jan 6, 2025
1 parent e798dc9 commit 51d9b56
Show file tree
Hide file tree
Showing 25 changed files with 1,228 additions and 894 deletions.
28 changes: 0 additions & 28 deletions crates/goose-cli/src/agents/agent.rs

This file was deleted.

54 changes: 45 additions & 9 deletions crates/goose-cli/src/agents/mock_agent.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,59 @@
use std::vec;

use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use goose::{message::Message, providers::base::ProviderUsage, systems::System};
use goose::providers::mock::MockProvider;
use goose::{
agents::Agent,
errors::AgentResult,
message::Message,
providers::base::{Provider, ProviderUsage},
systems::System,
};
use serde_json::Value;
use tokio::sync::Mutex;

use crate::agents::agent::Agent;
pub struct MockAgent {
systems: Vec<Box<dyn System>>,
provider: Box<dyn Provider>,
provider_usage: Mutex<Vec<ProviderUsage>>,
}

pub struct MockAgent;
impl MockAgent {
pub fn new() -> Self {
Self {
systems: Vec::new(),
provider: Box::new(MockProvider::new(Vec::new())),
provider_usage: Mutex::new(Vec::new()),
}
}
}

#[async_trait]
impl Agent for MockAgent {
fn add_system(&mut self, _system: Box<dyn System>) {}
async fn add_system(&mut self, system: Box<dyn System>) -> AgentResult<()> {
self.systems.push(system);
Ok(())
}

async fn remove_system(&mut self, name: &str) -> AgentResult<()> {
self.systems.retain(|s| s.name() != name);
Ok(())
}

async fn list_systems(&self) -> AgentResult<Vec<(String, String)>> {
Ok(self.systems.iter()
.map(|s| (s.name().to_string(), s.description().to_string()))
.collect())
}

async fn passthrough(&self, _system: &str, _request: Value) -> AgentResult<Value> {
Ok(Value::Null)
}

async fn reply(&self, _messages: &[Message]) -> Result<BoxStream<'_, Result<Message>>> {
async fn reply(&self, _messages: &[Message]) -> anyhow::Result<BoxStream<'_, anyhow::Result<Message>>> {
Ok(Box::pin(futures::stream::empty()))
}

async fn usage(&self) -> Result<Vec<ProviderUsage>> {
async fn usage(&self) -> AgentResult<Vec<ProviderUsage>> {
Ok(vec![ProviderUsage::new(
"mock".to_string(),
Default::default(),
Expand Down
2 changes: 0 additions & 2 deletions crates/goose-cli/src/agents/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
pub mod agent;

#[cfg(test)]
pub mod mock_agent;
28 changes: 28 additions & 0 deletions crates/goose-cli/src/commands/agent_version.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use anyhow::Result;
use clap::Args;
use goose::agents::AgentFactory;
use std::fmt::Write;

#[derive(Args)]
pub struct AgentCommand {}

impl AgentCommand {
pub fn run(&self) -> Result<()> {
let mut output = String::new();
writeln!(output, "Available agent versions:")?;

let versions = AgentFactory::available_versions();
let default_version = AgentFactory::default_version();

for version in versions {
if version == default_version {
writeln!(output, "* {} (default)", version)?;
} else {
writeln!(output, " {}", version)?;
}
}

print!("{}", output);
Ok(())
}
}
2 changes: 1 addition & 1 deletion crates/goose-cli/src/commands/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod agent_version;
pub mod configure;
pub mod session;
pub mod version;
pub mod expected_config;
11 changes: 6 additions & 5 deletions crates/goose-cli/src/commands/session.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use console::style;
use goose::agent::Agent;
use goose::agents::AgentFactory;
use goose::providers::factory;
use rand::{distributions::Alphanumeric, Rng};
use std::path::{Path, PathBuf};
Expand All @@ -13,6 +13,7 @@ use crate::session::{ensure_session_dir, get_most_recent_session, Session};
pub fn build_session<'a>(
session: Option<String>,
profile: Option<String>,
agent_version: Option<String>,
resume: bool,
) -> Box<Session<'a>> {
let session_dir = ensure_session_dir().expect("Failed to create session directory");
Expand Down Expand Up @@ -45,7 +46,7 @@ pub fn build_session<'a>(

// TODO: Odd to be prepping the provider rather than having that done in the agent?
let provider = factory::get_provider(provider_config).unwrap();
let agent = Box::new(Agent::new(provider));
let agent = AgentFactory::create(agent_version.as_deref().unwrap_or("default"), provider).unwrap();
let prompt = match std::env::var("GOOSE_INPUT") {
Ok(val) => match val.as_str() {
"rustyline" => Box::new(RustylinePrompt::new()) as Box<dyn Prompt>,
Expand Down Expand Up @@ -173,7 +174,7 @@ mod tests {
#[should_panic(expected = "Cannot resume session: file")]
fn test_resume_nonexistent_session_panics() {
run_with_tmp_dir(|| {
build_session(Some("nonexistent-session".to_string()), None, true);
build_session(Some("nonexistent-session".to_string()), None, None, true);
})
}

Expand All @@ -190,7 +191,7 @@ mod tests {
fs::write(&file2_path, "{}")?;

// Test resuming without a session name
let session = build_session(None, None, true);
let session = build_session(None, None, None, true);
assert_eq!(session.session_file().as_path(), file2_path.as_path());

Ok(())
Expand All @@ -201,7 +202,7 @@ mod tests {
#[should_panic(expected = "No session files found")]
fn test_resume_most_recent_session_no_files() {
run_with_tmp_dir(|| {
build_session(None, None, true);
build_session(None, None, None, true);
});
}
}
82 changes: 68 additions & 14 deletions crates/goose-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
mod commands {
pub mod configure;
pub mod session;
pub mod version;
}
pub mod agents;
use anyhow::Result;
use clap::{Parser, Subcommand};
use goose::agents::AgentFactory;

mod agents;
mod commands;
mod log_usage;
mod profile;
mod prompt;
pub mod session;

mod session;
mod systems;

use anyhow::Result;
use clap::{Parser, Subcommand};
use commands::agent_version::AgentCommand;
use commands::configure::handle_configure;
use commands::session::build_session;
use commands::version::print_version;
use profile::has_no_profiles;
use std::io::{self, Read};

mod log_usage;

#[cfg(test)]
mod test_helpers;

Expand Down Expand Up @@ -98,6 +95,15 @@ enum Command {
)]
profile: Option<String>,

/// Agent version to use (e.g., 'default', 'v1')
#[arg(
short,
long,
help = "Agent version to use (e.g., 'default', 'v1'), defaults to 'default'",
long_help = "Specify which agent version to use for this session."
)]
agent: Option<String>,

/// Resume a previous session
#[arg(
short,
Expand Down Expand Up @@ -151,6 +157,15 @@ enum Command {
)]
name: Option<String>,

/// Agent version to use (e.g., 'default', 'v1')
#[arg(
short,
long,
help = "Agent version to use (e.g., 'default', 'v1')",
long_help = "Specify which agent version to use for this session."
)]
agent: Option<String>,

/// Resume a previous run
#[arg(
short,
Expand All @@ -161,6 +176,9 @@ enum Command {
)]
resume: bool,
},

/// List available agent versions
Agents(AgentCommand),
}

#[derive(Subcommand)]
Expand Down Expand Up @@ -224,9 +242,25 @@ async fn main() -> Result<()> {
Some(Command::Session {
name,
profile,
agent,
resume,
}) => {
let mut session = build_session(name, profile, resume);
if let Some(agent_version) = agent.clone() {
if !AgentFactory::available_versions().contains(&agent_version.as_str()) {
eprintln!("Error: Invalid agent version '{}'", agent_version);
eprintln!("Available versions:");
for version in AgentFactory::available_versions() {
if version == AgentFactory::default_version() {
eprintln!("* {} (default)", version);
} else {
eprintln!(" {}", version);
}
}
std::process::exit(1);
}
}

let mut session = build_session(name, profile, agent, resume);
let _ = session.start().await;
return Ok(());
}
Expand All @@ -235,8 +269,24 @@ async fn main() -> Result<()> {
input_text,
profile,
name,
agent,
resume,
}) => {
if let Some(agent_version) = agent.clone() {
if !AgentFactory::available_versions().contains(&agent_version.as_str()) {
eprintln!("Error: Invalid agent version '{}'", agent_version);
eprintln!("Available versions:");
for version in AgentFactory::available_versions() {
if version == AgentFactory::default_version() {
eprintln!("* {} (default)", version);
} else {
eprintln!(" {}", version);
}
}
std::process::exit(1);
}
}

let contents = if let Some(file_name) = instructions {
let file_path = std::path::Path::new(&file_name);
std::fs::read_to_string(file_path).expect("Failed to read the instruction file")
Expand All @@ -249,10 +299,14 @@ async fn main() -> Result<()> {
.expect("Failed to read from stdin");
stdin
};
let mut session = build_session(name, profile, resume);
let mut session = build_session(name, profile, agent, resume);
let _ = session.headless_start(contents.clone()).await;
return Ok(());
}
Some(Command::Agents(cmd)) => {
cmd.run()?;
return Ok(());
}
None => {
println!("No command provided - Run 'goose help' to see available commands.");
if has_no_profiles().unwrap_or(false) {
Expand Down
Loading

0 comments on commit 51d9b56

Please sign in to comment.