Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the fee rate in oubound channel to calculate the TLC fee #489

Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 73 additions & 30 deletions src/fiber/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ pub struct AddTlcCommand {
/// Save it for outbound (offered) TLC to backward errors.
/// Use all zeros when no shared secrets are available.
pub shared_secret: [u8; 32],
pub previous_tlc: Option<(Hash256, u64)>,
pub previous_tlc: Option<(Hash256, u64, u128)>,
}
contrun marked this conversation as resolved.
Show resolved Hide resolved

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -1013,35 +1013,25 @@ where
}
} else {
if state.is_public() && state.is_tlc_forwarding_enabled() {
if state.local_tlc_info.tlc_minimum_value > received_amount {
return Err(ProcessingChannelError::TlcAmountIsTooLow);
}

if add_tlc.expiry
< peeled_onion_packet.current.expiry + state.local_tlc_info.tlc_expiry_delta
{
return Err(ProcessingChannelError::IncorrectTlcExpiry);
}

assert!(received_amount >= forward_amount);

// Next forwarding channel will get the forward_fee and check if it's enough.
let forward_fee = received_amount.saturating_sub(forward_amount);
let fee_rate: u128 = state.local_tlc_info.tlc_fee_proportional_millionths;

let expected_fee = calculate_tlc_forward_fee(forward_amount, fee_rate);
if expected_fee.is_err() || forward_fee < expected_fee.clone().unwrap() {
error!(
"too low forward_fee: {}, expected_fee: {:?}",
forward_fee, expected_fee
);
return Err(ProcessingChannelError::TlcForwardFeeIsTooLow);
}
// if this is not the last hop, forward TLC to next hop
self.register_retryable_forward_tlc(
myself,
state,
add_tlc.tlc_id,
add_tlc.payment_hash,
peeled_onion_packet.clone(),
forward_fee,
)
.await;
} else {
Expand Down Expand Up @@ -1241,6 +1231,7 @@ where
) -> Result<u64, ProcessingChannelError> {
state.check_for_tlc_update(Some(command.amount), true, true)?;
state.check_tlc_expiry(command.expiry)?;
state.check_tlc_forward_amount(command.amount, command.previous_tlc.map(|x| x.2))?;
let tlc = state.create_outbounding_tlc(command.clone());
state.check_insert_tlc(&tlc)?;
state.tlc_state.add_offered_tlc(tlc.clone());
Expand Down Expand Up @@ -1470,9 +1461,15 @@ where
tlc_id: TLCId,
payment_hash: Hash256,
peeled_onion_packet: PeeledPaymentOnionPacket,
forward_fee: u128,
) {
let forward_tlc =
RetryableTlcOperation::ForwardTlc(payment_hash, tlc_id, peeled_onion_packet, true);
let forward_tlc = RetryableTlcOperation::ForwardTlc(
payment_hash,
tlc_id,
peeled_onion_packet,
forward_fee,
true,
);
self.register_retryable_tlc_operation(myself, state, forward_tlc)
.await;
}
Expand All @@ -1498,11 +1495,10 @@ where
payment_hash: Hash256,
try_one_time: bool,
) {
if let Some(RetryableTlcOperation::ForwardTlc(_, _, _, ref mut sent)) = state
.tlc_state
.retryable_tlc_operations
.iter_mut()
.find(|op| matches!(op, RetryableTlcOperation::ForwardTlc(ph, _, _, _) if *ph == payment_hash))
if let Some(RetryableTlcOperation::ForwardTlc(.., ref mut sent)) =
state.tlc_state.retryable_tlc_operations.iter_mut().find(
|op| matches!(op, RetryableTlcOperation::ForwardTlc(ph,..) if *ph == payment_hash),
)
{
*sent = try_one_time;
}
Expand Down Expand Up @@ -1554,6 +1550,7 @@ where
payment_hash,
tlc_id,
ref peeled_onion_packet,
forward_fee,
try_one_time,
) => {
// there is a potential deadlock for waiting the result from another channel actor
Expand All @@ -1571,7 +1568,11 @@ where
match self.network.send_message(NetworkActorMessage::Command(
NetworkActorCommand::SendPaymentOnionPacket(SendOnionPacketCommand {
peeled_onion_packet: peeled_onion_packet.clone(),
previous_tlc: Some((state.get_id(), u64::from(*tlc_id))),
previous_tlc: Some((
state.get_id(),
u64::from(*tlc_id),
*forward_fee,
)),
payment_hash: *payment_hash,
}),
)) {
Expand Down Expand Up @@ -1608,7 +1609,7 @@ where
) {
let pending_ops = state.tlc_state.get_pending_operations();
if let Some((tlc_op, peeled_onion)) = pending_ops.iter().find_map(|op| match op {
RetryableTlcOperation::ForwardTlc(payment_hash, _, peel_onion_packet, _)
RetryableTlcOperation::ForwardTlc(payment_hash, _, peel_onion_packet, ..)
if *payment_hash == result.payment_hash =>
{
Some((op, peel_onion_packet))
Expand Down Expand Up @@ -2582,7 +2583,7 @@ impl From<TlcInfo> for TlcNotifyInfo {
pub enum RetryableTlcOperation {
RemoveTlc(TLCId, RemoveTlcReason),
RelayRemoveTlc(Hash256, u64, RemoveTlcReason),
ForwardTlc(Hash256, TLCId, PeeledPaymentOnionPacket, bool),
ForwardTlc(Hash256, TLCId, PeeledPaymentOnionPacket, u128, bool),
}

impl Debug for RetryableTlcOperation {
Expand All @@ -2599,10 +2600,11 @@ impl Debug for RetryableTlcOperation {
.field(tlc_id)
.field(reason)
.finish(),
RetryableTlcOperation::ForwardTlc(payment_hash, tlc_id, _, run_once) => f
RetryableTlcOperation::ForwardTlc(payment_hash, tlc_id, _, forward_fee, run_once) => f
.debug_tuple("ForwardTlc")
.field(payment_hash)
.field(tlc_id)
.field(forward_fee)
.field(run_once)
.finish(),
}
Expand Down Expand Up @@ -4393,9 +4395,9 @@ impl ChannelActorState {
}

pub fn get_local_channel_update_info(&self) -> ChannelUpdateInfo {
let balance = self.get_remote_balance();
let balance = self.get_local_balance();
let mut info = ChannelUpdateInfo::from(&self.local_tlc_info);
info.inbound_liquidity = Some(balance);
info.outbound_liquidity = Some(balance);
info
}

Expand All @@ -4408,10 +4410,10 @@ impl ChannelActorState {
}

pub fn get_remote_channel_update_info(&self) -> Option<ChannelUpdateInfo> {
let balance = self.get_local_balance();
let balance = self.get_remote_balance();
self.remote_tlc_info.as_ref().map(|tlc_info| {
let mut info = ChannelUpdateInfo::from(tlc_info);
info.inbound_liquidity = Some(balance);
info.outbound_liquidity = Some(balance);
info
})
}
Expand Down Expand Up @@ -4931,6 +4933,47 @@ impl ChannelActorState {
Ok(())
}

fn check_tlc_forward_amount(
&self,
forward_amount: u128,
forward_fee: Option<u128>,
) -> ProcessingChannelResult {
assert!(self.local_tlc_info.enabled, "TLC is disabled");
if self.local_tlc_info.tlc_minimum_value != 0
&& forward_amount < self.local_tlc_info.tlc_minimum_value
{
return Err(ProcessingChannelError::TlcAmountIsTooLow);
}
if self.local_tlc_info.tlc_maximum_value != 0
&& forward_amount > self.local_tlc_info.tlc_minimum_value
{
return Err(ProcessingChannelError::TlcAmountExceedLimit);
}
let forward_fee = match forward_fee {
Some(fee) => fee,
None => {
// We are not forwarding the tlc, so no need to check the fee.
return Ok(());
}
};
let fee_rate = self.local_tlc_info.tlc_fee_proportional_millionths;
let expected_fee = calculate_tlc_forward_fee(forward_amount, fee_rate);
match expected_fee {
Ok(expected_fee) if forward_fee >= expected_fee => Ok(()),
Ok(fee) => {
error!(
"too low forward_fee: {}, expected_fee: {}",
forward_fee, fee
);
Err(ProcessingChannelError::TlcForwardFeeIsTooLow)
}
Err(e) => {
error!("calculate_tlc_forward_fee error: {:?}", e);
Err(ProcessingChannelError::TlcForwardFeeIsTooLow)
}
}
}

// Check whether the reason is valid for removing the tlc.
fn check_remove_tlc_with_reason(
&self,
Expand Down Expand Up @@ -5060,7 +5103,7 @@ impl ChannelActorState {
shared_secret: command.shared_secret,
previous_tlc: command
.previous_tlc
.map(|(channel_id, tlc_id)| (channel_id, TLCId::Received(tlc_id))),
.map(|(channel_id, tlc_id, _)| (channel_id, TLCId::Received(tlc_id))),
}
}

Expand Down
57 changes: 31 additions & 26 deletions src/fiber/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,8 @@ pub struct ChannelUpdateInfo {
pub timestamp: u64,
/// Whether the channel can be currently used for payments (in this one direction).
pub enabled: bool,
/// The exact amount of balance that we can receive from the other party via the channel.
/// Note that this is not our balance, but the balance of the other party.
/// This node is forwarding the balance for the other party, so we need to use the receivable balance
/// instead of our balance.
pub inbound_liquidity: Option<u128>,
/// The exact amount of balance that we can send to the other party via the channel.
pub outbound_liquidity: Option<u128>,
/// The difference in htlc expiry values that you must have when routing through this channel (in milliseconds).
pub tlc_expiry_delta: u64,
/// The minimum value, which must be relayed to the next hop via the channel
Expand All @@ -238,7 +235,7 @@ impl From<&ChannelTlcInfo> for ChannelUpdateInfo {
Self {
timestamp: info.timestamp,
enabled: info.enabled,
inbound_liquidity: None,
outbound_liquidity: None,
tlc_expiry_delta: info.tlc_expiry_delta,
tlc_minimum_value: info.tlc_minimum_value,
fee_rate: info.tlc_fee_proportional_millionths as u64,
Expand All @@ -263,7 +260,7 @@ impl From<&ChannelUpdate> for ChannelUpdateInfo {
Self {
timestamp: update.timestamp,
enabled: !update.is_disabled(),
inbound_liquidity: None,
outbound_liquidity: None,
tlc_expiry_delta: update.tlc_expiry_delta,
tlc_minimum_value: update.tlc_minimum_value,
fee_rate: update.tlc_fee_proportional_millionths as u64,
Expand Down Expand Up @@ -297,7 +294,7 @@ pub struct NetworkGraph<S> {
// The pubkey of the node that is running this instance of the network graph.
source: Pubkey,
// All the channels in the network.
channels: HashMap<OutPoint, ChannelInfo>,
pub(crate) channels: HashMap<OutPoint, ChannelInfo>,
// All the nodes in the network.
nodes: HashMap<Pubkey, NodeInfo>,
// The latest cursor we read from the GossipMessageStore. When we need to refresh our view of the
Expand Down Expand Up @@ -730,16 +727,17 @@ where
.channels
.values()
.filter_map(move |channel| {
if let Some(info) = channel.update_of_node2.as_ref() {
if info.enabled && channel.node2() == node_id {
match channel.update_of_node1.as_ref() {
Some(info) if node_id == channel.node2() && info.enabled => {
return Some((channel.node1(), channel.node2(), channel, info));
}
_ => {}
}

if let Some(info) = channel.update_of_node1.as_ref() {
if info.enabled && channel.node1() == node_id {
match channel.update_of_node2.as_ref() {
Some(info) if node_id == channel.node1() && info.enabled => {
return Some((channel.node2(), channel.node1(), channel, info));
}
_ => {}
}
None
})
Expand All @@ -755,8 +753,8 @@ where
|(_, _, a_channel_info, a_channel_update_info),
(_, _, b_channel_info, b_channel_update_info)| {
b_channel_update_info
.inbound_liquidity
.cmp(&a_channel_update_info.inbound_liquidity)
.outbound_liquidity
contrun marked this conversation as resolved.
Show resolved Hide resolved
.cmp(&a_channel_update_info.outbound_liquidity)
.then(
b_channel_info
.capacity()
Expand Down Expand Up @@ -964,10 +962,11 @@ where

let mut target = target;
let mut expiry = final_tlc_expiry_delta;
let mut amount = amount;
let mut last_edge = None;

if route_to_self {
let (t, edge, e) = self.adjust_target_for_route_self(
let (edge, t, e, f) = self.adjust_target_for_route_self(
&hop_hint_map,
amount,
final_tlc_expiry_delta,
Expand All @@ -977,6 +976,7 @@ where
assert_ne!(target, t);
target = t;
expiry = expiry + e;
amount = amount + f;
last_edge = Some(edge);
}
assert_ne!(source, target);
Expand All @@ -998,7 +998,6 @@ where
for (from, to, channel_info, channel_update) in self.get_node_inbounds(cur_hop.node_id)
{
let is_initial = from == source;
let is_final = (to == target) && !route_to_self;

assert_eq!(to, cur_hop.node_id);
if &udt_type_script != channel_info.udt_type_script() {
Expand Down Expand Up @@ -1034,7 +1033,7 @@ where
continue;
}

let fee = if is_final {
let fee = if is_initial {
0
} else {
calculate_tlc_forward_fee(
Expand Down Expand Up @@ -1066,12 +1065,14 @@ where
if amount_to_send > channel_info.capacity() {
continue;
}
if amount_to_send < channel_update.tlc_minimum_value {
// We should use next_hop_received_amount because that is the amount to be
// sent over the channel.
if next_hop_received_amount < channel_update.tlc_minimum_value {
continue;
}

// If we already know the balance of the channel, check if we can send the amount.
if let Some(balance) = channel_update.inbound_liquidity {
if let Some(balance) = channel_update.outbound_liquidity {
if amount_to_send > balance {
continue;
}
Expand Down Expand Up @@ -1129,10 +1130,9 @@ where
next_hop: Some(PathEdge {
target: to,
channel_outpoint: channel_info.out_point().clone(),
// Here we need to use the amount accumulated so far (i.e. with the fees in current hop)
// because the fee here is for the receiving node to forward the amount to the next node.
// So the total amount in AddTlc packet should include the fee.
amount_received: amount_to_send,
// The amount_received is the amount that next hop is going to receive.
// That is exactly next_hop_received_amount.
amount_received: next_hop_received_amount,
// We need to use cur_hop.incoming_tlc_expiry instead of incoming_tlc_expiry here
// because we need the expiry for the AddTlc packet sent from source to target.
// cur_hop.incoming_tlc_expiry is the expiry time for the TLC that is going to be received by the target,
Expand Down Expand Up @@ -1179,7 +1179,7 @@ where
expiry: u64,
source: Pubkey,
target: Pubkey,
) -> Result<(Pubkey, PathEdge, u64), PathFindError> {
) -> Result<(PathEdge, Pubkey, u64, u128), PathFindError> {
let direct_channels: Vec<(Pubkey, Pubkey, &ChannelInfo, &ChannelUpdateInfo)> = self
.get_node_inbounds(source)
.filter(|(_, _, channel_info, _)| {
Expand Down Expand Up @@ -1224,7 +1224,12 @@ where
amount_received: amount,
incoming_tlc_expiry: expiry,
};
Ok((from, last_edge, channel_update.tlc_expiry_delta))
let fee = calculate_tlc_forward_fee(amount, channel_update.fee_rate as u128).map_err(
|err| {
PathFindError::PathFind(format!("calculate_tlc_forward_fee error: {:?}", err))
},
)?;
Ok((last_edge, from, channel_update.tlc_expiry_delta, fee))
} else {
return Err(PathFindError::PathFind(
"no direct channel found for source node".to_string(),
Expand Down
Loading
Loading