diff --git a/Cargo.lock b/Cargo.lock index 59e5b8b..211392c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -515,6 +515,7 @@ dependencies = [ "serde", "serde_json", "sha3", + "zeroize", ] [[package]] diff --git a/ml-kem/Cargo.toml b/ml-kem/Cargo.toml index 1b44806..36aac25 100644 --- a/ml-kem/Cargo.toml +++ b/ml-kem/Cargo.toml @@ -19,12 +19,14 @@ exclude = ["tests/key-gen.rs", "tests/key-gen.json", "tests/encap-decap.rs", "te default = ["std"] std = ["sha3/std"] deterministic = [] # Expose deterministic generation and encapsulation functions +zeroize = ["dep:zeroize"] [dependencies] kem = "0.3.0-pre.0" hybrid-array = { version = "0.2.0-rc.9", features = ["extra-sizes"] } rand_core = "0.6.4" sha3 = { version = "0.10.8", default-features = false } +zeroize = { version = "1.8.1", optional = true, default-features = false } [dev-dependencies] criterion = "0.5.1" diff --git a/ml-kem/src/algebra.rs b/ml-kem/src/algebra.rs index d4a42ad..dad1de0 100644 --- a/ml-kem/src/algebra.rs +++ b/ml-kem/src/algebra.rs @@ -7,6 +7,9 @@ use crate::encode::Encode; use crate::param::{ArraySize, CbdSamplingSize}; use crate::util::{Truncate, B32}; +#[cfg(feature = "zeroize")] +use zeroize::Zeroize; + pub type Integer = u16; /// An element of GF(q). Although `q` is only 16 bits wide, we use a wider uint type to so that we @@ -14,6 +17,13 @@ pub type Integer = u16; #[derive(Copy, Clone, Debug, Default, PartialEq)] pub struct FieldElement(pub Integer); +#[cfg(feature = "zeroize")] +impl Zeroize for FieldElement { + fn zeroize(&mut self) { + self.0.zeroize(); + } +} + impl FieldElement { pub const Q: Integer = 3329; pub const Q32: u32 = Self::Q as u32; @@ -174,6 +184,15 @@ impl PolynomialVector { #[derive(Clone, Default, Debug, PartialEq)] pub struct NttPolynomial(pub Array); +#[cfg(feature = "zeroize")] +impl Zeroize for NttPolynomial { + fn zeroize(&mut self) { + for fe in self.0.iter_mut() { + fe.zeroize() + } + } +} + impl Add<&NttPolynomial> for &NttPolynomial { type Output = NttPolynomial; @@ -410,6 +429,18 @@ impl NttVector { } } +#[cfg(feature = "zeroize")] +impl Zeroize for NttVector +where + K: ArraySize, +{ + fn zeroize(&mut self) { + for poly in self.0.iter_mut() { + poly.zeroize(); + } + } +} + impl Add<&NttVector> for &NttVector { type Output = NttVector; diff --git a/ml-kem/src/kem.rs b/ml-kem/src/kem.rs index 573f9f2..0047125 100644 --- a/ml-kem/src/kem.rs +++ b/ml-kem/src/kem.rs @@ -8,6 +8,9 @@ use crate::pke::{DecryptionKey, EncryptionKey}; use crate::util::B32; use crate::{Encoded, EncodedSizeUser}; +#[cfg(feature = "zeroize")] +use zeroize::{Zeroize, ZeroizeOnDrop}; + // Re-export traits from the `kem` crate pub use ::kem::{Decapsulate, Encapsulate}; @@ -26,6 +29,20 @@ where z: B32, } +#[cfg(feature = "zeroize")] +impl

Drop for DecapsulationKey

+where + P: KemParams, +{ + fn drop(&mut self) { + self.dk_pke.zeroize(); + self.z.zeroize(); + } +} + +#[cfg(feature = "zeroize")] +impl

ZeroizeOnDrop for DecapsulationKey

where P: KemParams {} + impl

EncodedSizeUser for DecapsulationKey

where P: KemParams, diff --git a/ml-kem/src/pke.rs b/ml-kem/src/pke.rs index c4d7369..571a1d6 100644 --- a/ml-kem/src/pke.rs +++ b/ml-kem/src/pke.rs @@ -7,6 +7,9 @@ use crate::encode::Encode; use crate::param::{EncodedCiphertext, EncodedDecryptionKey, EncodedEncryptionKey, PkeParams}; use crate::util::B32; +#[cfg(feature = "zeroize")] +use zeroize::Zeroize; + /// A `DecryptionKey` provides the ability to generate a new key pair, and decrypt an /// encrypted value. #[derive(Clone, Default, Debug, PartialEq)] @@ -17,6 +20,16 @@ where s_hat: NttVector, } +#[cfg(feature = "zeroize")] +impl

Zeroize for DecryptionKey

+where + P: PkeParams, +{ + fn zeroize(&mut self) { + self.s_hat.zeroize(); + } +} + impl

DecryptionKey

where P: PkeParams,