From c950ba192fc6026697c86c9c8db3b6a91e19e57a Mon Sep 17 00:00:00 2001 From: quake Date: Sat, 21 Dec 2024 19:39:59 +0900 Subject: [PATCH] chore: simplify fn prune_messages_to_be_saved --- src/fiber/gossip.rs | 92 ++++++++++----------------------------------- src/fiber/types.rs | 55 ++++++++++++++++++++++++++- 2 files changed, 73 insertions(+), 74 deletions(-) diff --git a/src/fiber/gossip.rs b/src/fiber/gossip.rs index 37e0e69e7..0f9154e5b 100644 --- a/src/fiber/gossip.rs +++ b/src/fiber/gossip.rs @@ -1,6 +1,7 @@ use std::{ - collections::{HashMap, HashSet, VecDeque}, + collections::{HashMap, HashSet}, marker::PhantomData, + mem::take, sync::Arc, time::Duration, }; @@ -1017,83 +1018,30 @@ impl ExtendedGossipMessageStoreState { // We will also change the relevant state (e.g. update the latest cursor). // The returned list may be sent to the subscribers. async fn prune_messages_to_be_saved(&mut self) -> Vec { - let complete_messages = self - .messages_to_be_saved - .iter() - .filter(|m| self.has_dependencies_available(m)) - .cloned() - .collect::>(); - self.messages_to_be_saved - .retain(|v| !complete_messages.contains(v)); - - let mut sorted_messages = Vec::with_capacity(complete_messages.len()); - - // Save all the messages to a map so that we can easily order messages by their dependencies. - let mut messages_map: HashMap< - (BroadcastMessageID, bool), - VecDeque, - > = HashMap::new(); - - for new_message in complete_messages { - let key = ( - new_message.message_id(), - match &new_message { - // Message id alone is not enough to differentiate channel updates. - // We need a flag to indicate if the message is an update of node 1. - BroadcastMessageWithTimestamp::ChannelUpdate(channel_update) => { - channel_update.is_update_of_node_1() - } - _ => true, - }, - ); - let messages = messages_map.entry(key).or_default(); - let index = messages.partition_point(|m| m.cursor() < new_message.cursor()); - match messages.get(index + 1) { - Some(message) if message == &new_message => { - // The same message is already saved. - continue; - } - _ => { - messages.insert(index, new_message); - } - } - } + let messages_to_be_saved = take(&mut self.messages_to_be_saved); + let (complete_messages, uncomplete_messages) = messages_to_be_saved + .into_iter() + .partition(|m| self.has_dependencies_available(m)); + self.messages_to_be_saved = uncomplete_messages; - loop { - let key = match messages_map.keys().next() { - None => break, - Some(key) => key.clone(), - }; - let messages = messages_map.remove(&key).expect("key exists"); - if let BroadcastMessageWithTimestamp::ChannelUpdate(channel_update) = &messages[0] { - let outpoint = channel_update.channel_outpoint.clone(); - if let Some(message) = - messages_map.remove(&(BroadcastMessageID::ChannelAnnouncement(outpoint), true)) - { - for message in message { - sorted_messages.push(message); - } - } - } - for message in messages { - sorted_messages.push(message); - } - } + let mut sorted_messages = complete_messages.into_iter().collect::>(); + sorted_messages.sort_unstable(); let mut verified_sorted_messages = Vec::with_capacity(sorted_messages.len()); - for message in sorted_messages { - if let Err(error) = - verify_and_save_broadcast_message(&message, &self.store, &self.chain_actor).await + match verify_and_save_broadcast_message(&message, &self.store, &self.chain_actor).await { - warn!( - "Failed to verify and save message {:?}: {:?}", - message, error - ); - continue; + Ok(_) => { + self.update_last_cursor(message.cursor()); + verified_sorted_messages.push(message); + } + Err(error) => { + warn!( + "Failed to verify and save message {:?}: {:?}", + message, error + ); + } } - self.update_last_cursor(message.cursor()); - verified_sorted_messages.push(message); } verified_sorted_messages diff --git a/src/fiber/types.rs b/src/fiber/types.rs index 03ee618e1..d35122eef 100644 --- a/src/fiber/types.rs +++ b/src/fiber/types.rs @@ -30,6 +30,7 @@ use secp256k1::{ use secp256k1::{Verification, XOnlyPublicKey}; use serde::{Deserialize, Serialize}; use serde_with::serde_as; +use std::cmp::Ordering; use std::fmt::Display; use std::marker::PhantomData; use std::str::FromStr; @@ -2407,6 +2408,20 @@ impl BroadcastMessageWithTimestamp { } } +impl Ord for BroadcastMessageWithTimestamp { + fn cmp(&self, other: &Self) -> Ordering { + self.message_id() + .cmp(&other.message_id()) + .then(self.timestamp().cmp(&other.timestamp())) + } +} + +impl PartialOrd for BroadcastMessageWithTimestamp { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + impl From for BroadcastMessage { fn from(broadcast_message_with_timestamp: BroadcastMessageWithTimestamp) -> Self { match broadcast_message_with_timestamp { @@ -2592,6 +2607,42 @@ pub enum BroadcastMessageID { NodeAnnouncement(Pubkey), } +// We need to implement Ord for BroadcastMessageID to make sure that a ChannelUpdate message is always ordered after ChannelAnnouncement, +// so that we can use it as the sorting key in fn prune_messages_to_be_saved to simplify the logic. +impl Ord for BroadcastMessageID { + fn cmp(&self, other: &Self) -> Ordering { + match (self, other) { + ( + BroadcastMessageID::ChannelAnnouncement(outpoint1), + BroadcastMessageID::ChannelAnnouncement(outpoint2), + ) => outpoint1.cmp(outpoint2), + ( + BroadcastMessageID::ChannelUpdate(outpoint1), + BroadcastMessageID::ChannelUpdate(outpoint2), + ) => outpoint1.cmp(outpoint2), + ( + BroadcastMessageID::NodeAnnouncement(pubkey1), + BroadcastMessageID::NodeAnnouncement(pubkey2), + ) => pubkey1.cmp(pubkey2), + (BroadcastMessageID::ChannelUpdate(_), _) => Ordering::Less, + (BroadcastMessageID::NodeAnnouncement(_), _) => Ordering::Greater, + ( + BroadcastMessageID::ChannelAnnouncement(_), + BroadcastMessageID::NodeAnnouncement(_), + ) => Ordering::Less, + (BroadcastMessageID::ChannelAnnouncement(_), BroadcastMessageID::ChannelUpdate(_)) => { + Ordering::Greater + } + } + } +} + +impl PartialOrd for BroadcastMessageID { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + // 1 byte for message type, 36 bytes for message id const MESSAGE_ID_SIZE: usize = 1 + 36; // 8 bytes for timestamp, MESSAGE_ID_SIZE bytes for message id @@ -2717,13 +2768,13 @@ impl Cursor { } impl Ord for Cursor { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { + fn cmp(&self, other: &Self) -> Ordering { self.to_bytes().cmp(&other.to_bytes()) } } impl PartialOrd for Cursor { - fn partial_cmp(&self, other: &Self) -> Option { + fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } }