diff --git a/src/fiber/channel.rs b/src/fiber/channel.rs index 112848a20..a1098b35b 100644 --- a/src/fiber/channel.rs +++ b/src/fiber/channel.rs @@ -471,9 +471,10 @@ where Ok(()) } FiberChannelMessage::UpdateTlcInfo(update_tlc_info) => { - let old = state - .remote_tlc_info - .get_or_insert_with(|| ChannelTlcInfo::default()); + let old = state.remote_tlc_info.get_or_insert_with(|| ChannelTlcInfo { + enabled: true, + ..Default::default() + }); if let Some(tlc_expiry_delta) = update_tlc_info.tlc_expiry_delta { old.tlc_expiry_delta = tlc_expiry_delta; @@ -932,47 +933,44 @@ where return Err(ProcessingChannelError::FinalIncorrectPaymentHash); } } else { - match state.public_channel_info.as_ref() { - Some(public_channel_info) if public_channel_info.enabled => { - if state.local_tlc_info.tlc_min_value > received_amount { - return Err(ProcessingChannelError::TlcAmountIsTooLow); - } + if state.is_public() && state.is_tlc_forwarding_enabled() { + if state.local_tlc_info.tlc_min_value > received_amount { + return Err(ProcessingChannelError::TlcAmountIsTooLow); + } - if add_tlc.expiry - < peeled_onion_packet.current.expiry + state.local_tlc_info.tlc_expiry_delta - { - return Err(ProcessingChannelError::IncorrectTlcExpiry); - } + if add_tlc.expiry + < peeled_onion_packet.current.expiry + state.local_tlc_info.tlc_expiry_delta + { + return Err(ProcessingChannelError::IncorrectTlcExpiry); + } - assert!(received_amount >= forward_amount); - let forward_fee = received_amount.saturating_sub(forward_amount); - let fee_rate: u128 = state.local_tlc_info.tlc_fee_proportional_millionths; + assert!(received_amount >= forward_amount); + let forward_fee = received_amount.saturating_sub(forward_amount); + let fee_rate: u128 = state.local_tlc_info.tlc_fee_proportional_millionths; - let expected_fee = calculate_tlc_forward_fee(forward_amount, fee_rate); - if expected_fee.is_err() || forward_fee < expected_fee.clone().unwrap() { - error!( - "too low forward_fee: {}, expected_fee: {:?}", - forward_fee, expected_fee - ); - return Err(ProcessingChannelError::TlcForwardFeeIsTooLow); - } - // if this is not the last hop, forward TLC to next hop - self.handle_forward_onion_packet( - state, - add_tlc.payment_hash, - peeled_onion_packet.clone(), - add_tlc.tlc_id.into(), - ) - .await?; - } - _ => { - // if we don't have public channel info, we can not forward the TLC - // this may happended some malicious sender build a invalid onion router - return Err(ProcessingChannelError::InvalidState( - "Received AddTlc message, but the channel is not public or disabled" - .to_string(), - )); + let expected_fee = calculate_tlc_forward_fee(forward_amount, fee_rate); + if expected_fee.is_err() || forward_fee < expected_fee.clone().unwrap() { + error!( + "too low forward_fee: {}, expected_fee: {:?}", + forward_fee, expected_fee + ); + return Err(ProcessingChannelError::TlcForwardFeeIsTooLow); } + // if this is not the last hop, forward TLC to next hop + self.handle_forward_onion_packet( + state, + add_tlc.payment_hash, + peeled_onion_packet.clone(), + add_tlc.tlc_id.into(), + ) + .await?; + } else { + // if we don't have public channel info, we can not forward the TLC + // this may happended some malicious sender build a invalid onion router + return Err(ProcessingChannelError::InvalidState( + "Received AddTlc message, but the channel is not public or disabled" + .to_string(), + )); } } Ok(()) @@ -1365,9 +1363,7 @@ where } if updated { - state - .generate_and_broadcast_channel_update(&self.network) - .await; + state.on_channel_tlc_info_updated(&self.network).await; } Ok(()) @@ -2877,6 +2873,12 @@ pub struct ShutdownInfo { #[serde_as] #[derive(Default, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub struct ChannelTlcInfo { + // The timestamp when the following information is updated. + pub timestamp: u64, + + // Whether this channel is enabled for TLC forwarding or not. + pub enabled: bool, + // The fee rate for tlc transfers. We only have these values set when // this is a public channel. Both sides may set this value differently. // This is a fee that is paid by the sender of the tlc. @@ -2904,6 +2906,7 @@ impl ChannelTlcInfo { tlc_min_value, tlc_expiry_delta, tlc_fee_proportional_millionths, + enabled: true, ..Default::default() } } @@ -2918,7 +2921,6 @@ impl ChannelTlcInfo { #[serde_as] #[derive(Default, Clone, Debug, Serialize, Deserialize)] pub struct PublicChannelInfo { - pub enabled: bool, // Channel announcement signatures, may be empty for private channel. pub local_channel_announcement_signature: Option<(EcdsaSignature, PartialSignature)>, pub remote_channel_announcement_signature: Option<(EcdsaSignature, PartialSignature)>, @@ -2932,10 +2934,7 @@ pub struct PublicChannelInfo { impl PublicChannelInfo { pub fn new() -> Self { - Self { - enabled: true, - ..Default::default() - } + Default::default() } } @@ -3323,6 +3322,10 @@ impl ChannelActorState { matches!(self.state, ChannelState::ChannelReady()) } + pub fn is_tlc_forwarding_enabled(&self) -> bool { + self.local_tlc_info.enabled + } + pub async fn try_create_channel_messages( &mut self, network: &ActorRef, @@ -3472,10 +3475,7 @@ impl ChannelActorState { .await } - async fn generate_and_broadcast_channel_update( - &mut self, - network: &ActorRef, - ) { + async fn on_channel_tlc_info_updated(&mut self, network: &ActorRef) { if self.is_public() { let channel_update = self.generate_channel_update(network).await; network @@ -3525,10 +3525,11 @@ impl ChannelActorState { } fn get_channel_update_channel_flags(&self) -> u32 { - self.public_channel_info - .as_ref() - .and_then(|info: &PublicChannelInfo| (!info.enabled).then_some(CHANNEL_DISABLED_FLAG)) - .unwrap_or(0) + if self.is_tlc_forwarding_enabled() { + 0 + } else { + CHANNEL_DISABLED_FLAG + } } pub fn get_unsigned_channel_update_message(&self) -> Option { @@ -3979,58 +3980,36 @@ impl ChannelActorState { .remote_channel_announcement_signature = Some((ecdsa_signature, partial_signatures)); } - fn get_our_tlc_fee_proportional_millionths(&self) -> u128 { - self.local_tlc_info.tlc_fee_proportional_millionths - } - fn update_our_tlc_fee_proportional_millionths(&mut self, fee: u128) -> bool { - if self.get_our_tlc_fee_proportional_millionths() != fee { - self.local_tlc_info.tlc_fee_proportional_millionths = fee; - true - } else { - false + if self.local_tlc_info.tlc_fee_proportional_millionths == fee { + return false; } - } - - fn get_our_tlc_min_value(&self) -> u128 { - self.local_tlc_info.tlc_min_value + self.local_tlc_info.tlc_fee_proportional_millionths = fee; + true } fn update_our_tlc_min_value(&mut self, value: u128) -> bool { - if self.get_our_tlc_min_value() != value { - self.local_tlc_info.tlc_min_value = value; - true - } else { - false + if self.local_tlc_info.tlc_min_value == value { + return false; } - } - - fn get_our_enabled(&self) -> Option { - self.public_channel_info.as_ref().map(|state| state.enabled) + self.local_tlc_info.tlc_min_value = value; + true } fn update_our_enabled(&mut self, enabled: bool) -> bool { - let old_value = self.get_our_enabled(); - match old_value { - Some(old_value) if old_value == enabled => false, - _ => { - self.public_channel_state_mut().enabled = enabled; - true - } + if self.local_tlc_info.enabled == enabled { + return false; } - } - - fn get_our_tlc_expiry_delta(&self) -> u64 { - self.local_tlc_info.tlc_expiry_delta + self.local_tlc_info.enabled = enabled; + true } fn update_our_tlc_expiry_delta(&mut self, value: u64) -> bool { - if self.get_our_tlc_expiry_delta() != value { - self.local_tlc_info.tlc_expiry_delta = value; - true - } else { - false + if self.local_tlc_info.tlc_expiry_delta == value { + return false; } + self.local_tlc_info.tlc_expiry_delta = value; + true } fn get_total_reserved_ckb_amount(&self) -> u64 { diff --git a/src/fiber/graph.rs b/src/fiber/graph.rs index 4bb435831..2b6b642a8 100644 --- a/src/fiber/graph.rs +++ b/src/fiber/graph.rs @@ -228,8 +228,8 @@ pub struct ChannelUpdateInfo { impl From<&ChannelTlcInfo> for ChannelUpdateInfo { fn from(info: &ChannelTlcInfo) -> Self { Self { - timestamp: 0, - enabled: true, + timestamp: info.timestamp, + enabled: info.enabled, tlc_expiry_delta: info.tlc_expiry_delta, tlc_minimum_value: info.tlc_min_value, fee_rate: info.tlc_fee_proportional_millionths as u64, @@ -611,6 +611,8 @@ where &self, node_id: Pubkey, ) -> impl Iterator { + debug!("get_node_inbounds for node {:?}", node_id); + debug!("all channels {:?}", self.channels); let mut channels: Vec<(Pubkey, Pubkey, &ChannelInfo, &ChannelUpdateInfo)> = self .channels .values() diff --git a/src/fiber/tests/test_utils.rs b/src/fiber/tests/test_utils.rs index bfff4f2ff..2c965e51b 100644 --- a/src/fiber/tests/test_utils.rs +++ b/src/fiber/tests/test_utils.rs @@ -604,9 +604,7 @@ impl NetworkNode { pub async fn disable_channel(&mut self, channel_id: Hash256) { let mut channel_actor_state = self.get_channel_actor_state(channel_id); - let mut public_info = channel_actor_state.public_channel_info.unwrap(); - public_info.enabled = false; - channel_actor_state.public_channel_info = Some(public_info); + channel_actor_state.local_tlc_info.enabled = false; self.update_channel_actor_state(channel_actor_state).await; } diff --git a/src/store/tests/store.rs b/src/store/tests/store.rs index 427ab3a7e..fb92712d5 100644 --- a/src/store/tests/store.rs +++ b/src/store/tests/store.rs @@ -312,7 +312,6 @@ fn test_channel_actor_state_store() { let state = ChannelActorState { state: ChannelState::NegotiatingFunding(NegotiatingFundingFlags::THEIR_INIT_SENT), public_channel_info: Some(PublicChannelInfo { - enabled: false, local_channel_announcement_signature: Some(( mock_ecdsa_signature(), MaybeScalar::two(), @@ -326,6 +325,8 @@ fn test_channel_actor_state_store() { channel_update: None, }), local_tlc_info: ChannelTlcInfo { + enabled: false, + timestamp: 0, tlc_fee_proportional_millionths: 123, tlc_expiry_delta: 3, tlc_min_value: 10, @@ -385,15 +386,7 @@ fn test_channel_actor_state_store() { let get_state = store.get_channel_actor_state(&state.id); assert!(get_state.is_some()); - assert_eq!( - get_state - .unwrap() - .public_channel_info - .as_ref() - .unwrap() - .enabled, - false - ); + assert_eq!(get_state.unwrap().is_tlc_forwarding_enabled(), false); let remote_peer_id = state.get_remote_peer_id(); assert_eq!(