Skip to content

Commit

Permalink
Move fields to ChannelTlcInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
contrun committed Jan 7, 2025
1 parent 1724fa9 commit 185f659
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 109 deletions.
167 changes: 73 additions & 94 deletions src/fiber/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,10 @@ where
Ok(())
}
FiberChannelMessage::UpdateTlcInfo(update_tlc_info) => {
let old = state
.remote_tlc_info
.get_or_insert_with(|| ChannelTlcInfo::default());
let old = state.remote_tlc_info.get_or_insert_with(|| ChannelTlcInfo {
enabled: true,
..Default::default()
});

if let Some(tlc_expiry_delta) = update_tlc_info.tlc_expiry_delta {
old.tlc_expiry_delta = tlc_expiry_delta;
Expand Down Expand Up @@ -932,47 +933,44 @@ where
return Err(ProcessingChannelError::FinalIncorrectPaymentHash);
}
} else {
match state.public_channel_info.as_ref() {
Some(public_channel_info) if public_channel_info.enabled => {
if state.local_tlc_info.tlc_min_value > received_amount {
return Err(ProcessingChannelError::TlcAmountIsTooLow);
}
if state.is_public() && state.is_tlc_forwarding_enabled() {
if state.local_tlc_info.tlc_min_value > received_amount {
return Err(ProcessingChannelError::TlcAmountIsTooLow);
}

if add_tlc.expiry
< peeled_onion_packet.current.expiry + state.local_tlc_info.tlc_expiry_delta
{
return Err(ProcessingChannelError::IncorrectTlcExpiry);
}
if add_tlc.expiry
< peeled_onion_packet.current.expiry + state.local_tlc_info.tlc_expiry_delta
{
return Err(ProcessingChannelError::IncorrectTlcExpiry);
}

assert!(received_amount >= forward_amount);
let forward_fee = received_amount.saturating_sub(forward_amount);
let fee_rate: u128 = state.local_tlc_info.tlc_fee_proportional_millionths;
assert!(received_amount >= forward_amount);
let forward_fee = received_amount.saturating_sub(forward_amount);
let fee_rate: u128 = state.local_tlc_info.tlc_fee_proportional_millionths;

let expected_fee = calculate_tlc_forward_fee(forward_amount, fee_rate);
if expected_fee.is_err() || forward_fee < expected_fee.clone().unwrap() {
error!(
"too low forward_fee: {}, expected_fee: {:?}",
forward_fee, expected_fee
);
return Err(ProcessingChannelError::TlcForwardFeeIsTooLow);
}
// if this is not the last hop, forward TLC to next hop
self.handle_forward_onion_packet(
state,
add_tlc.payment_hash,
peeled_onion_packet.clone(),
add_tlc.tlc_id.into(),
)
.await?;
}
_ => {
// if we don't have public channel info, we can not forward the TLC
// this may happended some malicious sender build a invalid onion router
return Err(ProcessingChannelError::InvalidState(
"Received AddTlc message, but the channel is not public or disabled"
.to_string(),
));
let expected_fee = calculate_tlc_forward_fee(forward_amount, fee_rate);
if expected_fee.is_err() || forward_fee < expected_fee.clone().unwrap() {
error!(
"too low forward_fee: {}, expected_fee: {:?}",
forward_fee, expected_fee
);
return Err(ProcessingChannelError::TlcForwardFeeIsTooLow);
}
// if this is not the last hop, forward TLC to next hop
self.handle_forward_onion_packet(
state,
add_tlc.payment_hash,
peeled_onion_packet.clone(),
add_tlc.tlc_id.into(),
)
.await?;
} else {
// if we don't have public channel info, we can not forward the TLC
// this may happended some malicious sender build a invalid onion router
return Err(ProcessingChannelError::InvalidState(
"Received AddTlc message, but the channel is not public or disabled"
.to_string(),
));
}
}
Ok(())
Expand Down Expand Up @@ -1365,9 +1363,7 @@ where
}

if updated {
state
.generate_and_broadcast_channel_update(&self.network)
.await;
state.on_channel_tlc_info_updated(&self.network).await;
}

Ok(())
Expand Down Expand Up @@ -2877,6 +2873,12 @@ pub struct ShutdownInfo {
#[serde_as]
#[derive(Default, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct ChannelTlcInfo {
// The timestamp when the following information is updated.
pub timestamp: u64,

// Whether this channel is enabled for TLC forwarding or not.
pub enabled: bool,

// The fee rate for tlc transfers. We only have these values set when
// this is a public channel. Both sides may set this value differently.
// This is a fee that is paid by the sender of the tlc.
Expand Down Expand Up @@ -2904,6 +2906,7 @@ impl ChannelTlcInfo {
tlc_min_value,
tlc_expiry_delta,
tlc_fee_proportional_millionths,
enabled: true,
..Default::default()
}
}
Expand All @@ -2918,7 +2921,6 @@ impl ChannelTlcInfo {
#[serde_as]
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
pub struct PublicChannelInfo {
pub enabled: bool,
// Channel announcement signatures, may be empty for private channel.
pub local_channel_announcement_signature: Option<(EcdsaSignature, PartialSignature)>,
pub remote_channel_announcement_signature: Option<(EcdsaSignature, PartialSignature)>,
Expand All @@ -2932,10 +2934,7 @@ pub struct PublicChannelInfo {

impl PublicChannelInfo {
pub fn new() -> Self {
Self {
enabled: true,
..Default::default()
}
Default::default()
}
}

Expand Down Expand Up @@ -3323,6 +3322,10 @@ impl ChannelActorState {
matches!(self.state, ChannelState::ChannelReady())
}

pub fn is_tlc_forwarding_enabled(&self) -> bool {
self.local_tlc_info.enabled
}

pub async fn try_create_channel_messages(
&mut self,
network: &ActorRef<NetworkActorMessage>,
Expand Down Expand Up @@ -3472,10 +3475,7 @@ impl ChannelActorState {
.await
}

async fn generate_and_broadcast_channel_update(
&mut self,
network: &ActorRef<NetworkActorMessage>,
) {
async fn on_channel_tlc_info_updated(&mut self, network: &ActorRef<NetworkActorMessage>) {
if self.is_public() {
let channel_update = self.generate_channel_update(network).await;
network
Expand Down Expand Up @@ -3525,10 +3525,11 @@ impl ChannelActorState {
}

fn get_channel_update_channel_flags(&self) -> u32 {
self.public_channel_info
.as_ref()
.and_then(|info: &PublicChannelInfo| (!info.enabled).then_some(CHANNEL_DISABLED_FLAG))
.unwrap_or(0)
if self.is_tlc_forwarding_enabled() {
0
} else {
CHANNEL_DISABLED_FLAG
}
}

pub fn get_unsigned_channel_update_message(&self) -> Option<ChannelUpdate> {
Expand Down Expand Up @@ -3979,58 +3980,36 @@ impl ChannelActorState {
.remote_channel_announcement_signature = Some((ecdsa_signature, partial_signatures));
}

fn get_our_tlc_fee_proportional_millionths(&self) -> u128 {
self.local_tlc_info.tlc_fee_proportional_millionths
}

fn update_our_tlc_fee_proportional_millionths(&mut self, fee: u128) -> bool {
if self.get_our_tlc_fee_proportional_millionths() != fee {
self.local_tlc_info.tlc_fee_proportional_millionths = fee;
true
} else {
false
if self.local_tlc_info.tlc_fee_proportional_millionths == fee {
return false;
}
}

fn get_our_tlc_min_value(&self) -> u128 {
self.local_tlc_info.tlc_min_value
self.local_tlc_info.tlc_fee_proportional_millionths = fee;
true
}

fn update_our_tlc_min_value(&mut self, value: u128) -> bool {
if self.get_our_tlc_min_value() != value {
self.local_tlc_info.tlc_min_value = value;
true
} else {
false
if self.local_tlc_info.tlc_min_value == value {
return false;
}
}

fn get_our_enabled(&self) -> Option<bool> {
self.public_channel_info.as_ref().map(|state| state.enabled)
self.local_tlc_info.tlc_min_value = value;
true
}

fn update_our_enabled(&mut self, enabled: bool) -> bool {
let old_value = self.get_our_enabled();
match old_value {
Some(old_value) if old_value == enabled => false,
_ => {
self.public_channel_state_mut().enabled = enabled;
true
}
if self.local_tlc_info.enabled == enabled {
return false;
}
}

fn get_our_tlc_expiry_delta(&self) -> u64 {
self.local_tlc_info.tlc_expiry_delta
self.local_tlc_info.enabled = enabled;
true
}

fn update_our_tlc_expiry_delta(&mut self, value: u64) -> bool {
if self.get_our_tlc_expiry_delta() != value {
self.local_tlc_info.tlc_expiry_delta = value;
true
} else {
false
if self.local_tlc_info.tlc_expiry_delta == value {
return false;
}
self.local_tlc_info.tlc_expiry_delta = value;
true
}

fn get_total_reserved_ckb_amount(&self) -> u64 {
Expand Down
6 changes: 4 additions & 2 deletions src/fiber/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ pub struct ChannelUpdateInfo {
impl From<&ChannelTlcInfo> for ChannelUpdateInfo {
fn from(info: &ChannelTlcInfo) -> Self {
Self {
timestamp: 0,
enabled: true,
timestamp: info.timestamp,
enabled: info.enabled,
tlc_expiry_delta: info.tlc_expiry_delta,
tlc_minimum_value: info.tlc_min_value,
fee_rate: info.tlc_fee_proportional_millionths as u64,
Expand Down Expand Up @@ -611,6 +611,8 @@ where
&self,
node_id: Pubkey,
) -> impl Iterator<Item = (Pubkey, Pubkey, &ChannelInfo, &ChannelUpdateInfo)> {
debug!("get_node_inbounds for node {:?}", node_id);
debug!("all channels {:?}", self.channels);
let mut channels: Vec<(Pubkey, Pubkey, &ChannelInfo, &ChannelUpdateInfo)> = self
.channels
.values()
Expand Down
4 changes: 1 addition & 3 deletions src/fiber/tests/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,9 +604,7 @@ impl NetworkNode {

pub async fn disable_channel(&mut self, channel_id: Hash256) {
let mut channel_actor_state = self.get_channel_actor_state(channel_id);
let mut public_info = channel_actor_state.public_channel_info.unwrap();
public_info.enabled = false;
channel_actor_state.public_channel_info = Some(public_info);
channel_actor_state.local_tlc_info.enabled = false;
self.update_channel_actor_state(channel_actor_state).await;
}

Expand Down
13 changes: 3 additions & 10 deletions src/store/tests/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,6 @@ fn test_channel_actor_state_store() {
let state = ChannelActorState {
state: ChannelState::NegotiatingFunding(NegotiatingFundingFlags::THEIR_INIT_SENT),
public_channel_info: Some(PublicChannelInfo {
enabled: false,
local_channel_announcement_signature: Some((
mock_ecdsa_signature(),
MaybeScalar::two(),
Expand All @@ -326,6 +325,8 @@ fn test_channel_actor_state_store() {
channel_update: None,
}),
local_tlc_info: ChannelTlcInfo {
enabled: false,
timestamp: 0,
tlc_fee_proportional_millionths: 123,
tlc_expiry_delta: 3,
tlc_min_value: 10,
Expand Down Expand Up @@ -385,15 +386,7 @@ fn test_channel_actor_state_store() {

let get_state = store.get_channel_actor_state(&state.id);
assert!(get_state.is_some());
assert_eq!(
get_state
.unwrap()
.public_channel_info
.as_ref()
.unwrap()
.enabled,
false
);
assert_eq!(get_state.unwrap().is_tlc_forwarding_enabled(), false);

let remote_peer_id = state.get_remote_peer_id();
assert_eq!(
Expand Down

0 comments on commit 185f659

Please sign in to comment.