diff --git a/src/chain/mod.rs b/src/chain/mod.rs index 92c4bdb64..5a326be97 100644 --- a/src/chain/mod.rs +++ b/src/chain/mod.rs @@ -9,7 +9,7 @@ pub(crate) mod bitcoind; mod electrum; mod esplora; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::{Arc, Mutex}; use std::time::Duration; @@ -84,7 +84,7 @@ impl WalletSyncStatus { pub(crate) struct ChainSource { kind: ChainSourceKind, - registered_txids: Mutex>, + registered_txids: Mutex>, tx_broadcaster: Arc, logger: Arc, } @@ -113,7 +113,7 @@ impl ChainSource { node_metrics, )?; let kind = ChainSourceKind::Esplora(esplora_chain_source); - let registered_txids = Mutex::new(Vec::new()); + let registered_txids = Mutex::new(HashSet::new()); Ok((Self { kind, registered_txids, tx_broadcaster, logger }, None)) } @@ -133,7 +133,7 @@ impl ChainSource { node_metrics, ); let kind = ChainSourceKind::Electrum(electrum_chain_source); - let registered_txids = Mutex::new(Vec::new()); + let registered_txids = Mutex::new(HashSet::new()); (Self { kind, registered_txids, tx_broadcaster, logger }, None) } @@ -156,7 +156,7 @@ impl ChainSource { ); let best_block = bitcoind_chain_source.poll_best_block().await.ok(); let kind = ChainSourceKind::Bitcoind(bitcoind_chain_source); - let registered_txids = Mutex::new(Vec::new()); + let registered_txids = Mutex::new(HashSet::new()); (Self { kind, registered_txids, tx_broadcaster, logger }, best_block) } @@ -180,7 +180,7 @@ impl ChainSource { ); let best_block = bitcoind_chain_source.poll_best_block().await.ok(); let kind = ChainSourceKind::Bitcoind(bitcoind_chain_source); - let registered_txids = Mutex::new(Vec::new()); + let registered_txids = Mutex::new(HashSet::new()); (Self { kind, registered_txids, tx_broadcaster, logger }, best_block) } @@ -214,7 +214,7 @@ impl ChainSource { } } - pub(crate) fn registered_txids(&self) -> Vec { + pub(crate) fn registered_txids(&self) -> HashSet { self.registered_txids.lock().expect("lock").clone() } @@ -472,7 +472,7 @@ impl ChainSource { impl Filter for ChainSource { fn register_tx(&self, txid: &Txid, script_pubkey: &Script) { - self.registered_txids.lock().expect("lock").push(*txid); + self.registered_txids.lock().expect("lock").insert(*txid); match &self.kind { ChainSourceKind::Esplora(esplora_chain_source) => { esplora_chain_source.register_tx(txid, script_pubkey) diff --git a/src/payment/bolt11.rs b/src/payment/bolt11.rs index 068269997..e3cb948a1 100644 --- a/src/payment/bolt11.rs +++ b/src/payment/bolt11.rs @@ -539,7 +539,7 @@ impl Bolt11Payment { _ => 0, }; if let Some(invoice_amount_msat) = details.amount_msat { - if claimable_amount_msat < invoice_amount_msat - skimmed_fee_msat { + if claimable_amount_msat < invoice_amount_msat.saturating_sub(skimmed_fee_msat) { log_error!( self.logger, "Failed to manually claim payment {} as the claimable amount is less than expected", diff --git a/src/payment/unified.rs b/src/payment/unified.rs index 3708afe8e..2ad77f772 100644 --- a/src/payment/unified.rs +++ b/src/payment/unified.rs @@ -129,9 +129,9 @@ impl UnifiedPayment { pub fn receive( &self, amount_sats: u64, description: &str, expiry_sec: u32, ) -> Result { - let onchain_address = self.onchain_payment.new_address()?; + let amount_msats = amount_sats.checked_mul(1_000).ok_or(Error::InvalidAmount)?; - let amount_msats = amount_sats * 1_000; + let onchain_address = self.onchain_payment.new_address()?; let bolt12_offer = match self.bolt12_payment.receive_inner(amount_msats, description, None, None) { diff --git a/tests/integration_tests_rust.rs b/tests/integration_tests_rust.rs index 309d5bf4d..76835b38a 100644 --- a/tests/integration_tests_rust.rs +++ b/tests/integration_tests_rust.rs @@ -1680,6 +1680,18 @@ async fn generate_bip21_uri() { assert!(uni_payment.contains("lno=")); } +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn unified_receive_rejects_msat_overflow() { + let (bitcoind, electrsd) = setup_bitcoind_and_electrsd(); + let chain_source = random_chain_source(&bitcoind, &electrsd); + let node = setup_node(&chain_source, random_config(true)); + + assert_eq!( + Err(NodeError::InvalidAmount), + node.unified_payment().receive(u64::MAX, "asdf", 4_000) + ); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn unified_send_receive_bip21_uri() { let (bitcoind, electrsd) = setup_bitcoind_and_electrsd();