diff --git a/.gitignore b/.gitignore index 1ae4ea4..fe3a7ef 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,5 @@ Cargo.lock /.idea/backend-rust.iml /.idea/modules.xml /.idea/vcs.xml +/fullchain.pem +/privkey.pem diff --git a/Cargo.toml b/Cargo.toml index bbaa85f..eb41626 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "backend-rust" -version = "0.3.0" +version = "0.4.0" edition = "2021" [dependencies] @@ -17,4 +17,8 @@ once_cell = "1.19.0" rand_core = "0.6.4" zeroize = "1.8.1" sha2 = "0.10.8" -hkdf = "0.12.4" \ No newline at end of file +hkdf = "0.12.4" +tokio-native-tls = "0.3.1" +native-tls = "0.2.12" +env_logger = "0.11.3" +log = "0.4.22" \ No newline at end of file diff --git a/README.md b/README.md index 0f2e210..d45f60f 100644 --- a/README.md +++ b/README.md @@ -38,9 +38,51 @@ cargo build openssl ec -in recipient_private_key.pem -text -noout ``` - ```shell - openssl ecparam -name prime256v1 -genkey -noout -out kas_private_key.pem - ``` +2. Generating Self-Signed Certificate + +For development purposes, you can generate a self-signed certificate using OpenSSL. Run the following command in your +terminal: + +```bash +openssl req -x509 -newkey rsa:4096 -keyout privkey.pem -out fullchain.pem -days 365 -nodes -subj "/CN=localhost" +``` + +This command will generate two files in your current directory: + +- `privkey.pem`: The private key file +- `fullchain.pem`: The self-signed certificate file + +Note: Self-signed certificates should only be used for development and testing. For production environments, use a +certificate from a trusted Certificate Authority. + +#### Configuration + +The server can be configured using environment variables. If not set, default values will be used. + +| Environment Variable | Description | Default Value | +|----------------------|------------------------------------------|-----------------------------| +| PORT | The port on which the server will listen | 8443 | +| TLS_CERT_PATH | Path to the TLS certificate file | ./fullchain.pem | +| TLS_KEY_PATH | Path to the TLS private key file | ./privkey.pem | +| KAS_KEY_PATH | Path to the KAS private key file | ./recipient_private_key.pem | + +All file paths are relative to the current working directory where the server is run. + +```env +export PORT=8443 +export TLS_CERT_PATH=/path/to/fullchain.pem +export TLS_KEY_PATH=/path/to/privkey.pem +export KAS_KEY_PATH=/path/to/recipient_private_key.pem +export ENABLE_TIMING_LOGS=true +export RUST_LOG=info +``` + +(Optional) Set the environment variables if you want to override the defaults. + +##### Security Note + +Remember to keep your private keys secure and never commit them to version control systems. It's recommended to use +environment variables or secure vaults for managing sensitive information in production environments. 2. Start the server: @@ -48,6 +90,8 @@ cargo build cargo run ``` +The server will start and listen on the configured port (default: 8443) using HTTPS. + ### Usage - **Key Agreement**: The server establishes a shared secret with each client using ECDH. diff --git a/src/main.rs b/src/main.rs index f3462a6..ce0b3c6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,7 @@ +use std::env; use std::sync::Arc; use std::sync::RwLock; +use std::time::Instant; use aes_gcm::aead::{Key, NewAead}; use aes_gcm::aead::Aead; @@ -8,6 +10,7 @@ use aes_gcm::Aes256Gcm; use elliptic_curve::point::AffineCoordinates; use futures_util::{SinkExt, StreamExt}; use hkdf::Hkdf; +use log::info; use once_cell::sync::OnceCell; use p256::{elliptic_curve::sec1::ToEncodedPoint, PublicKey, SecretKey}; use p256::ecdh::EphemeralSecret; @@ -15,6 +18,7 @@ use rand_core::{OsRng, RngCore}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use tokio::net::{TcpListener, TcpStream}; +use tokio_native_tls::TlsAcceptor; use tokio_tungstenite::accept_async; use tokio_tungstenite::tungstenite::Message; @@ -72,29 +76,58 @@ struct KasKeys { static KAS_KEYS: OnceCell> = OnceCell::new(); #[tokio::main(flavor = "multi_thread", worker_threads = 4)] -async fn main() { - // KAS public key - init_kas_keys().expect("KAS key not loaded"); - // Bind the server to localhost on port 8080 - let try_socket = TcpListener::bind("0.0.0.0:8080").await; - let listener = match try_socket { - Ok(socket) => socket, - Err(e) => { - println!("Failed to bind to port: {}", e); - return; - } - }; - println!("Listening on: 0.0.0.0:8080"); +async fn main() -> Result<(), Box> { + // Initialize logging + env_logger::init(); + + // Load configuration + let settings = load_config()?; + + // Initialize KAS keys + init_kas_keys(&settings.kas_key_path)?; + + // Set up TLS + let tls_config = load_tls_config(&settings.tls_cert_path, &settings.tls_key_path)?; + let tls_acceptor = TlsAcceptor::from(tls_config); + + // Bind the server + let listener = TcpListener::bind(format!("0.0.0.0:{}", settings.port)).await?; + println!("Listening on: 0.0.0.0:{}", settings.port); + // Accept connections while let Ok((stream, _)) = listener.accept().await { + let tls_acceptor = tls_acceptor.clone(); let connection_state = Arc::new(ConnectionState::new()); + let settings_clone = settings.clone(); + tokio::spawn(async move { - handle_connection(stream, connection_state).await + match tls_acceptor.accept(stream).await { + Ok(tls_stream) => { + handle_connection(tls_stream, connection_state, &settings_clone).await; + } + Err(e) => eprintln!("Failed to accept TLS connection: {}", e), + } }); } + + Ok(()) +} + +fn load_tls_config(cert_path: &str, key_path: &str) -> Result> { + let cert = std::fs::read(cert_path)?; + let key = std::fs::read(key_path)?; + + let identity = native_tls::Identity::from_pkcs8(&cert, &key)?; + let acceptor = native_tls::TlsAcceptor::new(identity)?; + + Ok(acceptor) } -async fn handle_connection(stream: TcpStream, connection_state: Arc) { +async fn handle_connection( + stream: tokio_native_tls::TlsStream, + connection_state: Arc, + settings: &ServerSettings, +) { let ws_stream = match accept_async(stream).await { Ok(ws) => ws, Err(e) => { @@ -112,7 +145,7 @@ async fn handle_connection(stream: TcpStream, connection_state: Arc, data: Vec, + settings: &ServerSettings, ) -> Option { if data.len() < 1 { println!("Invalid message format"); @@ -140,7 +174,7 @@ async fn handle_binary_message( match message_type { Some(MessageType::PublicKey) => handle_public_key(connection_state, payload).await, Some(MessageType::KasPublicKey) => handle_kas_public_key(payload).await, - Some(MessageType::Rewrap) => handle_rewrap(connection_state, payload).await, + Some(MessageType::Rewrap) => handle_rewrap(connection_state, payload, settings).await, Some(MessageType::RewrappedKey) => None, None => { println!("Unknown message type: {:?}", message_type); @@ -149,59 +183,61 @@ async fn handle_binary_message( } } -struct PrintOnDrop; - -impl Drop for PrintOnDrop { - fn drop(&mut self) { - // println!("END handle_rewrap"); - } -} +// struct PrintOnDrop; +// +// impl Drop for PrintOnDrop { +// fn drop(&mut self) { +// // println!("END handle_rewrap"); +// } +// } async fn handle_rewrap( connection_state: &Arc, payload: &[u8], + settings: &ServerSettings, ) -> Option { - let _print_on_drop = PrintOnDrop; - // println!("BEGIN handle_rewrap"); + let start_time = Instant::now(); + let session_shared_secret = { let shared_secret = connection_state.shared_secret_lock.read().unwrap(); shared_secret.clone() }; if session_shared_secret == None { - println!("Shared Secret not set"); + info!("Shared Secret not set"); return None; } let session_shared_secret = session_shared_secret.unwrap(); + // Parse NanoTDF header let mut parser = BinaryParser::new(payload); let header = match BinaryParser::parse_header(&mut parser) { Ok(header) => header, Err(e) => { - println!("Error parsing header: {:?}", e); + info!("Error parsing header: {:?}", e); return None; } }; - // Extract the policy - // let policy = header.get_policy(); - // println!("policy binding hex: {}", hex::encode(policy.get_binding().clone().unwrap())); + + let parse_time = start_time.elapsed(); + log_timing(settings, "Time to parse header", parse_time); + // TDF ephemeral key let tdf_ephemeral_key_bytes = header.get_ephemeral_key(); - // println!("tdf_ephemeral_key hex: {}", hex::encode(tdf_ephemeral_key_bytes)); - // Deserialize the public key sent by the client if tdf_ephemeral_key_bytes.len() != 33 { - println!("Invalid TDF compressed ephemeral key length"); + info!("Invalid TDF compressed ephemeral key length"); return None; } + // Deserialize the public key sent by the client let tdf_ephemeral_public_key = match PublicKey::from_sec1_bytes(tdf_ephemeral_key_bytes) { Ok(key) => key, Err(e) => { - println!("Error deserializing TDF ephemeral public key: {:?}", e); + info!("Error deserializing TDF ephemeral public key: {:?}", e); return None; } }; + let kas_private_key_bytes = get_kas_private_key_bytes().unwrap(); - // println!("kas_private_key_bytes {}", hex::encode(&kas_private_key_bytes)); let kas_private_key_array: [u8; 32] = match kas_private_key_bytes.try_into() { Ok(key) => key, Err(_) => return None, @@ -209,36 +245,49 @@ async fn handle_rewrap( let kas_private_key = SecretKey::from_bytes(&kas_private_key_array.into()) .map_err(|_| "Invalid private key") .ok()?; + // Perform custom ECDH + let ecdh_start = Instant::now(); let dek_shared_secret_bytes = match custom_ecdh(&kas_private_key, &tdf_ephemeral_public_key) { Ok(secret) => secret, Err(e) => { - println!("Error performing ECDH: {:?}", e); + info!("Error performing ECDH: {:?}", e); return None; } }; + let ecdh_time = ecdh_start.elapsed(); + log_timing(settings, "Time for ECDH operation", ecdh_time); + // Encrypt dek_shared_secret with symmetric key using AES GCM let salt = connection_state.salt_lock.read().unwrap().clone().unwrap(); let info = "rewrappedKey".as_bytes(); let hkdf = Hkdf::::new(Some(&salt), &session_shared_secret); let mut derived_key = [0u8; 32]; hkdf.expand(info, &mut derived_key).expect("HKDF expansion failed"); - // println!("Derived Session Key: {}", hex::encode(&derived_key)); + let mut nonce = [0u8; 12]; OsRng.fill_bytes(&mut nonce); let nonce = GenericArray::from_slice(&nonce); - // println!("nonce {}", hex::encode(nonce)); + let key = Key::::from(derived_key); let cipher = Aes256Gcm::new(&key); + + let encryption_start = Instant::now(); let wrapped_dek = cipher.encrypt(nonce, dek_shared_secret_bytes.as_ref()) .expect("encryption failure!"); - // println!("Rewrapped Key and Authentication tag {}", hex::encode(&wrapped_dek)); + let encryption_time = encryption_start.elapsed(); + log_timing(settings, "Time for AES-GCM encryption", encryption_time); + // binary response let mut response_data = Vec::new(); response_data.push(MessageType::RewrappedKey as u8); response_data.extend_from_slice(tdf_ephemeral_key_bytes); response_data.extend_from_slice(&nonce); response_data.extend_from_slice(&wrapped_dek); + + let total_time = start_time.elapsed(); + log_timing(settings, "Total time for handle_rewrap", total_time); + Some(Message::Binary(response_data)) } @@ -316,8 +365,8 @@ async fn handle_kas_public_key(_: &[u8]) -> Option { None } -fn init_kas_keys() -> Result<(), Box> { - let pem_content = std::fs::read_to_string("recipient_private_key.pem")?; +fn init_kas_keys(key_path: &str) -> Result<(), Box> { + let pem_content = std::fs::read_to_string(key_path)?; let ec_pem_contents = pem_content.as_bytes(); let pem = pem::parse(ec_pem_contents)?; if pem.tag() != "EC PRIVATE KEY" { @@ -374,6 +423,35 @@ fn custom_ecdh(secret_key: &SecretKey, public_key: &PublicKey) -> Result Ok(shared_secret) } +#[derive(Debug, Deserialize, Clone)] +struct ServerSettings { + port: u16, + tls_cert_path: String, + tls_key_path: String, + kas_key_path: String, + enable_timing_logs: bool, +} + +fn log_timing(settings: &ServerSettings, message: &str, duration: std::time::Duration) { + if settings.enable_timing_logs { + info!("{}: {:?}", message, duration); + } +} +fn load_config() -> Result> { + let current_dir = env::current_dir()?; + + Ok(ServerSettings { + port: env::var("PORT").unwrap_or_else(|_| "8443".to_string()).parse()?, + tls_cert_path: env::var("TLS_CERT_PATH") + .unwrap_or_else(|_| current_dir.join("fullchain.pem").to_str().unwrap().to_string()), + tls_key_path: env::var("TLS_KEY_PATH") + .unwrap_or_else(|_| current_dir.join("privkey.pem").to_str().unwrap().to_string()), + kas_key_path: env::var("KAS_KEY_PATH") + .unwrap_or_else(|_| current_dir.join("recipient_private_key.pem").to_str().unwrap().to_string()), + enable_timing_logs: env::var("ENABLE_TIMING_LOGS").unwrap_or_else(|_| "false".to_string()).parse().unwrap_or(false), + }) +} + #[cfg(test)] mod tests { use std::error::Error;