Skip to content

Commit

Permalink
Add "RewrappedKey" message type and corresponding logic
Browse files Browse the repository at this point in the history
Removed "AeadCore" from the use statement in the main function as it was not used. Added the new "RewrappedKey" to the MessageType and implemented corresponding switches. Reorganized some logic for a clear and easy way of sending and receiving different kinds of messages like PublicKey or KasPublicKey.
  • Loading branch information
arkavo-com committed Jun 14, 2024
1 parent 727fcc7 commit f23996d
Showing 1 changed file with 44 additions and 99 deletions.
143 changes: 44 additions & 99 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::fs;
use std::sync::Arc;
use std::sync::RwLock;

use aes_gcm::{aead::{Aead, AeadCore, KeyInit, OsRng}, Aes256Gcm, Key};
use aes_gcm::{aead::{Aead, KeyInit, OsRng}, Aes256Gcm, Key};
use aes_gcm::aead::generic_array::GenericArray;
use data_encoding::HEXUPPER;
use futures_util::{SinkExt, StreamExt};
Expand Down Expand Up @@ -48,6 +48,7 @@ enum MessageType {
PublicKey = 0x01,
KasPublicKey = 0x02,
Rewrap = 0x03,
RewrappedKey = 0x04,
}

impl MessageType {
Expand All @@ -56,6 +57,7 @@ impl MessageType {
0x01 => Some(MessageType::PublicKey),
0x02 => Some(MessageType::KasPublicKey),
0x03 => Some(MessageType::Rewrap),
0x04 => Some(MessageType::RewrappedKey),
_ => None,
}
}
Expand Down Expand Up @@ -148,8 +150,9 @@ async fn handle_connection(stream: TcpStream, connection_state: Arc<Mutex<Connec
if let Some(response) =
handle_binary_message(connection_state.clone(), msg.into_data()).await
{
let response = response.into();
ws_sender.send(response).await.unwrap();
// TODO remove clone
ws_sender.send(response.clone()).await.unwrap();
println!("Message type first byte: {:?}", response.into_data().get(0));
}
}
Err(e) => {
Expand All @@ -175,6 +178,7 @@ async fn handle_binary_message(
Some(MessageType::PublicKey) => handle_public_key(connection_state, payload).await,
Some(MessageType::KasPublicKey) => handle_kas_public_key(payload).await,
Some(MessageType::Rewrap) => handle_rewrap(connection_state, payload).await,
Some(MessageType::RewrappedKey) => None,
None => {
println!("Unknown message type: {:?}", message_type);
None
Expand Down Expand Up @@ -249,9 +253,12 @@ async fn handle_rewrap(
let cipher = Aes256Gcm::new(&key);
let nonce: [u8; 12] = rand::thread_rng().gen(); // NONCE MUST BE UNIQUE FOR EACH MESSAGE
let nonce = GenericArray::from_slice(&nonce);
let wrapped_dek_shared_secret = cipher.encrypt(nonce, dek_shared_secret.as_ref())
let mut wrapped_dek_shared_secret = cipher.encrypt(nonce, dek_shared_secret.as_ref())
.expect("encryption failure!");
Some(Message::Binary(Vec::from(wrapped_dek_shared_secret)))
let mut response_data = Vec::new();
response_data.push(MessageType::RewrappedKey as u8);
response_data.append(&mut wrapped_dek_shared_secret);
return Some(Message::Binary(response_data));
}

async fn handle_public_key(
Expand All @@ -260,9 +267,7 @@ async fn handle_public_key(
) -> Option<Message> {
// Generate an ephemeral private key
let my_private_key = EphemeralSecret::random_from_rng(OsRng);
let my_public_key = PublicKey::from(&my_private_key);
println!("Server Public Key: {}", hex::encode(my_public_key.as_ref()));

let server_public_key = PublicKey::from(&my_private_key);
let payload_arr: [u8; 32];
// Deserialize the public key sent by the client
if payload.len() == 33 {
Expand All @@ -279,108 +284,48 @@ async fn handle_public_key(
// Perform the key agreement
let shared_secret = my_private_key.diffie_hellman(&peer_public_key);
let shared_secret_bytes = shared_secret.as_bytes();

println!("Shared Secret: {}", hex::encode(shared_secret_bytes));
// Hash the shared secret
let mut hasher = Sha256::new();
hasher.update(shared_secret_bytes);
let hashed_secret = hasher.finalize();
println!("Shared Secret: {}", hex::encode(hashed_secret.to_vec()));
// TODO calculate symmetricKey us hkdf

// Convert server_public_key to bytes
let server_public_key_bytes = server_public_key.to_bytes();
// Determine prefix: 0x02 for even y, 0x03 for odd y
let prefix = if server_public_key_bytes[31] % 2 == 0 { 0x02 } else { 0x03 };
// Prepare a vector to hold the compressed key (1 byte prefix + 32 bytes key)
let mut server_compressed_public_key = Vec::with_capacity(33);
server_compressed_public_key.push(prefix);
server_compressed_public_key.extend_from_slice(&server_public_key_bytes);
println!("Compressed Server PublicKey: {:}", hex::encode(&server_compressed_public_key));
println!("Server Public Key Size: {} bytes", &server_compressed_public_key.len());
// Send server_public_key as publicKey message
let mut response_data = Vec::new();
// Appending MessageType::PublicKey
response_data.push(MessageType::PublicKey as u8);
// Appending my_public_key bytes
response_data.append(&mut server_compressed_public_key);

// Update the connection state with the hashed shared secret
// TODO store symmetric key not shared secret
let mut connection_state = connection_state.lock().await;
connection_state.shared_secret = Some(hashed_secret.to_vec());

Some(Message::Binary(my_public_key.as_ref().to_vec()))
Some(Message::Binary(response_data))
}

async fn handle_kas_public_key(payload: &[u8]) -> Option<Message> {
println!("Received KAS public key: {:?}", payload);
// Use static KAS_PUBLIC_KEY_DER
// TODO Use static KAS_PUBLIC_KEY_DER
let kas_public_key_der = KAS_PUBLIC_KEY_DER.read().unwrap();
if let Some(ref public_key) = *kas_public_key_der {
return Some(Message::Binary(public_key.clone()));
if let Some(ref kas_public_key_bytes) = *kas_public_key_der {
println!("KAS Public Key Size: {} bytes", kas_public_key_bytes.len());
// TODO make sure compressed key of 33 bytes is sent not 65
let mut response_data = Vec::new();
response_data.push(MessageType::KasPublicKey as u8);
response_data.append(&mut AsRef::<[u8]>::as_ref(kas_public_key_bytes).to_vec());
return Some(Message::Binary(response_data));
}
return None;
}

// async fn get_compressed_public_key() -> Result<Vec<u8>, Box<dyn std::error::Error>> {
// println!("get_compressed_public_key");
// // Load the private key from DER format
// // let mut file = File::open("recipient_private_key.der").await?;
// // let mut private_key_der = vec![];
// // file.read_to_end(&mut private_key_der).await?;
// // let ec_key = EcKey::private_key_from_der(&private_key_der)?;
// // Load the EC private key from the PEM file
// let ec_key = load_ec_private_key("recipient_private_key.pem").await?;
// // Generate EC private key
// // let ec_key = generate_ecdh_key()?;
//
// let curve_name = ec_key.group().curve_name();
// match curve_name {
// Some(nid) => {
// let name = nid.long_name();
// match name {
// Ok(value) => println!("Curve Name: {}", value),
// Err(_) => println!("Curve Name: failed to get"),
// }
// },
// None => {
// println!("failed to get curve_name");
// }
// }
//
// // Extract private key
// let private_key_bn_result = ec_key.private_key().to_owned();
// let private_key_bn = match private_key_bn_result {
// Ok(bn) => bn,
// Err(e) => {
// println!("Failed to extract private key: {}", e);
// return Err(Box::new(e));
// },
// };
// let private_key_hex = hex::encode(&private_key_bn.to_vec());
//
// // Extract public key
// let ec_group = ec_key.group();
// let public_key = ec_key.public_key();
// let mut ctx = openssl::bn::BigNumContext::new()?;
// let public_key_bytes = public_key.to_bytes(&ec_group, PointConversionForm::UNCOMPRESSED, &mut ctx)?;
//
// // Convert public key to hex
// let public_key_hex = hex::encode(public_key_bytes);
//
// // Log the information
// println!("EC Private Key: {}", private_key_hex);
// println!("EC Public Key: {}", public_key_hex);
//
// // Get the compressed public key
// let mut bn_ctx = openssl::bn::BigNumContext::new()?;
// let compressed_pub_key = public_key.to_bytes(
// &ec_key.group(),
// PointConversionForm::COMPRESSED,
// &mut bn_ctx,
// )?;
//
// Ok(compressed_pub_key)
// }
// fn generate_ecdh_key() -> Result<EcKey<Private>, Box<dyn Error>> {
// // Create the EC group for P-256
// let ec_group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?;
//
// // Generate a new private key for the group
// let ec_key = EcKey::generate(&ec_group)?;
//
// Ok(ec_key)
// }

// async fn load_ec_private_key(filename: &str) -> Result<EcKey<Private>, Box<dyn Error>> {
// // Read the PEM file content
// let mut file = File::open(filename).await?;
// let mut pem = String::new();
// file.read_to_string(&mut pem).await?;
//
// // Load the private key from PEM format
// let pkey = PKey::private_key_from_pem(pem.as_bytes())?;
// let ec_key = pkey.ec_key()?;
//
// Ok(ec_key)
// }
}

0 comments on commit f23996d

Please sign in to comment.