From 8fde10b489fc759cbf198602f90d6d3751090bad Mon Sep 17 00:00:00 2001 From: Vadim Anufriev Date: Tue, 4 Jun 2024 16:32:25 +0400 Subject: [PATCH] change function signatures --- core/lib/Cargo.toml | 1 - core/lib/src/config/secret_key.rs | 28 +++++++++------------ examples/private-data/Cargo.toml | 1 + examples/private-data/src/main.rs | 40 ++++++++++++++++++++++-------- examples/private-data/src/tests.rs | 13 +++++++++- 5 files changed, 55 insertions(+), 28 deletions(-) diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index ea25fe346a..ded8df6ae3 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -74,7 +74,6 @@ cookie = { version = "0.18", features = ["percent-encode"] } futures = { version = "0.3.30", default-features = false, features = ["std"] } state = "0.6" aes-gcm = "0.10.3" -base64 = "0.22.1" [dependencies.hyper-util] version = "0.1.3" diff --git a/core/lib/src/config/secret_key.rs b/core/lib/src/config/secret_key.rs index a8c48c76b7..4489c93a0f 100644 --- a/core/lib/src/config/secret_key.rs +++ b/core/lib/src/config/secret_key.rs @@ -2,10 +2,7 @@ use std::fmt; use aes_gcm::{Aes256Gcm, Nonce}; use aes_gcm::aead::{generic_array::GenericArray, Aead, KeyInit}; -use base64::{engine::general_purpose::URL_SAFE, Engine as _}; use rand::RngCore; - - use cookie::Key; use serde::{de, ser, Deserialize, Serialize}; @@ -194,7 +191,9 @@ impl SecretKey { /// /// # Example /// ```rust - /// let plaintext = "I like turtles"; + /// use rocket::config::SecretKey; + /// + /// let plaintext = "I like turtles".as_bytes(); /// let secret_key = SecretKey::generate().expect("error generate key"); /// /// let encrypted = secret_key.encrypt(&plaintext).expect("can't encrypt"); @@ -202,7 +201,7 @@ impl SecretKey { /// /// assert_eq!(decrypted, plaintext); /// ``` - pub fn encrypt(&self, plaintext: &str) -> Result { + pub fn encrypt>(&self, value: T) -> Result, &'static str> { // Convert the encryption key to a fixed-length array let key: [u8; KEY_LEN] = self.key.encryption().try_into().map_err(|_| "enc key len error")?; @@ -216,31 +215,29 @@ impl SecretKey { let nonce = Nonce::from_slice(&nonce); // Encrypt the plaintext using the nonce - let ciphertext = aead.encrypt(nonce, plaintext.as_ref()).map_err(|_| "encryption error")?; + let ciphertext = aead.encrypt(nonce, value.as_ref()).map_err(|_| "encryption error")?; // Prepare a vector to hold the nonce and ciphertext let mut encrypted_data = Vec::with_capacity(NONCE_LEN + ciphertext.len()); encrypted_data.extend_from_slice(nonce); encrypted_data.extend_from_slice(&ciphertext); - // Return the base64-encoded result - Ok(URL_SAFE.encode(encrypted_data)) + Ok(encrypted_data) } /// Decrypts the given base64-encoded encrypted data. /// Extracts the nonce from the data and uses it for decryption. /// Returns the decrypted plaintext string. - pub fn decrypt(&self, encrypted: &str) -> Result { - // Decode the base64-encoded encrypted data - let decoded = URL_SAFE.decode(encrypted).map_err(|_| "bad base64 value")?; + pub fn decrypt>(&self, encrypted: T) -> Result, &'static str> { + let encrypted = encrypted.as_ref(); // Check if the length of decoded data is at least the length of the nonce - if decoded.len() < NONCE_LEN { - return Err("length of decoded data is < NONCE_LEN"); + if encrypted.len() <= NONCE_LEN { + return Err("length of encrypted data is <= NONCE_LEN"); } // Split the decoded data into nonce and ciphertext - let (nonce, ciphertext) = decoded.split_at(NONCE_LEN); + let (nonce, ciphertext) = encrypted.split_at(NONCE_LEN); let nonce = Nonce::from_slice(nonce); // Convert the encryption key to a fixed-length array @@ -253,8 +250,7 @@ impl SecretKey { let decrypted = aead.decrypt(nonce, ciphertext) .map_err(|_| "invalid key/nonce/value: bad seal")?; - // Convert the decrypted bytes to a UTF-8 string - String::from_utf8(decrypted).map_err(|_| "bad unsealed utf8") + Ok(decrypted) } } diff --git a/examples/private-data/Cargo.toml b/examples/private-data/Cargo.toml index 83624ad16d..15eea69093 100644 --- a/examples/private-data/Cargo.toml +++ b/examples/private-data/Cargo.toml @@ -7,3 +7,4 @@ publish = false [dependencies] rocket = { path = "../../core/lib", features = ["secrets"] } +base64 = "0.22.1" diff --git a/examples/private-data/src/main.rs b/examples/private-data/src/main.rs index cd77b0c571..e9bb9c7545 100644 --- a/examples/private-data/src/main.rs +++ b/examples/private-data/src/main.rs @@ -3,29 +3,49 @@ extern crate rocket; use rocket::{Config, State}; use rocket::fairing::AdHoc; +use rocket::response::status; +use rocket::http::Status; +use base64::{engine::general_purpose::URL_SAFE, Engine as _}; #[cfg(test)] mod tests; #[get("/encrypt/")] -fn encrypt_endpoint(msg: &str, config: &State) -> String{ +fn encrypt_endpoint(msg: &str, config: &State) -> Result> { let secret_key = config.secret_key.clone(); - let encrypted = secret_key.encrypt(msg).unwrap(); + + let encrypted = secret_key.encrypt(msg).map_err(|_| { + status::Custom(Status::InternalServerError, "Failed to encrypt message".to_string()) + })?; + + let encrypted_msg = URL_SAFE.encode(&encrypted); info!("received message for encrypt: '{}'", msg); - info!("encrypted msg: '{}'", encrypted); + info!("encrypted msg: '{}'", encrypted_msg); - encrypted + Ok(encrypted_msg) } #[get("/decrypt/")] - fn decrypt_endpoint(msg: &str, config: &State) -> String { - let secret_key = config.secret_key.clone(); - let decrypted = secret_key.decrypt(msg).unwrap(); +fn decrypt_endpoint(msg: &str, config: &State) -> Result> { + let secret_key = config.secret_key.clone(); + + let decoded = URL_SAFE.decode(msg).map_err(|_| { + status::Custom(Status::BadRequest, "Failed to decode base64".to_string()) + })?; + + let decrypted = secret_key.decrypt(&decoded).map_err(|_| { + status::Custom(Status::InternalServerError, "Failed to decrypt message".to_string()) + })?; + + let decrypted_msg = String::from_utf8(decrypted).map_err(|_| { + status::Custom(Status::InternalServerError, + "Failed to convert decrypted message to UTF-8".to_string()) + })?; - info!("received message for decrypt: '{}'", msg); - info!("decrypted msg: '{}'", decrypted); + info!("received message for decrypt: '{}'", msg); + info!("decrypted msg: '{}'", decrypted_msg); - decrypted + Ok(decrypted_msg) } #[launch] diff --git a/examples/private-data/src/tests.rs b/examples/private-data/src/tests.rs index c8f000777e..041309022c 100644 --- a/examples/private-data/src/tests.rs +++ b/examples/private-data/src/tests.rs @@ -1,7 +1,18 @@ -use rocket::local::blocking::Client; +use rocket::{config::SecretKey, local::blocking::Client}; #[test] fn encrypt_decrypt() { + let secret_key = SecretKey::generate().unwrap(); + let msg = "very-secret-message".as_bytes(); + + let encrypted = secret_key.encrypt(msg).unwrap(); + let decrypted = secret_key.decrypt(&encrypted).unwrap(); + + assert_eq!(msg, decrypted); +} + +#[test] +fn encrypt_decrypt_api() { let client = Client::tracked(super::rocket()).unwrap(); let msg = "some-secret-message";