diff --git a/dkg-gadget/src/async_protocols/keygen/handler.rs b/dkg-gadget/src/async_protocols/ecdsa/keygen/handler.rs similarity index 100% rename from dkg-gadget/src/async_protocols/keygen/handler.rs rename to dkg-gadget/src/async_protocols/ecdsa/keygen/handler.rs diff --git a/dkg-gadget/src/async_protocols/keygen/mod.rs b/dkg-gadget/src/async_protocols/ecdsa/keygen/mod.rs similarity index 100% rename from dkg-gadget/src/async_protocols/keygen/mod.rs rename to dkg-gadget/src/async_protocols/ecdsa/keygen/mod.rs diff --git a/dkg-gadget/src/async_protocols/keygen/state_machine.rs b/dkg-gadget/src/async_protocols/ecdsa/keygen/state_machine.rs similarity index 100% rename from dkg-gadget/src/async_protocols/keygen/state_machine.rs rename to dkg-gadget/src/async_protocols/ecdsa/keygen/state_machine.rs diff --git a/dkg-gadget/src/async_protocols/ecdsa/mod.rs b/dkg-gadget/src/async_protocols/ecdsa/mod.rs new file mode 100644 index 000000000..e51884adf --- /dev/null +++ b/dkg-gadget/src/async_protocols/ecdsa/mod.rs @@ -0,0 +1,2 @@ +pub mod keygen; +pub mod sign; diff --git a/dkg-gadget/src/async_protocols/sign/handler.rs b/dkg-gadget/src/async_protocols/ecdsa/sign/handler.rs similarity index 100% rename from dkg-gadget/src/async_protocols/sign/handler.rs rename to dkg-gadget/src/async_protocols/ecdsa/sign/handler.rs diff --git a/dkg-gadget/src/async_protocols/sign/mod.rs b/dkg-gadget/src/async_protocols/ecdsa/sign/mod.rs similarity index 100% rename from dkg-gadget/src/async_protocols/sign/mod.rs rename to dkg-gadget/src/async_protocols/ecdsa/sign/mod.rs diff --git a/dkg-gadget/src/async_protocols/sign/state_machine.rs b/dkg-gadget/src/async_protocols/ecdsa/sign/state_machine.rs similarity index 100% rename from dkg-gadget/src/async_protocols/sign/state_machine.rs rename to dkg-gadget/src/async_protocols/ecdsa/sign/state_machine.rs diff --git a/dkg-gadget/src/async_protocols/frost/keygen/mod.rs b/dkg-gadget/src/async_protocols/frost/keygen/mod.rs new file mode 100644 index 000000000..c4ab629f1 --- /dev/null +++ b/dkg-gadget/src/async_protocols/frost/keygen/mod.rs @@ -0,0 +1,98 @@ +use std::collections::HashSet; +use sc_client_api::Backend; +use sp_core::hashing::sha2_256; +use sp_runtime::traits::Block; +use tokio::sync::mpsc::UnboundedReceiver; +use dkg_primitives::types::{DKGError, DKGMessage, NetworkMsgPayload, SignedDKGMessage}; +use dkg_runtime_primitives::crypto::AuthorityId; +use dkg_runtime_primitives::gossip_messages::DKGKeygenMessage; +use crate::async_protocols::blockchain_interface::BlockchainInterface; +use crate::async_protocols::remote::AsyncProtocolRemote; +use crate::Client; +use crate::dkg_modules::wt_frost::{FrostMessage, NetInterface}; +use crate::gossip_engine::GossipEngineIface; +use crate::worker::DKGWorker; + +pub struct FrostKeygen + where + B: Block, + BE: Backend, + C: Client, + GE: GossipEngineIface, { + pub dkg_worker: DKGWorker, + pub remote: AsyncProtocolRemote, + pub message_receiver: UnboundedReceiver>, + pub authority_id: AuthorityId, + pub keygen_protocol_hash: [u8; 32], + pub received_messages: HashSet<[u8;32]> + pub engine: BI +} + +impl FrostKeygen { + pub fn new(dkg_worker: DKGWorker, engine: BI, remote: AsyncProtocolRemote, authority_id: AuthorityId, retry_id: usize) -> Self { + let message_receiver = remote + .rx_keygen_signing + .lock() + .take() + .expect("rx_keygen_signing already taken"); + + let mut data = retry_id.to_be_bytes().to_vec(); + data.extend_from_slice(&remote.session_id.to_be_bytes()); + + let keygen_protocol_hash = sha2_256(&data); + let received_messages = HashSet::new(); + + Self { dkg_worker, engine, remote, message_receiver, authority_id, keygen_protocol_hash, received_messages } + } +} + +impl NetInterface for FrostKeygen + where + B: Block, + BE: Backend, + C: Client, + GE: GossipEngineIface { + type Error = DKGError; + + async fn next_message(&mut self) -> Result, Self::Error> { + loop { + let message = self.message_receiver.recv().await?; + // When we receive a message, it is filtered through the Job Manager, and as such + // we have these guarantees: + // * The SSID is correct, the block ID and session ID are acceptable, and the task hash is correct + // We do not need to check these things here, but we do need to check the signature + let message = self.engine.verify_signature_against_authorities(message).await?; + let message_bin = message.payload.payload(); + let message_hash = sha2_256(message_bin); + + if !self.received_messages.insert(message_hash) { + self.dkg_worker.logger.info("Received duplicate FROST keygen message, ignoring"); + continue; + } + + // Check to make sure we haven't already received the message + let deserialized = bincode2::deserialize::(message_bin) + .map_err(|err| DKGError::GenericError { reason: err.to_string() })?; + return Ok(Some(deserialized)) + } + } + + async fn send_message(&mut self, msg: FrostMessage) -> Result<(), Self::Error> { + let keygen_msg = bincode2::serialize(&msg) + .map_err(|err| DKGError::GenericError { reason: err.to_string() })?; + let message = DKGMessage { + sender_id: self.authority_id.clone(), + recipient_id: None, // We always gossip in FROST + payload: NetworkMsgPayload::Keygen(DKGKeygenMessage { + sender_id: 0, // We do not care to put the sender ID in the message for FROST, since it is already inside the FrostMessage + keygen_msg,// The Frost Message + keygen_protocol_hash: self.keygen_protocol_hash, + }), + session_id: self.remote.session_id, + associated_block_id: self.remote.associated_block_id, + ssid: self.remote.ssid, + }; + + self.engine.sign_and_send_msg(message) + } +} diff --git a/dkg-gadget/src/async_protocols/frost/mod.rs b/dkg-gadget/src/async_protocols/frost/mod.rs new file mode 100644 index 000000000..e51884adf --- /dev/null +++ b/dkg-gadget/src/async_protocols/frost/mod.rs @@ -0,0 +1,2 @@ +pub mod keygen; +pub mod sign; diff --git a/dkg-gadget/src/async_protocols/frost/sign/mod.rs b/dkg-gadget/src/async_protocols/frost/sign/mod.rs new file mode 100644 index 000000000..e69de29bb diff --git a/dkg-gadget/src/async_protocols/mod.rs b/dkg-gadget/src/async_protocols/mod.rs index cc3c1dab1..804ee5204 100644 --- a/dkg-gadget/src/async_protocols/mod.rs +++ b/dkg-gadget/src/async_protocols/mod.rs @@ -14,11 +14,11 @@ pub mod blockchain_interface; pub mod incoming; -pub mod keygen; pub mod remote; -pub mod sign; pub mod state_machine; pub mod state_machine_wrapper; +pub mod ecdsa; +pub mod frost; use sp_runtime::traits::Get; #[cfg(test)] pub mod test_utils; diff --git a/dkg-gadget/src/dkg_modules/mod.rs b/dkg-gadget/src/dkg_modules/mod.rs index c30fddf8e..1975b37be 100644 --- a/dkg-gadget/src/dkg_modules/mod.rs +++ b/dkg-gadget/src/dkg_modules/mod.rs @@ -21,7 +21,6 @@ use wt_frost::WTFrostDKG; pub mod mp_ecdsa; pub mod wt_frost; -pub mod wt_frost_wsts; /// Setup parameters for the Keygen protocol pub enum KeygenProtocolSetupParameters { diff --git a/dkg-gadget/src/dkg_modules/wt_frost.rs b/dkg-gadget/src/dkg_modules/wt_frost.rs index d51435a62..b34d90ed8 100644 --- a/dkg-gadget/src/dkg_modules/wt_frost.rs +++ b/dkg-gadget/src/dkg_modules/wt_frost.rs @@ -1,40 +1,56 @@ -use crate::{ - dkg_modules::{ - KeygenProtocolSetupParameters, ProtocolInitReturn, SigningProtocolSetupParameters, DKG, - }, - gossip_engine::GossipEngineIface, - worker::DKGWorker, - Client, -}; -use async_trait::async_trait; use dkg_primitives::types::DKGError; +use itertools::Itertools; +use rand::{CryptoRng, RngCore}; +use std::{collections::HashMap, fmt::Debug}; +use async_trait::async_trait; use sc_client_api::Backend; +use serde::{Deserialize, Serialize}; use sp_runtime::traits::Block; +use wsts::{ + common::{PolyCommitment, PublicNonce, Signature, SignatureShare}, + v2, + v2::SignatureAggregator, + Scalar, +}; +use crate::async_protocols::remote::AsyncProtocolRemote; +use crate::Client; +use crate::dkg_modules::{DKG, KeygenProtocolSetupParameters, ProtocolInitReturn, SigningProtocolSetupParameters}; +use crate::gossip_engine::GossipEngineIface; +use crate::worker::DKGWorker; /// DKG module for Weighted Threshold Frost pub struct WTFrostDKG -where - B: Block, - BE: Backend, - C: Client, - GE: GossipEngineIface, + where + B: Block, + BE: Backend, + C: Client, + GE: GossipEngineIface, { pub(super) dkg_worker: DKGWorker, } #[async_trait] impl DKG for WTFrostDKG -where - B: Block, - BE: Backend, - C: Client, - GE: GossipEngineIface, + where + B: Block, + BE: Backend, + C: Client, + GE: GossipEngineIface, { async fn initialize_keygen_protocol( &self, - _params: KeygenProtocolSetupParameters, + params: KeygenProtocolSetupParameters, ) -> Option> { - todo!() + if let KeygenProtocolSetupParameters::WTFrost {} = params { + let remote = AsyncProtocolRemote::new(); + let task = Box::pin(async move { + + }); + + Some((remote, task)) + } else { + None + } } async fn initialize_signing_protocol( @@ -53,141 +69,352 @@ where } } -#[cfg(test)] -mod tests { - use frost_coordinator::coordinator::{Command, Coordinator}; - use frost_signer::{ - config::{Config, PublicKeys}, - net::{Message, NetListen}, - signer::Signer, - signing_round::wsts::{ecdsa, Scalar}, +pub async fn run_dkg( + signer: &mut v2::Party, + rng: &mut RNG, + net: &mut Net, + n_signers: usize, +) -> Result, DKGError> { + // Broadcast our party_id, shares, and key_ids to each other + let party_id = signer.party_id; + let shares: HashMap = signer.get_shares().into_iter().collect(); + let key_ids = signer.key_ids.clone(); + let poly_commitment = signer.get_poly_commitment(rng); + let message = FrostMessage::DKG { + party_id, + shares: shares.clone(), + key_ids: key_ids.clone(), + poly_commitment: poly_commitment.clone(), }; - use futures::{stream::FuturesUnordered, TryStreamExt}; - use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc}; - #[derive(Clone)] - struct TestNetworkLayer { - tx: tokio::sync::broadcast::Sender, - rx: Arc>>, + // Send the message + net.send_message(message).await.map_err(|err| DKGError::GenericError { + reason: format!("Error sending FROST message: {err:?}"), + })?; + + let mut received_shares = HashMap::new(); + let mut received_key_ids = HashMap::new(); + let mut received_poly_commitments = HashMap::new(); + // insert our own shared into the received map + received_shares.insert(party_id, shares); + received_key_ids.insert(party_id, key_ids); + received_poly_commitments.insert(party_id, poly_commitment); + + // Wait for n_signers to send their messages to us + while received_shares.len() < n_signers { + match net.next_message().await { + Ok(Some(FrostMessage::DKG { party_id, shares, key_ids, poly_commitment })) => { + received_shares.insert(party_id, shares); + received_key_ids.insert(party_id, key_ids); + received_poly_commitments.insert(party_id, poly_commitment); + }, + + Ok(Some(_)) | Err(_) => {}, + None => + return Err(DKGError::GenericError { + reason: "NetListen connection died".to_string(), + }), + } } - #[async_trait::async_trait] - impl NetListen for TestNetworkLayer { - type Error = frost_signer::net::Error; + // Generate the party_shares: for each key id we own, we take our received key share at that + // index + let party_shares = signer + .key_ids + .iter() + .copied() + .map(|key_id| { + let mut key_shares = HashMap::new(); + + for (id, shares) in &received_shares { + key_shares.insert(*id, shares[&key_id]); + } - async fn poll(&self, _arg: u32) {} + (key_id, key_shares.into_iter().collect()) + }) + .collect(); + let polys = received_poly_commitments + .iter() + .sorted_by(|a, b| a.0.cmp(&b.0)) + .map(|r| r.1.clone()) + .collect_vec(); + signer + .compute_secret(&party_shares, &polys) + .map_err(|err| DKGError::GenericError { reason: err.to_string() })?; + Ok(polys) +} + +pub async fn run_signing( + signer: &mut v2::Party, + rng: &mut RNG, + msg: &[u8], + net: &mut Net, + n_signers: usize, + num_keys: u32, + threshold: u32, + public_key: Vec, +) -> Result { + // Broadcast the party_id, key_ids, and nonce to each other + let nonce = signer.gen_nonce(rng); + let party_id = signer.party_id; + let key_ids = signer.key_ids.clone(); + let message = FrostMessage::Sign { party_id, key_ids: key_ids.clone(), nonce: nonce.clone() }; + + // Send the message + net.send_message(message).await.map_err(|err| DKGError::GenericError { + reason: format!("Error sending FROST message: {err:?}"), + })?; + + let mut party_key_ids = HashMap::new(); + let mut party_nonces = HashMap::new(); + + party_key_ids.insert(party_id, key_ids); + party_nonces.insert(party_id, nonce); - async fn next_message(&self) -> Option { - dkg_logging::info!(target: "dkg", "Waiting for message"); - let msg = self.rx.lock().await.recv().await.ok(); - dkg_logging::info!(target: "dkg", "Received message"); - msg + while party_nonces.len() < n_signers { + match net.next_message().await { + Ok(Some(FrostMessage::Sign { party_id: party_id_recv, key_ids, nonce })) => { + party_key_ids.insert(party_id_recv, key_ids); + party_nonces.insert(party_id_recv, nonce); + }, + + Ok(Some(_)) | Err(_) => {}, + None => + return Err(DKGError::GenericError { + reason: "NetListen connection died".to_string(), + }), } + } + + // Sort the vecs + let party_ids = (0..n_signers).into_iter().map(|r| r as u32).collect_vec(); + let party_key_ids = party_key_ids + .into_iter() + .sorted_by(|a, b| a.0.cmp(&b.0)) + .flat_map(|r| r.1) + .collect_vec(); + let party_nonces = party_nonces + .into_iter() + .sorted_by(|a, b| a.0.cmp(&b.0)) + .map(|r| r.1) + .collect_vec(); + + // Generate our signature share + let signature_share = signer.sign(msg, &party_ids, &party_key_ids, &party_nonces); + let message = FrostMessage::SignFinal { party_id, signature_share: signature_share.clone() }; + // Broadcast our signature share to each other + net.send_message(message).await.map_err(|err| DKGError::GenericError { + reason: format!("Error sending FROST message: {err:?}"), + })?; + + let mut signature_shares = HashMap::new(); + signature_shares.insert(party_id, signature_share.clone()); - async fn send_message(&self, msg: Message) -> Result<(), Self::Error> { - dkg_logging::info!(target: "dkg", "Sending message"); - self.tx.send(msg).map(|_| ()).map_err(|_| frost_signer::net::Error::Timeout)?; - dkg_logging::info!(target: "dkg", "Sent message"); - Ok(()) + // Receive n_signers number of shares + while signature_shares.len() < n_signers { + match net.next_message().await { + Ok(Some(FrostMessage::SignFinal { party_id, signature_share })) => { + signature_shares.insert(party_id, signature_share); + }, + + Ok(Some(_)) | Err(_) => {}, + None => + return Err(DKGError::GenericError { + reason: "NetListen connection died".to_string(), + }), } } - fn create_signer_key_ids(signer_id: u32, keys_per_signer: u32) -> Vec { - (0..keys_per_signer).map(|i| keys_per_signer * signer_id + i + 1).collect() + // Sort the signature shares + let signature_shares = signature_shares + .into_iter() + .sorted_by(|a, b| a.0.cmp(&b.0)) + .map(|r| r.1) + .collect_vec(); + + // Aggregate and sign to generate the signature + let mut sig_agg = SignatureAggregator::new(num_keys, threshold, public_key) + .map_err(|err| DKGError::GenericError { reason: err.to_string() })?; + + sig_agg + .sign(msg, &party_nonces, &signature_shares, &party_key_ids) + .map_err(|err| DKGError::GenericError { reason: err.to_string() }) +} + +pub fn create_signer_key_ids(signer_id: u32, keys_per_signer: u32) -> Vec { + (0..keys_per_signer).map(|i| keys_per_signer * signer_id + i).collect() +} + +/// Returns a Vec of indices that denotes which indexes within the public key vector +/// are owned by which party. +/// +/// E.g., if n=4 and k=10, +/// +/// let party_key_ids: Vec> = [ +/// [0, 1, 2].to_vec(), +/// [3, 4].to_vec(), +/// [5, 6, 7].to_vec(), +/// [8, 9].to_vec(), +/// ] +/// +/// In the above case, we go up from 0..=9 possible key ids since k=10, and +/// we have 4 grouping since n=4. We need to generalize this below +#[allow(dead_code)] +pub fn generate_party_key_ids(n: u32, k: u32) -> Vec> { + let mut result = Vec::with_capacity(n as usize); + let ids_per_party = k / n; + let mut start = 0; + + for _ in 0..n { + let end = start + ids_per_party; + let ids = (start..end).collect(); + result.push(ids); + start = end; } - fn create_public_keys(signer_private_keys: &Vec, keys_per_signer: u32) -> PublicKeys { - let signer_id_keys: HashMap = signer_private_keys - .iter() - .enumerate() - .map(|(i, key)| ((i + 1) as u32, ecdsa::PublicKey::new(key).unwrap())) - .collect(); - - let key_ids = signer_id_keys - .iter() - .flat_map(|(signer_id, signer_key)| { - (0..keys_per_signer).map(|i| (keys_per_signer * *signer_id - i, signer_key.clone())) - }) - .collect(); + result +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum FrostMessage { + DKG { + party_id: u32, + shares: HashMap, + key_ids: Vec, + poly_commitment: PolyCommitment, + }, + Sign { + party_id: u32, + key_ids: Vec, + nonce: PublicNonce, + }, + SignFinal { + party_id: u32, + signature_share: SignatureShare, + }, +} + +#[async_trait::async_trait] +pub trait NetInterface { + type Error: Debug; + + async fn next_message(&mut self) -> Result, Self::Error>; + async fn send_message(&mut self, msg: FrostMessage) -> Result<(), Self::Error>; +} + +#[cfg(test)] +mod tests { + use crate::dkg_modules::wt_frost::{FrostMessage, NetInterface}; + use futures::{stream::FuturesUnordered, TryStreamExt}; + + struct TestNetworkLayer { + tx: tokio::sync::broadcast::Sender, + rx: tokio::sync::broadcast::Receiver, + } + + #[async_trait::async_trait] + impl NetInterface for TestNetworkLayer { + type Error = tokio::sync::broadcast::error::SendError; + + async fn next_message(&mut self) -> Result, Self::Error> { + Ok(self.rx.recv().await.ok()) + } - PublicKeys { signers: signer_id_keys.into_iter().collect(), key_ids } + async fn send_message(&mut self, msg: FrostMessage) -> Result<(), Self::Error> { + self.tx.send(msg).map(|_| ()) + } } #[tokio::test] - async fn test_dkg() { + async fn test_n3t2k3() { + test_inner::<3, 2, 3>().await; + } + + async fn test_inner() { dkg_logging::setup_log(); - let t = 15; - let n = 5; - let keys_per_signer = 5; + assert_eq!(K % N, 0); // Enforce that each party owns the same number of keys + assert_ne!(K, 0); // Enforce that K is not zero + assert!(N > T); - let (tx, _) = tokio::sync::broadcast::channel(1000); + // Each node creates their own party + let mut parties = Vec::new(); + let indices = super::generate_party_key_ids(N, K); + let rng = &mut rand::thread_rng(); + // In reality, the idx below would be our index in the best authorities, starting from zero + for (idx, key_indexes_owned_by_this_party) in indices.into_iter().enumerate() { + // See https://github.com/Trust-Machines/wsts/blob/037e2eb4105cf9f9b1c034ee5c1540a40123b530/src/v2.rs#L515 + // for generating the party key IDS + //let key_indexes_owned_by_this_party = super::create_signer_key_ids(idx, K); + dkg_logging::info!(target: "dkg", "keys owned by party {idx}: {key_indexes_owned_by_this_party:?}"); + parties.push(wsts::v2::Party::new( + idx as _, + &key_indexes_owned_by_this_party, + N, + K, + T, + rng, + )); + } - // Generate n+1 pub/priv keys, as well as their network layer - let mut networks = (1..=(n + 1)) + // setup the network + let (tx, _) = tokio::sync::broadcast::channel(1000); + let mut networks = (0..N) .into_iter() - .map(|_idx| { - let secret_key = Scalar::random(&mut rand::thread_rng()); - let public_key = ecdsa::PublicKey::new(&secret_key).unwrap(); - ( - public_key, - secret_key, - TestNetworkLayer { - tx: tx.clone(), - rx: tokio::sync::Mutex::new(tx.subscribe()).into(), - }, - ) + .map(|_idx| TestNetworkLayer { + tx: tx.clone(), + rx: tx.subscribe(), }) .collect::>(); - let public_keys = create_public_keys( - &networks.iter().map(|(_, private_key, _)| private_key.clone()).collect(), - keys_per_signer, - ); - let signer_key_ids: HashMap> = (0..n) - .into_iter() - .map(|i| (i + 1, create_signer_key_ids(i, keys_per_signer))) - .collect(); - - // Generate the coordinator - let (coordinator_pub_key, coordinator_priv_key, coordinator_network) = - networks.pop().unwrap(); - let coordinator_config = Config::new( - t, - coordinator_pub_key, - public_keys.clone(), - signer_key_ids.clone().into_iter().collect(), - coordinator_priv_key, - Default::default(), - ); - let mut coordinator = - Coordinator::new(0, &coordinator_config, coordinator_network).unwrap(); - - dkg_logging::info!(target: "dkg", "Signer key IDs: {:?}", signer_key_ids); - let nodes = FuturesUnordered::new(); - - nodes.push(Box::pin(async move { - coordinator.run(&Command::Dkg).await.map_err(|err| err.to_string())?; - Ok::<_, String>(()) - }) as Pin> + Send>>); - - // Create a config and coordinator for each node - for (i, (_public_key, secret_key, network)) in networks.into_iter().enumerate() { - let signer_config = Config::new( - t, - coordinator_pub_key.clone(), - public_keys.clone(), - signer_key_ids.clone().into_iter().collect(), - secret_key, - Default::default(), - ); - let mut signer = Signer::new(signer_config, (i + 1) as u32); - - nodes.push(Box::pin(async move { - signer.start_p2p_async(network).await.map_err(|err| err.to_string())?; - Ok::<_, String>(()) + // Test the DKG + let dkgs = FuturesUnordered::new(); + for (party, network) in parties.iter_mut().zip(networks.iter_mut()) { + dkgs.push(Box::pin(async move { + let mut rng = rand::thread_rng(); + crate::dkg_modules::wt_frost::run_dkg(party, &mut rng, network, N as _).await })); } - nodes.try_collect::>().await.unwrap(); + let mut public_keys = dkgs.try_collect::>().await.unwrap(); + for public_key in &public_keys { + assert_eq!(public_key.len(), N as usize); + for public_key0 in &public_keys { + // Assert all equal + assert!(public_key + .iter() + .zip(public_key0) + .all(|r| r.0.id.kG == r.1.id.kG && + r.0.id.id == r.1.id.id && r.0.id.kca == r.1.id.kca && + r.0.A == r.1.A)); + } + } + + let public_key = public_keys.pop().unwrap(); + + // Test the signing over an arbitrary message + let msg = b"Hello, world!"; + + // Start by choosing signers. Since our indexes, in reality, will be based on the set of + // best authorities, we will choose the best of the best of authorities, so from 0..T + let signers = FuturesUnordered::new(); + + for (party, network) in parties.iter_mut().zip(networks.iter_mut()).take(T as _) { + let public_key = public_key.clone(); + signers.push(Box::pin(async move { + let mut rng = rand::thread_rng(); + crate::dkg_modules::wt_frost::run_signing( + party, &mut rng, &*msg, network, T as usize, K, T, public_key, + ) + .await + })); + } + + let signatures = signers.try_collect::>().await.unwrap(); + for signature0 in &signatures { + for signature1 in &signatures { + assert_eq!(signature0.R, signature1.R); + assert_eq!(signature0.z, signature1.z); + } + } } } diff --git a/dkg-gadget/src/dkg_modules/wt_frost_wsts.rs b/dkg-gadget/src/dkg_modules/wt_frost_wsts.rs deleted file mode 100644 index 36a4495df..000000000 --- a/dkg-gadget/src/dkg_modules/wt_frost_wsts.rs +++ /dev/null @@ -1,363 +0,0 @@ -use dkg_primitives::types::DKGError; -use itertools::Itertools; -use rand::{CryptoRng, RngCore}; -use std::{collections::HashMap, fmt::Debug}; -use wsts::{ - common::{PolyCommitment, PublicNonce, Signature, SignatureShare}, - v2, - v2::SignatureAggregator, - Scalar, -}; - -pub async fn run_dkg( - signer: &mut v2::Party, - rng: &mut RNG, - net: &Net, - n_signers: usize, -) -> Result, DKGError> { - // Broadcast our party_id, shares, and key_ids to each other - let party_id = signer.party_id; - let shares: HashMap = signer.get_shares().into_iter().collect(); - let key_ids = signer.key_ids.clone(); - let poly_commitment = signer.get_poly_commitment(rng); - let message = FrostMessage::DKG { - party_id, - shares: shares.clone(), - key_ids: key_ids.clone(), - poly_commitment: poly_commitment.clone(), - }; - - // Send the message - net.send_message(message).await.map_err(|err| DKGError::GenericError { - reason: format!("Error sending FROST message: {err:?}"), - })?; - - let mut received_shares = HashMap::new(); - let mut received_key_ids = HashMap::new(); - let mut received_poly_commitments = HashMap::new(); - // insert our own shared into the received map - received_shares.insert(party_id, shares); - received_key_ids.insert(party_id, key_ids); - received_poly_commitments.insert(party_id, poly_commitment); - - // Wait for n_signers to send their messages to us - while received_shares.len() < n_signers { - match net.next_message().await { - Some(FrostMessage::DKG { party_id, shares, key_ids, poly_commitment }) => { - received_shares.insert(party_id, shares); - received_key_ids.insert(party_id, key_ids); - received_poly_commitments.insert(party_id, poly_commitment); - }, - - Some(_) => {}, - None => - return Err(DKGError::GenericError { - reason: "NetListen connection died".to_string(), - }), - } - } - - // Generate the party_shares: for each key id we own, we take our received key share at that - // index - let party_shares = signer - .key_ids - .iter() - .copied() - .map(|key_id| { - let mut key_shares = HashMap::new(); - - for (id, shares) in &received_shares { - key_shares.insert(*id, shares[&key_id]); - } - - (key_id, key_shares.into_iter().collect()) - }) - .collect(); - let polys = received_poly_commitments - .iter() - .sorted_by(|a, b| a.0.cmp(&b.0)) - .map(|r| r.1.clone()) - .collect_vec(); - signer - .compute_secret(&party_shares, &polys) - .map_err(|err| DKGError::GenericError { reason: err.to_string() })?; - Ok(polys) -} - -pub async fn run_signing( - signer: &mut v2::Party, - rng: &mut RNG, - msg: &[u8], - net: &Net, - n_signers: usize, - num_keys: u32, - threshold: u32, - public_key: Vec, -) -> Result { - // Broadcast the party_id, key_ids, and nonce to each other - let nonce = signer.gen_nonce(rng); - let party_id = signer.party_id; - let key_ids = signer.key_ids.clone(); - let message = FrostMessage::Sign { party_id, key_ids: key_ids.clone(), nonce: nonce.clone() }; - - // Send the message - net.send_message(message).await.map_err(|err| DKGError::GenericError { - reason: format!("Error sending FROST message: {err:?}"), - })?; - - let mut party_key_ids = HashMap::new(); - let mut party_nonces = HashMap::new(); - - party_key_ids.insert(party_id, key_ids); - party_nonces.insert(party_id, nonce); - - while party_nonces.len() < n_signers { - match net.next_message().await { - Some(FrostMessage::Sign { party_id: party_id_recv, key_ids, nonce }) => { - party_key_ids.insert(party_id_recv, key_ids); - party_nonces.insert(party_id_recv, nonce); - }, - - Some(_) => {}, - None => - return Err(DKGError::GenericError { - reason: "NetListen connection died".to_string(), - }), - } - } - - // Sort the vecs all in order - //let party_ids = party_ids.into_iter().sorted().collect_vec(); - let party_ids = (0..n_signers).into_iter().map(|r| r as u32).collect_vec(); - let party_key_ids = party_key_ids - .into_iter() - .sorted_by(|a, b| a.0.cmp(&b.0)) - .flat_map(|r| r.1) - .collect_vec(); - let party_nonces = party_nonces - .into_iter() - .sorted_by(|a, b| a.0.cmp(&b.0)) - .map(|r| r.1) - .collect_vec(); - - // Generate our signature share - let signature_share = signer.sign(msg, &party_ids, &party_key_ids, &party_nonces); - let message = FrostMessage::SignFinal { party_id, signature_share: signature_share.clone() }; - // Broadcast our signature share to each other - net.send_message(message).await.map_err(|err| DKGError::GenericError { - reason: format!("Error sending FROST message: {err:?}"), - })?; - - let mut signature_shares = HashMap::new(); - signature_shares.insert(party_id, signature_share.clone()); - - // Receive n_signers number of shares - while signature_shares.len() < n_signers { - match net.next_message().await { - Some(FrostMessage::SignFinal { party_id, signature_share }) => { - signature_shares.insert(party_id, signature_share); - }, - - Some(_) => {}, - None => - return Err(DKGError::GenericError { - reason: "NetListen connection died".to_string(), - }), - } - } - - // Sort the signature shares - let signature_shares = signature_shares - .into_iter() - .sorted_by(|a, b| a.0.cmp(&b.0)) - .map(|r| r.1) - .collect_vec(); - - // Aggregate and sign to generate the signature - let mut sig_agg = SignatureAggregator::new(num_keys, threshold, public_key) - .map_err(|err| DKGError::GenericError { reason: err.to_string() })?; - - sig_agg - .sign(msg, &party_nonces, &signature_shares, &party_key_ids) - .map_err(|err| DKGError::GenericError { reason: err.to_string() }) -} - -pub fn create_signer_key_ids(signer_id: u32, keys_per_signer: u32) -> Vec { - (0..keys_per_signer).map(|i| keys_per_signer * signer_id + i).collect() -} - -/// Returns a Vec of indices that denotes which indexes within the public key vector -/// are owned by which party. -/// -/// E.g., if n=4 and k=10, -/// -/// let party_key_ids: Vec> = [ -/// [0, 1, 2].to_vec(), -/// [3, 4].to_vec(), -/// [5, 6, 7].to_vec(), -/// [8, 9].to_vec(), -/// ] -/// -/// In the above case, we go up from 0..=9 possible key ids since k=10, and -/// we have 4 grouping since n=4. We need to generalize this below -#[allow(dead_code)] -pub fn generate_party_key_ids(n: u32, k: u32) -> Vec> { - let mut result = Vec::with_capacity(n as usize); - let ids_per_party = k / n; - let mut start = 0; - - for _ in 0..n { - let end = start + ids_per_party; - let ids = (start..end).collect(); - result.push(ids); - start = end; - } - - result -} - -#[derive(Clone, Debug)] -pub enum FrostMessage { - DKG { - party_id: u32, - shares: HashMap, - key_ids: Vec, - poly_commitment: PolyCommitment, - }, - Sign { - party_id: u32, - key_ids: Vec, - nonce: PublicNonce, - }, - SignFinal { - party_id: u32, - signature_share: SignatureShare, - }, -} - -#[async_trait::async_trait] -pub trait NetListen { - type Error: Debug; - - async fn next_message(&self) -> Option; - async fn send_message(&self, msg: FrostMessage) -> Result<(), Self::Error>; -} - -#[cfg(test)] -mod tests { - use crate::dkg_modules::wt_frost_wsts::{FrostMessage, NetListen}; - use futures::{stream::FuturesUnordered, TryStreamExt}; - use std::sync::Arc; - - #[derive(Clone)] - struct TestNetworkLayer { - tx: tokio::sync::broadcast::Sender, - rx: Arc>>, - } - - #[async_trait::async_trait] - impl NetListen for TestNetworkLayer { - type Error = tokio::sync::broadcast::error::SendError; - - async fn next_message(&self) -> Option { - self.rx.lock().await.recv().await.ok() - } - - async fn send_message(&self, msg: FrostMessage) -> Result<(), Self::Error> { - self.tx.send(msg).map(|_| ()) - } - } - - #[tokio::test] - async fn test_n3t2k3_raw() { - test_inner::<3, 2, 3>().await; - } - - async fn test_inner() { - dkg_logging::setup_log(); - assert_eq!(K % N, 0); // Enforce that each party owns the same number of keys - assert_ne!(K, 0); // Enforce that K is not zero - assert!(N > T); - - // Each node creates their own party - let mut parties = Vec::new(); - let indices = super::generate_party_key_ids(N, K); - let rng = &mut rand::thread_rng(); - // In reality, the idx below would be our index in the best authorities, starting from zero - for (idx, key_indexes_owned_by_this_party) in indices.into_iter().enumerate() { - // See https://github.com/Trust-Machines/wsts/blob/037e2eb4105cf9f9b1c034ee5c1540a40123b530/src/v2.rs#L515 - // for generating the party key IDS - //let key_indexes_owned_by_this_party = super::create_signer_key_ids(idx, K); - dkg_logging::info!(target: "dkg", "keys owned by party {idx}: {key_indexes_owned_by_this_party:?}"); - parties.push(wsts::v2::Party::new( - idx as _, - &key_indexes_owned_by_this_party, - N, - K, - T, - rng, - )); - } - - // setup the network - let (tx, _) = tokio::sync::broadcast::channel(1000); - let networks = (0..N) - .into_iter() - .map(|_idx| TestNetworkLayer { - tx: tx.clone(), - rx: tokio::sync::Mutex::new(tx.subscribe()).into(), - }) - .collect::>(); - - // Test the DKG - let dkgs = FuturesUnordered::new(); - for (party, network) in parties.iter_mut().zip(networks.iter()) { - dkgs.push(Box::pin(async move { - let mut rng = rand::thread_rng(); - crate::dkg_modules::wt_frost_wsts::run_dkg(party, &mut rng, network, N as _).await - })); - } - - let mut public_keys = dkgs.try_collect::>().await.unwrap(); - for public_key in &public_keys { - assert_eq!(public_key.len(), N as usize); - for public_key0 in &public_keys { - // Assert all equal - assert!(public_key - .iter() - .zip(public_key0) - .all(|r| r.0.id.kG == r.1.id.kG && - r.0.id.id == r.1.id.id && r.0.id.kca == r.1.id.kca && - r.0.A == r.1.A)); - } - } - - let public_key = public_keys.pop().unwrap(); - - // Test the signing over an arbitrary message - let msg = b"Hello, world!"; - - // Start by choosing signers. Since our indexes, in reality, will be based on the set of - // best authorities, we will choose the best of the best of authorities, so from 0..T - let signers = FuturesUnordered::new(); - - for (party, network) in parties.iter_mut().zip(networks.iter()).take(T as _) { - let public_key = public_key.clone(); - signers.push(Box::pin(async move { - let mut rng = rand::thread_rng(); - crate::dkg_modules::wt_frost_wsts::run_signing( - party, &mut rng, &*msg, network, T as usize, K, T, public_key, - ) - .await - })); - } - - let signatures = signers.try_collect::>().await.unwrap(); - for signature0 in &signatures { - for signature1 in &signatures { - assert_eq!(signature0.R, signature1.R); - assert_eq!(signature0.z, signature1.z); - } - } - } -}