diff --git a/Cargo.lock b/Cargo.lock index 5c1b76f..0a6ac0a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,7 +1,5 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 - [[package]] name = "addr2line" version = "0.14.0" diff --git a/crypto/zkp/utils/src/lib.rs b/crypto/zkp/utils/src/lib.rs index a619190..e65e327 100644 --- a/crypto/zkp/utils/src/lib.rs +++ b/crypto/zkp/utils/src/lib.rs @@ -394,22 +394,60 @@ pub fn point_to_slice(point: &RistrettoPoint) -> [u8; 32] { } /// Converts a vector to RistrettoPoint. +// pub fn bytes_to_point(point: &[u8]) -> Result { +// if point.len() != RISTRETTO_POINT_SIZE_IN_BYTES { +// wedpr_println!("bytes_to_point decode failed"); +// return Err(WedprError::FormatError); +// } +// let point_value = match CompressedRistretto::from_slice(&point).decompress() +// { +// Some(v) => v, +// None => { +// wedpr_println!( +// "bytes_to_point decompress CompressedRistretto failed" +// ); +// return Err(WedprError::FormatError); +// }, +// }; +// Ok(point_value) +// } + +/// Bytes to point with padding pub fn bytes_to_point(point: &[u8]) -> Result { - if point.len() != RISTRETTO_POINT_SIZE_IN_BYTES { - wedpr_println!("bytes_to_point decode failed"); + if point.len() > RISTRETTO_POINT_SIZE_IN_BYTES { + wedpr_println!("bytes_to_point decode failed, point: {:?}", point); return Err(WedprError::FormatError); } - let point_value = match CompressedRistretto::from_slice(&point).decompress() - { - Some(v) => v, - None => { - wedpr_println!( - "bytes_to_point decompress CompressedRistretto failed" + if point.len() < RISTRETTO_POINT_SIZE_IN_BYTES { + let padding = vec![0; RISTRETTO_POINT_SIZE_IN_BYTES - point.len()]; + let mut padded_point = padding; + padded_point.extend_from_slice(point); + let point_value = match CompressedRistretto::from_slice(&padded_point).decompress() + { + Some(v) => v, + None => { + wedpr_println!( + "bytes_to_point decompress CompressedRistretto failed, padded_point:{:?}, point: {:?}", padded_point, point ); - return Err(WedprError::FormatError); - }, - }; - Ok(point_value) + return Err(WedprError::FormatError); + }, + }; + Ok(point_value) + } + else { + let point_value = match CompressedRistretto::from_slice(&point).decompress() + { + Some(v) => v, + None => { + wedpr_println!( + "bytes_to_point decompress CompressedRistretto failed, point:{:?}", point + ); + return Err(WedprError::FormatError); + }, + }; + Ok(point_value) + } + } /// Gets a random u32 integer. @@ -418,3 +456,65 @@ pub fn get_random_u32() -> u32 { let blinding: u32 = rng.gen(); blinding } + + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_point_length() + { + let count = 10000; + println!("run {} times", count); + for i in 0..count { + let random_scalar = get_random_scalar(); + let random_point1 = *BASEPOINT_G1 * random_scalar; + let new_point_vec1 = point_to_bytes(&random_point1); + assert_eq!(new_point_vec1.len(), 32); + } + } + + #[test] + fn test_point_padding() + { + let re_slice = [0, 219, 114, 57, 172, 179, 141, 252, 215, 70, 84, 226, 211, 59, 127, 202, 166, 66, 113, 48, 223, 179, 56, 123, 205, 134, 62, 12, 175, 172, 250, 108]; + let re_vector = re_slice.to_vec(); + let re_point = bytes_to_point( &re_vector).unwrap(); + let re_re_vector = point_to_bytes(&re_point); + assert_eq!(re_re_vector, re_vector); + + let re_slice_31 = [219, 114, 57, 172, 179, 141, 252, 215, 70, 84, 226, 211, 59, 127, 202, 166, 66, 113, 48, 223, 179, 56, 123, 205, 134, 62, 12, 175, 172, 250, 108]; + let re_vector_31 = re_slice_31.to_vec(); + let re_point_with_padding = bytes_to_point( &re_vector_31).unwrap(); + let re_point_vector = point_to_bytes(&re_point_with_padding); + assert_eq!(re_vector, re_point_vector); + + } + + #[test] + fn test_point_random_padding() + { + let count = 10000; + let mut times = 0; + println!("run {} times", count); + + for i in 0..count { + let random_scalar = get_random_scalar(); + let random_point1 = *BASEPOINT_G1 * random_scalar; + let new_point_vec1 = point_to_bytes(&random_point1); + let mut mut_new_point_vec1 = new_point_vec1.clone(); + if mut_new_point_vec1[0] == 0u8 { + mut_new_point_vec1.remove(0); + times = times + 1; + } + let re_point_with_padding = bytes_to_point( &mut_new_point_vec1).unwrap(); + let re_point_vector = point_to_bytes(&re_point_with_padding); + assert_eq!(new_point_vec1, re_point_vector); + } + println!("padding {} times", times); + } + + +} +