Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HTTPS with rewrap timings #3

Merged
merged 3 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ Cargo.lock
/.idea/backend-rust.iml
/.idea/modules.xml
/.idea/vcs.xml
/fullchain.pem
/privkey.pem
8 changes: 6 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "backend-rust"
version = "0.3.0"
version = "0.4.0"
edition = "2021"

[dependencies]
Expand All @@ -17,4 +17,8 @@ once_cell = "1.19.0"
rand_core = "0.6.4"
zeroize = "1.8.1"
sha2 = "0.10.8"
hkdf = "0.12.4"
hkdf = "0.12.4"
tokio-native-tls = "0.3.1"
native-tls = "0.2.12"
env_logger = "0.11.3"
log = "0.4.22"
50 changes: 47 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,60 @@ cargo build
openssl ec -in recipient_private_key.pem -text -noout
```

```shell
openssl ecparam -name prime256v1 -genkey -noout -out kas_private_key.pem
```
2. Generating Self-Signed Certificate

For development purposes, you can generate a self-signed certificate using OpenSSL. Run the following command in your
terminal:

```bash
openssl req -x509 -newkey rsa:4096 -keyout privkey.pem -out fullchain.pem -days 365 -nodes -subj "/CN=localhost"
```

This command will generate two files in your current directory:

- `privkey.pem`: The private key file
- `fullchain.pem`: The self-signed certificate file

Note: Self-signed certificates should only be used for development and testing. For production environments, use a
certificate from a trusted Certificate Authority.

#### Configuration

The server can be configured using environment variables. If not set, default values will be used.

| Environment Variable | Description | Default Value |
|----------------------|------------------------------------------|-----------------------------|
| PORT | The port on which the server will listen | 8443 |
| TLS_CERT_PATH | Path to the TLS certificate file | ./fullchain.pem |
| TLS_KEY_PATH | Path to the TLS private key file | ./privkey.pem |
| KAS_KEY_PATH | Path to the KAS private key file | ./recipient_private_key.pem |

All file paths are relative to the current working directory where the server is run.

```env
export PORT=8443
export TLS_CERT_PATH=/path/to/fullchain.pem
export TLS_KEY_PATH=/path/to/privkey.pem
export KAS_KEY_PATH=/path/to/recipient_private_key.pem
export ENABLE_TIMING_LOGS=true
export RUST_LOG=info
```

(Optional) Set the environment variables if you want to override the defaults.

##### Security Note

Remember to keep your private keys secure and never commit them to version control systems. It's recommended to use
environment variables or secure vaults for managing sensitive information in production environments.

2. Start the server:

```shell
cargo run
```

The server will start and listen on the configured port (default: 8443) using HTTPS.

### Usage

- **Key Agreement**: The server establishes a shared secret with each client using ECDH.
Expand Down
162 changes: 120 additions & 42 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::env;
use std::sync::Arc;
use std::sync::RwLock;
use std::time::Instant;

use aes_gcm::aead::{Key, NewAead};
use aes_gcm::aead::Aead;
Expand All @@ -8,13 +10,15 @@ use aes_gcm::Aes256Gcm;
use elliptic_curve::point::AffineCoordinates;
use futures_util::{SinkExt, StreamExt};
use hkdf::Hkdf;
use log::info;
use once_cell::sync::OnceCell;
use p256::{elliptic_curve::sec1::ToEncodedPoint, PublicKey, SecretKey};
use p256::ecdh::EphemeralSecret;
use rand_core::{OsRng, RngCore};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::net::{TcpListener, TcpStream};
use tokio_native_tls::TlsAcceptor;
use tokio_tungstenite::accept_async;
use tokio_tungstenite::tungstenite::Message;

Expand Down Expand Up @@ -72,29 +76,58 @@ struct KasKeys {
static KAS_KEYS: OnceCell<Arc<KasKeys>> = OnceCell::new();

#[tokio::main(flavor = "multi_thread", worker_threads = 4)]
async fn main() {
// KAS public key
init_kas_keys().expect("KAS key not loaded");
// Bind the server to localhost on port 8080
let try_socket = TcpListener::bind("0.0.0.0:8080").await;
let listener = match try_socket {
Ok(socket) => socket,
Err(e) => {
println!("Failed to bind to port: {}", e);
return;
}
};
println!("Listening on: 0.0.0.0:8080");
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Initialize logging
env_logger::init();

// Load configuration
let settings = load_config()?;

// Initialize KAS keys
init_kas_keys(&settings.kas_key_path)?;

// Set up TLS
let tls_config = load_tls_config(&settings.tls_cert_path, &settings.tls_key_path)?;
let tls_acceptor = TlsAcceptor::from(tls_config);

// Bind the server
let listener = TcpListener::bind(format!("0.0.0.0:{}", settings.port)).await?;
println!("Listening on: 0.0.0.0:{}", settings.port);

// Accept connections
while let Ok((stream, _)) = listener.accept().await {
let tls_acceptor = tls_acceptor.clone();
let connection_state = Arc::new(ConnectionState::new());
let settings_clone = settings.clone();

tokio::spawn(async move {
handle_connection(stream, connection_state).await
match tls_acceptor.accept(stream).await {
Ok(tls_stream) => {
handle_connection(tls_stream, connection_state, &settings_clone).await;
}
Err(e) => eprintln!("Failed to accept TLS connection: {}", e),
}
});
}

Ok(())
}

fn load_tls_config(cert_path: &str, key_path: &str) -> Result<native_tls::TlsAcceptor, Box<dyn std::error::Error>> {
let cert = std::fs::read(cert_path)?;
let key = std::fs::read(key_path)?;

let identity = native_tls::Identity::from_pkcs8(&cert, &key)?;
let acceptor = native_tls::TlsAcceptor::new(identity)?;

Ok(acceptor)
}

async fn handle_connection(stream: TcpStream, connection_state: Arc<ConnectionState>) {
async fn handle_connection(
stream: tokio_native_tls::TlsStream<TcpStream>,
connection_state: Arc<ConnectionState>,
settings: &ServerSettings,
) {
let ws_stream = match accept_async(stream).await {
Ok(ws) => ws,
Err(e) => {
Expand All @@ -112,7 +145,7 @@ async fn handle_connection(stream: TcpStream, connection_state: Arc<ConnectionSt
println!("Received a close message.");
return;
}
if let Some(response) = handle_binary_message(&connection_state, msg.into_data()).await
if let Some(response) = handle_binary_message(&connection_state, msg.into_data(), settings).await
{
// TODO remove clone
ws_sender.send(response.clone()).await.expect("ws send failed");
Expand All @@ -129,6 +162,7 @@ async fn handle_connection(stream: TcpStream, connection_state: Arc<ConnectionSt
async fn handle_binary_message(
connection_state: &Arc<ConnectionState>,
data: Vec<u8>,
settings: &ServerSettings,
) -> Option<Message> {
if data.len() < 1 {
println!("Invalid message format");
Expand All @@ -140,7 +174,7 @@ async fn handle_binary_message(
match message_type {
Some(MessageType::PublicKey) => handle_public_key(connection_state, payload).await,
Some(MessageType::KasPublicKey) => handle_kas_public_key(payload).await,
Some(MessageType::Rewrap) => handle_rewrap(connection_state, payload).await,
Some(MessageType::Rewrap) => handle_rewrap(connection_state, payload, settings).await,
Some(MessageType::RewrappedKey) => None,
None => {
println!("Unknown message type: {:?}", message_type);
Expand All @@ -149,96 +183,111 @@ async fn handle_binary_message(
}
}

struct PrintOnDrop;

impl Drop for PrintOnDrop {
fn drop(&mut self) {
// println!("END handle_rewrap");
}
}
// struct PrintOnDrop;
//
// impl Drop for PrintOnDrop {
// fn drop(&mut self) {
// // println!("END handle_rewrap");
// }
// }

async fn handle_rewrap(
connection_state: &Arc<ConnectionState>,
payload: &[u8],
settings: &ServerSettings,
) -> Option<Message> {
let _print_on_drop = PrintOnDrop;
// println!("BEGIN handle_rewrap");
let start_time = Instant::now();

let session_shared_secret = {
let shared_secret = connection_state.shared_secret_lock.read().unwrap();
shared_secret.clone()
};
if session_shared_secret == None {
println!("Shared Secret not set");
info!("Shared Secret not set");
return None;
}
let session_shared_secret = session_shared_secret.unwrap();

// Parse NanoTDF header
let mut parser = BinaryParser::new(payload);
let header = match BinaryParser::parse_header(&mut parser) {
Ok(header) => header,
Err(e) => {
println!("Error parsing header: {:?}", e);
info!("Error parsing header: {:?}", e);
return None;
}
};
// Extract the policy
// let policy = header.get_policy();
// println!("policy binding hex: {}", hex::encode(policy.get_binding().clone().unwrap()));

let parse_time = start_time.elapsed();
log_timing(settings, "Time to parse header", parse_time);

// TDF ephemeral key
let tdf_ephemeral_key_bytes = header.get_ephemeral_key();
// println!("tdf_ephemeral_key hex: {}", hex::encode(tdf_ephemeral_key_bytes));
// Deserialize the public key sent by the client
if tdf_ephemeral_key_bytes.len() != 33 {
println!("Invalid TDF compressed ephemeral key length");
info!("Invalid TDF compressed ephemeral key length");
return None;
}

// Deserialize the public key sent by the client
let tdf_ephemeral_public_key = match PublicKey::from_sec1_bytes(tdf_ephemeral_key_bytes) {
Ok(key) => key,
Err(e) => {
println!("Error deserializing TDF ephemeral public key: {:?}", e);
info!("Error deserializing TDF ephemeral public key: {:?}", e);
return None;
}
};

let kas_private_key_bytes = get_kas_private_key_bytes().unwrap();
// println!("kas_private_key_bytes {}", hex::encode(&kas_private_key_bytes));
let kas_private_key_array: [u8; 32] = match kas_private_key_bytes.try_into() {
Ok(key) => key,
Err(_) => return None,
};
let kas_private_key = SecretKey::from_bytes(&kas_private_key_array.into())
.map_err(|_| "Invalid private key")
.ok()?;

// Perform custom ECDH
let ecdh_start = Instant::now();
let dek_shared_secret_bytes = match custom_ecdh(&kas_private_key, &tdf_ephemeral_public_key) {
Ok(secret) => secret,
Err(e) => {
println!("Error performing ECDH: {:?}", e);
info!("Error performing ECDH: {:?}", e);
return None;
}
};
let ecdh_time = ecdh_start.elapsed();
log_timing(settings, "Time for ECDH operation", ecdh_time);

// Encrypt dek_shared_secret with symmetric key using AES GCM
let salt = connection_state.salt_lock.read().unwrap().clone().unwrap();
let info = "rewrappedKey".as_bytes();
let hkdf = Hkdf::<Sha256>::new(Some(&salt), &session_shared_secret);
let mut derived_key = [0u8; 32];
hkdf.expand(info, &mut derived_key).expect("HKDF expansion failed");
// println!("Derived Session Key: {}", hex::encode(&derived_key));

let mut nonce = [0u8; 12];
OsRng.fill_bytes(&mut nonce);
let nonce = GenericArray::from_slice(&nonce);
// println!("nonce {}", hex::encode(nonce));

let key = Key::<Aes256Gcm>::from(derived_key);
let cipher = Aes256Gcm::new(&key);

let encryption_start = Instant::now();
let wrapped_dek = cipher.encrypt(nonce, dek_shared_secret_bytes.as_ref())
.expect("encryption failure!");
// println!("Rewrapped Key and Authentication tag {}", hex::encode(&wrapped_dek));
let encryption_time = encryption_start.elapsed();
log_timing(settings, "Time for AES-GCM encryption", encryption_time);

// binary response
let mut response_data = Vec::new();
response_data.push(MessageType::RewrappedKey as u8);
response_data.extend_from_slice(tdf_ephemeral_key_bytes);
response_data.extend_from_slice(&nonce);
response_data.extend_from_slice(&wrapped_dek);

let total_time = start_time.elapsed();
log_timing(settings, "Total time for handle_rewrap", total_time);

Some(Message::Binary(response_data))
}

Expand Down Expand Up @@ -316,8 +365,8 @@ async fn handle_kas_public_key(_: &[u8]) -> Option<Message> {
None
}

fn init_kas_keys() -> Result<(), Box<dyn std::error::Error>> {
let pem_content = std::fs::read_to_string("recipient_private_key.pem")?;
fn init_kas_keys(key_path: &str) -> Result<(), Box<dyn std::error::Error>> {
let pem_content = std::fs::read_to_string(key_path)?;
let ec_pem_contents = pem_content.as_bytes();
let pem = pem::parse(ec_pem_contents)?;
if pem.tag() != "EC PRIVATE KEY" {
Expand Down Expand Up @@ -374,6 +423,35 @@ fn custom_ecdh(secret_key: &SecretKey, public_key: &PublicKey) -> Result<Vec<u8>
Ok(shared_secret)
}

#[derive(Debug, Deserialize, Clone)]
struct ServerSettings {
port: u16,
tls_cert_path: String,
tls_key_path: String,
kas_key_path: String,
enable_timing_logs: bool,
}

fn log_timing(settings: &ServerSettings, message: &str, duration: std::time::Duration) {
if settings.enable_timing_logs {
info!("{}: {:?}", message, duration);
}
}
fn load_config() -> Result<ServerSettings, Box<dyn std::error::Error>> {
let current_dir = env::current_dir()?;

Ok(ServerSettings {
port: env::var("PORT").unwrap_or_else(|_| "8443".to_string()).parse()?,
tls_cert_path: env::var("TLS_CERT_PATH")
.unwrap_or_else(|_| current_dir.join("fullchain.pem").to_str().unwrap().to_string()),
tls_key_path: env::var("TLS_KEY_PATH")
.unwrap_or_else(|_| current_dir.join("privkey.pem").to_str().unwrap().to_string()),
kas_key_path: env::var("KAS_KEY_PATH")
.unwrap_or_else(|_| current_dir.join("recipient_private_key.pem").to_str().unwrap().to_string()),
enable_timing_logs: env::var("ENABLE_TIMING_LOGS").unwrap_or_else(|_| "false".to_string()).parse().unwrap_or(false),
})
}

#[cfg(test)]
mod tests {
use std::error::Error;
Expand Down