diff --git a/crates/goose-cli/src/agents/agent.rs b/crates/goose-cli/src/agents/agent.rs index 850f7685b..0ba751fcb 100644 --- a/crates/goose-cli/src/agents/agent.rs +++ b/crates/goose-cli/src/agents/agent.rs @@ -1,12 +1,20 @@ use anyhow::Result; use async_trait::async_trait; use futures::stream::BoxStream; -use goose::{agent::Agent as GooseAgent, models::message::Message, systems::System}; +use goose::{ + agent::{Agent as GooseAgent, ApprovalMonitor}, + models::message::Message, + systems::System, +}; #[async_trait] pub trait Agent { fn add_system(&mut self, system: Box); - async fn reply(&self, messages: &[Message]) -> Result>>; + async fn reply( + &self, + messages: &[Message], + approval_monitor: ApprovalMonitor, + ) -> Result>>; } #[async_trait] @@ -15,7 +23,11 @@ impl Agent for GooseAgent { self.add_system(system); } - async fn reply(&self, messages: &[Message]) -> Result>> { - self.reply(messages).await + async fn reply( + &self, + messages: &[Message], + approval_monitor: ApprovalMonitor, + ) -> Result>> { + self.reply(messages, approval_monitor).await } } diff --git a/crates/goose-cli/src/agents/mock_agent.rs b/crates/goose-cli/src/agents/mock_agent.rs index e19b0eb8a..f734d6a9c 100644 --- a/crates/goose-cli/src/agents/mock_agent.rs +++ b/crates/goose-cli/src/agents/mock_agent.rs @@ -1,7 +1,7 @@ use anyhow::Result; use async_trait::async_trait; use futures::stream::BoxStream; -use goose::{models::message::Message, systems::System}; +use goose::{agent::ApprovalMonitor, models::message::Message, systems::System}; use crate::agents::agent::Agent; @@ -11,7 +11,11 @@ pub struct MockAgent; impl Agent for MockAgent { fn add_system(&mut self, _system: Box) {} - async fn reply(&self, _messages: &[Message]) -> Result>> { + async fn reply( + &self, + _messages: &[Message], + approval_monitor: ApprovalMonitor, + ) -> Result>> { Ok(Box::pin(futures::stream::empty())) } } diff --git a/crates/goose-cli/src/prompt.rs b/crates/goose-cli/src/prompt.rs index 386f079d3..64e550388 100644 --- a/crates/goose-cli/src/prompt.rs +++ b/crates/goose-cli/src/prompt.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use goose::models::message::Message; +use goose::models::message::{ApprovalRequest, Message}; pub mod cliclack; pub mod renderer; @@ -9,6 +9,9 @@ pub mod thinking; pub trait Prompt { fn render(&mut self, message: Box); fn get_input(&mut self) -> Result; + fn handle_approval_request(&mut self, approval_request: ApprovalRequest) -> Result { + Err(anyhow::anyhow!("Not implemented")) + } fn show_busy(&mut self); fn hide_busy(&self); fn close(&self); diff --git a/crates/goose-cli/src/prompt/renderer.rs b/crates/goose-cli/src/prompt/renderer.rs index 2425ae162..67fa9590e 100644 --- a/crates/goose-cli/src/prompt/renderer.rs +++ b/crates/goose-cli/src/prompt/renderer.rs @@ -3,7 +3,7 @@ use std::io::{self, Write}; use bat::WrappingMode; use console::style; -use goose::models::message::{Message, MessageContent, ToolRequest, ToolResponse}; +use goose::models::message::{ApprovalRequest, Message, MessageContent, ToolRequest, ToolResponse}; use goose::models::role::Role; use goose::models::{content::Content, tool::ToolCall}; use serde_json::Value; @@ -57,6 +57,7 @@ impl ToolRenderer for DefaultRenderer { // Format and print the parameters print_params(&call.arguments, 0); + render_approval_required(tool_request.approval_request.clone()); print_newline(); } Err(e) => print_markdown(&e.to_string(), theme), @@ -87,6 +88,7 @@ impl ToolRenderer for BashDeveloperSystemRenderer { } _ => print_params(&call.arguments, 0), } + render_approval_required(tool_request.approval_request.clone()); print_newline(); } Err(e) => print_markdown(&e.to_string(), theme), @@ -186,6 +188,24 @@ pub fn default_print_request_header(call: &ToolCall) { println!("{}", tool_header); } +pub fn render_approval_required(approval_request: Option) { + if let Some(approval_request) = approval_request { + println!( + "{}", + style("────────────────────────────────────────────────────────").red() + ); + println!( + "{} Approval request id: {}", + style("|").red(), + approval_request.id + ); + println!( + "{}", + style("────────────────────────────────────────────────────────").red() + ); + } +} + pub fn print_markdown(content: &str, theme: &str) { bat::PrettyPrinter::new() .input(bat::Input::from_bytes(content.as_bytes())) diff --git a/crates/goose-cli/src/prompt/rustyline.rs b/crates/goose-cli/src/prompt/rustyline.rs index 1ed6e8c1a..707fa27ac 100644 --- a/crates/goose-cli/src/prompt/rustyline.rs +++ b/crates/goose-cli/src/prompt/rustyline.rs @@ -8,7 +8,7 @@ use super::{ use anyhow::Result; use cliclack::spinner; -use goose::models::message::Message; +use goose::models::message::{ApprovalRequest, Message}; const PROMPT: &str = "\x1b[1m\x1b[38;5;30m( O)> \x1b[0m"; @@ -128,4 +128,34 @@ impl Prompt for RustylinePrompt { fn as_any(&self) -> &dyn std::any::Any { panic!("Not implemented"); } + + // TODO: Return a richer result type + fn handle_approval_request(&mut self, approval_request: ApprovalRequest) -> Result { + let message = format!( + "Approve tool request {} ? (y/n) [Default: y] ", + approval_request.id + ); + loop { + let mut editor = rustyline::DefaultEditor::new()?; + let response = editor.readline(&message); + let response = match response { + Ok(text) => text, + Err(e) => { + match e { + rustyline::error::ReadlineError::Interrupted => (), + _ => eprintln!("Input error: {}", e), + } + return Err(anyhow::anyhow!("Approval request interrupted")); + } + }; + + if response.eq_ignore_ascii_case("y") || response.is_empty() { + return Ok(true); + } else if response.eq_ignore_ascii_case("n") { + return Ok(false); + } else { + println!("Invalid response. Please enter 'y' or 'n'."); + } + } + } } diff --git a/crates/goose-cli/src/session.rs b/crates/goose-cli/src/session.rs index df8d4468f..4d773bec8 100644 --- a/crates/goose-cli/src/session.rs +++ b/crates/goose-cli/src/session.rs @@ -1,6 +1,7 @@ use anyhow::Result; use core::panic; use futures::StreamExt; +use goose::agent::ApprovalMonitor; use serde_json; use std::fs::{self, File}; use std::io::{self, BufRead, Write}; @@ -9,7 +10,7 @@ use std::path::PathBuf; use crate::agents::agent::Agent; use crate::prompt::{InputType, Prompt}; use goose::developer::DeveloperSystem; -use goose::models::message::{Message, MessageContent}; +use goose::models::message::{ApprovalRequest, Message, MessageContent}; use goose::models::role::Role; use goose::systems::goose_hints::GooseHintsSystem; @@ -140,7 +141,13 @@ impl<'a> Session<'a> { } async fn agent_process_messages(&mut self) { - let mut stream = match self.agent.reply(&self.messages).await { + let (approval_tx, approval_rx) = tokio::sync::mpsc::channel(1); + let (rejection_tx, rejection_rx) = tokio::sync::mpsc::channel(1); + let approval_monitor = ApprovalMonitor { + approval_rx: approval_rx, + rejection_rx: rejection_rx, + }; + let mut stream = match self.agent.reply(&self.messages, approval_monitor).await { Ok(stream) => stream, Err(e) => { eprintln!("Error starting reply stream: {}", e); @@ -156,6 +163,39 @@ impl<'a> Session<'a> { persist_messages(&self.session_file, &self.messages).unwrap_or_else(|e| eprintln!("Failed to persist messages: {}", e)); self.prompt.hide_busy(); self.prompt.render(Box::new(message.clone())); + + let mut abandon_tools = false; + // Handle any tool requests that require approval + for approval in pending_approvals(message).iter() { + if abandon_tools { + // If we've already abandoned tools, reject any new requests + if let Err(e) = rejection_tx.send(approval.id.clone()).await { + eprintln!("Failed to handle approval: {}", e); + } + continue; + } + // TODO: Auto-approve some tool requests based on external configuration. + match self.prompt.handle_approval_request(approval.clone()) { + Ok(approval_result) => { + if approval_result { + if let Err(e) = approval_tx.send(approval.id.clone()).await { + eprintln!("Failed to send approval: {}", e); + } + } else { + if let Err(e) = rejection_tx.send(approval.id.clone()).await { + eprintln!("Failed to send rejection: {}", e); + } + } + }, + Err(e) => { + abandon_tools = true; + eprintln!("Failed to handle approval: {}", e); + if let Err(e) = rejection_tx.send(approval.id.clone()).await { + eprintln!("Failed to send rejection: {}", e); + } + } + } + } self.prompt.show_busy(); } Some(Err(e)) => { @@ -297,6 +337,20 @@ We've removed the conversation up to the most recent user message } } +fn pending_approvals(message: Message) -> Vec { + message + .content + .iter() + .filter_map(|content| { + if let MessageContent::ToolRequest(req) = content { + req.approval_request.clone() + } else { + None + } + }) + .collect() +} + fn raw_message(content: &str) -> Box { Box::new(Message::assistant().with_text(content)) } diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 4de752454..4327ef0c2 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -8,10 +8,13 @@ use axum::{ }; use bytes::Bytes; use futures::{stream::StreamExt, Stream}; -use goose::models::{ - content::Content, - message::{Message, MessageContent}, - role::Role, +use goose::{ + agent::ApprovalMonitor, + models::{ + content::Content, + message::{Message, MessageContent}, + role::Role, + }, }; use serde::Deserialize; use serde_json::{json, Value}; @@ -276,10 +279,18 @@ async fn handler( // Get a lock on the shared agent let agent = state.agent.clone(); + // Create a monitor for approvals. TODO: Implement this for gui + let (approval_tx, approval_rx) = mpsc::channel(1); + let (reject_tx, rejection_rx) = mpsc::channel(1); + let approval_monitor = ApprovalMonitor { + approval_rx, + rejection_rx, + }; + // Spawn task to handle streaming tokio::spawn(async move { let agent = agent.lock().await; - let mut stream = match agent.reply(&messages).await { + let mut stream = match agent.reply(&messages, approval_monitor).await { Ok(stream) => stream, Err(e) => { tracing::error!("Failed to start reply stream: {}", e); @@ -347,9 +358,16 @@ async fn ask_handler( // Create a single message for the prompt let messages = vec![Message::user().with_text(request.prompt)]; + let (approval_tx, approval_rx) = mpsc::channel(1); + let (reject_tx, rejection_rx) = mpsc::channel(1); + let approval_monitor = ApprovalMonitor { + approval_rx, + rejection_rx, + }; + // Get response from agent let mut response_text = String::new(); - let mut stream = match agent.reply(&messages).await { + let mut stream = match agent.reply(&messages, approval_monitor).await { Ok(stream) => stream, Err(e) => { tracing::error!("Failed to start reply stream: {}", e); diff --git a/crates/goose/src/agent.rs b/crates/goose/src/agent.rs index ee5bbcb46..ee972d820 100644 --- a/crates/goose/src/agent.rs +++ b/crates/goose/src/agent.rs @@ -3,10 +3,11 @@ use async_stream; use futures::stream::BoxStream; use serde_json::json; use std::collections::HashMap; +use tokio::sync::mpsc::Receiver; use crate::errors::{AgentError, AgentResult}; use crate::models::content::Content; -use crate::models::message::{Message, ToolRequest}; +use crate::models::message::{ApprovalRequest, Message, MessageContent, ToolRequest}; use crate::models::tool::{Tool, ToolCall}; use crate::prompt_template::load_prompt_file; use crate::providers::base::Provider; @@ -51,6 +52,11 @@ pub struct Agent { provider: Box, } +pub struct ApprovalMonitor { + pub approval_rx: Receiver, + pub rejection_rx: Receiver, +} + impl Agent { /// Create a new Agent with the specified provider pub fn new(provider: Box) -> Self { @@ -97,6 +103,7 @@ impl Agent { async fn dispatch_tool_call( &self, tool_call: AgentResult, + permission_denied: bool, ) -> AgentResult> { let call = tool_call?; let system = self @@ -110,7 +117,12 @@ impl Agent { .ok_or_else(|| AgentError::InvalidToolName(call.name.clone()))?; let system_tool_call = ToolCall::new(tool_name, call.arguments); - system.call(system_tool_call).await + if permission_denied { + // This should really end up with the audience of assistant as users don't need to see this. + permission_denied_content().await + } else { + system.call(system_tool_call).await + } } fn get_system_prompt(&self) -> AgentResult { @@ -202,7 +214,11 @@ impl Agent { /// Create a stream that yields each message as it's generated by the agent. /// This includes both the assistant's responses and any tool responses. - pub async fn reply(&self, messages: &[Message]) -> Result>> { + pub async fn reply( + &self, + messages: &[Message], + mut approval_monitor: ApprovalMonitor, + ) -> Result>> { let mut messages = messages.to_vec(); let tools = self.get_prefixed_tools(); let system_prompt = self.get_system_prompt()?; @@ -215,12 +231,27 @@ impl Agent { Ok(Box::pin(async_stream::try_stream! { loop { // Get completion from provider - let (response, _) = self.provider.complete( + let (mut response, _) = self.provider.complete( &system_prompt, &messages, &tools, ).await?; + // Add approval requests to any tool requests + let mut approval_map: HashMap = HashMap::new(); + let mut new_content: Vec = vec![]; + response.content.iter().for_each(|content| { + let c = if let MessageContent::ToolRequest(tool_request) = content { + let approval_request = ApprovalRequest::internal(); + approval_map.insert(approval_request.id.clone(), tool_request.id.clone()); + MessageContent::tool_request_with_approval(tool_request.id.clone(), tool_request.tool_call.clone(), Some(approval_request)) + } else { + content.clone() + }; + new_content.push(c); + }); + response.content = new_content; + // The assistant's response is added in rewrite_messages_on_tool_response // Yield the assistant's response yield response.clone(); @@ -240,10 +271,39 @@ impl Agent { break; } + // Wait for all approvals to be handled by the user + let mut rejected_tool_request_ids: Vec = Vec::new(); + while !approval_map.is_empty() { + tokio::select! { + approval = approval_monitor.approval_rx.recv() => { + if let Some(approval) = approval { + approval_map.remove(&approval); + } + } + rejection = approval_monitor.rejection_rx.recv() => { + if let Some(rejection) = rejection { + if let Some(tool_request_id) = approval_map.remove(&rejection) { + rejected_tool_request_ids.push(tool_request_id.clone()); + } + } + } + } + } + + // If we want user role on the requests, have to deal with the duplication of the tool requests from the initial provider.complete yielding. + // + // let message_tool_response = Message::user(); + // yield tool_requests.iter().fold(message_tool_response, |msg, request| { + // // TOOD: request.id is irrelevant here, only the approval id within the tool_call matters. + // msg.with_tool_request_and_approval(request.id.clone(), request.tool_call.clone(), Some(ApprovalRequest::internal())) + // }); + // Then dispatch each in parallel let futures: Vec<_> = tool_requests .iter() - .map(|request| self.dispatch_tool_call(request.tool_call.clone())) + .map(|request| { + self.dispatch_tool_call(request.tool_call.clone(), rejected_tool_request_ids.contains(&request.id)) + }) .collect(); // Process all the futures in parallel but wait until all are finished @@ -273,6 +333,12 @@ impl Agent { } } +async fn permission_denied_content() -> AgentResult> { + Ok(vec![Content::text( + "Permission denied by the user, try an alternative approach.", + )]) +} + #[cfg(test)] mod tests { use super::*; @@ -343,7 +409,9 @@ mod tests { let initial_message = Message::user().with_text("Hi"); let initial_messages = vec![initial_message]; - let mut stream = agent.reply(&initial_messages).await?; + let mut stream = agent + .reply(&initial_messages, mock_approval_monitor()) + .await?; let mut messages = Vec::new(); while let Some(msg) = stream.try_next().await? { messages.push(msg); @@ -369,7 +437,9 @@ mod tests { let initial_message = Message::user().with_text("Echo test"); let initial_messages = vec![initial_message]; - let mut stream = agent.reply(&initial_messages).await?; + let mut stream = agent + .reply(&initial_messages, mock_approval_monitor()) + .await?; let mut messages = Vec::new(); while let Some(msg) = stream.try_next().await? { messages.push(msg); @@ -398,7 +468,9 @@ mod tests { let initial_message = Message::user().with_text("Invalid tool"); let initial_messages = vec![initial_message]; - let mut stream = agent.reply(&initial_messages).await?; + let mut stream = agent + .reply(&initial_messages, mock_approval_monitor()) + .await?; let mut messages = Vec::new(); while let Some(msg) = stream.try_next().await? { messages.push(msg); @@ -437,7 +509,9 @@ mod tests { let initial_message = Message::user().with_text("Multiple calls"); let initial_messages = vec![initial_message]; - let mut stream = agent.reply(&initial_messages).await?; + let mut stream = agent + .reply(&initial_messages, mock_approval_monitor()) + .await?; let mut messages = Vec::new(); while let Some(msg) = stream.try_next().await? { messages.push(msg); @@ -452,4 +526,13 @@ mod tests { assert_eq!(messages[2].content[0], MessageContent::text("All done!")); Ok(()) } + + fn mock_approval_monitor() -> ApprovalMonitor { + let (approval_tx, approval_rx) = tokio::sync::mpsc::channel(1); + let (rejection_tx, rejection_rx) = tokio::sync::mpsc::channel(1); + ApprovalMonitor { + approval_rx, + rejection_rx, + } + } } diff --git a/crates/goose/src/models/message.rs b/crates/goose/src/models/message.rs index aef46d692..6ee4bad1c 100644 --- a/crates/goose/src/models/message.rs +++ b/crates/goose/src/models/message.rs @@ -8,6 +8,26 @@ use chrono::Utc; pub struct ToolRequest { pub id: String, pub tool_call: AgentResult, + pub approval_request: Option, +} + +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +pub struct ApprovalRequest { + pub id: String, + pub approval_type: ApprovalType, +} +impl ApprovalRequest { + pub fn internal() -> Self { + ApprovalRequest { + id: uuid::Uuid::new_v4().to_string(), + approval_type: ApprovalType::Internal, + } + } +} + +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +pub enum ApprovalType { + Internal, // Send an approval response to goose with the approval request id. } #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] @@ -47,6 +67,19 @@ impl MessageContent { MessageContent::ToolRequest(ToolRequest { id: id.into(), tool_call, + approval_request: None, + }) + } + + pub fn tool_request_with_approval>( + id: S, + tool_call: AgentResult, + approval_request: Option, + ) -> Self { + MessageContent::ToolRequest(ToolRequest { + id: id.into(), + tool_call, + approval_request: approval_request, }) } @@ -135,6 +168,19 @@ impl Message { self.with_content(MessageContent::tool_request(id, tool_call)) } + pub fn with_tool_request_and_approval>( + self, + id: S, + tool_call: AgentResult, + approval_request: Option, + ) -> Self { + self.with_content(MessageContent::tool_request_with_approval( + id, + tool_call, + approval_request, + )) + } + /// Add a tool response to the message pub fn with_tool_response>( self,