Skip to content

Commit

Permalink
Refactor connection state handling and key generation
Browse files Browse the repository at this point in the history
This update refactors the manner in which the connection state is handled. Instead of cloning the connection state, we now use a mutable reference. Additionally, changes were made in the process of generating the server's ephemeral private key, handling the client's public key, performing key agreement, and sending the server's public key as a publicKey message. Extraneous unused code blocks and fixme comments have been removed for a cleaner and more maintainable codebase.
  • Loading branch information
arkavo-com committed Jun 20, 2024
1 parent af36faf commit 36f0483
Showing 1 changed file with 33 additions and 48 deletions.
81 changes: 33 additions & 48 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct ConnectionState {

impl ConnectionState {
fn new() -> Self {
println!("New ConnectionState");
ConnectionState {
shared_secret: None,
}
Expand Down Expand Up @@ -65,14 +66,6 @@ lazy_static! {
static ref KAS_PUBLIC_KEY_DER: RwLock<Option<Vec<u8>>> = RwLock::new(None);
}

const ENCRYPTED_PAYLOAD: &str = "\
4c 31 4c 01 0e 6b 61 73 2e 76 69 72 74 72 75 2e 63 6f 6d 80\
80 00 01 15 6b 61 73 2e 76 69 72 74 72 75 2e 63 6f 6d 2f 70\
6f 6c 69 63 79 b5 e4 13 a6 02 11 e5 f1 7b 22 34 a0 cd 3f 36\
ff 7b ba 6d 8f e8 df 23 f6 2c 9d 09 35 6f 85 82 f8 a9 cf 15\
12 6c 8a 9d a4 6c 5e 4e 0c bc c8 26 97 19 ac 05 1b 80 62 5c\
c7 54 03 03 6f fb 82 87 1f 02 f7 7f ba e5 26 09 da";

#[tokio::main]
async fn main() {
// KAS public key
Expand Down Expand Up @@ -121,10 +114,6 @@ async fn main() {
}

async fn handle_connection(stream: TcpStream, connection_state: Arc<Mutex<ConnectionState>>) {
// FIXME read from rewrap
let ec_bytes: Vec<u8> = hex::decode(ENCRYPTED_PAYLOAD.replace(" ", "")).unwrap();
let _nanotdf = BinaryParser::new(&*ec_bytes);

let ws_stream = match accept_async(stream).await {
Ok(ws) => ws,
Err(e) => {
Expand All @@ -133,9 +122,6 @@ async fn handle_connection(stream: TcpStream, connection_state: Arc<Mutex<Connec
}
};
let (mut ws_sender, mut ws_receiver) = ws_stream.split();
// TODO rewrap
// let compressed_pub_key = get_compressed_public_key().await.expect("Failed to get compressed public key");

// Handle incoming WebSocket messages
while let Some(message) = ws_receiver.next().await {
match message {
Expand All @@ -145,12 +131,10 @@ async fn handle_connection(stream: TcpStream, connection_state: Arc<Mutex<Connec
println!("Received a close message.");
return;
}
if let Some(response) =
handle_binary_message(connection_state.clone(), msg.into_data()).await
if let Some(response) = handle_binary_message(&connection_state, msg.into_data()).await
{
// TODO remove clone
ws_sender.send(response.clone()).await.unwrap();
println!("Message type first byte: {:?}", response.into_data().get(0));
ws_sender.send(response.clone()).await.expect("ws send failed");
}
}
Err(e) => {
Expand All @@ -162,7 +146,7 @@ async fn handle_connection(stream: TcpStream, connection_state: Arc<Mutex<Connec
}

async fn handle_binary_message(
connection_state: Arc<Mutex<ConnectionState>>,
connection_state: &Arc<Mutex<ConnectionState>>,
data: Vec<u8>,
) -> Option<Message> {
if data.len() < 1 {
Expand All @@ -185,7 +169,7 @@ async fn handle_binary_message(
}

async fn handle_rewrap(
connection_state: Arc<Mutex<ConnectionState>>,
connection_state: &Arc<Mutex<ConnectionState>>,
payload: &[u8],
) -> Option<Message> {
let session_shared_secret = {
Expand Down Expand Up @@ -260,29 +244,35 @@ async fn handle_rewrap(
}

async fn handle_public_key(
connection_state: Arc<Mutex<ConnectionState>>,
connection_state: &Arc<Mutex<ConnectionState>>,
payload: &[u8],
) -> Option<Message> {
// Generate an ephemeral private key
let my_private_key = EphemeralSecret::random_from_rng(OsRng);
let server_public_key = PublicKey::from(&my_private_key);
let payload_arr: [u8; 32];
// Deserialize the public key sent by the client
if payload.len() == 33 {
// If payload length is 33, it is possible that the public key was prefixed with 0x04, which is common in some implementations
payload_arr = <[u8; 32]>::try_from(&payload[1..]).unwrap();
} else if payload.len() != 32 {
{
let connection_state = connection_state.lock().await;
println!("Connection shared secret: {:?}", connection_state.shared_secret);
if connection_state.shared_secret.is_some() {
return None;
}
}
println!("Client Public Key payload: {}", hex::encode(payload.as_ref()));
if payload.len() != 32 {
return None;
} else {
payload_arr = payload.try_into().unwrap();
}
let peer_public_key = PublicKey::from(payload_arr);
println!("Peer Public Key: {}", hex::encode(peer_public_key.as_ref()));

let payload_arr: [u8; 32];
// Deserialize the public key sent by the client
// If payload length is 33, compressed 32 with 1 leading byte
payload_arr = <[u8; 32]>::try_from(&payload[..]).unwrap();
let client_public_key = PublicKey::from(payload_arr);
println!("Client Public Key: {:?}", client_public_key);
// Generate an ephemeral private key
let server_private_key = EphemeralSecret::random_from_rng(OsRng);
let mut server_public_key = PublicKey::from(&server_private_key);
// Perform the key agreement
let shared_secret = my_private_key.diffie_hellman(&peer_public_key);
let shared_secret = server_private_key.diffie_hellman(&client_public_key);
let shared_secret_bytes = shared_secret.as_bytes();
println!("Shared Secret +++++++++++++");
println!("Shared Secret: {}", hex::encode(shared_secret_bytes));
println!("Shared Secret +++++++++++++");
// Hash the shared secret
let mut hasher = Sha256::new();
hasher.update(shared_secret_bytes);
Expand All @@ -292,24 +282,19 @@ async fn handle_public_key(
// Convert server_public_key to bytes
let server_public_key_bytes = server_public_key.to_bytes();
// Determine prefix: 0x02 for even y, 0x03 for odd y
let prefix = if server_public_key_bytes[31] % 2 == 0 { 0x02 } else { 0x03 };
// Prepare a vector to hold the compressed key (1 byte prefix + 32 bytes key)
let mut server_compressed_public_key = Vec::with_capacity(33);
server_compressed_public_key.push(prefix);
server_compressed_public_key.extend_from_slice(&server_public_key_bytes);
println!("Compressed Server PublicKey: {:}", hex::encode(&server_compressed_public_key));
println!("Server Public Key Size: {} bytes", &server_compressed_public_key.len());
println!("Server Public Key Size: {:?} bytes", server_public_key);
// Send server_public_key as publicKey message
let mut response_data = Vec::new();
// Appending MessageType::PublicKey
response_data.push(MessageType::PublicKey as u8);
// Appending my_public_key bytes
response_data.append(&mut server_compressed_public_key);

response_data.extend_from_slice(&server_public_key_bytes);
// Update the connection state with the hashed shared secret
// TODO store symmetric key not shared secret
let mut connection_state = connection_state.lock().await;
connection_state.shared_secret = Some(hashed_secret.to_vec());
{
let mut connection_state = connection_state.lock().await;
connection_state.shared_secret = Some(hashed_secret.to_vec());
}
Some(Message::Binary(response_data))
}

Expand Down

0 comments on commit 36f0483

Please sign in to comment.