Skip to content

Commit

Permalink
feat: Simplify interface by removing subcommands
Browse files Browse the repository at this point in the history
  • Loading branch information
leoshimo committed Nov 30, 2023
1 parent b505c59 commit 229b820
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 121 deletions.
110 changes: 26 additions & 84 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,8 @@ use clap::{
use derive_builder::Builder;

/// CLI invocations that can be launched
#[derive(Debug)]
pub enum Invocation {
/// Invoke chat completion,
ChatCompletion(ChatCompletionArgs),
}

/// Arguments parsed for ChatCompletion
#[derive(Debug, Default, Builder)]
pub struct ChatCompletionArgs {
pub struct Invocation {
pub api_key: Option<String>,
pub messages: Vec<Message>,
pub model: String,
Expand All @@ -36,7 +29,7 @@ pub enum OutputFormat {
JSONPretty,
}

/// Parse commandline arguments into `ChatCompletionArgs`. May exit with help or error message
/// Parse commandline arguments into `Invocation`. May exit with help or error message
#[must_use]
pub fn parse() -> Invocation {
cli().get_matches().into()
Expand All @@ -45,14 +38,6 @@ pub fn parse() -> Invocation {
/// Top-level command
fn cli() -> Command {
command!()
.subcommand(chat_completion_cmd())
.subcommand_required(true)
}

/// Subcommand for chat completion interface
fn chat_completion_cmd() -> Command {
Command::new("chat")
.about("Chat Completion")
.arg(arg!(model: -m --model <MODEL> "Sets model").default_value("gpt-3.5-turbo"))
.arg(
arg!(temperature: -t --temperature <TEMP> "Sets temperature")
Expand Down Expand Up @@ -92,24 +77,9 @@ fn chat_completion_cmd() -> Command {
}

impl From<ArgMatches> for Invocation {
fn from(matches: ArgMatches) -> Self {
use Invocation::*;

let (name, submatch) = matches.subcommand().expect("Subcommands are required");

match name {
"chat" => ChatCompletion(ChatCompletionArgs::from(submatch.to_owned())),
_ => {
panic!("Unrecognized subcommand");
}
}
}
}

impl From<ArgMatches> for ChatCompletionArgs {
fn from(matches: ArgMatches) -> Self {
let api_key = matches.get_one::<String>("api_key").cloned();
let messages = ChatCompletionArgs::messages_from_matches(&matches);
let messages = Invocation::messages_from_matches(&matches);
let model = matches
.get_one::<String>("model")
.expect("Models is required")
Expand Down Expand Up @@ -145,10 +115,10 @@ impl From<ArgMatches> for ChatCompletionArgs {
}
}

impl ChatCompletionArgs {
impl Invocation {
/// Builder
pub fn builder() -> ChatCompletionArgsBuilder {
ChatCompletionArgsBuilder::default()
pub fn builder() -> InvocationBuilder {
InvocationBuilder::default()
}

/// Given `clap::ArgMatches`, creates a vector of `Message` with assigned roles and ordering
Expand Down Expand Up @@ -197,38 +167,25 @@ impl ValueEnum for OutputFormat {

#[cfg(test)]
mod test {
use super::Invocation::*;
use super::*;

type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;

#[test]
fn chat_no_args_is_err() {
let res = cli()
.try_get_matches_from(vec!["cogni"])
.map(Invocation::from);
assert!(res.is_err());
}

#[test]
fn chat_one_msgs() -> Result<()> {
let res = cli()
.try_get_matches_from(vec!["cogni", "chat", "-u", "USER"])
let args = cli()
.try_get_matches_from(vec!["cogni", "-u", "USER"])
.map(Invocation::from)?;
let ChatCompletion(args) = res;

assert_eq!(args.messages, vec![Message::user("USER")]);
Ok(())
}

#[test]
fn chat_many_msgs() -> Result<()> {
let res = cli()
.try_get_matches_from(vec![
"cogni", "chat", "-u", "USER1", "-a", "ROBOT", "-u", "USER2",
])
let args = cli()
.try_get_matches_from(vec!["cogni", "-u", "USER1", "-a", "ROBOT", "-u", "USER2"])
.map(Invocation::from)?;
let ChatCompletion(args) = res;

assert_eq!(
args.messages,
Expand All @@ -244,12 +201,11 @@ mod test {

#[test]
fn chat_many_msgs_with_system_prompt() -> Result<()> {
let res = cli()
let args = cli()
.try_get_matches_from(vec![
"cogni", "chat", "-s", "SYSTEM", "-u", "USER1", "-a", "ROBOT", "-u", "USER2",
"cogni", "-s", "SYSTEM", "-u", "USER1", "-a", "ROBOT", "-u", "USER2",
])
.map(Invocation::from)?;
let ChatCompletion(args) = res;

assert_eq!(
args.messages,
Expand All @@ -266,12 +222,11 @@ mod test {

#[test]
fn chat_many_msgs_with_system_prompt_last() -> Result<()> {
let res = cli()
let args = cli()
.try_get_matches_from(vec![
"cogni", "chat", "-s", "SYSTEM", "-u", "USER1", "-a", "ROBOT", "-u", "USER2",
"cogni", "-s", "SYSTEM", "-u", "USER1", "-a", "ROBOT", "-u", "USER2",
])
.map(Invocation::from)?;
let ChatCompletion(args) = res;

assert_eq!(
args.messages,
Expand All @@ -288,10 +243,9 @@ mod test {

#[test]
fn chat_output_format_default() -> Result<()> {
let res = cli()
.try_get_matches_from(vec!["cogni", "chat", "-u", "ABC"])
let args = cli()
.try_get_matches_from(vec!["cogni", "-u", "ABC"])
.map(Invocation::from)?;
let ChatCompletion(args) = res;

assert_eq!(
args.output_format,
Expand All @@ -303,61 +257,49 @@ mod test {

#[test]
fn chat_output_format_explicit_json() -> Result<()> {
let res = cli()
.try_get_matches_from(vec![
"cogni",
"chat",
"-u",
"ABC",
"--output_format",
"json",
])
let args = cli()
.try_get_matches_from(vec!["cogni", "-u", "ABC", "--output_format", "json"])
.map(Invocation::from)?;
let ChatCompletion(args) = res;

assert_eq!(args.output_format, OutputFormat::JSON);
Ok(())
}

#[test]
fn chat_output_format_shorthand_json() -> Result<()> {
let res = cli()
.try_get_matches_from(vec!["cogni", "chat", "-u", "ABC", "--json"])
let args = cli()
.try_get_matches_from(vec!["cogni", "-u", "ABC", "--json"])
.map(Invocation::from)?;
let ChatCompletion(args) = res;

assert_eq!(args.output_format, OutputFormat::JSON);
Ok(())
}

#[test]
fn chat_output_format_shorthand_jsonp() -> Result<()> {
let res = cli()
.try_get_matches_from(vec!["cogni", "chat", "-u", "ABC", "--jsonp"])
let args = cli()
.try_get_matches_from(vec!["cogni", "-u", "ABC", "--jsonp"])
.map(Invocation::from)?;
let ChatCompletion(args) = res;

assert_eq!(args.output_format, OutputFormat::JSONPretty);
Ok(())
}

#[test]
fn chat_file_default() -> Result<()> {
let res = cli()
.try_get_matches_from(vec!["cogni", "chat"])
let args = cli()
.try_get_matches_from(vec!["cogni"])
.map(Invocation::from)?;
let ChatCompletion(args) = res;

assert_eq!(args.file, "-");
Ok(())
}

#[test]
fn chat_file_positional() -> Result<()> {
let res = cli()
.try_get_matches_from(vec!["cogni", "chat", "dialog_log"])
let args = cli()
.try_get_matches_from(vec!["cogni", "dialog_log"])
.map(Invocation::from)?;
let ChatCompletion(args) = res;

assert_eq!(args.file, "dialog_log");
Ok(())
Expand Down
27 changes: 11 additions & 16 deletions src/exec/chat.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
//! Implements chat subcommand

use crate::cli::{ChatCompletionArgs, OutputFormat};
use crate::cli::{Invocation, OutputFormat};
use crate::openai::{self, ChatCompletion, FinishReason, Message};
use crate::parse;
use crate::Error;

use std::io::{self, Read, Write, BufWriter, IsTerminal};
use anyhow::{Context, Result};
use std::fs::File;
use anyhow::{Result, Context};
use std::io::{self, BufWriter, IsTerminal, Read, Write};

/// Executes `Invocation::ChatCompletion` via given args
pub async fn exec(args: ChatCompletionArgs) -> Result<()> {
let base_url = std::env::var("OPENAI_API_ENDPOINT")
.unwrap_or("https://api.openai.com".to_string());
/// Executes `Invocation` via given args
pub async fn exec(args: Invocation) -> Result<()> {
let base_url =
std::env::var("OPENAI_API_ENDPOINT").unwrap_or("https://api.openai.com".to_string());

let client = openai::Client::new(args.api_key.clone(), base_url)
.with_context(|| "failed to create http client")?;
Expand Down Expand Up @@ -44,7 +44,6 @@ pub async fn exec(args: ChatCompletionArgs) -> Result<()> {
Ok(())
}


/// Read messages from non-tty stdin or file specified by `args.file`
fn read_messages_from_file(file: &str) -> Result<Vec<Message>> {
let reader: Option<Box<dyn Read>> = match file {
Expand All @@ -66,11 +65,7 @@ fn read_messages_from_file(file: &str) -> Result<Vec<Message>> {
}

/// Show formatted output for `ChatCompletionRequest`
fn show_response(
dest: impl Write,
args: &ChatCompletionArgs,
resp: &ChatCompletion,
) -> Result<(), Error> {
fn show_response(dest: impl Write, args: &Invocation, resp: &ChatCompletion) -> Result<(), Error> {
let mut writer = BufWriter::new(dest);
let choice = match resp.choices.len() {
1 => &resp.choices[0],
Expand Down Expand Up @@ -113,7 +108,7 @@ mod test {
use predicates::prelude::*;

use crate::{
cli::{ChatCompletionArgs, ChatCompletionArgsBuilder, OutputFormat},
cli::{Invocation, InvocationBuilder, OutputFormat},
openai::{ChatCompletion, ChatCompletionBuilder, Choice, FinishReason, Message, Usage},
};

Expand Down Expand Up @@ -188,8 +183,8 @@ mod test {
Ok(())
}

fn default_args() -> ChatCompletionArgsBuilder {
ChatCompletionArgs::builder()
fn default_args() -> InvocationBuilder {
Invocation::builder()
.api_key(Some(String::default()))
.messages(vec![])
.model(String::default())
Expand Down
6 changes: 1 addition & 5 deletions src/exec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,5 @@ use anyhow::Result;

/// Execute the invocation
pub async fn exec(inv: Invocation) -> Result<()> {
use Invocation::*;
match inv {
ChatCompletion(args) => chat::exec(args).await,
}
chat::exec(inv).await
}

21 changes: 5 additions & 16 deletions tests/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,10 @@ use assert_fs::prelude::*;
use predicates::prelude::*;
use serde_json::json;

#[test]
fn no_args() {
Command::cargo_bin("cogni")
.unwrap()
.assert()
.failure()
.stderr(predicate::str::contains("Usage"));
}

#[test]
fn chat_no_message() {
Command::cargo_bin("cogni")
.unwrap()
.args(["chat"])
.assert()
.failure()
.stderr(predicate::str::contains("no messages provided"));
Expand All @@ -28,7 +18,7 @@ fn chat_no_message() {
fn chat_no_file() {
Command::cargo_bin("cogni")
.unwrap()
.args(["chat", "file_does_not_exist"])
.args(["file_does_not_exist"])
.assert()
.failure()
.stderr(predicate::str::contains(
Expand Down Expand Up @@ -74,7 +64,7 @@ fn chat_user_message_from_flag() {

let cmd = Command::cargo_bin("cogni")
.unwrap()
.args(["chat", "-u", "Hello"])
.args(["-u", "Hello"])
.env("OPENAI_API_ENDPOINT", server.url())
.env("OPENAI_API_KEY", "ABCDE")
.assert();
Expand Down Expand Up @@ -123,7 +113,6 @@ fn chat_user_message_from_stdin() {

let cmd = Command::cargo_bin("cogni")
.unwrap()
.args(["chat"])
.write_stdin("Hello")
.env("OPENAI_API_ENDPOINT", server.url())
.env("OPENAI_API_KEY", "ABCDE")
Expand Down Expand Up @@ -194,7 +183,7 @@ fn chat_multiple_messages() {
let cmd = Command::cargo_bin("cogni")
.unwrap()
.args([
"chat", "-s", "SYSTEM", "-u", "USER_1", "-a", "ASSI_1", "-u", "USER_2", "-a", "ASSI_2",
"-s", "SYSTEM", "-u", "USER_1", "-a", "ASSI_1", "-u", "USER_2", "-a", "ASSI_2",
])
.write_stdin("USER_STDIN")
.env("OPENAI_API_ENDPOINT", server.url())
Expand Down Expand Up @@ -238,7 +227,7 @@ fn chat_api_error() {

let cmd = Command::cargo_bin("cogni")
.unwrap()
.args(["chat", "-u", "USER", "-t", "1000"])
.args(["-u", "USER", "-t", "1000"])
.write_stdin("USER_STDIN")
.env("OPENAI_API_ENDPOINT", server.url())
.env("OPENAI_API_KEY", "ABCDE")
Expand Down Expand Up @@ -293,7 +282,7 @@ fn chat_user_message_from_file() {

let cmd = Command::cargo_bin("cogni")
.unwrap()
.args(["chat", infile.path().to_str().unwrap()])
.args([infile.path().to_str().unwrap()])
.env("OPENAI_API_ENDPOINT", server.url())
.env("OPENAI_API_KEY", "ABCDE")
.assert();
Expand Down

0 comments on commit 229b820

Please sign in to comment.