Skip to content

Commit

Permalink
Add NTRU KEM per Bernstein's recommendation (#197)
Browse files Browse the repository at this point in the history
  • Loading branch information
tbraun96 authored Oct 23, 2023
1 parent 319ca88 commit 8b18bf9
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 32 deletions.
2 changes: 1 addition & 1 deletion citadel_pqcrypto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ ascon-aead = { workspace = true }
zeroize = { workspace = true, features = ["zeroize_derive", "alloc", "serde"] }

[target.'cfg(not(target_family = "wasm"))'.dependencies]
oqs = { workspace = true, features = ["serde", "falcon"] }
oqs = { workspace = true, features = ["serde", "falcon", "ntruprime"] }

[target.'cfg(target_family = "wasm")'.dependencies]
pqcrypto-falcon-wasi = { workspace = true, features = ["serialization", "avx2"] }
Expand Down
21 changes: 16 additions & 5 deletions citadel_pqcrypto/src/encryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,21 +209,32 @@ pub(crate) mod kyber_module {
}

pub fn encrypt_pke<T: AsRef<[u8]>, R: AsRef<[u8]>, V: AsRef<[u8]>>(
_: KemAlgorithm,
kem_alg: KemAlgorithm,
local_pk: T,
plaintext: R,
nonce: V,
) -> Result<Vec<u8>, Error> {
kyber_pke::encrypt(local_pk, plaintext, nonce)
.map_err(|err| Error::Other(format!("{err:?}")))
match kem_alg {
KemAlgorithm::Kyber => kyber_pke::encrypt(local_pk, plaintext, nonce)
.map_err(|err| Error::Other(format!("{err:?}"))),
KemAlgorithm::Ntru => Err(Error::Other(format!(
"Kem ALG {kem_alg:?} does not support PKE"
))),
}
}

pub fn decrypt_pke<T: AsRef<[u8]>, R: AsRef<[u8]>>(
_: KemAlgorithm,
kem_alg: KemAlgorithm,
local_sk: T,
ciphertext: R,
) -> Result<Vec<u8>, Error> {
kyber_pke::decrypt(local_sk, ciphertext).map_err(|err| Error::Other(format!("{err:?}")))
match kem_alg {
KemAlgorithm::Kyber => kyber_pke::decrypt(local_sk, ciphertext)
.map_err(|err| Error::Other(format!("{err:?}"))),
KemAlgorithm::Ntru => Err(Error::Other(format!(
"Kem ALG {kem_alg:?} does not support PKE"
))),
}
}

fn encode_length_be_bytes(len: usize, buf: &mut dyn Buffer) -> Result<(), Error> {
Expand Down
83 changes: 73 additions & 10 deletions citadel_pqcrypto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,10 @@ pub mod algorithm_dictionary {
}

pub fn plaintext_length(&self, ciphertext: &[u8]) -> Option<usize> {
if ciphertext.len() < 16 {
return None;
}

match self {
Self::AES_GCM_256 => Some(ciphertext.len() - 16),
Self::ChaCha20Poly_1305 => Some(ciphertext.len() - 16),
Expand Down Expand Up @@ -835,6 +839,7 @@ pub mod algorithm_dictionary {
#[strum(ascii_case_insensitive)]
#[default]
Kyber = 0,
Ntru = 1,
}

#[derive(
Expand Down Expand Up @@ -995,15 +1000,29 @@ pub enum PostQuantumMeta {
impl PostQuantumMeta {
fn new_alice(kem_alg: KemAlgorithm, sig_alg: SigAlgorithm) -> Result<Self, Error> {
log::trace!(target: "citadel", "About to generate keypair for {:?}", kem_alg);
let pk_alice = kyber_pke::kem_keypair().map_err(|err| Error::Other(err.to_string()))?;
let (public_key, secret_key) = (pk_alice.public, pk_alice.secret);
let (public_key, secret_key) = match kem_alg {
KemAlgorithm::Kyber => {
let pk_alice =
kyber_pke::kem_keypair().map_err(|err| Error::Other(err.to_string()))?;
(pk_alice.public.to_vec(), pk_alice.secret.to_vec())
}
KemAlgorithm::Ntru => {
let (pk_alice, sk_alice) =
oqs::kem::Kem::new(oqs::kem::Algorithm::NtruPrimeSntrup761)
.map_err(|err| Error::Other(err.to_string()))?
.keypair()
.map_err(|err| Error::Other(err.to_string()))?;
(pk_alice.into_vec(), sk_alice.into_vec())
}
};

let ciphertext = None;
let shared_secret = None;
let remote_sig_public_key = None;
let secret_key = Some(Arc::new(secret_key.to_vec().into()));
let secret_key = Some(Arc::new(secret_key.into()));

let kex = PostQuantumMetaKex {
public_key: Arc::new(public_key.to_vec().into()),
public_key: Arc::new(public_key.into()),
secret_key,
ciphertext,
shared_secret,
Expand Down Expand Up @@ -1042,9 +1061,20 @@ impl PostQuantumMeta {
} => (*kem_scheme, alice_pk),
};

let pk_bob = kyber_pke::kem_keypair().map_err(|err| Error::Other(err.to_string()))?;
let (kem_pk_bob, kem_sk_bob) = (pk_bob.public, pk_bob.secret);
let (kem_pk_bob, kem_sk_bob) = (kem_pk_bob.to_vec(), kem_sk_bob.to_vec());
let (kem_pk_bob, kem_sk_bob) = match kem_scheme {
KemAlgorithm::Kyber => {
let pk_bob =
kyber_pke::kem_keypair().map_err(|err| Error::Other(err.to_string()))?;
(pk_bob.public.to_vec(), pk_bob.secret.to_vec())
}
KemAlgorithm::Ntru => {
let (pk_bob, sk_bob) = oqs::kem::Kem::new(oqs::kem::Algorithm::NtruPrimeSntrup761)
.map_err(|err| Error::Other(err.to_string()))?
.keypair()
.map_err(|err| Error::Other(err.to_string()))?;
(pk_bob.into_vec(), sk_bob.into_vec())
}
};

let (ciphertext, shared_secret) = match kem_scheme {
KemAlgorithm::Kyber => {
Expand All @@ -1053,6 +1083,19 @@ impl PostQuantumMeta {
.map_err(|_err| get_generic_error("Failed encapsulate step"))?;
(ciphertext.to_vec(), shared_secret.to_vec())
}

KemAlgorithm::Ntru => {
let kem = oqs::kem::Kem::new(oqs::kem::Algorithm::NtruPrimeSntrup761)
.map_err(|err| Error::Other(err.to_string()))?;
let wrapper = kem
.public_key_from_bytes(pk_alice.as_slice())
.ok_or_else(|| Error::Other("Bad public key input".to_string()))?;

let (ciphertext, shared_secret) = kem
.encapsulate(wrapper)
.map_err(|err| Error::Other(err.to_string()))?;
(ciphertext.into_vec(), shared_secret.into_vec())
}
};

let public_key = Arc::new(kem_pk_bob.into());
Expand Down Expand Up @@ -1141,9 +1184,29 @@ impl PostQuantumMeta {
};

let secret_key = self.get_secret_key()?;
let shared_secret = kyber_pke::decapsulate(&bob_ciphertext, secret_key)
.map_err(|err| Error::Other(err.to_string()))?;
self.get_kex_mut().shared_secret = Some(Arc::new(shared_secret.to_vec().into()));

let shared_secret = match self.kex().kem_alg {
KemAlgorithm::Kyber => kyber_pke::decapsulate(&bob_ciphertext, secret_key)
.map_err(|err| Error::Other(err.to_string()))?
.to_vec(),
KemAlgorithm::Ntru => {
let kem = oqs::kem::Kem::new(oqs::kem::Algorithm::NtruPrimeSntrup761)
.map_err(|err| Error::Other(err.to_string()))?;

let wrapper_sk = kem
.secret_key_from_bytes(secret_key.as_slice())
.ok_or_else(|| Error::Other("Bad secret key input".to_string()))?;
let wrapper_ct = kem
.ciphertext_from_bytes(bob_ciphertext.as_slice())
.ok_or_else(|| Error::Other("Bad ciphertext input".to_string()))?;

kem.decapsulate(wrapper_sk, wrapper_ct)
.map_err(|err| Error::Other(err.to_string()))?
.into_vec()
}
};

self.get_kex_mut().shared_secret = Some(Arc::new(shared_secret.into()));
self.get_kex_mut().ciphertext = Some(bob_ciphertext);

match params {
Expand Down
20 changes: 12 additions & 8 deletions citadel_pqcrypto/tests/primary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,10 @@ mod tests {

assert_eq!(plaintext.as_slice(), decrypted);

// test local encryption
local_encryption(&alice_container, &bob_container, &plaintext, nonce);
if kem_algorithm == KemAlgorithm::Kyber {
// test local encryption
local_encryption(&alice_container, &bob_container, &plaintext, nonce);
}
}

Ok(())
Expand Down Expand Up @@ -355,12 +357,14 @@ mod tests {
SigAlgorithm::None,
)
.unwrap();
run(
algorithm.as_u8(),
EncryptionAlgorithm::Kyber,
SigAlgorithm::Falcon1024,
)
.unwrap();
if algorithm == KemAlgorithm::Kyber {
run(
algorithm.as_u8(),
EncryptionAlgorithm::Kyber,
SigAlgorithm::Falcon1024,
)
.unwrap();
}
}
}

Expand Down
5 changes: 4 additions & 1 deletion citadel_proto/src/proto/node_result.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::prelude::{GroupBroadcast, GroupChannel, PeerChannel, PeerSignal, UdpChannel};
use crate::prelude::{
GroupBroadcast, GroupChannel, PeerChannel, PeerSignal, SessionSecuritySettings, UdpChannel,
};
use crate::proto::peer::peer_layer::MailboxTransfer;
use crate::proto::remote::Ticket;
use crate::proto::state_container::VirtualConnectionType;
Expand Down Expand Up @@ -39,6 +41,7 @@ pub struct ConnectSuccess {
pub welcome_message: String,
pub channel: PeerChannel,
pub udp_rx_opt: Option<tokio::sync::oneshot::Receiver<UdpChannel>>,
pub session_security_settings: SessionSecuritySettings,
}

#[derive(Debug)]
Expand Down
11 changes: 10 additions & 1 deletion citadel_proto/src/proto/packet_processor/connect_packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ pub async fn process_connect(
session,
);

let session_security_settings = state_container
.session_security_settings
.expect("Should be set");

drop(state_container);

// Upgrade the connect BEFORE updating the CNAC
Expand Down Expand Up @@ -146,7 +150,8 @@ pub async fn process_connect(
services: post_login_object,
welcome_message: format!("Client {cid} successfully established a connection to the local HyperNode"),
channel,
udp_rx_opt: udp_channel_rx
udp_rx_opt: udp_channel_rx,
session_security_settings
});
// safe unwrap. Store the signal
inner_mut_state!(session.state_container)
Expand Down Expand Up @@ -255,6 +260,9 @@ pub async fn process_connect(
header.session_cid.get(),
session,
);
let session_security_settings = state_container
.session_security_settings
.expect("Should be set");
std::mem::drop(state_container);

session.implicated_cid.set(Some(cid)); // This makes is_provisional equal to false
Expand Down Expand Up @@ -305,6 +313,7 @@ pub async fn process_connect(
welcome_message: message,
channel,
udp_rx_opt: udp_channel_rx,
session_security_settings,
}))?;
//finally, if there are any mailbox items, send them to the kernel for processing
if let Some(mailbox_delivery) = payload.mailbox {
Expand Down
4 changes: 2 additions & 2 deletions citadel_sdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
//!
//! # Encryption Algorithms
//! The user may also select a symmetric encryption algorithm before a session starts (see: [SessionSecuritySettingsBuilder](crate::prelude::SessionSecuritySettingsBuilder))
//! - AES-256-GCM-SIV
//! - XChacha20Poly-1305
//! - AES-256-GCM
//! - Chacha20Poly-1305
//! - Ascon-80pq
//! - Kyber "scramcryption" (see below for explanation)
//!
Expand Down
2 changes: 2 additions & 0 deletions citadel_sdk/src/prefabs/client/single_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ where
tracing::instrument(target = "citadel", skip_all, err(Debug))
)]
async fn on_start(&self) -> Result<(), NetworkError> {
let session_security_settings = self.session_security_settings;
let remote = self.remote.clone().unwrap();
let (auth_info, handler) = {
(
Expand Down Expand Up @@ -259,6 +260,7 @@ where
inner: remote,
unprocessed_signals_rx: Arc::new(Mutex::new(unprocessed_signal_filter)),
conn_type,
session_security_settings,
},
)
.await
Expand Down
12 changes: 11 additions & 1 deletion citadel_sdk/src/prefabs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,23 @@ pub struct ClientServerRemote {
pub(crate) inner: NodeRemote,
pub(crate) unprocessed_signals_rx: Arc<Mutex<Option<UnboundedReceiver<NodeResult>>>>,
conn_type: VirtualTargetType,
session_security_settings: SessionSecuritySettings,
}

impl_remote!(ClientServerRemote);

impl ClientServerRemote {
/// constructs a new [`ClientServerRemote`] from a [`NodeRemote`] and a [`VirtualTargetType`]
pub fn new(conn_type: VirtualTargetType, remote: NodeRemote) -> Self {
pub fn new(
conn_type: VirtualTargetType,
remote: NodeRemote,
session_security_settings: SessionSecuritySettings,
) -> Self {
Self {
inner: remote,
unprocessed_signals_rx: Default::default(),
conn_type,
session_security_settings,
}
}
/// Can only be called once per remote. Allows receiving events
Expand All @@ -55,6 +61,10 @@ impl TargetLockedRemote for ClientServerRemote {
fn user_mut(&mut self) -> &mut VirtualTargetType {
&mut self.conn_type
}

fn session_security_settings(&self) -> Option<&SessionSecuritySettings> {
Some(&self.session_security_settings)
}
}

impl ClientServerRemote {
Expand Down
9 changes: 7 additions & 2 deletions citadel_sdk/src/prefabs/server/client_connect_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,20 @@ where
welcome_message: _,
channel,
udp_rx_opt: udp_channel_rx,
session_security_settings,
}) => {
let client_server_remote =
ClientServerRemote::new(conn_type, self.node_remote.clone().unwrap());
let client_server_remote = ClientServerRemote::new(
conn_type,
self.node_remote.clone().unwrap(),
session_security_settings,
);
(self.on_channel_received)(
ConnectionSuccess {
channel,
udp_channel_rx,
services,
cid,
session_security_settings,
},
client_server_remote,
)
Expand Down
Loading

0 comments on commit 8b18bf9

Please sign in to comment.