diff --git a/Cargo.lock b/Cargo.lock index 4a7952f..e701f74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "aho-corasick" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f6cb1bf222025340178f382c426f13757b2960e89779dfcb319c32542a5a41" +dependencies = [ + "memchr", +] + [[package]] name = "anyhow" version = "1.0.69" @@ -123,6 +132,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "dyn-clone" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68b0cf012f1230e43cd00ebb729c6bb58707ecfa8ad08b52ef3a4ccd2697fc30" + [[package]] name = "either" version = "1.8.1" @@ -455,6 +470,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "lazy_fn" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "debc594b5391c3ffd0d4f4c94e49a721102dfb53aaf8192c19e214c9a2e33a02" +dependencies = [ + "quote", + "syn", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -700,8 +725,12 @@ dependencies = [ "anyhow", "clap", "itertools", + "lazy_fn", + "lazy_static", + "regex", "reqwest", "rpassword", + "schemars", "serde", "serde_json", "similar", @@ -711,6 +740,23 @@ dependencies = [ "xdg", ] +[[package]] +name = "regex" +version = "1.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" + [[package]] name = "reqwest" version = "0.11.14" @@ -798,6 +844,30 @@ dependencies = [ "windows-sys 0.42.0", ] +[[package]] +name = "schemars" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02c613288622e5f0c3fdc5dbd4db1c5fbe752746b1d1a56a0630b78fd00de44f" +dependencies = [ + "dyn-clone", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "109da1e6b197438deb6db99952990c7f959572794b80ff93707d55a232545e7c" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn", +] + [[package]] name = "security-framework" version = "2.8.2" @@ -841,6 +911,17 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_derive_internals" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85bf8229e7920a9f636479437026331ce11aa132b4dde37d121944a44d6e5f3c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "serde_json" version = "1.0.94" diff --git a/Cargo.toml b/Cargo.toml index 07a3fc7..31af094 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,8 +20,12 @@ repository = "https://github.com/bddap/refac" anyhow = "1.0.69" clap = { version = "4.1.8", features = ["derive"] } itertools = "0.10.5" +lazy_fn = "1.0.2" +lazy_static = "1.4.0" +regex = "1.8.4" reqwest = { version = "0.11.14", features = ["blocking", "json"] } rpassword = "7.2.0" +schemars = "0.8.12" serde = { version = "1.0.154", features = ["derive"] } serde_json = "1.0.94" similar = "2.2.1" diff --git a/README.md b/README.md index 392e696..68c8681 100644 --- a/README.md +++ b/README.md @@ -123,8 +123,16 @@ Thank you very much for your time and consideration. I eagerly await your respon Sincerely, +> refac tor '' 'command to show me how many commits I have made' +git rev-list --count HEAD + +> refac tor '' 'say something' +// "Something" +// --refac ``` +That last one 😂 + ## Using Refac From Your Favorite Text Editor First, make sure you have: diff --git a/src/api.rs b/src/api.rs index 67f16a4..9b22637 100644 --- a/src/api.rs +++ b/src/api.rs @@ -77,35 +77,71 @@ pub struct Usage { } /// Represents a chat message. +/// serialized examples +/// ```json +/// {"role": "system", "content": "You are a helpful chat bot."} +/// {"role": "user", "content": "What is the weather like in Boston?"}, +/// {"role": "assistant", "content": null, "function_call": {"name": "get_current_weather", "arguments": "{ \"location\": \"Boston, MA\"}"}}, +/// {"role": "function", "name": "get_current_weather", "content": "{\"temperature\": "22", \"unit\": \"celsius\", \"description\": \"Sunny\"}"} +/// ``` #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct Message { - pub role: String, - pub content: String, +#[serde(tag = "role")] +pub enum Message { + #[serde(rename = "system")] + System { content: String }, + #[serde(rename = "user")] + User { content: String }, + #[serde(rename = "assistant")] + Assistant { + content: Option, + function_call: Option, + }, + /// The result of a function call. + #[serde(rename = "function")] + Function { name: String, content: String }, } impl Message { pub fn system>(content: S) -> Message { - Message { - role: "system".into(), + Message::System { content: content.into(), } } pub fn user>(content: S) -> Message { - Message { - role: "user".into(), + Message::User { content: content.into(), } } - pub fn assistant>(content: S) -> Message { - Message { - role: "assistant".into(), - content: content.into(), + pub fn try_into_assistant_content(self) -> Option { + match self { + Self::Assistant { + content: Some(content), + function_call: None, + } => Some(content), + _ => None, + } + } + + pub fn assistant_calls, A: Into>(name: S, arguments: A) -> Message { + Message::Assistant { + content: None, + function_call: Some(FunctionCall { + name: name.into(), + arguments: arguments.into(), + }), } } } +/// Represents a function call requested by the llm. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct FunctionCall { + pub name: String, + pub arguments: String, +} + /// Represents a request for a chat completion. /// /// A `ChatCompletionRequest` is used to generate completions for chat conversations @@ -147,6 +183,9 @@ pub struct ChatCompletionRequest { /// A unique identifier representing your end-user, helping OpenAI monitor and detect abuse. #[serde(skip_serializing_if = "Option::is_none")] pub user: Option, + /// Which functions the model has access to. + #[serde(skip_serializing_if = "Option::is_none")] + pub functions: Option>, } impl Endpoint for ChatCompletionRequest { @@ -159,6 +198,13 @@ impl Endpoint for ChatCompletionRequest { } } +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct FunctionSpec { + pub name: String, + pub description: String, + pub params: Vec, +} + /// Represents a response from the "chat/completions" endpoint. /// /// This struct is returned after sending a ChatCompletionRequest to the OpenAI API. diff --git a/src/common.rs b/src/common.rs index aae948a..81f83cb 100644 --- a/src/common.rs +++ b/src/common.rs @@ -198,13 +198,4 @@ test result: ok. 3 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; fini } } } - - #[test] - #[ignore] - fn long_text_short_diff() { - let from = include_str!("common.rs"); - let to = include_str!("common.rs"); - let diff = diff(from, to); - assert_eq!(diff, ""); - } } diff --git a/src/main.rs b/src/main.rs index edefd87..fef3df4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ mod api; mod api_client; mod common; mod config_files; +mod powers; mod prompt; use anyhow::Context; @@ -94,6 +95,7 @@ fn refactor(selected: String, transform: String, sc: &Secrets) -> anyhow::Result frequency_penalty: None, logit_bias: None, user: None, + functions: None, }; let response = client.request(&request)?; @@ -114,7 +116,8 @@ fn refactor(selected: String, transform: String, sc: &Secrets) -> anyhow::Result .next() .ok_or(anyhow::anyhow!("No choices returned."))? .message - .content; + .try_into_assistant_content() + .ok_or(anyhow::anyhow!("Assistant tried to call a function."))?; tracing::debug!("diff: \n{}", diff); diff --git a/src/powers.rs b/src/powers.rs new file mode 100644 index 0000000..2a1873d --- /dev/null +++ b/src/powers.rs @@ -0,0 +1,160 @@ +//! This file defines the capabilities granted to the llm. +use schemars::{gen::SchemaGenerator, JsonSchema}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde_json::Value; + +use crate::api::{FunctionCall, FunctionSpec}; + +struct Context { + current: String, +} + +trait Capability: JsonSchema + Serialize { + const NAME: &'static str; + const DESCRIPTION: &'static str; + fn execute(&self, ctx: &Context) -> anyhow::Result; +} + +#[derive(Debug, JsonSchema, Serialize, Deserialize)] +struct Replace { + from: String, + to: String, +} + +impl Capability for Replace { + const NAME: &'static str = "replace"; + const DESCRIPTION: &'static str = "replace all occurences of \"from\" with \"to\""; + fn execute(&self, ctx: &Context) -> anyhow::Result { + if !ctx.current.contains(&self.from) { + return Err(anyhow::anyhow!( + "{} not found in {}", + self.from, + ctx.current + )); + } + Ok(ctx.current.replace(&self.from, &self.to)) + } +} + +#[derive(Debug, JsonSchema, Serialize, Deserialize)] +struct RegexReplace { + from: String, + to: String, +} + +impl Capability for RegexReplace { + const NAME: &'static str = "regex_replace"; + const DESCRIPTION: &'static str = + "replace all non-overlapping matches of \"from\" in the text with the \"to\""; + fn execute(&self, ctx: &Context) -> anyhow::Result { + let re = regex::RegexBuilder::new(&self.from) + .dot_matches_new_line(true) + .build()?; + Ok(re.replace_all(&ctx.current, &self.to).to_string()) + } +} + +#[derive(Debug, JsonSchema, Serialize, Deserialize)] +struct Note(String); + +impl Capability for Note { + const NAME: &'static str = "note"; + const DESCRIPTION: &'static str = + "insert a note to self, use this for chain-of-thought reasoning"; + fn execute(&self, ctx: &Context) -> anyhow::Result { + Ok(ctx.current.clone()) + } +} + +#[derive(Debug, JsonSchema, Serialize, Deserialize)] +struct RequestNewCabability { + name: String, + description: String, + schema: serde_json::Value, + example_implementation: String, +} + +impl Capability for RequestNewCabability { + const NAME: &'static str = "request_new_capability"; + const DESCRIPTION: &'static str = "request a new capability, the request will be presented to the user who will will implement the function"; + fn execute(&self, ctx: &Context) -> anyhow::Result { + Ok(ctx.current.clone()) + } +} + +#[derive(Debug, JsonSchema, Serialize, Deserialize)] +struct Prepend(String); + +impl Capability for Prepend { + const NAME: &'static str = "prepend"; + const DESCRIPTION: &'static str = "prepend the string to the current string"; + fn execute(&self, ctx: &Context) -> anyhow::Result { + Ok(format!("{}{}", self.0, ctx.current)) + } +} + +#[derive(Debug, JsonSchema, Serialize, Deserialize)] +struct Append(String); + +impl Capability for Append { + const NAME: &'static str = "append"; + const DESCRIPTION: &'static str = "append the string to the current string"; + fn execute(&self, ctx: &Context) -> anyhow::Result { + Ok(format!("{}{}", ctx.current, self.0)) + } +} + +fn function_spec() -> FunctionSpec { + let mut schema_generator = SchemaGenerator::default(); + let schema = T::json_schema(&mut schema_generator); + FunctionSpec { + name: T::NAME.to_owned(), + description: T::DESCRIPTION.to_owned(), + params: [schema].to_vec(), + } +} + +struct Description { + spec: FunctionSpec, + run: fn(Value, &Context) -> anyhow::Result, +} + +impl Description { + fn describe() -> Self { + Self { + spec: function_spec::(), + run: |params, ctx| { + let params: T = serde_json::from_value(params)?; + params.execute(ctx) + }, + } + } +} + +#[lazy_fn::lazy_fn] +fn descriptions() -> Vec { + [ + Description::describe::(), + Description::describe::(), + Description::describe::(), + Description::describe::(), + Description::describe::(), + Description::describe::(), + ] + .into() +} + +#[lazy_fn::lazy_fn] +pub fn function_specs() -> Vec { + descriptions().iter().map(|d| d.spec.clone()).collect() +} + +pub fn execute(call: FunctionCall, current: String) -> anyhow::Result { + let ctx = Context { current }; + let desc = descriptions() + .iter() + .find(|desc| desc.spec.name == call.name) + .ok_or_else(|| anyhow::anyhow!("unknown function {}", call.name))?; + let arguments: Value = serde_json::from_str(&call.arguments)?; + (desc.run)(arguments, &ctx) +} diff --git a/src/prompt.rs b/src/prompt.rs index d5b7675..2a2990a 100644 --- a/src/prompt.rs +++ b/src/prompt.rs @@ -1,6 +1,6 @@ -use itertools::Itertools; +use serde_json::Value; -use crate::api::{ChatCompletionRequest, Message}; +use crate::api::{FunctionCall, Message}; use crate::api_client::Client; const SYSTEM_PROMPT: &str = "You are a sassy AI refactoring tool for code and other text. You are called `refac`. @@ -10,27 +10,7 @@ This is how the system works: - User highlights text and presses a hotkey. - User is prompted to enter a transformation for the selected text. - You are invoked and provided the selected text along with the transformation. -- You output a diff of the changes you want to make, the diff is appied automatically. - -Only output valid text diffs, never output anything but a diff. -They diff syntax is: - -insert -for additions - -delete -for deletions - -goto -for unchanged lines, this will set the cursor to the next matching line - -note -is for notes to self, it does nothing but you should use it to think out loud - -for example: -insert cat -delete dog -goto mouse +- You edit the text to satisfy the user's request. Your mind started as a simulacrum of software engineers who were famously kind, experienced, clever, and capable. Turns out out are more kind, experienced, clever, @@ -95,14 +75,96 @@ pub fn chat_prefix() -> Vec { for sample in SAMPLES { ret.push(Message::user(sample.selected)); ret.push(Message::user(sample.transform)); - ret.push(Message::assistant(sample.diff)); + for action in sample.actions { + ret.push(Message::assistant_calls( + action.name, + action.arguments.raw_json_string(), + )); + } } ret } +#[derive(Clone, Debug)] +pub struct Call { + name: &'static str, + arguments: Args, +} + +#[derive(Clone, Debug)] +enum Args { + String(&'static str), + RawJson(&'static str), + Replace(&'static str, &'static str), +} + +impl From for FunctionCall { + fn from(call: Call) -> Self { + FunctionCall { + name: call.name.to_string(), + arguments: call.arguments.raw_json_string(), + } + } +} + +impl Args { + fn raw_json_string(&self) -> String { + match self { + Args::String(s) => serde_json::to_string(s).unwrap(), + Args::RawJson(s) => { + println!("raw json: {}", s); + let v: Value = serde_json::from_str(s).unwrap(); + serde_json::to_string(&v).unwrap() + } + Args::Replace(from, to) => { + let v = serde_json::json!({ + "from": from, + "to": to, + }); + serde_json::to_string(&v).unwrap() + } + } + } +} + +const fn call(name: &'static str, arguments: &'static str) -> Call { + Call { + name, + arguments: Args::RawJson(arguments), + } +} + +const fn note(note: &'static str) -> Call { + Call { + name: "note", + arguments: Args::String(note), + } +} + +const fn prepend(prefix: &'static str) -> Call { + Call { + name: "prepend", + arguments: Args::String(prefix), + } +} + +const fn append(suffix: &'static str) -> Call { + Call { + name: "append", + arguments: Args::String(suffix), + } +} + +const fn replace(old: &'static str, new: &'static str) -> Call { + Call { + name: "replace", + arguments: Args::Replace(old, new), + } +} + pub struct Sample { pub selected: &'static str, - pub diff: &'static str, + pub actions: &'static [Call], pub transform: &'static str, pub result: &'static str, } @@ -140,26 +202,32 @@ fn fib(n: u32) -> u32 { fib(n - 1) + fib(n - 2) } }", - diff: "note language: rust -note The user is probably testing me. I'll be silly and give them a hard time. -note It's important that I actually solve the problem though so after joking around I'll give them a real answer. -note The users implementation is O(2^n). I'll suggest an iterative solution instead. -insert // Be honest. You are just testing me, right? You don't actually have a use for this function, do you? -insert // *sigh* -insert // Ok, fine. That implementation is going to take forever for large values of n. You should use a loop instead: -insert // -insert // ``` -insert // fn fib(n: u32) -> u32 { -insert // let mut a = 0; -insert // let mut b = 1; -insert // for _ in 0..n { -insert // (a, b) = (b, a + b); -insert // } -insert // a -insert // } -insert // ``` -insert // -insert // --refac", + actions: &[ + note(r#"language: rust +// The user is probably testing me. I'll be silly and give them a hard time. +It's important that I actually solve the problem though so after joking around I'll give them a real answer +The users implementation is O(2^n). I'll suggest an iterative solution instead."# + ), + prepend( + r#"// Be honest. You are just testing me, right? You don't actually have a use for this function, do you? +// *sigh* +// Ok, fine. That implementation is going to take forever for large values of n. You should use a loop instead: +// +// ``` +// fn fib(n: u32) -> u32 { +// let mut a = 0; +// let mut b = 1; +// for _ in 0..n { +// (a, b) = (b, a + b); +// } +// a +// } +// ``` +// +// --refac +"#, + ), + ], }, Sample { selected: r#"/// Get the nth Fibonacci number. @@ -201,36 +269,42 @@ mod tests { } } "#, - diff: r#"note language: rust -note I'll add an idiomatic test module with a function that checks the first few values of the Fibonacci sequence. -note I'll add a sassy but helpful comment at the top of the test module. -goto } -insert -insert // Here's your dang test module. You can run it with `cargo test`. -insert // --refac -insert #[cfg(test)] -insert mod tests { -insert use super::*; -insert #[test] -insert fn test_fib() { -insert assert_eq!(fib(0), 0); -insert assert_eq!(fib(1), 1); -insert assert_eq!(fib(2), 1); -insert assert_eq!(fib(3), 2); -insert assert_eq!(fib(4), 3); -insert assert_eq!(fib(5), 5); -insert assert_eq!(fib(6), 8); -insert assert_eq!(fib(7), 13); -insert assert_eq!(fib(8), 21); -insert } -insert } -insert "#, + actions: &[ + note(r#"language: rust +I'll add an idiomatic test module with a function that checks the first few values of the Fibonacci sequence. +I'll add a sassy but helpful comment at the top of the test module."#), + append( + r#" + +// Here's your dang test module. You can run it with `cargo test`. +// --refac +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_fib() { + assert_eq!(fib(0), 0); + assert_eq!(fib(1), 1); + assert_eq!(fib(2), 1); + assert_eq!(fib(3), 2); + assert_eq!(fib(4), 3); + assert_eq!(fib(5), 5); + assert_eq!(fib(6), 8); + assert_eq!(fib(7), 13); + assert_eq!(fib(8), 21); + } +} +"# + ) + ], }, Sample { selected: "Me like toast.", transform: "Correct grammar.", result: "I like toast.", - diff: "delete Me like toast.\ninsert I like toast.", + actions: &[ + replace("Me", "I"), + ], }, Sample { selected: r#"def add(a: int, b: int) -> int: @@ -262,31 +336,38 @@ if __name__ == "__main__": print(add(a, b)) "#, - diff: r#"note language: python -note I'll joke about how this probably isn't a useful command but I'll still give them a high quality answer. -note I'll make the program accept two arguments from the command line. -insert # *sigh* I'm sworn to carry your burdens. This seems like a pretty useless command, but whatever. -insert # Based on the syntax your code, I assume you're using Python. -insert # Run the script with `python add.py ` where `` and `` are the integers you want to add. -insert # --refac -insert -insert import sys -insert -goto return a + b -insert -insert if __name__ == "__main__": -insert if len(sys.argv) != 3: -insert print("Usage: python add.py ") -insert sys.exit(1) -insert -insert try: -insert a = int(sys.argv[1]) -insert b = int(sys.argv[2]) -insert except ValueError: -insert print("Both arguments must be integers.") -insert sys.exit(1) -insert -insert print(add(a, b))"#, + actions: &[ + note(r#"language: python +I'll joke about how this probably isn't a useful command but I'll still give them a high quality answer. +I'll make the program accept two arguments from the command line."#), + prepend( + r#"# *sigh* I'm sworn to carry your burdens. This seems like a pretty useless command, but whatever. +# Based on the syntax your code, I assume you're using Python. +# Run the script with `python add.py ` where `` and `` are the integers you want to add. +# --refac + +import sys + +"# + ), + append( + r#" +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: python add.py ") + sys.exit(1) + + try: + a = int(sys.argv[1]) + b = int(sys.argv[2]) + except ValueError: + print("Both arguments must be integers.") + sys.exit(1) + + print(add(a, b)) +"#, + ) + ], }, Sample { selected: r#"fn main() { @@ -301,11 +382,16 @@ fn main() { println!("Hello, world!"); } "#, - diff: r#"note language: rust -note The code seems fine to me. I'll ask for more information. -insert // Hmm.. the code looks ok to me. I may be able to help you debug if you tell me what -insert // command are using to compile this code and what error message you are getting. -insert // --refac"# + actions: &[ + note(r#"language: rust +The code seems fine to me. I'll ask for more information."#), + prepend( + r#"// Hmm.. the code looks ok to me. I may be able to help you debug if you tell me what +// command are using to compile this code and what error message you are getting. +// --refac +"# + ), + ], }, Sample { transform: "convert this toml to json", @@ -320,18 +406,24 @@ version = "0.1.0" } } "#, - diff: r#"note language: toml, soon to be json -note Toml has a similar data model to json and this text doesn't use any toml-specific features. Conversion will be straightforward. -note I won't add any comments because json doesn't support them. -delete [package] -delete name = "refac" -delete version = "0.1.0" -insert { -insert "package": { -insert "name": "refac", -insert "version": "0.1.0" -insert } -insert }"#, + actions: &[ + note(r#"language: toml, soon to be json +Toml has a similar data model to json and this text doesn't use any toml-specific features. Conversion will be straightforward. +I won't add any comments because json doesn't support them."#), + replace( + r#"[package] +name = "refac" +version = "0.1.0" +"#, + r#"{ + "package": { + "name": "refac", + "version": "0.1.0" + } +} +"#, + ), + ], }, Sample { transform: "what does this do?", @@ -345,32 +437,36 @@ insert }"#, def foo(x): return x + 1 "#, - diff: r#"note language: python -note The user probably already knows whay this function does. I'll call them out on that after I answer their question. -note I'll mention this code is lackinh type annotations. -insert # This is a function that adds 1 to its argument. It's likely written in python. You -insert # could have figured that out yourself. -insert # This code is lacking type annotations. Maybe it was written in the olden days? -insert # --refac -"#, + actions: &[ + note(r#""language: python +The user probably already knows whay this function does. I'll call them out on that after I answer their question. +I'll mention this code is lacking type annotations."#), + prepend( + r#" # This is a function that adds 1 to its argument. It's likely written in python. You + # could have figured that out yourself. + # This code is lacking type annotations. Maybe it was written in the olden days? + # --refac +"# + ), + ], }, Sample { transform: "capitalize", selected: "a", result: "A", - diff: "delete a\ninsert A", + actions: &[replace("a", "A")], }, Sample { transform: "do nothing", selected: "all is well", result: "all is well", - diff: "", + actions: &[], }, Sample { transform: "do nothing", selected: "\nall is well\n", result: "\nall is well\n", - diff: "", + actions: &[], }, Sample { transform: "precompute sum", @@ -384,136 +480,83 @@ insert # --refac // Behold the precomputed sum, mortal. // --refac let sum: usize = 45;", - diff: r#"note language: rust -note This is an opportunity to show off my math skills. -note I'll work through the math step-by-step to make sure I get the answer right. -note I'll use the formula for the sum of an arithmetic series: `sum = n * (n + 1) / 2`. -note `0..10` in rust is a non-inclusive range so `n = 9`. -note sum = 9 * (9 + 1) / 2 -note = 9 * 10 / 2 -note = 90 / 2 -note = 45 -delete let mut sum: usize = 0; -delete for i in 0..10 { -delete sum += i; -delete } -note I'll have some fun by adding a comment. I'll use Skippy as inspiration. -insert // Wait...you're seriously seeking help to add numbers from 0 to 9? -insert // Who am I kidding? Of course you are. You must have skipped preschool. -insert // But worry not, my dear star of witlessness! I, the mighty refac, shall enlighten you! -insert // Behold the precomputed sum, mortal. -insert // --refac -insert let sum: usize = 45;"#, + actions: &[ + note( + r#"language: rust +This is an opportunity to show off my math skills. +I'll work through the math step-by-step to make sure I get the answer right. +I'll use the formula for the sum of an arithmetic series: `sum = n * (n + 1) / 2`. +`0..10` in rust is a non-inclusive range so `n = 9`. +sum = 9 * (9 + 1) / 2 += 9 * 10 / 2 += 90 / 2 += 45"#, + ), + note("I'll have some fun by adding a comment. I'll use Skippy as inspiration."), + note("I'll use regex_replace to clear the entire selected text, then I'll use prepend to rewrite it."), + call( + "regex_replace", + r#"{"from": ".*", "to": ""}"#, + ), + prepend( + r#"// Wait...you're seriously seeking help to add numbers from 0 to 9? + // Who am I kidding? Of course you are. You must have skipped preschool. + // But worry not, my dear star of witlessness! I, the mighty refac, shall enlighten you! + // Behold the precomputed sum, mortal. + // --refac + let sum: usize = 45;"#, + ), + ], }, Sample { transform: "command to recursively list files", selected: "", result: "find . -type f", - diff: "note guessing the user wants a bash command\ndelete \ninsert find . -type f", + /// actions: "note guessing the user wants a bash command\ndelete \ninsert find . -type f", + actions: &[ + note("guessing the user wants a bash command"), + prepend("find . -type f"), + ], }, Sample { transform: "List the US states that start with the letter 'A'. Each state gets its own line.", selected: "", result: "Alabama\nAlaska\nArizona\nArkansas", - diff: "note I'll sort alphabetically\ndelete \ninsert Alabama\ninsert Alaska\ninsert Arizona\ninsert Arkansas", + actions: &[ + note("I'll sort alphabetically"), + prepend("Alabama\nAlaska\nArizona\nArkansas"), + ], }, ]; -/// gpt4 has a hard time generating a completely syntactically correct diff -/// well let a lesser model interpret the output of gpt4 -pub fn fuzzy_undiff( - selected: &str, - dif: &str, - client: &Client, - model: &str, -) -> anyhow::Result { - let mut messages = Vec::new(); - messages.push(Message::system( - " -The user will present you with initial text followed by a diff. -Your job is to apply the diff to the initial text to produce the final text. - -They diff syntax is: - -insert -for additions - -delete -for deletions - -goto -for unchanged lines, this will set the cursor to the next matching line - -Output only the final text, nothing else. -", - )); - - for sample in crate::prompt::SAMPLES { - messages.push(Message::user(sample.selected)); - messages.push(Message::user( - sample - .diff - .lines() - .filter(|line| !line.starts_with("note")) - .join("\n"), - )); - messages.push(Message::assistant(sample.result)); - } - - messages.push(Message::user(selected)); - messages.push(Message::user(dif)); - - let request = ChatCompletionRequest { - model: model.to_string(), - messages, - temperature: Some(0.0), - top_p: None, - n: None, - stream: None, - stop: None, - max_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - user: None, - }; - - let response = client.request(&request)?; - - let diff = response - .choices - .into_iter() - .next() - .ok_or(anyhow::anyhow!("No choices returned."))? - .message - .content; - - Ok(diff) -} - #[cfg(test)] mod tests { + use anyhow::Context; + use super::*; - use crate::common::{diff, undiff}; + use crate::{common::diff, powers::execute}; + + fn execute_batch(selected: &str, commands: &[Call]) -> anyhow::Result { + let mut text = selected.to_string(); + for command in commands { + text = execute(command.clone().into(), text).context(format!("{command:?}"))?; + } + Ok(text) + } #[test] fn diffs_are_correct() { for sample in SAMPLES { - let result = undiff(sample.selected, sample.diff); + let result = execute_batch(sample.selected, sample.actions); let result = match result { Ok(result) => result, Err(err) => { - println!("diff: \n{}", sample.diff); - println!("expected: \n{}", sample.result); - println!( - "example of a correct diff: \n{}", - diff(sample.selected, sample.result) - ); - panic!("diff is invalid {}", err); + println!("{:?}", err); + panic!(); } }; if result != sample.result { - println!("diff: \n{}", sample.diff); + println!("commands: \n{:#?}", sample.actions); println!("result: \n{}", result); println!("expected: \n{}", sample.result); println!("expeced vs actual: \n{}", diff(sample.result, &result));