diff --git a/Cargo.toml b/Cargo.toml index 26444192..b614e4b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ version = "0.5" features = ["ecdsa"] [dev-dependencies] +serde_test = "1.0" bincode = "1.1" serde_json = "1.0" paste = "1.0.2" diff --git a/src/arithmetic/big_gmp.rs b/src/arithmetic/big_gmp.rs index 4526174f..ab2e8f87 100644 --- a/src/arithmetic/big_gmp.rs +++ b/src/arithmetic/big_gmp.rs @@ -21,7 +21,6 @@ use std::{fmt, ops, ptr}; use gmp::mpz::Mpz; use gmp::sign::Sign; use num_traits::{One, Zero}; -use serde::{Deserialize, Serialize}; use zeroize::Zeroize; use super::errors::*; @@ -35,8 +34,7 @@ type BN = Mpz; /// very limited API that allows easily switching between implementations. /// /// Set of traits implemented on BigInt remains the same regardless of underlying implementation. -#[derive(PartialOrd, PartialEq, Ord, Eq, Clone, Serialize, Deserialize)] -#[serde(transparent)] +#[derive(PartialOrd, PartialEq, Ord, Eq, Clone)] pub struct BigInt { gmp: Mpz, } diff --git a/src/arithmetic/big_native.rs b/src/arithmetic/big_native.rs index 8ccfd1d8..d83f8eb9 100644 --- a/src/arithmetic/big_native.rs +++ b/src/arithmetic/big_native.rs @@ -2,7 +2,6 @@ use std::convert::{TryFrom, TryInto}; use std::{fmt, ops}; use num_traits::Signed; -use serde::{Deserialize, Serialize}; use super::errors::*; use super::traits::*; @@ -18,8 +17,7 @@ mod primes; /// very limited API that allows easily switching between implementations. /// /// Set of traits implemented on BigInt remains the same regardless of underlying implementation. -#[derive(PartialOrd, PartialEq, Ord, Eq, Clone, Serialize, Deserialize)] -#[serde(transparent)] +#[derive(PartialOrd, PartialEq, Ord, Eq, Clone)] pub struct BigInt { num: BN, } diff --git a/src/arithmetic/mod.rs b/src/arithmetic/mod.rs index dbda76a7..7badd37d 100644 --- a/src/arithmetic/mod.rs +++ b/src/arithmetic/mod.rs @@ -17,6 +17,7 @@ mod errors; mod macros; mod samplable; +mod serde_support; pub mod traits; #[cfg(not(any(feature = "rust-gmp-kzen", feature = "num-bigint")))] @@ -31,6 +32,7 @@ pub use big_gmp::BigInt; #[cfg(feature = "num-bigint")] mod big_native; + #[cfg(feature = "num-bigint")] pub use big_native::BigInt; @@ -45,6 +47,16 @@ mod test { use super::*; + #[test] + fn serde() { + use serde_test::{assert_tokens, Token::*}; + for bigint in [BigInt::zero(), BigInt::sample(1024)] { + let bytes = bigint.to_bytes(); + let tokens = vec![Bytes(bytes.leak())]; + assert_tokens(&bigint, &tokens) + } + } + #[test] fn serializing_to_hex() { let n = BigInt::from(1_000_000_u32); diff --git a/src/arithmetic/serde_support.rs b/src/arithmetic/serde_support.rs new file mode 100644 index 00000000..fc6040aa --- /dev/null +++ b/src/arithmetic/serde_support.rs @@ -0,0 +1,43 @@ +use std::fmt; + +use serde::de::Visitor; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use super::traits::Converter; +use super::BigInt; + +impl Serialize for BigInt { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let bytes = self.to_bytes(); + serializer.serialize_bytes(&bytes) + } +} + +impl<'de> Deserialize<'de> for BigInt { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct BigintVisitor; + + impl<'de> Visitor<'de> for BigintVisitor { + type Value = BigInt; + + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "bigint") + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: serde::de::Error, + { + Ok(BigInt::from_bytes(v)) + } + } + + deserializer.deserialize_bytes(BigintVisitor) + } +} diff --git a/src/elliptic/curves/secp256_k1.rs b/src/elliptic/curves/secp256_k1.rs index 7e813db2..e1c8e267 100644 --- a/src/elliptic/curves/secp256_k1.rs +++ b/src/elliptic/curves/secp256_k1.rs @@ -32,8 +32,7 @@ use secp256k1::constants::{ CURVE_ORDER, GENERATOR_X, GENERATOR_Y, SECRET_KEY_SIZE, UNCOMPRESSED_PUBLIC_KEY_SIZE, }; use secp256k1::{PublicKey, Secp256k1, SecretKey, VerifyOnly}; -use serde::de::{self, Error, MapAccess, SeqAccess, Visitor}; -use serde::ser::SerializeStruct; +use serde::de::{self, Visitor}; use serde::ser::{Serialize, Serializer}; use serde::{Deserialize, Deserializer}; use std::fmt; @@ -543,10 +542,7 @@ impl Serialize for Secp256k1Point { where S: Serializer, { - let mut state = serializer.serialize_struct("Secp256k1Point", 2)?; - state.serialize_field("x", &self.x_coor().unwrap().to_hex())?; - state.serialize_field("y", &self.y_coor().unwrap().to_hex())?; - state.end() + serializer.serialize_str(&self.bytes_compressed_to_big_int().to_hex()) } } @@ -555,8 +551,7 @@ impl<'de> Deserialize<'de> for Secp256k1Point { where D: Deserializer<'de>, { - let fields = &["x", "y"]; - deserializer.deserialize_struct("Secp256k1Point", fields, Secp256k1PointVisitor) + deserializer.deserialize_str(Secp256k1PointVisitor) } } @@ -569,42 +564,16 @@ impl<'de> Visitor<'de> for Secp256k1PointVisitor { formatter.write_str("Secp256k1Point") } - fn visit_seq(self, mut seq: V) -> Result + fn visit_str(self, p: &str) -> Result where - V: SeqAccess<'de>, + E: serde::de::Error, { - let x = seq - .next_element()? - .ok_or_else(|| V::Error::invalid_length(0, &"a single element"))?; - let y = seq - .next_element()? - .ok_or_else(|| V::Error::invalid_length(0, &"a single element"))?; - - let bx = BigInt::from_hex(x).map_err(V::Error::custom)?; - let by = BigInt::from_hex(y).map_err(V::Error::custom)?; - - Ok(Secp256k1Point::from_coor(&bx, &by)) - } - - fn visit_map>(self, mut map: E) -> Result { - let mut x = String::new(); - let mut y = String::new(); - - while let Some(ref key) = map.next_key::()? { - let v = map.next_value::()?; - if key == "x" { - x = v - } else if key == "y" { - y = v - } else { - return Err(E::Error::unknown_field(key, &["x", "y"])); - } + let bp = BigInt::from_hex(p).map_err(E::custom)?; + let bp = bp.to_bytes(); + if bp.len() < 33 { + return Err(E::invalid_length(bp.len(), &"33 bytes")); } - - let bx = BigInt::from_hex(&x).map_err(E::Error::custom)?; - let by = BigInt::from_hex(&y).map_err(E::Error::custom)?; - - Ok(Secp256k1Point::from_coor(&bx, &by)) + Secp256k1Point::from_bytes(&bp[1..33]).map_err(|_e| E::custom("invalid point")) } } @@ -673,13 +642,12 @@ mod tests { fn serialize_pk() { let pk = Secp256k1Point::generator(); let x = pk.x_coor().unwrap(); - let y = pk.y_coor().unwrap(); let s = serde_json::to_string(&pk).expect("Failed in serialization"); - let expected = format!("{{\"x\":\"{}\",\"y\":\"{}\"}}", x.to_hex(), y.to_hex()); + let expected = format!("\"{}{}\"", 2, x.to_hex()); assert_eq!(s, expected); - let des_pk: Secp256k1Point = serde_json::from_str(&s).expect("Failed in serialization"); + let des_pk: Secp256k1Point = serde_json::from_str(&s).expect("Failed in deserialization"); assert_eq!(des_pk.ge, pk.ge); }