Skip to content

Commit

Permalink
Update shared secret handling and KAS key initialization
Browse files Browse the repository at this point in the history
The commit updates how shared secrets are handled, making use of asynchronous programming principles for more efficient execution. The usage of static KAS keys has also been replaced with an improved initialization function that ensures the keys are only loaded once. This change leads to significant performance improvements and reduces the possibility of error during key loading.
  • Loading branch information
arkavo-com committed Jun 21, 2024
1 parent 36f0483 commit db48bb8
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 86 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ serde = { version = "1.0.202", features = ["derive"] }
digest = "0.10.7"
data-encoding = "2.6.0"
pem = "3.0.4"
lazy_static = "1.4.0"
rust-crypto = "0.2.36"
hmac = "0.13.0-pre.3"
hkdf = "0.13.0-pre.3"
aes-gcm = "0.10.3"
p256 = "0.13.2"
p256 = "0.13.2"
once_cell = "1.19.0"
166 changes: 82 additions & 84 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
use std::fs;
use std::sync::Arc;
use std::sync::RwLock;

use aes_gcm::{aead::{Aead, KeyInit, OsRng}, Aes256Gcm, Key};
use aes_gcm::aead::generic_array::GenericArray;
use futures_util::{SinkExt, StreamExt};
use lazy_static::lazy_static;
use once_cell::sync::OnceCell;
use p256::elliptic_curve::sec1::ToEncodedPoint;
use p256::SecretKey;
use rand::Rng;
use rand::RngCore;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Mutex;
use tokio_tungstenite::accept_async;
use tokio_tungstenite::tungstenite::Message;
use x25519_dalek::{EphemeralSecret, PublicKey};
Expand All @@ -28,16 +25,16 @@ struct PublicKeyMessage {
public_key: Vec<u8>,
}

#[derive(Serialize, Deserialize, Debug)]
#[derive(Debug)]
struct ConnectionState {
shared_secret: Option<Vec<u8>>,
shared_secret: RwLock<Option<Vec<u8>>>,
}

impl ConnectionState {
fn new() -> Self {
println!("New ConnectionState");
ConnectionState {
shared_secret: None,
shared_secret: RwLock::new(None),
}
}
}
Expand All @@ -62,40 +59,17 @@ impl MessageType {
}
}

lazy_static! {
static ref KAS_PUBLIC_KEY_DER: RwLock<Option<Vec<u8>>> = RwLock::new(None);
struct KasKeys {
public_key: Vec<u8>,
private_key: SecretKey,
}

#[tokio::main]
static KAS_KEYS: OnceCell<Arc<KasKeys>> = OnceCell::new();

#[tokio::main(flavor = "multi_thread", worker_threads = 4)]
async fn main() {
// KAS public key
// Load the PEM file
let pem_content = fs::read_to_string("recipient_private_key.pem").unwrap();
// Load the private key from PEM format
let ec_pem_contents = pem_content.as_bytes();
// Parse the pem file
let pem = pem::parse(ec_pem_contents).expect("Failed to parse the PEM.");
// Ensure it's an EC private key
if pem.tag() != "EC PRIVATE KEY" {
println!("Not an EC private key: {:?}", pem.tag());
}
// Parse the private key
let kas_private_key = SecretKey::from_sec1_der(pem.contents());
// Check if successful and continue if Ok
match kas_private_key {
Ok(kas_private_key) => {
// Derive the corresponding public key
let kas_public_key = kas_private_key.public_key();
let kas_public_key_der = kas_public_key.to_encoded_point(true);
let kas_public_key_der_bytes = kas_public_key_der.as_bytes().to_vec();
// Set static KAS_PUBLIC_KEY_DER
{
let mut kas_public_key_der = KAS_PUBLIC_KEY_DER.write().unwrap();
*kas_public_key_der = Some(kas_public_key_der_bytes);
}
}
Err(error) => println!("Problem with the secret key: {:?}", error),
}
init_kas_keys().expect("KAS key not loaded");
// Bind the server to localhost on port 8080
let try_socket = TcpListener::bind("0.0.0.0:8080").await;
let listener = match try_socket {
Expand All @@ -108,12 +82,14 @@ async fn main() {
println!("Listening on: 0.0.0.0:8080");
// Accept connections
while let Ok((stream, _)) = listener.accept().await {
let connection_state = Arc::new(Mutex::new(ConnectionState::new()));
tokio::spawn(handle_connection(stream, connection_state));
let connection_state = Arc::new(ConnectionState::new());
tokio::spawn(async move {
handle_connection(stream, connection_state).await
});
}
}

async fn handle_connection(stream: TcpStream, connection_state: Arc<Mutex<ConnectionState>>) {
async fn handle_connection(stream: TcpStream, connection_state: Arc<ConnectionState>) {
let ws_stream = match accept_async(stream).await {
Ok(ws) => ws,
Err(e) => {
Expand Down Expand Up @@ -146,7 +122,7 @@ async fn handle_connection(stream: TcpStream, connection_state: Arc<Mutex<Connec
}

async fn handle_binary_message(
connection_state: &Arc<Mutex<ConnectionState>>,
connection_state: &Arc<ConnectionState>,
data: Vec<u8>,
) -> Option<Message> {
if data.len() < 1 {
Expand All @@ -169,21 +145,14 @@ async fn handle_binary_message(
}

async fn handle_rewrap(
connection_state: &Arc<Mutex<ConnectionState>>,
connection_state: &Arc<ConnectionState>,
payload: &[u8],
) -> Option<Message> {
let session_shared_secret = {
let connection_state = connection_state.lock().await;
// Ensure we have a shared secret
match &connection_state.shared_secret {
Some(secret) => Some(secret.clone()),
None => {
eprintln!("No shared secret available");
None
}
}
let shared_secret = connection_state.shared_secret.read().unwrap();
shared_secret.clone()
};
println!("session shared_secret {:?}", session_shared_secret);
println!("Shared Secret Connection: {}", hex::encode(session_shared_secret.clone().unwrap()));
// Parse NanoTDF header
let mut parser = BinaryParser::new(payload);
let header = match BinaryParser::parse_header(&mut parser) {
Expand All @@ -198,7 +167,6 @@ async fn handle_rewrap(
println!("policy {:?}", policy);
let policy = header.get_policy();
println!("policy binding hex: {}", hex::encode(policy.get_binding().clone().unwrap()));
println!("tdf_ephemeral_key {:?}", header.get_ephemeral_key());
println!("tdf_ephemeral_key hex: {}", hex::encode(header.get_ephemeral_key()));
let tdf_ephemeral_key_bytes = header.get_ephemeral_key();
// Deserialize the public key sent by the client
Expand All @@ -209,7 +177,8 @@ async fn handle_rewrap(
let payload_arr = <[u8; 32]>::try_from(&tdf_ephemeral_key_bytes[1..]).unwrap();
let tdf_ephemeral_public_key = PublicKey::from(payload_arr);
println!("tdf_ephemeral_key {:?}", tdf_ephemeral_public_key);

let kas_private_key = get_kas_private_key().unwrap();
println!("kas_private_key {:?}", kas_private_key);
// TODO Verify the policy binding
// TODO Access check
// Generate Symmetric Key
Expand All @@ -233,7 +202,8 @@ async fn handle_rewrap(
let session_shared_secret = session_shared_secret.unwrap();
let key = Key::<Aes256Gcm>::from_slice(&session_shared_secret);
let cipher = Aes256Gcm::new(&key);
let nonce: [u8; 12] = rand::thread_rng().gen(); // NONCE MUST BE UNIQUE FOR EACH MESSAGE
let mut nonce = [0u8; 12];
OsRng.fill_bytes(&mut nonce);
let nonce = GenericArray::from_slice(&nonce);
let mut wrapped_dek_shared_secret = cipher.encrypt(nonce, dek_shared_secret.as_ref())
.expect("encryption failure!");
Expand All @@ -244,13 +214,14 @@ async fn handle_rewrap(
}

async fn handle_public_key(
connection_state: &Arc<Mutex<ConnectionState>>,
connection_state: &Arc<ConnectionState>,
payload: &[u8],
) -> Option<Message> {
{
let connection_state = connection_state.lock().await;
println!("Connection shared secret: {:?}", connection_state.shared_secret);
if connection_state.shared_secret.is_some() {
let shared_secret_lock = connection_state.shared_secret.read();
let shared_secret = shared_secret_lock.unwrap();
if shared_secret.is_some() {
println!("Shared Secret Connection: {}", hex::encode(shared_secret.clone().unwrap()));
return None;
}
}
Expand All @@ -266,49 +237,76 @@ async fn handle_public_key(
println!("Client Public Key: {:?}", client_public_key);
// Generate an ephemeral private key
let server_private_key = EphemeralSecret::random_from_rng(OsRng);
let mut server_public_key = PublicKey::from(&server_private_key);
let server_public_key = PublicKey::from(&server_private_key);
// Perform the key agreement
let shared_secret = server_private_key.diffie_hellman(&client_public_key);
let shared_secret_bytes = shared_secret.as_bytes();
println!("Shared Secret +++++++++++++");
println!("Shared Secret: {}", hex::encode(shared_secret_bytes));
println!("Shared Secret +++++++++++++");
// Hash the shared secret
let mut hasher = Sha256::new();
hasher.update(shared_secret_bytes);
let hashed_secret = hasher.finalize();
// TODO calculate symmetricKey us hkdf

{
let shared_secret = connection_state.shared_secret.write();
*shared_secret.unwrap() = Some(shared_secret_bytes.to_vec());
}
// Convert server_public_key to bytes
let server_public_key_bytes = server_public_key.to_bytes();
// Determine prefix: 0x02 for even y, 0x03 for odd y
println!("Server Public Key Size: {:?} bytes", server_public_key);
// Send server_public_key as publicKey message
let mut response_data = Vec::new();
// Appending MessageType::PublicKey
response_data.push(MessageType::PublicKey as u8);
// Appending my_public_key bytes
// Appending server_public_key bytes
response_data.extend_from_slice(&server_public_key_bytes);
// Update the connection state with the hashed shared secret
// TODO store symmetric key not shared secret
{
let mut connection_state = connection_state.lock().await;
connection_state.shared_secret = Some(hashed_secret.to_vec());
}
Some(Message::Binary(response_data))
}

async fn handle_kas_public_key(payload: &[u8]) -> Option<Message> {
println!("Received KAS public key: {:?}", payload);
// TODO Use static KAS_PUBLIC_KEY_DER
let kas_public_key_der = KAS_PUBLIC_KEY_DER.read().unwrap();
if let Some(ref kas_public_key_bytes) = *kas_public_key_der {
async fn handle_kas_public_key(_: &[u8]) -> Option<Message> {
println!("Handling KAS public key");
if let Some(kas_public_key_bytes) = get_kas_public_key() {
println!("KAS Public Key Size: {} bytes", kas_public_key_bytes.len());
// TODO make sure compressed key of 33 bytes is sent not 65
let mut response_data = Vec::new();
response_data.push(MessageType::KasPublicKey as u8);
response_data.append(&mut AsRef::<[u8]>::as_ref(kas_public_key_bytes).to_vec());
response_data.extend_from_slice(&kas_public_key_bytes);
return Some(Message::Binary(response_data));
}
return None;
None
// let kas_public_key_der = KAS_PUBLIC_KEY_DER.read().unwrap();
// if let Some(ref kas_public_key_bytes) = *kas_public_key_der {
// println!("KAS Public Key Size: {} bytes", kas_public_key_bytes.len());
// let mut response_data = Vec::new();
// response_data.push(MessageType::KasPublicKey as u8);
// response_data.append(&mut AsRef::<[u8]>::as_ref(kas_public_key_bytes).to_vec());
// return Some(Message::Binary(response_data));
// }
// return None;
}

fn init_kas_keys() -> Result<(), Box<dyn std::error::Error>> {
let pem_content = std::fs::read_to_string("recipient_private_key.pem")?;
let ec_pem_contents = pem_content.as_bytes();
let pem = pem::parse(ec_pem_contents)?;

if pem.tag() != "EC PRIVATE KEY" {
return Err("Not an EC private key".into());
}

let kas_private_key = SecretKey::from_sec1_der(pem.contents())?;
let kas_public_key = kas_private_key.public_key();
let kas_public_key_der = kas_public_key.to_encoded_point(true);
let kas_public_key_der_bytes = kas_public_key_der.as_bytes().to_vec();

let kas_keys = KasKeys {
public_key: kas_public_key_der_bytes,
private_key: kas_private_key,
};

KAS_KEYS.set(Arc::new(kas_keys))
.map_err(|_| "KAS keys already initialized".into())
}

fn get_kas_public_key() -> Option<Vec<u8>> {
KAS_KEYS.get().map(|keys| keys.public_key.clone())
}

fn get_kas_private_key() -> Option<SecretKey> {
KAS_KEYS.get().map(|keys| keys.private_key.clone())
}

0 comments on commit db48bb8

Please sign in to comment.