From 45458d3b267c48b8868b1d990b4279ce0c9f39e4 Mon Sep 17 00:00:00 2001 From: jacobkaufmann Date: Wed, 1 Nov 2023 16:15:08 -0600 Subject: [PATCH] fix: add deserialization error checks for group points --- src/bls.rs | 44 ++++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/src/bls.rs b/src/bls.rs index 18dd65a..28267b5 100644 --- a/src/bls.rs +++ b/src/bls.rs @@ -6,10 +6,10 @@ use std::{ use blst::{ blst_bendian_from_scalar, blst_fp, blst_fr, blst_fr_add, blst_fr_eucl_inverse, blst_fr_from_scalar, blst_fr_from_uint64, blst_fr_mul, blst_fr_sub, blst_p1, blst_p1_add, - blst_p1_affine, blst_p1_deserialize, blst_p1_from_affine, blst_p1_in_g1, blst_p1_mult, - blst_p1_on_curve, blst_p2, blst_p2_affine, blst_p2_deserialize, blst_p2_from_affine, - blst_p2_in_g2, blst_p2_on_curve, blst_scalar, blst_scalar_fr_check, blst_scalar_from_bendian, - blst_scalar_from_fr, blst_scalar_from_uint64, blst_uint64_from_fr, + blst_p1_affine, blst_p1_affine_in_g1, blst_p1_deserialize, blst_p1_from_affine, blst_p1_mult, + blst_p2, blst_p2_affine, blst_p2_affine_in_g2, blst_p2_deserialize, blst_p2_from_affine, + blst_scalar, blst_scalar_fr_check, blst_scalar_from_bendian, blst_scalar_from_fr, + blst_scalar_from_uint64, blst_uint64_from_fr, BLST_ERROR, }; #[derive(Clone, Copy, Debug)] @@ -295,17 +295,19 @@ impl P1 { let mut affine = MaybeUninit::::uninit(); let mut out = MaybeUninit::::uninit(); unsafe { - blst_p1_deserialize(affine.as_mut_ptr(), bytes.as_ref().as_ptr()); - blst_p1_from_affine(out.as_mut_ptr(), affine.as_ptr()); - - // TODO: is one of the following checks more expensive (?) - if !blst_p1_in_g1(out.as_ptr()) { - return Err(ECGroupError::NotInGroup); + // NOTE: deserialize performs a curve check but not a subgroup check. if that changes, + // then we should encounter `unreachable` for `BLST_POINT_NOT_IN_GROUP` in tests. + match blst_p1_deserialize(affine.as_mut_ptr(), bytes.as_ref().as_ptr()) { + BLST_ERROR::BLST_SUCCESS => {} + BLST_ERROR::BLST_BAD_ENCODING => return Err(ECGroupError::InvalidEncoding), + BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(ECGroupError::NotOnCurve), + other => unreachable!("{other:?}"), } - if !blst_p1_on_curve(out.as_ptr()) { - return Err(ECGroupError::NotOnCurve); + if !blst_p1_affine_in_g1(affine.as_ptr()) { + return Err(ECGroupError::NotInGroup); } + blst_p1_from_affine(out.as_mut_ptr(), affine.as_ptr()); Ok(Self { element: out.assume_init(), }) @@ -354,17 +356,19 @@ impl P2 { let mut affine = MaybeUninit::::uninit(); let mut out = MaybeUninit::::uninit(); unsafe { - blst_p2_deserialize(affine.as_mut_ptr(), bytes.as_ref().as_ptr()); - blst_p2_from_affine(out.as_mut_ptr(), affine.as_ptr()); - - // TODO: is one of the following checks more expensive (?) - if !blst_p2_in_g2(out.as_ptr()) { - return Err(ECGroupError::NotInGroup); + // NOTE: deserialize performs a curve check but not a subgroup check. if that changes, + // then we should encounter `unreachable` for `BLST_POINT_NOT_IN_GROUP` in tests. + match blst_p2_deserialize(affine.as_mut_ptr(), bytes.as_ref().as_ptr()) { + BLST_ERROR::BLST_SUCCESS => {} + BLST_ERROR::BLST_BAD_ENCODING => return Err(ECGroupError::InvalidEncoding), + BLST_ERROR::BLST_POINT_NOT_ON_CURVE => return Err(ECGroupError::NotOnCurve), + other => unreachable!("{other:?}"), } - if !blst_p2_on_curve(out.as_ptr()) { - return Err(ECGroupError::NotOnCurve); + if !blst_p2_affine_in_g2(affine.as_ptr()) { + return Err(ECGroupError::NotInGroup); } + blst_p2_from_affine(out.as_mut_ptr(), affine.as_ptr()); Ok(Self { element: out.assume_init(), })