From db48bb87d38cc6f5fbe90c7ea987147d47add873 Mon Sep 17 00:00:00 2001 From: Paul Flynn Date: Fri, 21 Jun 2024 00:03:05 -0400 Subject: [PATCH] Update shared secret handling and KAS key initialization The commit updates how shared secrets are handled, making use of asynchronous programming principles for more efficient execution. The usage of static KAS keys has also been replaced with an improved initialization function that ensures the keys are only loaded once. This change leads to significant performance improvements and reduces the possibility of error during key loading. --- Cargo.toml | 4 +- src/main.rs | 166 ++++++++++++++++++++++++++-------------------------- 2 files changed, 84 insertions(+), 86 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4bc52c7..d3fbcba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,9 +18,9 @@ serde = { version = "1.0.202", features = ["derive"] } digest = "0.10.7" data-encoding = "2.6.0" pem = "3.0.4" -lazy_static = "1.4.0" rust-crypto = "0.2.36" hmac = "0.13.0-pre.3" hkdf = "0.13.0-pre.3" aes-gcm = "0.10.3" -p256 = "0.13.2" \ No newline at end of file +p256 = "0.13.2" +once_cell = "1.19.0" \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index e1b5aee..0f3f87e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,18 +1,15 @@ -use std::fs; use std::sync::Arc; use std::sync::RwLock; use aes_gcm::{aead::{Aead, KeyInit, OsRng}, Aes256Gcm, Key}; use aes_gcm::aead::generic_array::GenericArray; use futures_util::{SinkExt, StreamExt}; -use lazy_static::lazy_static; +use once_cell::sync::OnceCell; use p256::elliptic_curve::sec1::ToEncodedPoint; use p256::SecretKey; -use rand::Rng; +use rand::RngCore; use serde::{Deserialize, Serialize}; -use sha2::{Digest, Sha256}; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::Mutex; use tokio_tungstenite::accept_async; use tokio_tungstenite::tungstenite::Message; use x25519_dalek::{EphemeralSecret, PublicKey}; @@ -28,16 +25,16 @@ struct PublicKeyMessage { public_key: Vec, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug)] struct ConnectionState { - shared_secret: Option>, + shared_secret: RwLock>>, } impl ConnectionState { fn new() -> Self { println!("New ConnectionState"); ConnectionState { - shared_secret: None, + shared_secret: RwLock::new(None), } } } @@ -62,40 +59,17 @@ impl MessageType { } } -lazy_static! { - static ref KAS_PUBLIC_KEY_DER: RwLock>> = RwLock::new(None); +struct KasKeys { + public_key: Vec, + private_key: SecretKey, } -#[tokio::main] +static KAS_KEYS: OnceCell> = OnceCell::new(); + +#[tokio::main(flavor = "multi_thread", worker_threads = 4)] async fn main() { // KAS public key - // Load the PEM file - let pem_content = fs::read_to_string("recipient_private_key.pem").unwrap(); - // Load the private key from PEM format - let ec_pem_contents = pem_content.as_bytes(); - // Parse the pem file - let pem = pem::parse(ec_pem_contents).expect("Failed to parse the PEM."); - // Ensure it's an EC private key - if pem.tag() != "EC PRIVATE KEY" { - println!("Not an EC private key: {:?}", pem.tag()); - } - // Parse the private key - let kas_private_key = SecretKey::from_sec1_der(pem.contents()); - // Check if successful and continue if Ok - match kas_private_key { - Ok(kas_private_key) => { - // Derive the corresponding public key - let kas_public_key = kas_private_key.public_key(); - let kas_public_key_der = kas_public_key.to_encoded_point(true); - let kas_public_key_der_bytes = kas_public_key_der.as_bytes().to_vec(); - // Set static KAS_PUBLIC_KEY_DER - { - let mut kas_public_key_der = KAS_PUBLIC_KEY_DER.write().unwrap(); - *kas_public_key_der = Some(kas_public_key_der_bytes); - } - } - Err(error) => println!("Problem with the secret key: {:?}", error), - } + init_kas_keys().expect("KAS key not loaded"); // Bind the server to localhost on port 8080 let try_socket = TcpListener::bind("0.0.0.0:8080").await; let listener = match try_socket { @@ -108,12 +82,14 @@ async fn main() { println!("Listening on: 0.0.0.0:8080"); // Accept connections while let Ok((stream, _)) = listener.accept().await { - let connection_state = Arc::new(Mutex::new(ConnectionState::new())); - tokio::spawn(handle_connection(stream, connection_state)); + let connection_state = Arc::new(ConnectionState::new()); + tokio::spawn(async move { + handle_connection(stream, connection_state).await + }); } } -async fn handle_connection(stream: TcpStream, connection_state: Arc>) { +async fn handle_connection(stream: TcpStream, connection_state: Arc) { let ws_stream = match accept_async(stream).await { Ok(ws) => ws, Err(e) => { @@ -146,7 +122,7 @@ async fn handle_connection(stream: TcpStream, connection_state: Arc>, + connection_state: &Arc, data: Vec, ) -> Option { if data.len() < 1 { @@ -169,21 +145,14 @@ async fn handle_binary_message( } async fn handle_rewrap( - connection_state: &Arc>, + connection_state: &Arc, payload: &[u8], ) -> Option { let session_shared_secret = { - let connection_state = connection_state.lock().await; - // Ensure we have a shared secret - match &connection_state.shared_secret { - Some(secret) => Some(secret.clone()), - None => { - eprintln!("No shared secret available"); - None - } - } + let shared_secret = connection_state.shared_secret.read().unwrap(); + shared_secret.clone() }; - println!("session shared_secret {:?}", session_shared_secret); + println!("Shared Secret Connection: {}", hex::encode(session_shared_secret.clone().unwrap())); // Parse NanoTDF header let mut parser = BinaryParser::new(payload); let header = match BinaryParser::parse_header(&mut parser) { @@ -198,7 +167,6 @@ async fn handle_rewrap( println!("policy {:?}", policy); let policy = header.get_policy(); println!("policy binding hex: {}", hex::encode(policy.get_binding().clone().unwrap())); - println!("tdf_ephemeral_key {:?}", header.get_ephemeral_key()); println!("tdf_ephemeral_key hex: {}", hex::encode(header.get_ephemeral_key())); let tdf_ephemeral_key_bytes = header.get_ephemeral_key(); // Deserialize the public key sent by the client @@ -209,7 +177,8 @@ async fn handle_rewrap( let payload_arr = <[u8; 32]>::try_from(&tdf_ephemeral_key_bytes[1..]).unwrap(); let tdf_ephemeral_public_key = PublicKey::from(payload_arr); println!("tdf_ephemeral_key {:?}", tdf_ephemeral_public_key); - + let kas_private_key = get_kas_private_key().unwrap(); + println!("kas_private_key {:?}", kas_private_key); // TODO Verify the policy binding // TODO Access check // Generate Symmetric Key @@ -233,7 +202,8 @@ async fn handle_rewrap( let session_shared_secret = session_shared_secret.unwrap(); let key = Key::::from_slice(&session_shared_secret); let cipher = Aes256Gcm::new(&key); - let nonce: [u8; 12] = rand::thread_rng().gen(); // NONCE MUST BE UNIQUE FOR EACH MESSAGE + let mut nonce = [0u8; 12]; + OsRng.fill_bytes(&mut nonce); let nonce = GenericArray::from_slice(&nonce); let mut wrapped_dek_shared_secret = cipher.encrypt(nonce, dek_shared_secret.as_ref()) .expect("encryption failure!"); @@ -244,13 +214,14 @@ async fn handle_rewrap( } async fn handle_public_key( - connection_state: &Arc>, + connection_state: &Arc, payload: &[u8], ) -> Option { { - let connection_state = connection_state.lock().await; - println!("Connection shared secret: {:?}", connection_state.shared_secret); - if connection_state.shared_secret.is_some() { + let shared_secret_lock = connection_state.shared_secret.read(); + let shared_secret = shared_secret_lock.unwrap(); + if shared_secret.is_some() { + println!("Shared Secret Connection: {}", hex::encode(shared_secret.clone().unwrap())); return None; } } @@ -266,49 +237,76 @@ async fn handle_public_key( println!("Client Public Key: {:?}", client_public_key); // Generate an ephemeral private key let server_private_key = EphemeralSecret::random_from_rng(OsRng); - let mut server_public_key = PublicKey::from(&server_private_key); + let server_public_key = PublicKey::from(&server_private_key); // Perform the key agreement let shared_secret = server_private_key.diffie_hellman(&client_public_key); let shared_secret_bytes = shared_secret.as_bytes(); println!("Shared Secret +++++++++++++"); println!("Shared Secret: {}", hex::encode(shared_secret_bytes)); println!("Shared Secret +++++++++++++"); - // Hash the shared secret - let mut hasher = Sha256::new(); - hasher.update(shared_secret_bytes); - let hashed_secret = hasher.finalize(); - // TODO calculate symmetricKey us hkdf - + { + let shared_secret = connection_state.shared_secret.write(); + *shared_secret.unwrap() = Some(shared_secret_bytes.to_vec()); + } // Convert server_public_key to bytes let server_public_key_bytes = server_public_key.to_bytes(); - // Determine prefix: 0x02 for even y, 0x03 for odd y - println!("Server Public Key Size: {:?} bytes", server_public_key); // Send server_public_key as publicKey message let mut response_data = Vec::new(); // Appending MessageType::PublicKey response_data.push(MessageType::PublicKey as u8); - // Appending my_public_key bytes + // Appending server_public_key bytes response_data.extend_from_slice(&server_public_key_bytes); - // Update the connection state with the hashed shared secret - // TODO store symmetric key not shared secret - { - let mut connection_state = connection_state.lock().await; - connection_state.shared_secret = Some(hashed_secret.to_vec()); - } Some(Message::Binary(response_data)) } -async fn handle_kas_public_key(payload: &[u8]) -> Option { - println!("Received KAS public key: {:?}", payload); - // TODO Use static KAS_PUBLIC_KEY_DER - let kas_public_key_der = KAS_PUBLIC_KEY_DER.read().unwrap(); - if let Some(ref kas_public_key_bytes) = *kas_public_key_der { +async fn handle_kas_public_key(_: &[u8]) -> Option { + println!("Handling KAS public key"); + if let Some(kas_public_key_bytes) = get_kas_public_key() { println!("KAS Public Key Size: {} bytes", kas_public_key_bytes.len()); - // TODO make sure compressed key of 33 bytes is sent not 65 let mut response_data = Vec::new(); response_data.push(MessageType::KasPublicKey as u8); - response_data.append(&mut AsRef::<[u8]>::as_ref(kas_public_key_bytes).to_vec()); + response_data.extend_from_slice(&kas_public_key_bytes); return Some(Message::Binary(response_data)); } - return None; + None + // let kas_public_key_der = KAS_PUBLIC_KEY_DER.read().unwrap(); + // if let Some(ref kas_public_key_bytes) = *kas_public_key_der { + // println!("KAS Public Key Size: {} bytes", kas_public_key_bytes.len()); + // let mut response_data = Vec::new(); + // response_data.push(MessageType::KasPublicKey as u8); + // response_data.append(&mut AsRef::<[u8]>::as_ref(kas_public_key_bytes).to_vec()); + // return Some(Message::Binary(response_data)); + // } + // return None; +} + +fn init_kas_keys() -> Result<(), Box> { + let pem_content = std::fs::read_to_string("recipient_private_key.pem")?; + let ec_pem_contents = pem_content.as_bytes(); + let pem = pem::parse(ec_pem_contents)?; + + if pem.tag() != "EC PRIVATE KEY" { + return Err("Not an EC private key".into()); + } + + let kas_private_key = SecretKey::from_sec1_der(pem.contents())?; + let kas_public_key = kas_private_key.public_key(); + let kas_public_key_der = kas_public_key.to_encoded_point(true); + let kas_public_key_der_bytes = kas_public_key_der.as_bytes().to_vec(); + + let kas_keys = KasKeys { + public_key: kas_public_key_der_bytes, + private_key: kas_private_key, + }; + + KAS_KEYS.set(Arc::new(kas_keys)) + .map_err(|_| "KAS keys already initialized".into()) +} + +fn get_kas_public_key() -> Option> { + KAS_KEYS.get().map(|keys| keys.public_key.clone()) +} + +fn get_kas_private_key() -> Option { + KAS_KEYS.get().map(|keys| keys.private_key.clone()) } \ No newline at end of file