Skip to content

Commit

Permalink
chore: merge Encode/EncodeFields and Decode/DecodeFields
Browse files Browse the repository at this point in the history
  • Loading branch information
cfcosta committed Oct 30, 2024
1 parent 9c98e94 commit bf53baa
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 135 deletions.
1 change: 1 addition & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
cargo-nextest
cargo-pgo
cargo-watch
cargo-machete
];
};
}
Expand Down
16 changes: 6 additions & 10 deletions src/crypto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ pub use native::{NativePoint, NativeScalar};
use crate::{protocol::*, Error};

pub trait BlindDiffieHellmanKeyExchange {
fn hash_to_curve(&self, data: impl EncodeFields) -> Result<Hash, Error>;
fn blind(&self, data: impl EncodeFields, r: Hash) -> Result<BlindedValue, Error>;
fn hash_to_curve(&self, data: impl Encode) -> Result<Hash, Error>;
fn blind(&self, data: impl Encode, r: Hash) -> Result<BlindedValue, Error>;
fn unblind(
&self,
public_key: PublicKey,
Expand All @@ -21,12 +21,7 @@ pub trait BlindDiffieHellmanKeyExchange {
sk: SecretKey,
blinded_message: BlindedValue,
) -> Result<BlindSignature, Error>;
fn verify(
&self,
pk: PublicKey,
data: impl EncodeFields,
signature: Signature,
) -> Result<bool, Error>;
fn verify(&self, pk: PublicKey, data: impl Encode, signature: Signature) -> Result<bool, Error>;
}

#[inline]
Expand All @@ -51,10 +46,11 @@ mod tests {
fn test_htc<T: BlindDiffieHellmanKeyExchange + UnwindSafe>(note: Note, bdhke: T) -> Result {
let fields = note.as_fields();
let native = bdhke.hash_to_curve(note.clone())?.into();
let size = note.field_len();

let proof = unwind_panic(move || {
let mut builder = circuit_builder();
let inputs = builder.add_virtual_targets(Note::FIELD_SIZE);
let inputs = builder.add_virtual_targets(size);
let expected = builder.add_virtual_hash();
let result = hash_to_curve(&mut builder, &inputs);

Expand All @@ -73,7 +69,7 @@ mod tests {
})?;

prop_assert_eq!(
Hash::from_fields(&proof.public_inputs[Note::FIELD_SIZE..])?,
Hash::from_fields(&proof.public_inputs[size..])?,
bdhke.hash_to_curve(note.clone())?
);

Expand Down
11 changes: 3 additions & 8 deletions src/crypto/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ pub struct NativeBdhke;

impl BlindDiffieHellmanKeyExchange for NativeBdhke {
#[inline]
fn hash_to_curve(&self, data: impl EncodeFields) -> Result<Hash, Error> {
fn hash_to_curve(&self, data: impl Encode) -> Result<Hash, Error> {
let data: NativeScalar = data.hash().try_into()?;
Ok((data * G).into())
}

#[inline]
fn blind(&self, data: impl EncodeFields, r: Hash) -> Result<BlindedValue, Error> {
fn blind(&self, data: impl Encode, r: Hash) -> Result<BlindedValue, Error> {
let y: NativePoint = NativeScalar::try_from(self.hash_to_curve(data)?)? * G;
let r_scalar: NativeScalar = r.try_into()?;

Expand Down Expand Up @@ -48,12 +48,7 @@ impl BlindDiffieHellmanKeyExchange for NativeBdhke {
Ok(c_prime.into())
}

fn verify(
&self,
pk: PublicKey,
data: impl EncodeFields,
signature: Signature,
) -> Result<bool, Error> {
fn verify(&self, pk: PublicKey, data: impl Encode, signature: Signature) -> Result<bool, Error> {
let y: NativeScalar = self.hash_to_curve(data)?.try_into()?;
let a_point: NativePoint = pk.try_into()?;
let c: NativePoint = signature.try_into()?;
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::panic::UnwindSafe;

pub use self::{
error::Error,
protocol::{Decode, DecodeFields, Encode, EncodeFields},
protocol::{Decode, Encode},
};

#[inline]
Expand Down
39 changes: 6 additions & 33 deletions src/protocol/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use rand::{CryptoRng, Rng};
use serde::{Deserialize, Serialize};
use test_strategy::Arbitrary;

use crate::{protocol::circuit::*, Decode, DecodeFields, Encode, EncodeFields, Error};
use crate::{protocol::circuit::*, Decode, Encode, Error};

pub type Hash = Bytes<32>;
pub type BlindSignature = Bytes<32>;
Expand Down Expand Up @@ -196,13 +196,6 @@ impl From<Bytes<32>> for HashOut<F> {
}

impl<const N: usize> Encode for Bytes<N> {
#[inline]
fn as_bytes(&self) -> Vec<u8> {
self.0.to_vec()
}
}

impl<const N: usize> EncodeFields for Bytes<N> {
#[inline]
fn as_fields(&self) -> Vec<F> {
self.0
Expand All @@ -213,22 +206,6 @@ impl<const N: usize> EncodeFields for Bytes<N> {
}

impl<const N: usize> Decode for Bytes<N> {
#[inline]
fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
if bytes.len() != N {
return Err(Error::DecodeError(format!(
"Invalid length for Bytes: expected N, got {}",
bytes.len()
)));
}

let mut array = [0u8; N];
array.copy_from_slice(bytes);
Ok(Self(array))
}
}

impl<const N: usize> DecodeFields for Bytes<N> {
#[inline]
fn from_fields(fields: &[F]) -> Result<Self, Error> {
if fields.len() != N / 8 {
Expand Down Expand Up @@ -258,23 +235,19 @@ mod tests {
use test_strategy::proptest;

use super::Bytes;
use crate::{test_encode_bytes, test_encode_fields};
use crate::test_encode_decode;

pub type Bytes8 = Bytes<8>;
test_encode_bytes!(Bytes8);
test_encode_fields!(Bytes8);
test_encode_decode!(Bytes8);

pub type Bytes16 = Bytes<16>;
test_encode_bytes!(Bytes16);
test_encode_fields!(Bytes16);
test_encode_decode!(Bytes16);

pub type Bytes32 = Bytes<32>;
test_encode_bytes!(Bytes32);
test_encode_fields!(Bytes32);
test_encode_decode!(Bytes32);

pub type Bytes64 = Bytes<64>;
test_encode_bytes!(Bytes64);
test_encode_fields!(Bytes64);
test_encode_decode!(Bytes64);

fn crypto_rng() -> impl Strategy<Value = ChaCha20Rng> {
any::<[u8; 32]>().prop_map(ChaCha20Rng::from_seed)
Expand Down
7 changes: 4 additions & 3 deletions src/protocol/circuit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ pub fn magic_prefix() -> [F; 2] {
#[inline]
pub fn seal_note(builder: &mut CircuitBuilder) -> (HashOutTarget, Vec<Target>) {
let zero = builder.zero();
let note_size = Note::default().field_len();

// Private inputs
let mut targets = builder.add_virtual_targets(Note::FIELD_SIZE);
let mut targets = builder.add_virtual_targets(note_size);
let (amount, rest) = targets.split_at_mut(1);
let (asset_id, rest) = rest.split_at_mut(4);
let (asset_name, nonce) = rest.split_at_mut(4);
Expand Down Expand Up @@ -96,9 +97,9 @@ fn targets_are_zero(builder: &mut CircuitBuilder, targets: &[Target]) -> BoolTar
target
}

pub trait Sealable: EncodeFields + RefUnwindSafe {
pub trait Sealable: Encode + RefUnwindSafe {
type Circuit;
type Payload: EncodeFields;
type Payload: Encode;

fn circuit() -> Self::Circuit;
fn circuit_data() -> CircuitData;
Expand Down
72 changes: 60 additions & 12 deletions src/protocol/codec.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use plonky2::{hash::poseidon::PoseidonHash, plonk::config::Hasher};
use plonky2::{field::types::PrimeField64, hash::poseidon::PoseidonHash, plonk::config::Hasher};

use super::circuit::{Field, F};
use crate::{protocol::Hash, Error};
Expand All @@ -7,16 +7,12 @@ pub const MAGIC_PREFIX: [u8; 16] = *b"mugraph.v1.ecash";
pub const MAGIC_PREFIX_FIELDS: [u64; 2] = [3344046287156114797, 7526466481793413494];

pub trait Encode {
fn as_bytes(&self) -> Vec<u8>;
fn as_fields(&self) -> Vec<F>;

#[inline]
fn as_bytes_with_prefix(&self) -> Vec<u8> {
[MAGIC_PREFIX.to_vec(), self.as_bytes()].concat()
fn field_len(&self) -> usize {
self.as_fields().len()
}
}

pub trait EncodeFields {
fn as_fields(&self) -> Vec<F>;

#[inline]
fn as_fields_with_prefix(&self) -> Vec<F> {
Expand All @@ -33,21 +29,73 @@ pub trait EncodeFields {

#[inline]
fn hash(&self) -> Hash {
PoseidonHash::hash_no_pad(&self.as_fields()).into()
}

#[inline]
fn hash_with_prefix(&self) -> Hash {
PoseidonHash::hash_no_pad(&self.as_fields_with_prefix()).into()
}

#[inline]
fn as_bytes(&self) -> Vec<u8> {
fields_to_bytes(&self.as_fields())
}

#[inline]
fn byte_len(&self) -> usize {
self.as_bytes().len()
}

#[inline]
fn as_bytes_with_prefix(&self) -> Vec<u8> {
fields_to_bytes(&self.as_fields_with_prefix())
}

#[inline]
fn hash_bytes(&self) -> Hash {
PoseidonHash::hash_no_pad(&bytes_to_field(&self.as_bytes())).into()
}

#[inline]
fn hash_bytes_with_prefix(&self) -> Hash {
PoseidonHash::hash_no_pad(&bytes_to_field(&self.as_bytes_with_prefix())).into()
}
}

impl<T: EncodeFields> EncodeFields for [T] {
impl<T: Encode> Encode for [T] {
#[inline]
fn as_fields(&self) -> Vec<F> {
self.iter().flat_map(|x| x.as_fields()).collect()
}
}

pub trait Decode: Sized {
fn from_bytes(bytes: &[u8]) -> Result<Self, Error>;
fn from_fields(bytes: &[F]) -> Result<Self, Error>;

#[inline]
fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
Self::from_fields(&bytes_to_field(bytes))
}
}

pub trait DecodeFields: Sized {
fn from_fields(bytes: &[F]) -> Result<Self, Error>;
fn bytes_to_field(val: &[u8]) -> Vec<F> {
val.chunks(32)
.map(|chunk| {
let mut padded = [0u8; 32];
padded[..chunk.len()].copy_from_slice(chunk);
let value = u64::from_le_bytes(padded[..8].try_into().unwrap());
F::from_canonical_u64(value)
})
.collect()
}

fn fields_to_bytes(fields: &[F]) -> Vec<u8> {
fields
.iter()
.flat_map(|&f| {
let bytes = f.to_canonical_u64().to_le_bytes();
bytes.to_vec()
})
.collect()
}
4 changes: 2 additions & 2 deletions src/protocol/message/append.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub struct Payload {
outputs: Vec<BlindedValue>,
}

impl EncodeFields for Payload {
impl Encode for Payload {
#[inline]
fn as_fields(&self) -> Vec<F> {
self.outputs.iter().flat_map(|x| x.as_fields()).collect()
Expand Down Expand Up @@ -50,7 +50,7 @@ impl<const I: usize, const O: usize> Append<I, O> {
}
}

impl<const I: usize, const O: usize> EncodeFields for Append<I, O> {
impl<const I: usize, const O: usize> Encode for Append<I, O> {
#[inline]
fn as_fields(&self) -> Vec<F> {
let mut fields = Vec::new();
Expand Down
2 changes: 1 addition & 1 deletion src/protocol/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub struct Payload {
pub outputs: Vec<BlindedValue>,
}

impl EncodeFields for Payload {
impl Encode for Payload {
#[inline]
fn as_fields(&self) -> Vec<F> {
self.inputs
Expand Down
Loading

0 comments on commit bf53baa

Please sign in to comment.