Skip to content

Commit

Permalink
feat: redactions are reverted before returning the output
Browse files Browse the repository at this point in the history
  • Loading branch information
kyluke mcdougall committed Dec 2, 2024
1 parent b8d9ffd commit 569f490
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 13 deletions.
12 changes: 11 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ tempfile = "3.13.0"
dirs = "5.0.1"
regex = "1.11.0"

[dependencies.uuid]
version = "1.11.0"
features = ["v4"]
23 changes: 11 additions & 12 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod openai;
mod output;
mod path;
mod repository;
mod redactions;

use crate::args::Args;
use crate::config::repository::ConfigRepository;
Expand All @@ -15,7 +16,7 @@ use clap::Parser;
use config::{model::keys::ConfigKeys, service::config_service};
use openai::service::chat::chat;
use output::outputter;
use regex::Regex;
use output::message::Message;
use repository::db::SqliteRepository;
use std::fs::create_dir_all;
use std::io::IsTerminal;
Expand Down Expand Up @@ -92,16 +93,9 @@ async fn request_response_from_ai<R: ConfigRepository>(
None => input.clone(),
};

let redactions = redacted_config::fetch_redactions(repo);
let input_with_redactions =
redactions
.iter()
.fold(input_with_local_context.clone(), |acc, redaction| {
let re = Regex::new(&format!("(?i){}", regex::escape(redaction))).unwrap();
re.replace_all(&acc, "<redacted>").to_string()
});
let (redacted_input, mapped_redactions) = redactions::redact::redact(repo, &input_with_local_context);

let chat_response = match chat(&open_ai_api_key.value, user_defined_system_prompt, &input_with_redactions).await {
let chat_response = match chat(&open_ai_api_key.value, user_defined_system_prompt, &redacted_input).await {
Ok(response) => response,
Err(err) => {
println!("{:#?}", err);
Expand All @@ -112,9 +106,14 @@ async fn request_response_from_ai<R: ConfigRepository>(
let output_messages = chat_response
.iter()
.map(|message| message.to_output_message())
.collect();
.collect::<Vec<Message>>();

outputter::print(output_messages);
let unredacted_messages = output_messages.iter().map(|message| {
let unredacted_message = redactions::revert::unredact(&mapped_redactions, &message.message);
message.copy_with_message(unredacted_message)
}).collect::<Vec<Message>>();

outputter::print(unredacted_messages);
Ok(())
}

Expand Down
1 change: 1 addition & 0 deletions src/openai/model/role.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[derive(Clone)]
pub enum Role {
System,
User,
Expand Down
9 changes: 9 additions & 0 deletions src/output/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,12 @@ pub struct Message {
pub role: Role,
pub message: String,
}

impl Message {
pub fn copy_with_message(&self, message: String) -> Self {
return Message {
role: self.role.clone(),
message
}
}
}
14 changes: 14 additions & 0 deletions src/redactions/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use std::collections::HashMap;
use uuid::Uuid;

pub fn redaction_map(redactions: Vec<String>) -> HashMap<String, String> {
let mut map = HashMap::new();
for redaction in redactions {
map.insert(redaction, generate_uuid_v4());
}
map
}

fn generate_uuid_v4() -> String {
Uuid::new_v4().to_string()
}
3 changes: 3 additions & 0 deletions src/redactions/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub(crate) mod redact;
pub(crate) mod revert;
pub(crate) mod common;
21 changes: 21 additions & 0 deletions src/redactions/redact.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use std::collections::HashMap;
use regex::Regex;
use crate::config::repository::ConfigRepository;
use crate::config::service::redacted_config;

use super::common;

pub fn redact<R: ConfigRepository>(repo: &R, content: &str) -> (String, HashMap<String, String>) {
let redactions = redacted_config::fetch_redactions(repo);
let mapped_redactions = common::redaction_map(redactions);

let input_with_redactions =
mapped_redactions
.iter()
.fold(content.to_string(), |acc, (redaction, id)| {
let re = Regex::new(&format!("(?i){}", regex::escape(redaction))).unwrap();
re.replace_all(&acc, id).to_string()
});

(input_with_redactions.to_string(), mapped_redactions)
}
12 changes: 12 additions & 0 deletions src/redactions/revert.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use std::collections::HashMap;
use regex::Regex;

pub fn unredact(mapped_redactions: &HashMap<String, String>, content: &str) -> String {
mapped_redactions
.iter()
.fold(content.to_string(), |acc, (redaction, id)| {
let re = Regex::new(&format!("(?i){}", regex::escape(id))).unwrap();
re.replace_all(&acc, redaction).to_string()
})
}

0 comments on commit 569f490

Please sign in to comment.