Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add proofs for portable compress module #631

Open
wants to merge 1 commit into
base: experiment-refined-ints
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,43 @@ let compress_ciphertext_coefficient (coefficient_bits: u8) (fe: u16) =

let compress_message_coefficient (fe: u16) =
let (shifted: i16):i16 = Rust_primitives.mk_i16 1664 -! (cast (fe <: u16) <: i16) in
let _:Prims.unit = assert (v shifted == 1664 - v fe) in
let mask:i16 = shifted >>! Rust_primitives.mk_i32 15 in
let _:Prims.unit =
assert (v mask = v shifted / pow2 15);
assert (if v shifted < 0 then mask = ones else mask = zero)
in
let shifted_to_positive:i16 = mask ^. shifted in
let _:Prims.unit =
logxor_lemma shifted mask;
assert (v shifted < 0 ==> v shifted_to_positive = v (lognot shifted));
neg_equiv_lemma shifted;
assert (v (lognot shifted) = - (v shifted) - 1);
assert (v shifted >= 0 ==> v shifted_to_positive = v (mask `logxor` shifted));
assert (v shifted >= 0 ==> mask = zero);
assert (v shifted >= 0 ==> mask ^. shifted = shifted);
assert (v shifted >= 0 ==> v shifted_to_positive = v shifted);
assert (shifted_to_positive >=. mk_i16 0)
in
let shifted_positive_in_range:i16 = shifted_to_positive -! Rust_primitives.mk_i16 832 in
cast ((shifted_positive_in_range >>! Rust_primitives.mk_i32 15 <: i16) &. Rust_primitives.mk_i16 1
<:
i16)
<:
u8
let _:Prims.unit =
assert (1664 - v fe >= 0 ==> v shifted_positive_in_range == 832 - v fe);
assert (1664 - v fe < 0 ==> v shifted_positive_in_range == - 2497 + v fe)
in
let r0:i16 = shifted_positive_in_range >>! Rust_primitives.mk_i32 15 in
let (r1: i16):i16 = r0 &. Rust_primitives.mk_i16 1 in
let res:u8 = cast (r1 <: i16) <: u8 in
let _:Prims.unit =
assert (v r0 = v shifted_positive_in_range / pow2 15);
assert (if v shifted_positive_in_range < 0 then r0 = ones else r0 = zero);
logand_lemma (mk_i16 1) r0;
assert (if v shifted_positive_in_range < 0 then r1 = mk_i16 1 else r1 = mk_i16 0);
assert ((v fe >= 833 && v fe <= 2496) ==> r1 = mk_i16 1);
assert (v fe < 833 ==> r1 = mk_i16 0);
assert (v fe > 2496 ==> r1 = mk_i16 0);
assert (v res = v r1)
in
res

#push-options "--fuel 0 --ifuel 0 --z3rlimit 2000"

Expand Down Expand Up @@ -167,45 +196,88 @@ let compress_1_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)

#pop-options

#push-options "--z3rlimit 300 --ext context_pruning"

let decompress_ciphertext_coefficient
(v_COEFFICIENT_BITS: i32)
(v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)
(a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)
=
let v:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector =
let _:Prims.unit =
assert_norm (pow2 1 == 2);
assert_norm (pow2 4 == 16);
assert_norm (pow2 5 == 32);
assert_norm (pow2 10 == 1024);
assert_norm (pow2 11 == 2048)
in
let a:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector =
Rust_primitives.Hax.Folds.fold_range (Rust_primitives.mk_usize 0)
Libcrux_ml_kem.Vector.Traits.v_FIELD_ELEMENTS_IN_VECTOR
(fun v temp_1_ ->
let v:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = v in
let _:usize = temp_1_ in
true)
v
(fun v i ->
let v:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = v in
(fun a i ->
let a:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = a in
let i:usize = i in
(v i < 16 ==>
(forall (j: nat).
(j >= v i /\ j < 16) ==>
v (Seq.index a.f_elements j) >= 0 /\
v (Seq.index a.f_elements j) < pow2 (v v_COEFFICIENT_BITS))) /\
(forall (j: nat).
j < v i ==>
v (Seq.index a.f_elements j) < v Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS))
a
(fun a i ->
let a:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = a in
let i:usize = i in
let _:Prims.unit =
assert (v (a.f_elements.[ i ] <: i16) < pow2 11);
assert (v (a.f_elements.[ i ] <: i16) == v (cast (a.f_elements.[ i ] <: i16) <: i32));
assert (v (Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS <: i16) ==
v (cast (Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS <: i16) <: i32));
assert (v ((cast (a.f_elements.[ i ] <: i16) <: i32) *!
(cast (Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS <: i16) <: i32)) ==
v (cast (a.f_elements.[ i ] <: i16) <: i32) *
v (cast (Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS <: i16) <: i32))
in
let decompressed:i32 =
(cast (v.Libcrux_ml_kem.Vector.Portable.Vector_type.f_elements.[ i ] <: i16) <: i32) *!
(cast (a.Libcrux_ml_kem.Vector.Portable.Vector_type.f_elements.[ i ] <: i16) <: i32) *!
(cast (Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS <: i16) <: i32)
in
let _:Prims.unit =
assert (v (decompressed <<! mk_i32 1) == v decompressed * 2);
assert (v (mk_i32 1 <<! v_COEFFICIENT_BITS) == pow2 (v v_COEFFICIENT_BITS));
assert (v ((decompressed <<! mk_i32 1) +! (mk_i32 1 <<! v_COEFFICIENT_BITS)) ==
v (decompressed <<! mk_i32 1) + v (mk_i32 1 <<! v_COEFFICIENT_BITS))
in
let decompressed:i32 =
(decompressed <<! Rust_primitives.mk_i32 1 <: i32) +!
(Rust_primitives.mk_i32 1 <<! v_COEFFICIENT_BITS <: i32)
in
let _:Prims.unit =
assert (v (v_COEFFICIENT_BITS +! mk_i32 1) == v v_COEFFICIENT_BITS + 1);
assert (v (decompressed >>! (v_COEFFICIENT_BITS +! mk_i32 1 <: i32)) ==
v decompressed / pow2 (v v_COEFFICIENT_BITS + 1))
in
let decompressed:i32 =
decompressed >>! (v_COEFFICIENT_BITS +! Rust_primitives.mk_i32 1 <: i32)
in
let v:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector =
let _:Prims.unit =
assert (v decompressed < v Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS);
assert (v (cast decompressed <: i16) < v Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS)
in
let a:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector =
{
v with
a with
Libcrux_ml_kem.Vector.Portable.Vector_type.f_elements
=
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize v
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize a
.Libcrux_ml_kem.Vector.Portable.Vector_type.f_elements
i
(cast (decompressed <: i32) <: i16)
}
<:
Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector
in
v)
a)
in
v
a

#pop-options
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ val compress_message_coefficient (fe: u16)
fun result ->
let result:u8 = result in
Hax_lib.implies ((Rust_primitives.mk_u16 833 <=. fe <: bool) &&
(fe <=. Rust_primitives.mk_u16 2596 <: bool))
(fe <=. Rust_primitives.mk_u16 2496 <: bool))
(fun temp_0_ ->
let _:Prims.unit = temp_0_ in
result =. Rust_primitives.mk_u8 1 <: bool) &&
Hax_lib.implies (~.((Rust_primitives.mk_u16 833 <=. fe <: bool) &&
(fe <=. Rust_primitives.mk_u16 2596 <: bool))
(fe <=. Rust_primitives.mk_u16 2496 <: bool))
<:
bool)
(fun temp_0_ ->
Expand Down Expand Up @@ -84,7 +84,18 @@ val compress_1_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)

val decompress_ciphertext_coefficient
(v_COEFFICIENT_BITS: i32)
(v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)
(a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector)
: Prims.Pure Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector
Prims.l_True
(fun _ -> Prims.l_True)
(requires
(v v_COEFFICIENT_BITS == 4 \/ v v_COEFFICIENT_BITS == 5 \/ v v_COEFFICIENT_BITS == 10 \/
v v_COEFFICIENT_BITS == 11) /\
(forall (i: nat).
i < 16 ==>
v (Seq.index a.f_elements i) >= 0 /\
v (Seq.index a.f_elements i) < pow2 (v v_COEFFICIENT_BITS)))
(ensures
fun result ->
let result:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = result in
forall (i: nat).
i < 16 ==>
v (Seq.index result.f_elements i) < Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS)
81 changes: 67 additions & 14 deletions libcrux-ml-kem/src/vector/portable/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ use crate::vector::FIELD_MODULUS;
/// <https://csrc.nist.gov/pubs/fips/203/ipd>.
#[cfg_attr(hax, hax_lib::requires(fe < (FIELD_MODULUS as u16)))]
#[cfg_attr(hax, hax_lib::ensures(|result|
hax_lib::implies(833 <= fe && fe <= 2596, || result == 1) &&
hax_lib::implies(!(833 <= fe && fe <= 2596), || result == 0)
hax_lib::implies(833 <= fe && fe <= 2496, || result == 1) &&
hax_lib::implies(!(833 <= fe && fe <= 2496), || result == 0)
))]
pub(crate) fn compress_message_coefficient(fe: u16) -> u8 {
// The approach used here is inspired by:
Expand All @@ -35,6 +35,7 @@ pub(crate) fn compress_message_coefficient(fe: u16) -> u8 {
// If 833 <= fe <= 2496,
// then -832 <= shifted <= 831
let shifted: i16 = 1664 - (fe as i16);
hax_lib::fstar!("assert (v $shifted == 1664 - v $fe)");

// If shifted < 0, then
// (shifted >> 15) ^ shifted = flip_bits(shifted) = -shifted - 1, and so
Expand All @@ -44,13 +45,37 @@ pub(crate) fn compress_message_coefficient(fe: u16) -> u8 {
// (shifted >> 15) ^ shifted = shifted, and so
// if 0 <= shifted <= 831 then 0 <= shifted_positive <= 831
let mask = shifted >> 15;
hax_lib::fstar!("assert (v $mask = v $shifted / pow2 15);
assert (if v $shifted < 0 then $mask = ones else $mask = zero)");
let shifted_to_positive = mask ^ shifted;
hax_lib::fstar!("logxor_lemma $shifted $mask;
assert (v $shifted < 0 ==> v $shifted_to_positive = v (lognot $shifted));
neg_equiv_lemma $shifted;
assert (v (lognot $shifted) = -(v $shifted) -1);
assert (v $shifted >= 0 ==> v $shifted_to_positive = v ($mask `logxor` $shifted));
assert (v $shifted >= 0 ==> $mask = zero);
assert (v $shifted >= 0 ==> $mask ^. $shifted = $shifted);
assert (v $shifted >= 0 ==> v $shifted_to_positive = v $shifted);
assert ($shifted_to_positive >=. mk_i16 0)");

let shifted_positive_in_range = shifted_to_positive - 832;
hax_lib::fstar!("assert (1664 - v $fe >= 0 ==> v $shifted_positive_in_range == 832 - v $fe);
assert (1664 - v $fe < 0 ==> v $shifted_positive_in_range == -2497 + v $fe)");

// If x <= 831, then x - 832 <= -1, and so x - 832 < 0, which means
// the most significant bit of shifted_positive_in_range will be 1.
((shifted_positive_in_range >> 15) & 1) as u8
let r0 = shifted_positive_in_range >> 15;
let r1: i16 = r0 & 1;
let res = r1 as u8;
hax_lib::fstar!("assert (v $r0 = v $shifted_positive_in_range / pow2 15);
assert (if v $shifted_positive_in_range < 0 then $r0 = ones else $r0 = zero);
logand_lemma (mk_i16 1) $r0;
assert (if v $shifted_positive_in_range < 0 then $r1 = mk_i16 1 else $r1 = mk_i16 0);
assert ((v $fe >= 833 && v $fe <= 2496) ==> $r1 = mk_i16 1);
assert (v $fe < 833 ==> $r1 = mk_i16 0);
assert (v $fe > 2496 ==> $r1 = mk_i16 0);
assert (v $res = v $r1)");
res
}

#[cfg_attr(hax,
Expand Down Expand Up @@ -147,23 +172,51 @@ pub(crate) fn compress<const COEFFICIENT_BITS: i32>(mut a: PortableVector) -> Po
}

#[inline(always)]
#[hax_lib::fstar::options("--z3rlimit 300 --ext context_pruning")]
#[hax_lib::requires(fstar!("(v $COEFFICIENT_BITS == 4 \\/
v $COEFFICIENT_BITS == 5 \\/
v $COEFFICIENT_BITS == 10 \\/
v $COEFFICIENT_BITS == 11) /\\
(forall (i:nat). i < 16 ==> v (Seq.index ${a}.f_elements i) >= 0 /\\
v (Seq.index ${a}.f_elements i) < pow2 (v $COEFFICIENT_BITS))"))]
#[hax_lib::ensures(|result| fstar!("forall (i:nat). i < 16 ==> v (Seq.index ${result}.f_elements i) < $FIELD_MODULUS"))]
pub(crate) fn decompress_ciphertext_coefficient<const COEFFICIENT_BITS: i32>(
mut v: PortableVector,
mut a: PortableVector,
) -> PortableVector {
// debug_assert!(to_i16_array(v)
// .into_iter()
// .all(|coefficient| coefficient.abs() < 1 << COEFFICIENT_BITS));
hax_lib::fstar!("assert_norm (pow2 1 == 2);
assert_norm (pow2 4 == 16);
assert_norm (pow2 5 == 32);
assert_norm (pow2 10 == 1024);
assert_norm (pow2 11 == 2048)");

for i in 0..FIELD_ELEMENTS_IN_VECTOR {
let mut decompressed = v.elements[i] as i32 * FIELD_MODULUS as i32;
hax_lib::loop_invariant!(|i: usize| { fstar!("(v $i < 16 ==> (forall (j:nat). (j >= v $i /\\ j < 16) ==>
v (Seq.index ${a}.f_elements j) >= 0 /\\ v (Seq.index ${a}.f_elements j) < pow2 (v $COEFFICIENT_BITS))) /\\
(forall (j:nat). j < v $i ==>
v (Seq.index ${a}.f_elements j) < v $FIELD_MODULUS)") });
hax_lib::fstar!("assert (v (${a}.f_elements.[ $i ] <: i16) < pow2 11);
assert (v (${a}.f_elements.[ $i ] <: i16) ==
v (cast (${a}.f_elements.[ $i ] <: i16) <: i32));
assert (v ($FIELD_MODULUS <: i16) ==
v (cast ($FIELD_MODULUS <: i16) <: i32));
assert (v ((cast (${a}.f_elements.[ $i ] <: i16) <: i32) *!
(cast ($FIELD_MODULUS <: i16) <: i32)) ==
v (cast (${a}.f_elements.[ $i ] <: i16) <: i32) *
v (cast ($FIELD_MODULUS <: i16) <: i32))");
let mut decompressed = a.elements[i] as i32 * FIELD_MODULUS as i32;
hax_lib::fstar!("assert (v ($decompressed <<! mk_i32 1) == v $decompressed * 2);
assert (v (mk_i32 1 <<! $COEFFICIENT_BITS) == pow2 (v $COEFFICIENT_BITS));
assert (v (($decompressed <<! mk_i32 1) +! (mk_i32 1 <<! $COEFFICIENT_BITS)) ==
v ($decompressed <<! mk_i32 1) + v (mk_i32 1 <<! $COEFFICIENT_BITS))");
decompressed = (decompressed << 1) + (1i32 << COEFFICIENT_BITS);
hax_lib::fstar!("assert (v ($COEFFICIENT_BITS +! mk_i32 1) == v $COEFFICIENT_BITS + 1);
assert (v ($decompressed >>! ($COEFFICIENT_BITS +! mk_i32 1 <: i32)) ==
v $decompressed / pow2 (v $COEFFICIENT_BITS + 1))");
decompressed = decompressed >> (COEFFICIENT_BITS + 1);
v.elements[i] = decompressed as i16;
hax_lib::fstar!("assert (v $decompressed < v $FIELD_MODULUS);
assert (v (cast $decompressed <: i16) < v $FIELD_MODULUS)");
a.elements[i] = decompressed as i16;
}

// debug_assert!(to_i16_array(v)
// .into_iter()
// .all(|coefficient| coefficient.abs() as u16 <= 1 << 12));

v
a
}