From 5a0a19c1a4d704e805049f97153e53e426d57bfc Mon Sep 17 00:00:00 2001 From: Martin Grenouilloux Date: Wed, 28 Feb 2024 17:36:55 +0100 Subject: [PATCH] mitigate kyberslash with official patching method from pq-crystals/kyber --- src/reference/poly.rs | 28 ++++++++++++++++++++++------ src/reference/polyvec.rs | 20 ++++++++++++++++---- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/src/reference/poly.rs b/src/reference/poly.rs index 236eb3d..ec9bd19 100644 --- a/src/reference/poly.rs +++ b/src/reference/poly.rs @@ -32,6 +32,7 @@ pub fn poly_compress(r: &mut [u8], a: Poly) { let mut t = [0u8; 8]; let mut k = 0usize; let mut u: i16; + let mut d0: u32; match KYBER_POLYCOMPRESSEDBYTES { 128 => { @@ -40,7 +41,12 @@ pub fn poly_compress(r: &mut [u8], a: Poly) { // map to positive standard representatives u = a.coeffs[8 * i + j]; u += (u >> 15) & KYBER_Q as i16; - t[j] = (((((u as u16) << 4) + KYBER_Q as u16 / 2) / KYBER_Q as u16) & 15) as u8; + /* t[j] = (((((u as u16) << 4) + KYBER_Q as u16 / 2) / KYBER_Q as u16) & 15) as u8; */ + d0 = ((u as u16) << 4) as u32; + d0 = d0.wrapping_add(1665); + d0 = d0.wrapping_mul(80635); + d0 >>= 28; + t[j] = (d0 & 0xf) as u8; } r[k] = t[0] | (t[1] << 4); r[k + 1] = t[2] | (t[3] << 4); @@ -55,7 +61,12 @@ pub fn poly_compress(r: &mut [u8], a: Poly) { // map to positive standard representatives u = a.coeffs[8 * i + j]; u += (u >> 15) & KYBER_Q as i16; - t[j] = (((((u as u32) << 5) + KYBER_Q as u32 / 2) / KYBER_Q as u32) & 31) as u8; + /* t[j] = (((((u as u32) << 5) + KYBER_Q as u32 / 2) / KYBER_Q as u32) & 31) as u8; */ + d0 = ((u as u32) << 5) as u32; + d0 = d0.wrapping_add(1664); + d0 = d0.wrapping_mul(40318); + d0 >>= 27; + t[j] = (d0 & 0x1f) as u8; } r[k] = t[0] | (t[1] << 5); r[k + 1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7); @@ -300,14 +311,19 @@ pub fn poly_frommsg(r: &mut Poly, msg: &[u8]) { /// Arguments: - [u8] msg: output message /// - const poly *a: input polynomial pub fn poly_tomsg(msg: &mut [u8], a: Poly) { - let mut t; + let mut t: u32; for i in 0..KYBER_N / 8 { msg[i] = 0; for j in 0..8 { - t = a.coeffs[8 * i + j]; - t += (t >> 15) & KYBER_Q as i16; - t = (((t << 1) + KYBER_Q as i16 / 2) / KYBER_Q as i16) & 1; + t = a.coeffs[8 * i + j] as u32; + // t += (t >> 15) & KYBER_Q as i16; + // t = (((t << 1) + KYBER_Q as i16 / 2) / KYBER_Q as i16) & 1; + t <<= 1; + t = t.wrapping_add(1665); + t = t.wrapping_mul(80635); + t >>= 28; + t &= 1; msg[i] |= (t << j) as u8; } } diff --git a/src/reference/polyvec.rs b/src/reference/polyvec.rs index 4b9e69a..324c63d 100644 --- a/src/reference/polyvec.rs +++ b/src/reference/polyvec.rs @@ -25,6 +25,7 @@ impl Polyvec { pub fn polyvec_compress(r: &mut [u8], a: Polyvec) { #[cfg(feature = "kyber1024")] { + let mut d0: u64; let mut t = [0u16; 8]; let mut idx = 0usize; for i in 0..KYBER_K { @@ -32,8 +33,13 @@ pub fn polyvec_compress(r: &mut [u8], a: Polyvec) { for k in 0..8 { t[k] = a.vec[i].coeffs[8 * j + k] as u16; t[k] = t[k].wrapping_add((((t[k] as i16) >> 15) & KYBER_Q as i16) as u16); - t[k] = (((((t[k] as u32) << 11) + KYBER_Q as u32 / 2) / KYBER_Q as u32) & 0x7ff) - as u16; + /*t[k] = (((((t[k] as u32) << 11) + KYBER_Q as u32 / 2) / KYBER_Q as u32) & 0x7ff) as u16; */ + d0 = t[k] as u64; + d0 <<= 11; + d0 = d0.wrapping_add(1664); + d0 = d0.wrapping_mul(645084); + d0 >>= 31; + t[k] = (d0 & 0x7ff) as u16; } r[idx + 0] = (t[0] >> 0) as u8; r[idx + 1] = ((t[0] >> 8) | (t[1] << 3)) as u8; @@ -53,6 +59,7 @@ pub fn polyvec_compress(r: &mut [u8], a: Polyvec) { #[cfg(not(feature = "kyber1024"))] { + let mut d0: u64; let mut t = [0u16; 4]; let mut idx = 0usize; for i in 0..KYBER_K { @@ -60,8 +67,13 @@ pub fn polyvec_compress(r: &mut [u8], a: Polyvec) { for k in 0..4 { t[k] = a.vec[i].coeffs[4 * j + k] as u16; t[k] = t[k].wrapping_add((((t[k] as i16) >> 15) & KYBER_Q as i16) as u16); - t[k] = (((((t[k] as u32) << 10) + KYBER_Q as u32 / 2) / KYBER_Q as u32) & 0x3ff) - as u16; + /* t[k] = (((((t[k] as u32) << 10) + KYBER_Q as u32 / 2) / KYBER_Q as u32) & 0x3ff) as u16; */ + d0 = t[k] as u64; + d0 <<= 10; + d0 = d0.wrapping_add(1665); + d0 = d0.wrapping_mul(1290167); + d0 >>= 32; + t[k] = (d0 & 0x3ff) as u16; } r[idx + 0] = (t[0] >> 0) as u8; r[idx + 1] = ((t[0] >> 8) | (t[1] << 2)) as u8;