diff --git a/simln-lib/src/sim_node.rs b/simln-lib/src/sim_node.rs index dde16835..7e21976a 100755 --- a/simln-lib/src/sim_node.rs +++ b/simln-lib/src/sim_node.rs @@ -7,8 +7,11 @@ use bitcoin::secp256k1::PublicKey; use bitcoin::{Network, ScriptBuf, TxOut}; use lightning::ln::chan_utils::make_funding_redeemscript; use std::collections::{hash_map::Entry, HashMap}; +use std::error::Error; +use std::fmt::Display; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; +use tokio::task::JoinSet; use tokio_util::task::TaskTracker; use lightning::ln::features::{ChannelFeatures, NodeFeatures}; @@ -17,7 +20,9 @@ use lightning::ln::msgs::{ }; use lightning::ln::{PaymentHash, PaymentPreimage}; use lightning::routing::gossip::{NetworkGraph, NodeId}; -use lightning::routing::router::{find_route, Path, PaymentParameters, Route, RouteParameters}; +use lightning::routing::router::{ + find_route, Path, PaymentParameters, Route, RouteHop, RouteParameters, +}; use lightning::routing::scoring::{ProbabilisticScorer, ProbabilisticScoringDecayParameters}; use lightning::routing::utxo::{UtxoLookup, UtxoResult}; use lightning::util::logger::{Level, Logger, Record}; @@ -81,6 +86,12 @@ pub enum ForwardingError { /// Sanity check on channel balances failed (node balances / channel capacity). #[error("SanityCheckFailed: node balance: {0} != capacity: {1}")] SanityCheckFailed(u64, u64), + #[error("DuplicateCustomRecord: key {0}")] + DuplicateCustomRecord(u64), + #[error("InterceptorError: {0}")] + InterceptorError(Box), + #[error("NotifierError: key {0}")] + NotifierError(Box), } impl ForwardingError { @@ -95,6 +106,8 @@ impl ForwardingError { | ForwardingError::PaymentHashNotFound(_) | ForwardingError::SanityCheckFailed(_, _) | ForwardingError::FeeOverflow(_, _, _) + | ForwardingError::DuplicateCustomRecord(_) + | ForwardingError::NotifierError(_) ) } } @@ -161,8 +174,11 @@ macro_rules! fail_forwarding_inequality { #[derive(Clone)] struct ChannelState { local_balance_msat: u64, - in_flight: HashMap, + /// Maps payment hash to htlc and index that it was added at. + in_flight: HashMap, policy: ChannelPolicy, + /// Tracks unique identifier for htlcs proposed by this node (sent in the outgoing direction). + index: u64, } impl ChannelState { @@ -174,12 +190,13 @@ impl ChannelState { local_balance_msat, in_flight: HashMap::new(), policy, + index: 0, } } /// Returns the sum of all the *in flight outgoing* HTLCs on the channel. fn in_flight_total(&self) -> u64 { - self.in_flight.values().map(|h| h.amount_msat).sum() + self.in_flight.values().map(|h| h.0.amount_msat).sum() } /// Checks whether the proposed HTLC abides by the channel policy advertised for using this channel as the @@ -233,18 +250,21 @@ impl ChannelState { /// /// Note: MPP payments are not currently supported, so this function will fail if a duplicate payment hash is /// reported. - fn add_outgoing_htlc(&mut self, hash: PaymentHash, htlc: Htlc) -> Result<(), ForwardingError> { + fn add_outgoing_htlc(&mut self, hash: PaymentHash, htlc: Htlc) -> Result { self.check_outgoing_addition(&htlc)?; if self.in_flight.contains_key(&hash) { return Err(ForwardingError::PaymentHashExists(hash)); } + let index = self.index; + self.index += 1; + self.local_balance_msat -= htlc.amount_msat; - self.in_flight.insert(hash, htlc); - Ok(()) + self.in_flight.insert(hash, (htlc, index)); + Ok(index) } /// Removes the HTLC from our set of outgoing in-flight HTLCs, failing if the payment hash is not found. - fn remove_outgoing_htlc(&mut self, hash: &PaymentHash) -> Result { + fn remove_outgoing_htlc(&mut self, hash: &PaymentHash) -> Result<(Htlc, u64), ForwardingError> { self.in_flight .remove(hash) .ok_or(ForwardingError::PaymentHashNotFound(*hash)) @@ -352,14 +372,17 @@ impl SimulatedChannel { sending_node: &PublicKey, hash: PaymentHash, htlc: Htlc, - ) -> Result<(), ForwardingError> { + ) -> Result { if htlc.amount_msat == 0 { return Err(ForwardingError::ZeroAmountHtlc); } - self.get_node_mut(sending_node)? + let index = self + .get_node_mut(sending_node)? .add_outgoing_htlc(hash, htlc)?; - self.sanity_check() + self.sanity_check()?; + + Ok(index) } /// Performs a sanity check on the total balances in a channel. Note that we do not currently include on-chain @@ -382,12 +405,14 @@ impl SimulatedChannel { sending_node: &PublicKey, hash: &PaymentHash, success: bool, - ) -> Result<(), ForwardingError> { + ) -> Result<(Htlc, u64), ForwardingError> { let htlc = self .get_node_mut(sending_node)? .remove_outgoing_htlc(hash)?; - self.settle_htlc(sending_node, htlc.amount_msat, success)?; - self.sanity_check() + self.settle_htlc(sending_node, htlc.0.amount_msat, success)?; + self.sanity_check()?; + + Ok(htlc) } /// Updates the local balance of each node in the channel once a htlc has been resolved, pushing funds to the @@ -643,6 +668,126 @@ impl LightningNode for SimNode<'_, T> { } } +#[async_trait] +pub trait Interceptor: Send + Sync { + /// Implemented by HTLC interceptors that provide input on the resolution of HTLCs forwarded in the simulation. + async fn intercept_htlc(&self, req: InterceptRequest) + -> Result; + + /// Notifies the interceptor that a previously intercepted htlc has been resolved. Default implementation is a no-op + /// for cases where the interceptor only cares about interception, not resolution of htlcs. + async fn notify_resolution( + &self, + _res: InterceptResolution, + ) -> Result<(), Box> { + Ok(()) + } + + /// Returns an identifying name for the interceptor for logging, does not need to be unique. + fn name(&self) -> String; +} + +/// Request sent to an external interceptor to provide feedback on the resolution of the HTLC. +#[derive(Debug, Clone)] +pub struct InterceptRequest { + /// The node that is forwarding this htlc. + pub forwarding_node: PublicKey, + + /// The payment hash for the htlc (note that this is not unique). + pub payment_hash: PaymentHash, + + /// The short channel id for the incoming channel that this htlc was delivered on. + pub incoming_htlc: HtlcRef, + + /// Custom records provided by the incoming htlc. + pub incoming_custom_records: CustomRecords, + + /// The short channel id for the outgoing channel that this htlc should be forwarded over. + pub outgoing_channel_id: Option, + + /// The amount that was forwarded over the incoming_channel_id. + pub incoming_amount_msat: u64, + + /// The amount that will be forwarded over to outgoing_channel_id. + pub outgoing_amount_msat: u64, + + /// The expiry height on the incoming htlc. + pub incoming_expiry_height: u32, + + /// The expiry height on the outgoing htlc. + pub outgoing_expiry_height: u32, +} + +impl InterceptRequest { + fn new( + hop: RouteHop, + payment_hash: PaymentHash, + incoming_amount_msat: u64, + incoming_htlc: HtlcRef, + incoming_custom_records: CustomRecords, + outgoing_channel_id: Option, + incoming_expiry_height: u32, + ) -> Self { + Self { + forwarding_node: hop.pubkey, + payment_hash, + outgoing_channel_id, + incoming_amount_msat, + incoming_htlc, + incoming_custom_records, + outgoing_amount_msat: incoming_amount_msat - hop.fee_msat, + incoming_expiry_height, + outgoing_expiry_height: incoming_expiry_height - hop.cltv_expiry_delta, + } + } +} + +impl Display for InterceptRequest { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "htlc forwarded by {} over {}:{} -> {} forward amounts {} {}", + self.forwarding_node, + self.incoming_htlc.channel_id, + self.incoming_htlc.index, + { + if let Some(c) = self.outgoing_channel_id { + format!("-> {c}") + } else { + "receive".to_string() + } + }, + self.incoming_amount_msat, + self.outgoing_amount_msat + ) + } +} + +/// Notification sent to an external interceptor notifying that a htlc that was previously intercepted has been +/// resolved. +pub struct InterceptResolution { + /// The node that is forwarding this HTLC. + pub forwarding_node: PublicKey, + + /// Unique identifier for the incoming htlc. + pub incoming_htlc: HtlcRef, + + /// The short channel id for the outgoing channel that this htlc should be forwarded over, None if notifying the + /// receiving node. + pub outgoing_channel_id: Option, + + /// True if the htlc was settled successfully. + pub success: bool, +} + +pub type CustomRecords = HashMap>; + +#[derive(Debug, Clone)] +pub struct HtlcRef { + pub channel_id: ShortChannelID, + pub index: u64, +} + /// Graph is the top level struct that is used to coordinate simulation of lightning nodes. pub struct SimGraph { /// nodes caches the list of nodes in the network with a vector of their channel capacities, only used for quick @@ -656,6 +801,9 @@ pub struct SimGraph { /// in this tracker must be done externally. tasks: TaskTracker, + /// Optional set of interceptors that will be called every time a HTLC is added to a simulated channel. + interceptors: Vec>, + /// trigger shutdown if a critical error occurs. shutdown_trigger: Trigger, } @@ -665,6 +813,7 @@ impl SimGraph { pub fn new( graph_channels: Vec, tasks: TaskTracker, + interceptors: Vec>, shutdown_trigger: Trigger, ) -> Result { let mut nodes: HashMap> = HashMap::new(); @@ -699,6 +848,7 @@ impl SimGraph { nodes, channels: Arc::new(Mutex::new(channels)), tasks, + interceptors, shutdown_trigger, }) } @@ -824,6 +974,7 @@ impl SimNetwork for SimGraph { path.clone(), payment_hash, sender, + self.interceptors.clone(), self.shutdown_trigger.clone(), )); } @@ -861,18 +1012,21 @@ async fn add_htlcs( source: PublicKey, route: Path, payment_hash: PaymentHash, + interceptors: Vec>, ) -> Result<(), (Option, ForwardingError)> { let mut outgoing_node = source; let mut outgoing_amount = route.fee_msat() + route.final_value_msat(); let mut outgoing_cltv = route.hops.iter().map(|hop| hop.cltv_expiry_delta).sum(); + let mut incoming_custom_records = HashMap::new(); + // Tracks the hop index that we need to remove htlcs from on payment completion (both success and failure). // Given a payment from A to C, over the route A -- B -- C, this index has the following meanings: // - None: A could not add the outgoing HTLC to B, no action for payment failure. // - Some(0): A -- B added the HTLC but B could not forward the HTLC to C, so it only needs removing on A -- B. // - Some(1): A -- B and B -- C added the HTLC, so it should be removed from the full route. let mut fail_idx = None; - + let last_hop = route.hops.len() - 1; for (i, hop) in route.hops.iter().enumerate() { // Lock the node that we want to add the HTLC to next. We choose to lock one hop at a time (rather than for // the whole route) so that we can mimic the behavior of payments in the real network where the HTLCs in a @@ -880,8 +1034,9 @@ async fn add_htlcs( let mut node_lock = nodes.lock().await; let scid = ShortChannelID::from(hop.short_channel_id); - if let Some(channel) = node_lock.get_mut(&scid) { - channel + let (incoming_htlc, next_scid) = { + if let Some(channel) = node_lock.get_mut(&scid) { + let htlc_index = channel .add_htlc( &outgoing_node, payment_hash, @@ -894,21 +1049,22 @@ async fn add_htlcs( // have to progress our fail_idx. .map_err(|e| (fail_idx, e))?; - // If the HTLC was successfully added, then we'll need to remove the HTLC from this channel if we fail, - // so we progress our failure index to include this node. - fail_idx = Some(i); - - // Once we've added the HTLC on this hop's channel, we want to check whether it has sufficient fee - // and CLTV delta per the _next_ channel's policy (because fees and CLTV delta in LN are charged on - // the outgoing link). We check the policy belonging to the node that we just forwarded to, which - // represents the fee in that direction. - // - // TODO: add invoice-related checks (including final CTLV) if we support non-keysend payments. - if i != route.hops.len() - 1 { - if let Some(channel) = - node_lock.get(&ShortChannelID::from(route.hops[i + 1].short_channel_id)) - { - channel + // If the HTLC was successfully added, then we'll need to remove the HTLC from this channel if we fail, + // so we progress our failure index to include this node. + fail_idx = Some(i); + + // Once we've added the HTLC on this hop's channel, we want to check whether it has sufficient fee + // and CLTV delta per the _next_ channel's policy (because fees and CLTV delta in LN are charged on + // the outgoing link). We check the policy belonging to the node that we just forwarded to, which + // represents the fee in that direction. + // + // TODO: add invoice-related checks (including final CTLV) if we support non-keysend payments. + let mut next_scid = None; + if i != last_hop { + next_scid = Some(ShortChannelID::from(route.hops[i + 1].short_channel_id)); + + if let Some(channel) = node_lock.get(&next_scid.unwrap()) { + channel .check_htlc_forward( &hop.pubkey, hop.cltv_expiry_delta, @@ -918,16 +1074,99 @@ async fn add_htlcs( // If we haven't met forwarding conditions for the next channel's policy, then we fail at // the current index, because we've already added the HTLC as outgoing. .map_err(|e| (fail_idx, e))?; + } + } + let incoming_htlc = HtlcRef { + channel_id: scid, + index: htlc_index, + }; + (incoming_htlc, next_scid) + } else { + return Err((fail_idx, ForwardingError::ChannelNotFound(scid))); + } + }; + + // Before we continue on to the next hop, we'll call any interceptors registered to get external input on the + // forwarding decision for this HTLC. + // + // We drop our node lock so that we can await our interceptors (which may choose to hold the HTLC for a long + // time) without holding our entire graph hostage. + drop(node_lock); + + // Collect any custom records set by the interceptor for the outgoing link. + let mut outgoing_custom_records: HashMap> = HashMap::new(); + + if !interceptors.is_empty() { + let mut intercepts: JoinSet>, ForwardingError>> = + JoinSet::new(); + + for interceptor in interceptors.iter() { + let request = InterceptRequest::new( + hop.clone(), + payment_hash, + // We've just added the outgoing amount to the sending node, and we're notifying the forward to its + // peer that has just received an incoming htlc, so the outgoing amount added to the sending node + // is the incoming amount for the forwarding node. + outgoing_amount, + incoming_htlc.clone(), + incoming_custom_records.clone(), + next_scid, + outgoing_cltv, + ); + + log::trace!( + "Sending HTLC to intercepor: {} {request}", + interceptor.name() + ); + + let interceptor_clone = Arc::clone(interceptor); + intercepts.spawn(async move { interceptor_clone.intercept_htlc(request).await }); + } + + // Read results from the interceptors and check whether any of them returned an instruction to fail + // the HTLC. If any of the interceptors did return an error, we drop the intercepts + // JoinSet to abort the other interceptor calls that may have not returned yet. + while let Some(res) = intercepts.join_next().await { + match res { + Ok(Ok(records)) => { + // Interceptor call succeeded and indicated that we should proceed with the forward. Merge + // any custom records provided, failing if interceptors provide duplicate values for the + // same key. + for (k, v) in records { + match outgoing_custom_records.entry(k) { + Entry::Occupied(e) => { + let existing_value = e.get(); + if *existing_value == v { + return Err(( + fail_idx, + ForwardingError::DuplicateCustomRecord(k), + )); + } + }, + Entry::Vacant(e) => { + e.insert(v); + }, + }; + } + }, + Ok(Err(e)) => { + drop(intercepts); + return Err((fail_idx, e)); + }, + Err(e) => { + drop(intercepts); + return Err((fail_idx, ForwardingError::InterceptorError(Box::new(e)))); + }, } } - } else { - return Err((fail_idx, ForwardingError::ChannelNotFound(scid))); } - // Once we've taken the "hop" to the destination pubkey, it becomes the source of the next outgoing htlc. + // Once we've taken the "hop" to the destination pubkey, it becomes the source of the next outgoing htlc and + // any outgoing custom records set by the interceptor become the incoming custom records for the next hop. outgoing_node = hop.pubkey; outgoing_amount -= hop.fee_msat; outgoing_cltv -= hop.cltv_expiry_delta; + incoming_custom_records = outgoing_custom_records; // TODO: introduce artificial latency between hops? } @@ -951,7 +1190,9 @@ async fn remove_htlcs( route: Path, payment_hash: PaymentHash, success: bool, + interceptors: Vec>, ) -> Result<(), ForwardingError> { + let mut outgoing_channel_id = None; for (i, hop) in route.hops[0..=resolution_idx].iter().enumerate().rev() { // When we add HTLCs, we do so on the state of the node that sent the htlc along the channel so we need to // look up our incoming node so that we can remove it when we go backwards. For the first htlc, this is just @@ -964,18 +1205,38 @@ async fn remove_htlcs( // As with when we add HTLCs, we remove them one hop at a time (rather than locking for the whole route) to // mimic the behavior of payments in a real network. - match nodes - .lock() - .await - .get_mut(&ShortChannelID::from(hop.short_channel_id)) - { + let mut node_lock = nodes.lock().await; + let incoming_scid = ShortChannelID::from(hop.short_channel_id); + let (_removed_htlc, index) = match node_lock.get_mut(&incoming_scid) { Some(channel) => channel.remove_htlc(&incoming_node, &payment_hash, success)?, None => { return Err(ForwardingError::ChannelNotFound(ShortChannelID::from( hop.short_channel_id, ))) }, + }; + + // We drop our node lock so that we can notify interceptors without blocking other payments processing. + drop(node_lock); + + for interceptor in interceptors.iter() { + log::trace!("Sending resolution to interceptor: {}", interceptor.name()); + + interceptor + .notify_resolution(InterceptResolution { + forwarding_node: hop.pubkey, + incoming_htlc: HtlcRef { + channel_id: incoming_scid, + index, + }, + outgoing_channel_id, + success, + }) + .await + .map_err(ForwardingError::NotifierError)?; } + + outgoing_channel_id = Some(incoming_scid); } Ok(()) @@ -991,20 +1252,35 @@ async fn propagate_payment( route: Path, payment_hash: PaymentHash, sender: Sender>, + interceptors: Vec>, shutdown: Trigger, ) { // If we partially added HTLCs along the route, we need to fail them back to the source to clean up our partial // state. It's possible that we failed with the very first add, and then we don't need to clean anything up. - let notify_result = if let Err((fail_idx, err)) = - add_htlcs(nodes.clone(), source, route.clone(), payment_hash).await + let notify_result = if let Err((fail_idx, err)) = add_htlcs( + nodes.clone(), + source, + route.clone(), + payment_hash, + interceptors.clone(), + ) + .await { if err.is_critical() { shutdown.trigger(); } if let Some(resolution_idx) = fail_idx { - if let Err(e) = - remove_htlcs(nodes, resolution_idx, source, route, payment_hash, false).await + if let Err(e) = remove_htlcs( + nodes, + resolution_idx, + source, + route, + payment_hash, + false, + interceptors, + ) + .await { if e.is_critical() { shutdown.trigger(); @@ -1032,6 +1308,7 @@ async fn propagate_payment( route, payment_hash, true, + interceptors, ) .await { @@ -1553,6 +1830,21 @@ mod tests { )); } + mock! { + #[derive(Debug)] + TestInterceptor{} + + #[async_trait] + impl Interceptor for TestInterceptor { + async fn intercept_htlc(&self, req: InterceptRequest) -> Result; + async fn notify_resolution( + &self, + res: InterceptResolution, + ) -> Result<(), Box>; + fn name(&self) -> String; + } + } + /// Contains elements required to test dispatch_payment functionality. struct DispatchPaymentTestKit<'a> { graph: SimGraph, @@ -1569,7 +1861,7 @@ mod tests { /// Alice (100) --- (0) Bob (100) --- (0) Carol (100) --- (0) Dave /// /// The nodes pubkeys in this chain of channels are provided in-order for easy access. - async fn new(capacity: u64) -> Self { + async fn new(capacity: u64, interceptors: Vec>) -> Self { let (shutdown, _listener) = triggered::trigger(); let channels = create_simulated_channels(3, capacity); let routing_graph = Arc::new(populate_network_graph(channels.clone()).unwrap()); @@ -1589,8 +1881,13 @@ mod tests { nodes.push(channels.last().unwrap().node_2.policy.pubkey); let kit = DispatchPaymentTestKit { - graph: SimGraph::new(channels.clone(), TaskTracker::new(), shutdown.clone()) - .expect("could not create test graph"), + graph: SimGraph::new( + channels.clone(), + TaskTracker::new(), + interceptors, + shutdown.clone(), + ) + .expect("could not create test graph"), nodes, routing_graph, scorer, @@ -1631,12 +1928,12 @@ mod tests { // Sends a test payment from source to destination and waits for the payment to complete, returning the route // used. - async fn send_test_payemnt( + async fn send_test_payment( &mut self, source: PublicKey, dest: PublicKey, amt: u64, - ) -> Route { + ) -> (Route, Result) { let route = find_payment_route(&source, dest, amt, &self.routing_graph, &self.scorer).unwrap(); @@ -1644,10 +1941,11 @@ mod tests { self.graph .dispatch_payment(source, route.clone(), PaymentHash([1; 32]), sender); + let payment_result = timeout(Duration::from_millis(10), receiver).await; // Assert that we receive from the channel or fail. - assert!(timeout(Duration::from_millis(10), receiver).await.is_ok()); + assert!(payment_result.is_ok()); - route + (route, payment_result.unwrap().unwrap()) } // Sets the balance on the channel to the tuple provided, used to arrange liquidity for testing. @@ -1667,12 +1965,12 @@ mod tests { #[tokio::test] async fn test_successful_dispatch() { let chan_capacity = 500_000_000; - let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await; + let mut test_kit = DispatchPaymentTestKit::new(chan_capacity, vec![]).await; // Send a payment that should succeed from Alice -> Dave. let mut amt = 20_000; - let route = test_kit - .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], amt) + let (route, _) = test_kit + .send_test_payment(test_kit.nodes[0], test_kit.nodes[3], amt) .await; let route_total = amt + route.get_total_fees(); @@ -1692,7 +1990,7 @@ mod tests { // machine, so we want to specifically hit it. To do this, we'll try to send double the amount that we just // pushed to Dave back to Bob, expecting a failure on Dave's outgoing link due to insufficient liquidity. let _ = test_kit - .send_test_payemnt(test_kit.nodes[3], test_kit.nodes[1], amt * 2) + .send_test_payment(test_kit.nodes[3], test_kit.nodes[1], amt * 2) .await; assert_eq!(test_kit.channel_balances().await, expected_balances); @@ -1701,7 +1999,7 @@ mod tests { // use 50% of the channel's capacity, so we need to do two payments. amt = bob_to_carol.0 / 2; let _ = test_kit - .send_test_payemnt(test_kit.nodes[1], test_kit.nodes[2], amt) + .send_test_payment(test_kit.nodes[1], test_kit.nodes[2], amt) .await; bob_to_carol = (bob_to_carol.0 / 2, bob_to_carol.1 + amt); @@ -1710,7 +2008,7 @@ mod tests { // When we push this amount a second time, all the liquidity should be moved to Carol's end. let _ = test_kit - .send_test_payemnt(test_kit.nodes[1], test_kit.nodes[2], amt) + .send_test_payment(test_kit.nodes[1], test_kit.nodes[2], amt) .await; bob_to_carol = (0, chan_capacity); expected_balances = vec![alice_to_bob, bob_to_carol, carol_to_dave]; @@ -1719,7 +2017,7 @@ mod tests { // Finally, we'll test a multi-hop failure by trying to send from Alice -> Dave. Since Bob's liquidity is // drained, we expect a failure and unchanged balances along the route. let _ = test_kit - .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], 20_000) + .send_test_payment(test_kit.nodes[0], test_kit.nodes[3], 20_000) .await; assert_eq!(test_kit.channel_balances().await, expected_balances); @@ -1732,12 +2030,12 @@ mod tests { #[tokio::test] async fn test_successful_multi_hop() { let chan_capacity = 500_000_000; - let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await; + let mut test_kit = DispatchPaymentTestKit::new(chan_capacity, vec![]).await; // Send a payment that should succeed from Alice -> Dave. let amt = 20_000; - let route = test_kit - .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], amt) + let (route, _) = test_kit + .send_test_payment(test_kit.nodes[0], test_kit.nodes[3], amt) .await; let route_total = amt + route.get_total_fees(); @@ -1762,12 +2060,12 @@ mod tests { #[tokio::test] async fn test_single_hop_payments() { let chan_capacity = 500_000_000; - let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await; + let mut test_kit = DispatchPaymentTestKit::new(chan_capacity, vec![]).await; // Send a single hop payment from Alice -> Bob, it will succeed because Alice has all the liquidity. let amt = 150_000; let _ = test_kit - .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[1], amt) + .send_test_payment(test_kit.nodes[0], test_kit.nodes[1], amt) .await; let expected_balances = vec![ @@ -1780,7 +2078,7 @@ mod tests { // Send a single hop payment from Dave -> Carol that will fail due to lack of liquidity, balances should be // unchanged. let _ = test_kit - .send_test_payemnt(test_kit.nodes[3], test_kit.nodes[2], amt) + .send_test_payment(test_kit.nodes[3], test_kit.nodes[2], amt) .await; assert_eq!(test_kit.channel_balances().await, expected_balances); @@ -1794,7 +2092,7 @@ mod tests { #[tokio::test] async fn test_multi_hop_faiulre() { let chan_capacity = 500_000_000; - let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await; + let mut test_kit = DispatchPaymentTestKit::new(chan_capacity, vec![]).await; // Drain liquidity between Bob and Carol to force failures on Bob's outgoing linke. test_kit @@ -1808,7 +2106,7 @@ mod tests { // Send a payment from Alice -> Dave which we expect to fail leaving balances unaffected. let amt = 150_000; let _ = test_kit - .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], amt) + .send_test_payment(test_kit.nodes[0], test_kit.nodes[3], amt) .await; assert_eq!(test_kit.channel_balances().await, expected_balances); @@ -1821,7 +2119,7 @@ mod tests { .await; let _ = test_kit - .send_test_payemnt(test_kit.nodes[3], test_kit.nodes[0], amt) + .send_test_payment(test_kit.nodes[3], test_kit.nodes[0], amt) .await; assert_eq!(test_kit.channel_balances().await, expected_balances); @@ -1830,4 +2128,121 @@ mod tests { test_kit.graph.tasks.close(); test_kit.graph.tasks.wait().await; } + + /// Tests intercepted htlc failures. + #[tokio::test] + async fn test_intercepted_htlc_failure() { + // Test with 2 interceptors where one of them returns a signal to fail the htlc. + let mut mock_interceptor_1 = MockTestInterceptor::new(); + mock_interceptor_1 + .expect_intercept_htlc() + .returning(|_| Ok(CustomRecords::default())); + mock_interceptor_1 + .expect_notify_resolution() + .returning(|_| Ok(())); + + let mut mock_interceptor_2 = MockTestInterceptor::new(); + mock_interceptor_2.expect_intercept_htlc().returning(|_| { + Err(ForwardingError::InterceptorError( + "failing from mock interceptor".into(), + )) + }); + mock_interceptor_2 + .expect_notify_resolution() + .returning(|_| Ok(())); + + let chan_capacity = 500_000_000; + let mock_1 = Arc::new(mock_interceptor_1); + let mock_2 = Arc::new(mock_interceptor_2); + let mut test_kit = DispatchPaymentTestKit::new(chan_capacity, vec![mock_1, mock_2]).await; + let expected_balances = vec![(chan_capacity, 0), (chan_capacity, 0), (chan_capacity, 0)]; + + // Send payment where there is enough liquidity but one of the interceptors fails the htlc. + let amt = 150_000; + let (_, result) = test_kit + .send_test_payment(test_kit.nodes[0], test_kit.nodes[3], amt) + .await; + + assert_eq!(test_kit.channel_balances().await, expected_balances); + assert!(matches!( + result.unwrap().payment_outcome, + PaymentOutcome::Unknown + )); + + let mut mock_interceptor_1 = MockTestInterceptor::new(); + mock_interceptor_1 + .expect_intercept_htlc() + .returning(|_| Ok(CustomRecords::from([(1000, vec![1])]))); + mock_interceptor_1 + .expect_notify_resolution() + .returning(|_| Ok(())); + + let mut mock_interceptor_2 = MockTestInterceptor::new(); + mock_interceptor_2 + .expect_intercept_htlc() + .returning(|_| Ok(CustomRecords::from([(1000, vec![1])]))); + mock_interceptor_2 + .expect_notify_resolution() + .returning(|_| Ok(())); + + let mock_1 = Arc::new(mock_interceptor_1); + let mock_2 = Arc::new(mock_interceptor_2); + let mut test_kit = DispatchPaymentTestKit::new(chan_capacity, vec![mock_1, mock_2]).await; + + // Test intercepted htlc with conflicting records. Conflicting records should fail the + // htlc. + let (_, result) = test_kit + .send_test_payment(test_kit.nodes[0], test_kit.nodes[3], amt) + .await; + + assert_eq!(test_kit.channel_balances().await, expected_balances); + assert!(matches!( + result.unwrap().payment_outcome, + PaymentOutcome::Unknown + )); + + test_kit.shutdown.trigger(); + test_kit.graph.tasks.close(); + test_kit.graph.tasks.wait().await; + } + + /// Tests intercepted htlc success. + #[tokio::test] + async fn test_intercepted_htlc_success() { + let mut mock_interceptor_1 = MockTestInterceptor::new(); + mock_interceptor_1 + .expect_intercept_htlc() + .returning(|_| Ok(CustomRecords::default())); + mock_interceptor_1 + .expect_notify_resolution() + .returning(|_| Ok(())); + + let chan_capacity = 500_000_000; + let mock_1 = Arc::new(mock_interceptor_1); + let mut test_kit = DispatchPaymentTestKit::new(chan_capacity, vec![mock_1]).await; + + // Test payment with enough liquidity and interceptor succeeds. + let amt = 150_000; + let (route, result) = test_kit + .send_test_payment(test_kit.nodes[0], test_kit.nodes[3], amt) + .await; + + let route_total = amt + route.get_total_fees(); + let hop_1_amt = amt + route.paths[0].hops[1].fee_msat; + let expected_balances = vec![ + (chan_capacity - route_total, route_total), + (chan_capacity - hop_1_amt, hop_1_amt), + (chan_capacity - amt, amt), + ]; + + assert_eq!(test_kit.channel_balances().await, expected_balances); + assert!(matches!( + result.unwrap().payment_outcome, + PaymentOutcome::Success + )); + + test_kit.shutdown.trigger(); + test_kit.graph.tasks.close(); + test_kit.graph.tasks.wait().await; + } }