From 229b8208a483b540a8740fa561f587d7d8734f4c Mon Sep 17 00:00:00 2001 From: leoshimo <56844000+leoshimo@users.noreply.github.com> Date: Thu, 30 Nov 2023 13:15:24 -0800 Subject: [PATCH] feat: Simplify interface by removing subcommands --- src/cli.rs | 110 +++++++++++------------------------------------ src/exec/chat.rs | 27 +++++------- src/exec/mod.rs | 6 +-- tests/chat.rs | 21 +++------ 4 files changed, 43 insertions(+), 121 deletions(-) diff --git a/src/cli.rs b/src/cli.rs index cc070aa..5de9356 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -9,15 +9,8 @@ use clap::{ use derive_builder::Builder; /// CLI invocations that can be launched -#[derive(Debug)] -pub enum Invocation { - /// Invoke chat completion, - ChatCompletion(ChatCompletionArgs), -} - -/// Arguments parsed for ChatCompletion #[derive(Debug, Default, Builder)] -pub struct ChatCompletionArgs { +pub struct Invocation { pub api_key: Option, pub messages: Vec, pub model: String, @@ -36,7 +29,7 @@ pub enum OutputFormat { JSONPretty, } -/// Parse commandline arguments into `ChatCompletionArgs`. May exit with help or error message +/// Parse commandline arguments into `Invocation`. May exit with help or error message #[must_use] pub fn parse() -> Invocation { cli().get_matches().into() @@ -45,14 +38,6 @@ pub fn parse() -> Invocation { /// Top-level command fn cli() -> Command { command!() - .subcommand(chat_completion_cmd()) - .subcommand_required(true) -} - -/// Subcommand for chat completion interface -fn chat_completion_cmd() -> Command { - Command::new("chat") - .about("Chat Completion") .arg(arg!(model: -m --model "Sets model").default_value("gpt-3.5-turbo")) .arg( arg!(temperature: -t --temperature "Sets temperature") @@ -92,24 +77,9 @@ fn chat_completion_cmd() -> Command { } impl From for Invocation { - fn from(matches: ArgMatches) -> Self { - use Invocation::*; - - let (name, submatch) = matches.subcommand().expect("Subcommands are required"); - - match name { - "chat" => ChatCompletion(ChatCompletionArgs::from(submatch.to_owned())), - _ => { - panic!("Unrecognized subcommand"); - } - } - } -} - -impl From for ChatCompletionArgs { fn from(matches: ArgMatches) -> Self { let api_key = matches.get_one::("api_key").cloned(); - let messages = ChatCompletionArgs::messages_from_matches(&matches); + let messages = Invocation::messages_from_matches(&matches); let model = matches .get_one::("model") .expect("Models is required") @@ -145,10 +115,10 @@ impl From for ChatCompletionArgs { } } -impl ChatCompletionArgs { +impl Invocation { /// Builder - pub fn builder() -> ChatCompletionArgsBuilder { - ChatCompletionArgsBuilder::default() + pub fn builder() -> InvocationBuilder { + InvocationBuilder::default() } /// Given `clap::ArgMatches`, creates a vector of `Message` with assigned roles and ordering @@ -197,25 +167,15 @@ impl ValueEnum for OutputFormat { #[cfg(test)] mod test { - use super::Invocation::*; use super::*; type Result = std::result::Result>; - #[test] - fn chat_no_args_is_err() { - let res = cli() - .try_get_matches_from(vec!["cogni"]) - .map(Invocation::from); - assert!(res.is_err()); - } - #[test] fn chat_one_msgs() -> Result<()> { - let res = cli() - .try_get_matches_from(vec!["cogni", "chat", "-u", "USER"]) + let args = cli() + .try_get_matches_from(vec!["cogni", "-u", "USER"]) .map(Invocation::from)?; - let ChatCompletion(args) = res; assert_eq!(args.messages, vec![Message::user("USER")]); Ok(()) @@ -223,12 +183,9 @@ mod test { #[test] fn chat_many_msgs() -> Result<()> { - let res = cli() - .try_get_matches_from(vec![ - "cogni", "chat", "-u", "USER1", "-a", "ROBOT", "-u", "USER2", - ]) + let args = cli() + .try_get_matches_from(vec!["cogni", "-u", "USER1", "-a", "ROBOT", "-u", "USER2"]) .map(Invocation::from)?; - let ChatCompletion(args) = res; assert_eq!( args.messages, @@ -244,12 +201,11 @@ mod test { #[test] fn chat_many_msgs_with_system_prompt() -> Result<()> { - let res = cli() + let args = cli() .try_get_matches_from(vec![ - "cogni", "chat", "-s", "SYSTEM", "-u", "USER1", "-a", "ROBOT", "-u", "USER2", + "cogni", "-s", "SYSTEM", "-u", "USER1", "-a", "ROBOT", "-u", "USER2", ]) .map(Invocation::from)?; - let ChatCompletion(args) = res; assert_eq!( args.messages, @@ -266,12 +222,11 @@ mod test { #[test] fn chat_many_msgs_with_system_prompt_last() -> Result<()> { - let res = cli() + let args = cli() .try_get_matches_from(vec![ - "cogni", "chat", "-s", "SYSTEM", "-u", "USER1", "-a", "ROBOT", "-u", "USER2", + "cogni", "-s", "SYSTEM", "-u", "USER1", "-a", "ROBOT", "-u", "USER2", ]) .map(Invocation::from)?; - let ChatCompletion(args) = res; assert_eq!( args.messages, @@ -288,10 +243,9 @@ mod test { #[test] fn chat_output_format_default() -> Result<()> { - let res = cli() - .try_get_matches_from(vec!["cogni", "chat", "-u", "ABC"]) + let args = cli() + .try_get_matches_from(vec!["cogni", "-u", "ABC"]) .map(Invocation::from)?; - let ChatCompletion(args) = res; assert_eq!( args.output_format, @@ -303,17 +257,9 @@ mod test { #[test] fn chat_output_format_explicit_json() -> Result<()> { - let res = cli() - .try_get_matches_from(vec![ - "cogni", - "chat", - "-u", - "ABC", - "--output_format", - "json", - ]) + let args = cli() + .try_get_matches_from(vec!["cogni", "-u", "ABC", "--output_format", "json"]) .map(Invocation::from)?; - let ChatCompletion(args) = res; assert_eq!(args.output_format, OutputFormat::JSON); Ok(()) @@ -321,10 +267,9 @@ mod test { #[test] fn chat_output_format_shorthand_json() -> Result<()> { - let res = cli() - .try_get_matches_from(vec!["cogni", "chat", "-u", "ABC", "--json"]) + let args = cli() + .try_get_matches_from(vec!["cogni", "-u", "ABC", "--json"]) .map(Invocation::from)?; - let ChatCompletion(args) = res; assert_eq!(args.output_format, OutputFormat::JSON); Ok(()) @@ -332,10 +277,9 @@ mod test { #[test] fn chat_output_format_shorthand_jsonp() -> Result<()> { - let res = cli() - .try_get_matches_from(vec!["cogni", "chat", "-u", "ABC", "--jsonp"]) + let args = cli() + .try_get_matches_from(vec!["cogni", "-u", "ABC", "--jsonp"]) .map(Invocation::from)?; - let ChatCompletion(args) = res; assert_eq!(args.output_format, OutputFormat::JSONPretty); Ok(()) @@ -343,10 +287,9 @@ mod test { #[test] fn chat_file_default() -> Result<()> { - let res = cli() - .try_get_matches_from(vec!["cogni", "chat"]) + let args = cli() + .try_get_matches_from(vec!["cogni"]) .map(Invocation::from)?; - let ChatCompletion(args) = res; assert_eq!(args.file, "-"); Ok(()) @@ -354,10 +297,9 @@ mod test { #[test] fn chat_file_positional() -> Result<()> { - let res = cli() - .try_get_matches_from(vec!["cogni", "chat", "dialog_log"]) + let args = cli() + .try_get_matches_from(vec!["cogni", "dialog_log"]) .map(Invocation::from)?; - let ChatCompletion(args) = res; assert_eq!(args.file, "dialog_log"); Ok(()) diff --git a/src/exec/chat.rs b/src/exec/chat.rs index 2342517..6bfac73 100644 --- a/src/exec/chat.rs +++ b/src/exec/chat.rs @@ -1,18 +1,18 @@ //! Implements chat subcommand -use crate::cli::{ChatCompletionArgs, OutputFormat}; +use crate::cli::{Invocation, OutputFormat}; use crate::openai::{self, ChatCompletion, FinishReason, Message}; use crate::parse; use crate::Error; -use std::io::{self, Read, Write, BufWriter, IsTerminal}; +use anyhow::{Context, Result}; use std::fs::File; -use anyhow::{Result, Context}; +use std::io::{self, BufWriter, IsTerminal, Read, Write}; -/// Executes `Invocation::ChatCompletion` via given args -pub async fn exec(args: ChatCompletionArgs) -> Result<()> { - let base_url = std::env::var("OPENAI_API_ENDPOINT") - .unwrap_or("https://api.openai.com".to_string()); +/// Executes `Invocation` via given args +pub async fn exec(args: Invocation) -> Result<()> { + let base_url = + std::env::var("OPENAI_API_ENDPOINT").unwrap_or("https://api.openai.com".to_string()); let client = openai::Client::new(args.api_key.clone(), base_url) .with_context(|| "failed to create http client")?; @@ -44,7 +44,6 @@ pub async fn exec(args: ChatCompletionArgs) -> Result<()> { Ok(()) } - /// Read messages from non-tty stdin or file specified by `args.file` fn read_messages_from_file(file: &str) -> Result> { let reader: Option> = match file { @@ -66,11 +65,7 @@ fn read_messages_from_file(file: &str) -> Result> { } /// Show formatted output for `ChatCompletionRequest` -fn show_response( - dest: impl Write, - args: &ChatCompletionArgs, - resp: &ChatCompletion, -) -> Result<(), Error> { +fn show_response(dest: impl Write, args: &Invocation, resp: &ChatCompletion) -> Result<(), Error> { let mut writer = BufWriter::new(dest); let choice = match resp.choices.len() { 1 => &resp.choices[0], @@ -113,7 +108,7 @@ mod test { use predicates::prelude::*; use crate::{ - cli::{ChatCompletionArgs, ChatCompletionArgsBuilder, OutputFormat}, + cli::{Invocation, InvocationBuilder, OutputFormat}, openai::{ChatCompletion, ChatCompletionBuilder, Choice, FinishReason, Message, Usage}, }; @@ -188,8 +183,8 @@ mod test { Ok(()) } - fn default_args() -> ChatCompletionArgsBuilder { - ChatCompletionArgs::builder() + fn default_args() -> InvocationBuilder { + Invocation::builder() .api_key(Some(String::default())) .messages(vec![]) .model(String::default()) diff --git a/src/exec/mod.rs b/src/exec/mod.rs index f4d60ba..af78f4c 100644 --- a/src/exec/mod.rs +++ b/src/exec/mod.rs @@ -6,9 +6,5 @@ use anyhow::Result; /// Execute the invocation pub async fn exec(inv: Invocation) -> Result<()> { - use Invocation::*; - match inv { - ChatCompletion(args) => chat::exec(args).await, - } + chat::exec(inv).await } - diff --git a/tests/chat.rs b/tests/chat.rs index 3a9cc60..34fc261 100644 --- a/tests/chat.rs +++ b/tests/chat.rs @@ -5,20 +5,10 @@ use assert_fs::prelude::*; use predicates::prelude::*; use serde_json::json; -#[test] -fn no_args() { - Command::cargo_bin("cogni") - .unwrap() - .assert() - .failure() - .stderr(predicate::str::contains("Usage")); -} - #[test] fn chat_no_message() { Command::cargo_bin("cogni") .unwrap() - .args(["chat"]) .assert() .failure() .stderr(predicate::str::contains("no messages provided")); @@ -28,7 +18,7 @@ fn chat_no_message() { fn chat_no_file() { Command::cargo_bin("cogni") .unwrap() - .args(["chat", "file_does_not_exist"]) + .args(["file_does_not_exist"]) .assert() .failure() .stderr(predicate::str::contains( @@ -74,7 +64,7 @@ fn chat_user_message_from_flag() { let cmd = Command::cargo_bin("cogni") .unwrap() - .args(["chat", "-u", "Hello"]) + .args(["-u", "Hello"]) .env("OPENAI_API_ENDPOINT", server.url()) .env("OPENAI_API_KEY", "ABCDE") .assert(); @@ -123,7 +113,6 @@ fn chat_user_message_from_stdin() { let cmd = Command::cargo_bin("cogni") .unwrap() - .args(["chat"]) .write_stdin("Hello") .env("OPENAI_API_ENDPOINT", server.url()) .env("OPENAI_API_KEY", "ABCDE") @@ -194,7 +183,7 @@ fn chat_multiple_messages() { let cmd = Command::cargo_bin("cogni") .unwrap() .args([ - "chat", "-s", "SYSTEM", "-u", "USER_1", "-a", "ASSI_1", "-u", "USER_2", "-a", "ASSI_2", + "-s", "SYSTEM", "-u", "USER_1", "-a", "ASSI_1", "-u", "USER_2", "-a", "ASSI_2", ]) .write_stdin("USER_STDIN") .env("OPENAI_API_ENDPOINT", server.url()) @@ -238,7 +227,7 @@ fn chat_api_error() { let cmd = Command::cargo_bin("cogni") .unwrap() - .args(["chat", "-u", "USER", "-t", "1000"]) + .args(["-u", "USER", "-t", "1000"]) .write_stdin("USER_STDIN") .env("OPENAI_API_ENDPOINT", server.url()) .env("OPENAI_API_KEY", "ABCDE") @@ -293,7 +282,7 @@ fn chat_user_message_from_file() { let cmd = Command::cargo_bin("cogni") .unwrap() - .args(["chat", infile.path().to_str().unwrap()]) + .args([infile.path().to_str().unwrap()]) .env("OPENAI_API_ENDPOINT", server.url()) .env("OPENAI_API_KEY", "ABCDE") .assert();