Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require approval in cli #441

Draft
wants to merge 1 commit into
base: v1.0
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions crates/goose-cli/src/agents/agent.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use goose::{agent::Agent as GooseAgent, models::message::Message, systems::System};
use goose::{
agent::{Agent as GooseAgent, ApprovalMonitor},
models::message::Message,
systems::System,
};

#[async_trait]
pub trait Agent {
fn add_system(&mut self, system: Box<dyn System>);
async fn reply(&self, messages: &[Message]) -> Result<BoxStream<'_, Result<Message>>>;
async fn reply(
&self,
messages: &[Message],
approval_monitor: ApprovalMonitor,
) -> Result<BoxStream<'_, Result<Message>>>;
}

#[async_trait]
Expand All @@ -15,7 +23,11 @@ impl Agent for GooseAgent {
self.add_system(system);
}

async fn reply(&self, messages: &[Message]) -> Result<BoxStream<'_, Result<Message>>> {
self.reply(messages).await
async fn reply(
&self,
messages: &[Message],
approval_monitor: ApprovalMonitor,
) -> Result<BoxStream<'_, Result<Message>>> {
self.reply(messages, approval_monitor).await
}
}
8 changes: 6 additions & 2 deletions crates/goose-cli/src/agents/mock_agent.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use goose::{models::message::Message, systems::System};
use goose::{agent::ApprovalMonitor, models::message::Message, systems::System};

use crate::agents::agent::Agent;

Expand All @@ -11,7 +11,11 @@ pub struct MockAgent;
impl Agent for MockAgent {
fn add_system(&mut self, _system: Box<dyn System>) {}

async fn reply(&self, _messages: &[Message]) -> Result<BoxStream<'_, Result<Message>>> {
async fn reply(
&self,
_messages: &[Message],
approval_monitor: ApprovalMonitor,
) -> Result<BoxStream<'_, Result<Message>>> {
Ok(Box::pin(futures::stream::empty()))
}
}
5 changes: 4 additions & 1 deletion crates/goose-cli/src/prompt.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::Result;
use goose::models::message::Message;
use goose::models::message::{ApprovalRequest, Message};

pub mod cliclack;
pub mod renderer;
Expand All @@ -9,6 +9,9 @@ pub mod thinking;
pub trait Prompt {
fn render(&mut self, message: Box<Message>);
fn get_input(&mut self) -> Result<Input>;
fn handle_approval_request(&mut self, approval_request: ApprovalRequest) -> Result<bool> {
Err(anyhow::anyhow!("Not implemented"))
}
fn show_busy(&mut self);
fn hide_busy(&self);
fn close(&self);
Expand Down
22 changes: 21 additions & 1 deletion crates/goose-cli/src/prompt/renderer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::io::{self, Write};

use bat::WrappingMode;
use console::style;
use goose::models::message::{Message, MessageContent, ToolRequest, ToolResponse};
use goose::models::message::{ApprovalRequest, Message, MessageContent, ToolRequest, ToolResponse};
use goose::models::role::Role;
use goose::models::{content::Content, tool::ToolCall};
use serde_json::Value;
Expand Down Expand Up @@ -57,6 +57,7 @@ impl ToolRenderer for DefaultRenderer {

// Format and print the parameters
print_params(&call.arguments, 0);
render_approval_required(tool_request.approval_request.clone());
print_newline();
}
Err(e) => print_markdown(&e.to_string(), theme),
Expand Down Expand Up @@ -87,6 +88,7 @@ impl ToolRenderer for BashDeveloperSystemRenderer {
}
_ => print_params(&call.arguments, 0),
}
render_approval_required(tool_request.approval_request.clone());
print_newline();
}
Err(e) => print_markdown(&e.to_string(), theme),
Expand Down Expand Up @@ -186,6 +188,24 @@ pub fn default_print_request_header(call: &ToolCall) {
println!("{}", tool_header);
}

pub fn render_approval_required(approval_request: Option<ApprovalRequest>) {
if let Some(approval_request) = approval_request {
println!(
"{}",
style("────────────────────────────────────────────────────────").red()
);
println!(
"{} Approval request id: {}",
style("|").red(),
approval_request.id
);
println!(
"{}",
style("────────────────────────────────────────────────────────").red()
);
}
}

pub fn print_markdown(content: &str, theme: &str) {
bat::PrettyPrinter::new()
.input(bat::Input::from_bytes(content.as_bytes()))
Expand Down
32 changes: 31 additions & 1 deletion crates/goose-cli/src/prompt/rustyline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use super::{

use anyhow::Result;
use cliclack::spinner;
use goose::models::message::Message;
use goose::models::message::{ApprovalRequest, Message};

const PROMPT: &str = "\x1b[1m\x1b[38;5;30m( O)> \x1b[0m";

Expand Down Expand Up @@ -128,4 +128,34 @@ impl Prompt for RustylinePrompt {
fn as_any(&self) -> &dyn std::any::Any {
panic!("Not implemented");
}

// TODO: Return a richer result type
fn handle_approval_request(&mut self, approval_request: ApprovalRequest) -> Result<bool> {
let message = format!(
"Approve tool request {} ? (y/n) [Default: y] ",
approval_request.id
);
loop {
let mut editor = rustyline::DefaultEditor::new()?;
let response = editor.readline(&message);
let response = match response {
Ok(text) => text,
Err(e) => {
match e {
rustyline::error::ReadlineError::Interrupted => (),
_ => eprintln!("Input error: {}", e),
}
return Err(anyhow::anyhow!("Approval request interrupted"));
}
};

if response.eq_ignore_ascii_case("y") || response.is_empty() {
return Ok(true);
} else if response.eq_ignore_ascii_case("n") {
return Ok(false);
} else {
println!("Invalid response. Please enter 'y' or 'n'.");
}
}
}
}
58 changes: 56 additions & 2 deletions crates/goose-cli/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use anyhow::Result;
use core::panic;
use futures::StreamExt;
use goose::agent::ApprovalMonitor;
use serde_json;
use std::fs::{self, File};
use std::io::{self, BufRead, Write};
Expand All @@ -9,7 +10,7 @@ use std::path::PathBuf;
use crate::agents::agent::Agent;
use crate::prompt::{InputType, Prompt};
use goose::developer::DeveloperSystem;
use goose::models::message::{Message, MessageContent};
use goose::models::message::{ApprovalRequest, Message, MessageContent};
use goose::models::role::Role;
use goose::systems::goose_hints::GooseHintsSystem;

Expand Down Expand Up @@ -140,7 +141,13 @@ impl<'a> Session<'a> {
}

async fn agent_process_messages(&mut self) {
let mut stream = match self.agent.reply(&self.messages).await {
let (approval_tx, approval_rx) = tokio::sync::mpsc::channel(1);
let (rejection_tx, rejection_rx) = tokio::sync::mpsc::channel(1);
let approval_monitor = ApprovalMonitor {
approval_rx: approval_rx,
rejection_rx: rejection_rx,
};
let mut stream = match self.agent.reply(&self.messages, approval_monitor).await {
Ok(stream) => stream,
Err(e) => {
eprintln!("Error starting reply stream: {}", e);
Expand All @@ -156,6 +163,39 @@ impl<'a> Session<'a> {
persist_messages(&self.session_file, &self.messages).unwrap_or_else(|e| eprintln!("Failed to persist messages: {}", e));
self.prompt.hide_busy();
self.prompt.render(Box::new(message.clone()));

let mut abandon_tools = false;
// Handle any tool requests that require approval
for approval in pending_approvals(message).iter() {
if abandon_tools {
// If we've already abandoned tools, reject any new requests
if let Err(e) = rejection_tx.send(approval.id.clone()).await {
eprintln!("Failed to handle approval: {}", e);
}
continue;
}
// TODO: Auto-approve some tool requests based on external configuration.
match self.prompt.handle_approval_request(approval.clone()) {
Ok(approval_result) => {
if approval_result {
if let Err(e) = approval_tx.send(approval.id.clone()).await {
eprintln!("Failed to send approval: {}", e);
}
} else {
if let Err(e) = rejection_tx.send(approval.id.clone()).await {
eprintln!("Failed to send rejection: {}", e);
}
}
},
Err(e) => {
abandon_tools = true;
eprintln!("Failed to handle approval: {}", e);
if let Err(e) = rejection_tx.send(approval.id.clone()).await {
eprintln!("Failed to send rejection: {}", e);
}
}
}
}
self.prompt.show_busy();
}
Some(Err(e)) => {
Expand Down Expand Up @@ -297,6 +337,20 @@ We've removed the conversation up to the most recent user message
}
}

fn pending_approvals(message: Message) -> Vec<ApprovalRequest> {
message
.content
.iter()
.filter_map(|content| {
if let MessageContent::ToolRequest(req) = content {
req.approval_request.clone()
} else {
None
}
})
.collect()
}

fn raw_message(content: &str) -> Box<Message> {
Box::new(Message::assistant().with_text(content))
}
Expand Down
30 changes: 24 additions & 6 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ use axum::{
};
use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use goose::models::{
content::Content,
message::{Message, MessageContent},
role::Role,
use goose::{
agent::ApprovalMonitor,
models::{
content::Content,
message::{Message, MessageContent},
role::Role,
},
};
use serde::Deserialize;
use serde_json::{json, Value};
Expand Down Expand Up @@ -276,10 +279,18 @@ async fn handler(
// Get a lock on the shared agent
let agent = state.agent.clone();

// Create a monitor for approvals. TODO: Implement this for gui
let (approval_tx, approval_rx) = mpsc::channel(1);
let (reject_tx, rejection_rx) = mpsc::channel(1);
let approval_monitor = ApprovalMonitor {
approval_rx,
rejection_rx,
};

// Spawn task to handle streaming
tokio::spawn(async move {
let agent = agent.lock().await;
let mut stream = match agent.reply(&messages).await {
let mut stream = match agent.reply(&messages, approval_monitor).await {
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {}", e);
Expand Down Expand Up @@ -347,9 +358,16 @@ async fn ask_handler(
// Create a single message for the prompt
let messages = vec![Message::user().with_text(request.prompt)];

let (approval_tx, approval_rx) = mpsc::channel(1);
let (reject_tx, rejection_rx) = mpsc::channel(1);
let approval_monitor = ApprovalMonitor {
approval_rx,
rejection_rx,
};

// Get response from agent
let mut response_text = String::new();
let mut stream = match agent.reply(&messages).await {
let mut stream = match agent.reply(&messages, approval_monitor).await {
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {}", e);
Expand Down
Loading