Skip to content

Commit

Permalink
Merge pull request #414 from quake/quake/refactor-prune_messages_to_b…
Browse files Browse the repository at this point in the history
…e_saved

chore: simplify fn prune_messages_to_be_saved
  • Loading branch information
quake authored Dec 25, 2024
2 parents aebf41f + c950ba1 commit 6042adc
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 74 deletions.
92 changes: 20 additions & 72 deletions src/fiber/gossip.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{
collections::{HashMap, HashSet, VecDeque},
collections::{HashMap, HashSet},
marker::PhantomData,
mem::take,
sync::Arc,
time::Duration,
};
Expand Down Expand Up @@ -1017,83 +1018,30 @@ impl<S: GossipMessageStore> ExtendedGossipMessageStoreState<S> {
// We will also change the relevant state (e.g. update the latest cursor).
// The returned list may be sent to the subscribers.
async fn prune_messages_to_be_saved(&mut self) -> Vec<BroadcastMessageWithTimestamp> {
let complete_messages = self
.messages_to_be_saved
.iter()
.filter(|m| self.has_dependencies_available(m))
.cloned()
.collect::<HashSet<_>>();
self.messages_to_be_saved
.retain(|v| !complete_messages.contains(v));

let mut sorted_messages = Vec::with_capacity(complete_messages.len());

// Save all the messages to a map so that we can easily order messages by their dependencies.
let mut messages_map: HashMap<
(BroadcastMessageID, bool),
VecDeque<BroadcastMessageWithTimestamp>,
> = HashMap::new();

for new_message in complete_messages {
let key = (
new_message.message_id(),
match &new_message {
// Message id alone is not enough to differentiate channel updates.
// We need a flag to indicate if the message is an update of node 1.
BroadcastMessageWithTimestamp::ChannelUpdate(channel_update) => {
channel_update.is_update_of_node_1()
}
_ => true,
},
);
let messages = messages_map.entry(key).or_default();
let index = messages.partition_point(|m| m.cursor() < new_message.cursor());
match messages.get(index + 1) {
Some(message) if message == &new_message => {
// The same message is already saved.
continue;
}
_ => {
messages.insert(index, new_message);
}
}
}
let messages_to_be_saved = take(&mut self.messages_to_be_saved);
let (complete_messages, uncomplete_messages) = messages_to_be_saved
.into_iter()
.partition(|m| self.has_dependencies_available(m));
self.messages_to_be_saved = uncomplete_messages;

loop {
let key = match messages_map.keys().next() {
None => break,
Some(key) => key.clone(),
};
let messages = messages_map.remove(&key).expect("key exists");
if let BroadcastMessageWithTimestamp::ChannelUpdate(channel_update) = &messages[0] {
let outpoint = channel_update.channel_outpoint.clone();
if let Some(message) =
messages_map.remove(&(BroadcastMessageID::ChannelAnnouncement(outpoint), true))
{
for message in message {
sorted_messages.push(message);
}
}
}
for message in messages {
sorted_messages.push(message);
}
}
let mut sorted_messages = complete_messages.into_iter().collect::<Vec<_>>();
sorted_messages.sort_unstable();

let mut verified_sorted_messages = Vec::with_capacity(sorted_messages.len());

for message in sorted_messages {
if let Err(error) =
verify_and_save_broadcast_message(&message, &self.store, &self.chain_actor).await
match verify_and_save_broadcast_message(&message, &self.store, &self.chain_actor).await
{
warn!(
"Failed to verify and save message {:?}: {:?}",
message, error
);
continue;
Ok(_) => {
self.update_last_cursor(message.cursor());
verified_sorted_messages.push(message);
}
Err(error) => {
warn!(
"Failed to verify and save message {:?}: {:?}",
message, error
);
}
}
self.update_last_cursor(message.cursor());
verified_sorted_messages.push(message);
}

verified_sorted_messages
Expand Down
55 changes: 53 additions & 2 deletions src/fiber/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use secp256k1::{
use secp256k1::{Verification, XOnlyPublicKey};
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use std::cmp::Ordering;
use std::fmt::Display;
use std::marker::PhantomData;
use std::str::FromStr;
Expand Down Expand Up @@ -2407,6 +2408,20 @@ impl BroadcastMessageWithTimestamp {
}
}

impl Ord for BroadcastMessageWithTimestamp {
fn cmp(&self, other: &Self) -> Ordering {
self.message_id()
.cmp(&other.message_id())
.then(self.timestamp().cmp(&other.timestamp()))
}
}

impl PartialOrd for BroadcastMessageWithTimestamp {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl From<BroadcastMessageWithTimestamp> for BroadcastMessage {
fn from(broadcast_message_with_timestamp: BroadcastMessageWithTimestamp) -> Self {
match broadcast_message_with_timestamp {
Expand Down Expand Up @@ -2592,6 +2607,42 @@ pub enum BroadcastMessageID {
NodeAnnouncement(Pubkey),
}

// We need to implement Ord for BroadcastMessageID to make sure that a ChannelUpdate message is always ordered after ChannelAnnouncement,
// so that we can use it as the sorting key in fn prune_messages_to_be_saved to simplify the logic.
impl Ord for BroadcastMessageID {
fn cmp(&self, other: &Self) -> Ordering {
match (self, other) {
(
BroadcastMessageID::ChannelAnnouncement(outpoint1),
BroadcastMessageID::ChannelAnnouncement(outpoint2),
) => outpoint1.cmp(outpoint2),
(
BroadcastMessageID::ChannelUpdate(outpoint1),
BroadcastMessageID::ChannelUpdate(outpoint2),
) => outpoint1.cmp(outpoint2),
(
BroadcastMessageID::NodeAnnouncement(pubkey1),
BroadcastMessageID::NodeAnnouncement(pubkey2),
) => pubkey1.cmp(pubkey2),
(BroadcastMessageID::ChannelUpdate(_), _) => Ordering::Less,
(BroadcastMessageID::NodeAnnouncement(_), _) => Ordering::Greater,
(
BroadcastMessageID::ChannelAnnouncement(_),
BroadcastMessageID::NodeAnnouncement(_),
) => Ordering::Less,
(BroadcastMessageID::ChannelAnnouncement(_), BroadcastMessageID::ChannelUpdate(_)) => {
Ordering::Greater
}
}
}
}

impl PartialOrd for BroadcastMessageID {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

// 1 byte for message type, 36 bytes for message id
const MESSAGE_ID_SIZE: usize = 1 + 36;
// 8 bytes for timestamp, MESSAGE_ID_SIZE bytes for message id
Expand Down Expand Up @@ -2717,13 +2768,13 @@ impl Cursor {
}

impl Ord for Cursor {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
fn cmp(&self, other: &Self) -> Ordering {
self.to_bytes().cmp(&other.to_bytes())
}
}

impl PartialOrd for Cursor {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
Expand Down

0 comments on commit 6042adc

Please sign in to comment.