diff --git a/lightning/src/util/mod.rs b/lightning/src/util/mod.rs index 29954e86624..59b2fd3d8e0 100644 --- a/lightning/src/util/mod.rs +++ b/lightning/src/util/mod.rs @@ -15,23 +15,23 @@ pub(crate) mod fuzz_wrappers; #[macro_use] pub mod ser_macros; +#[cfg(fuzzing)] +pub mod base32; +#[cfg(not(fuzzing))] +pub(crate) mod base32; pub mod errors; -pub mod ser; pub mod message_signing; pub mod persist; pub mod scid_utils; +pub mod ser; pub mod sweep; pub mod wakers; -#[cfg(fuzzing)] -pub mod base32; -#[cfg(not(fuzzing))] -pub(crate) mod base32; -pub(crate) mod atomic_counter; pub(crate) mod async_poll; +pub(crate) mod atomic_counter; pub(crate) mod byte_utils; -pub(crate) mod transaction_utils; pub mod hash_tables; +pub(crate) mod transaction_utils; #[cfg(feature = "std")] pub(crate) mod time; @@ -43,8 +43,8 @@ pub mod indexed_map; pub(crate) mod macro_logger; // These have to come after macro_logger to build -pub mod logger; pub mod config; +pub mod logger; #[cfg(any(test, feature = "_test_utils"))] pub mod test_utils; diff --git a/lightning/src/util/scid_utils.rs b/lightning/src/util/scid_utils.rs index 4e3fa37dcf0..b9dcc4688e8 100644 --- a/lightning/src/util/scid_utils.rs +++ b/lightning/src/util/scid_utils.rs @@ -49,7 +49,9 @@ pub fn vout_from_scid(short_channel_id: u64) -> u16 { /// Constructs a `short_channel_id` using the components pieces. Results in an error /// if the block height, tx index, or vout index overflow the maximum sizes. -pub fn scid_from_parts(block: u64, tx_index: u64, vout_index: u64) -> Result { +pub fn scid_from_parts( + block: u64, tx_index: u64, vout_index: u64, +) -> Result { if block > MAX_SCID_BLOCK { return Err(ShortChannelIdError::BlockOverflow); } @@ -71,12 +73,12 @@ pub fn scid_from_parts(block: u64, tx_index: u64, vout_index: u64) -> Result(&self, highest_seen_blockheight: u32, chain_hash: &ChainHash, fake_scid_rand_bytes: &[u8; 32], entropy_source: &ES) -> u64 - where ES::Target: EntropySource, + pub(crate) fn get_fake_scid( + &self, highest_seen_blockheight: u32, chain_hash: &ChainHash, + fake_scid_rand_bytes: &[u8; 32], entropy_source: &ES, + ) -> u64 + where + ES::Target: EntropySource, { // Ensure we haven't created a namespace that doesn't fit into the 3 bits we've allocated for // namespaces. @@ -118,27 +123,38 @@ pub(crate) mod fake_scid { let rand_bytes = entropy_source.get_secure_random_bytes(); let segwit_activation_height = segwit_activation_height(chain_hash); - let mut blocks_since_segwit_activation = highest_seen_blockheight.saturating_sub(segwit_activation_height); + let mut blocks_since_segwit_activation = + highest_seen_blockheight.saturating_sub(segwit_activation_height); // We want to ensure that this fake channel won't conflict with any transactions we haven't // seen yet, in case `highest_seen_blockheight` is updated before we get full information // about transactions confirmed in the given block. - blocks_since_segwit_activation = blocks_since_segwit_activation.saturating_sub(MAX_SCID_BLOCKS_FROM_NOW); + blocks_since_segwit_activation = + blocks_since_segwit_activation.saturating_sub(MAX_SCID_BLOCKS_FROM_NOW); let rand_for_height = u32::from_be_bytes(rand_bytes[..4].try_into().unwrap()); - let fake_scid_height = segwit_activation_height + rand_for_height % (blocks_since_segwit_activation + 1); + let fake_scid_height = + segwit_activation_height + rand_for_height % (blocks_since_segwit_activation + 1); let rand_for_tx_index = u32::from_be_bytes(rand_bytes[4..8].try_into().unwrap()); let fake_scid_tx_index = rand_for_tx_index % MAX_TX_INDEX; // Put the scid in the given namespace. - let fake_scid_vout = self.get_encrypted_vout(fake_scid_height, fake_scid_tx_index, fake_scid_rand_bytes); - scid_utils::scid_from_parts(fake_scid_height as u64, fake_scid_tx_index as u64, fake_scid_vout as u64).unwrap() + let fake_scid_vout = + self.get_encrypted_vout(fake_scid_height, fake_scid_tx_index, fake_scid_rand_bytes); + scid_utils::scid_from_parts( + fake_scid_height as u64, + fake_scid_tx_index as u64, + fake_scid_vout as u64, + ) + .unwrap() } /// We want to ensure that a 3rd party can't identify a payment as belong to a given /// `Namespace`. Therefore, we encrypt it using a random bytes provided by `ChannelManager`. - fn get_encrypted_vout(&self, block_height: u32, tx_index: u32, fake_scid_rand_bytes: &[u8; 32]) -> u8 { + fn get_encrypted_vout( + &self, block_height: u32, tx_index: u32, fake_scid_rand_bytes: &[u8; 32], + ) -> u8 { let mut salt = [0 as u8; 8]; let block_height_bytes = block_height.to_be_bytes(); salt[0..4].copy_from_slice(&block_height_bytes); @@ -161,7 +177,9 @@ pub(crate) mod fake_scid { } /// Returns whether the given fake scid falls into the phantom namespace. - pub fn is_valid_phantom(fake_scid_rand_bytes: &[u8; 32], scid: u64, chain_hash: &ChainHash) -> bool { + pub fn is_valid_phantom( + fake_scid_rand_bytes: &[u8; 32], scid: u64, chain_hash: &ChainHash, + ) -> bool { let block_height = scid_utils::block_from_scid(scid); let tx_index = scid_utils::tx_index_from_scid(scid); let namespace = Namespace::Phantom; @@ -171,7 +189,9 @@ pub(crate) mod fake_scid { } /// Returns whether the given fake scid falls into the intercept namespace. - pub fn is_valid_intercept(fake_scid_rand_bytes: &[u8; 32], scid: u64, chain_hash: &ChainHash) -> bool { + pub fn is_valid_intercept( + fake_scid_rand_bytes: &[u8; 32], scid: u64, chain_hash: &ChainHash, + ) -> bool { let block_height = scid_utils::block_from_scid(scid); let tx_index = scid_utils::tx_index_from_scid(scid); let namespace = Namespace::Intercept; @@ -182,12 +202,16 @@ pub(crate) mod fake_scid { #[cfg(test)] mod tests { - use bitcoin::constants::ChainHash; - use bitcoin::network::Network; - use crate::util::scid_utils::fake_scid::{is_valid_intercept, is_valid_phantom, MAINNET_SEGWIT_ACTIVATION_HEIGHT, MAX_TX_INDEX, MAX_NAMESPACES, Namespace, NAMESPACE_ID_BITMASK, segwit_activation_height, TEST_SEGWIT_ACTIVATION_HEIGHT}; + use crate::sync::Arc; use crate::util::scid_utils; + use crate::util::scid_utils::fake_scid::{ + is_valid_intercept, is_valid_phantom, segwit_activation_height, Namespace, + MAINNET_SEGWIT_ACTIVATION_HEIGHT, MAX_NAMESPACES, MAX_TX_INDEX, NAMESPACE_ID_BITMASK, + TEST_SEGWIT_ACTIVATION_HEIGHT, + }; use crate::util::test_utils; - use crate::sync::Arc; + use bitcoin::constants::ChainHash; + use bitcoin::network::Network; #[test] fn namespace_identifier_is_within_range() { @@ -203,7 +227,10 @@ pub(crate) mod fake_scid { #[test] fn test_segwit_activation_height() { let mainnet_genesis = ChainHash::using_genesis_block(Network::Bitcoin); - assert_eq!(segwit_activation_height(&mainnet_genesis), MAINNET_SEGWIT_ACTIVATION_HEIGHT); + assert_eq!( + segwit_activation_height(&mainnet_genesis), + MAINNET_SEGWIT_ACTIVATION_HEIGHT + ); let testnet_genesis = ChainHash::using_genesis_block(Network::Testnet); assert_eq!(segwit_activation_height(&testnet_genesis), TEST_SEGWIT_ACTIVATION_HEIGHT); @@ -221,7 +248,8 @@ pub(crate) mod fake_scid { let fake_scid_rand_bytes = [0; 32]; let testnet_genesis = ChainHash::using_genesis_block(Network::Testnet); let valid_encrypted_vout = namespace.get_encrypted_vout(0, 0, &fake_scid_rand_bytes); - let valid_fake_scid = scid_utils::scid_from_parts(1, 0, valid_encrypted_vout as u64).unwrap(); + let valid_fake_scid = + scid_utils::scid_from_parts(1, 0, valid_encrypted_vout as u64).unwrap(); assert!(is_valid_phantom(&fake_scid_rand_bytes, valid_fake_scid, &testnet_genesis)); let invalid_fake_scid = scid_utils::scid_from_parts(1, 0, 12).unwrap(); assert!(!is_valid_phantom(&fake_scid_rand_bytes, invalid_fake_scid, &testnet_genesis)); @@ -233,10 +261,15 @@ pub(crate) mod fake_scid { let fake_scid_rand_bytes = [0; 32]; let testnet_genesis = ChainHash::using_genesis_block(Network::Testnet); let valid_encrypted_vout = namespace.get_encrypted_vout(0, 0, &fake_scid_rand_bytes); - let valid_fake_scid = scid_utils::scid_from_parts(1, 0, valid_encrypted_vout as u64).unwrap(); + let valid_fake_scid = + scid_utils::scid_from_parts(1, 0, valid_encrypted_vout as u64).unwrap(); assert!(is_valid_intercept(&fake_scid_rand_bytes, valid_fake_scid, &testnet_genesis)); let invalid_fake_scid = scid_utils::scid_from_parts(1, 0, 12).unwrap(); - assert!(!is_valid_intercept(&fake_scid_rand_bytes, invalid_fake_scid, &testnet_genesis)); + assert!(!is_valid_intercept( + &fake_scid_rand_bytes, + invalid_fake_scid, + &testnet_genesis + )); } #[test] @@ -244,9 +277,15 @@ pub(crate) mod fake_scid { let mainnet_genesis = ChainHash::using_genesis_block(Network::Bitcoin); let seed = [0; 32]; let fake_scid_rand_bytes = [1; 32]; - let keys_manager = Arc::new(test_utils::TestKeysInterface::new(&seed, Network::Testnet)); + let keys_manager = + Arc::new(test_utils::TestKeysInterface::new(&seed, Network::Testnet)); let namespace = Namespace::Phantom; - let fake_scid = namespace.get_fake_scid(500_000, &mainnet_genesis, &fake_scid_rand_bytes, &keys_manager); + let fake_scid = namespace.get_fake_scid( + 500_000, + &mainnet_genesis, + &fake_scid_rand_bytes, + &keys_manager, + ); let fake_height = scid_utils::block_from_scid(fake_scid); assert!(fake_height >= MAINNET_SEGWIT_ACTIVATION_HEIGHT); @@ -298,8 +337,17 @@ mod tests { assert_eq!(scid_from_parts(0x00000001, 0x00000002, 0x0003).unwrap(), 0x000001_000002_0003); assert_eq!(scid_from_parts(0x00111111, 0x00222222, 0x3333).unwrap(), 0x111111_222222_3333); assert_eq!(scid_from_parts(0x00ffffff, 0x00ffffff, 0xffff).unwrap(), 0xffffff_ffffff_ffff); - assert_eq!(scid_from_parts(0x01ffffff, 0x00000000, 0x0000).err().unwrap(), ShortChannelIdError::BlockOverflow); - assert_eq!(scid_from_parts(0x00000000, 0x01ffffff, 0x0000).err().unwrap(), ShortChannelIdError::TxIndexOverflow); - assert_eq!(scid_from_parts(0x00000000, 0x00000000, 0x010000).err().unwrap(), ShortChannelIdError::VoutIndexOverflow); + assert_eq!( + scid_from_parts(0x01ffffff, 0x00000000, 0x0000).err().unwrap(), + ShortChannelIdError::BlockOverflow + ); + assert_eq!( + scid_from_parts(0x00000000, 0x01ffffff, 0x0000).err().unwrap(), + ShortChannelIdError::TxIndexOverflow + ); + assert_eq!( + scid_from_parts(0x00000000, 0x00000000, 0x010000).err().unwrap(), + ShortChannelIdError::VoutIndexOverflow + ); } } diff --git a/lightning/src/util/ser.rs b/lightning/src/util/ser.rs index 281789067ea..0f8dae0eb8d 100644 --- a/lightning/src/util/ser.rs +++ b/lightning/src/util/ser.rs @@ -13,39 +13,41 @@ //! [`ChannelManager`]: crate::ln::channelmanager::ChannelManager //! [`ChannelMonitor`]: crate::chain::channelmonitor::ChannelMonitor -use crate::prelude::*; use crate::io::{self, BufRead, Read, Write}; use crate::io_extras::{copy, sink}; -use core::hash::Hash; +use crate::prelude::*; use crate::sync::{Mutex, RwLock}; use core::cmp; +use core::hash::Hash; use core::ops::Deref; use alloc::collections::BTreeMap; -use bitcoin::secp256k1::{PublicKey, SecretKey}; -use bitcoin::secp256k1::constants::{PUBLIC_KEY_SIZE, SECRET_KEY_SIZE, COMPACT_SIGNATURE_SIZE, SCHNORR_SIGNATURE_SIZE}; -use bitcoin::secp256k1::ecdsa; -use bitcoin::secp256k1::schnorr; use bitcoin::amount::Amount; +use bitcoin::consensus::Encodable; use bitcoin::constants::ChainHash; +use bitcoin::hash_types::{BlockHash, Txid}; +use bitcoin::hashes::hmac::Hmac; +use bitcoin::hashes::sha256::Hash as Sha256; +use bitcoin::hashes::sha256d::Hash as Sha256dHash; use bitcoin::script::{self, ScriptBuf}; +use bitcoin::secp256k1::constants::{ + COMPACT_SIGNATURE_SIZE, PUBLIC_KEY_SIZE, SCHNORR_SIGNATURE_SIZE, SECRET_KEY_SIZE, +}; +use bitcoin::secp256k1::ecdsa; +use bitcoin::secp256k1::schnorr; +use bitcoin::secp256k1::{PublicKey, SecretKey}; use bitcoin::transaction::{OutPoint, Transaction, TxOut}; use bitcoin::{consensus, Witness}; -use bitcoin::consensus::Encodable; -use bitcoin::hashes::hmac::Hmac; -use bitcoin::hashes::sha256d::Hash as Sha256dHash; -use bitcoin::hashes::sha256::Hash as Sha256; -use bitcoin::hash_types::{Txid, BlockHash}; use dnssec_prover::rr::Name; -use core::time::Duration; use crate::chain::ClaimId; use crate::ln::msgs::DecodeError; #[cfg(taproot)] use crate::ln::msgs::PartialSignatureWithNonce; -use crate::types::payment::{PaymentPreimage, PaymentHash, PaymentSecret}; +use crate::types::payment::{PaymentHash, PaymentPreimage, PaymentSecret}; +use core::time::Duration; use crate::util::byte_utils::{be48_to_array, slice_to_be48}; use crate::util::string::UntrustedString; @@ -79,24 +81,22 @@ impl Writer for W { struct BufReader<'a, R: Read> { inner: &'a mut R, buf: [u8; 1], - is_consumed: bool + is_consumed: bool, } impl<'a, R: Read> BufReader<'a, R> { /// Creates a [`BufReader`] which will read from the given `inner`. pub fn new(inner: &'a mut R) -> Self { - BufReader { - inner, - buf: [0; 1], - is_consumed: true - } + BufReader { inner, buf: [0; 1], is_consumed: true } } } impl<'a, R: Read> Read for BufReader<'a, R> { #[inline] fn read(&mut self, output: &mut [u8]) -> io::Result { - if output.is_empty() { return Ok(0); } + if output.is_empty() { + return Ok(0); + } let mut offset = 0; if !self.is_consumed { output[0] = self.buf[0]; @@ -306,14 +306,17 @@ pub trait Writeable { } impl<'a, T: Writeable> Writeable for &'a T { - fn write(&self, writer: &mut W) -> Result<(), io::Error> { (*self).write(writer) } + fn write(&self, writer: &mut W) -> Result<(), io::Error> { + (*self).write(writer) + } } /// A trait that various LDK types implement allowing them to be read in from a [`Read`]. /// /// This is not exported to bindings users as we only export serialization to/from byte arrays instead pub trait Readable - where Self: Sized +where + Self: Sized, { /// Reads a `Self` in from the given [`Read`]. fn read(reader: &mut R) -> Result; @@ -321,7 +324,10 @@ pub trait Readable /// A trait that various LDK types implement allowing them to be read in from a /// [`io::Cursor`]. -pub(crate) trait CursorReadable where Self: Sized { +pub(crate) trait CursorReadable +where + Self: Sized, +{ /// Reads a `Self` in from the given [`Read`]. fn read>(reader: &mut io::Cursor) -> Result; } @@ -331,7 +337,8 @@ pub(crate) trait CursorReadable where Self: Sized { /// /// This is not exported to bindings users as we only export serialization to/from byte arrays instead pub trait ReadableArgs

- where Self: Sized +where + Self: Sized, { /// Reads a `Self` in from the given [`Read`]. fn read(reader: &mut R, params: P) -> Result; @@ -346,7 +353,9 @@ pub(crate) trait LengthRead: Read { /// A trait that various higher-level LDK types implement allowing them to be read in /// from a Read given some additional set of arguments which is required to deserialize, requiring /// the implementer to provide the total length of the read. -pub(crate) trait LengthReadableArgs

where Self: Sized +pub(crate) trait LengthReadableArgs

+where + Self: Sized, { /// Reads a `Self` in from the given [`LengthRead`]. fn read(reader: &mut R, params: P) -> Result; @@ -354,7 +363,9 @@ pub(crate) trait LengthReadableArgs

where Self: Sized /// A trait that various higher-level LDK types implement allowing them to be read in /// from a [`Read`], requiring the implementer to provide the total length of the read. -pub(crate) trait LengthReadable where Self: Sized +pub(crate) trait LengthReadable +where + Self: Sized, { /// Reads a `Self` in from the given [`LengthRead`]. fn read(reader: &mut R) -> Result; @@ -364,7 +375,8 @@ pub(crate) trait LengthReadable where Self: Sized /// /// This is not exported to bindings users as we only export serialization to/from byte arrays instead pub trait MaybeReadable - where Self: Sized +where + Self: Sized, { /// Reads a `Self` in from the given [`Read`]. fn read(reader: &mut R) -> Result, DecodeError>; @@ -397,7 +409,9 @@ impl> ReadableArgs for RequiredWrapper { /// to a `RequiredWrapper` in a way that works for `field: T = t;` as /// well. Thus, we assume `Into for T` does nothing and use that. impl From for RequiredWrapper { - fn from(t: T) -> RequiredWrapper { RequiredWrapper(Some(t)) } + fn from(t: T) -> RequiredWrapper { + RequiredWrapper(Some(t)) + } } /// Wrapper to read a required (non-optional) TLV record that may have been upgraded without @@ -409,7 +423,9 @@ impl MaybeReadable for UpgradableRequired { #[inline] fn read(reader: &mut R) -> Result, DecodeError> { let tlv = MaybeReadable::read(reader)?; - if let Some(tlv) = tlv { return Ok(Some(Self(Some(tlv)))) } + if let Some(tlv) = tlv { + return Ok(Some(Self(Some(tlv)))); + } Ok(None) } } @@ -443,9 +459,7 @@ impl Writeable for BigSize { #[inline] fn write(&self, writer: &mut W) -> Result<(), io::Error> { match self.0 { - 0..=0xFC => { - (self.0 as u8).write(writer) - }, + 0..=0xFC => (self.0 as u8).write(writer), 0xFD..=0xFFFF => { 0xFDu8.write(writer)?; (self.0 as u16).write(writer) @@ -473,7 +487,7 @@ impl Readable for BigSize { } else { Ok(BigSize(x)) } - } + }, 0xFE => { let x: u32 = Readable::read(reader)?; if x < 0x10000 { @@ -481,7 +495,7 @@ impl Readable for BigSize { } else { Ok(BigSize(x as u64)) } - } + }, 0xFD => { let x: u16 = Readable::read(reader)?; if x < 0xFD { @@ -489,8 +503,8 @@ impl Readable for BigSize { } else { Ok(BigSize(x as u64)) } - } - n => Ok(BigSize(n as u64)) + }, + n => Ok(BigSize(n as u64)), } } } @@ -522,8 +536,8 @@ impl Readable for CollectionLength { fn read(r: &mut R) -> Result { let mut val: u64 = ::read(r)? as u64; if val == 0xffff { - val = ::read(r)? - .checked_add(0xffff).ok_or(DecodeError::InvalidValue)?; + val = + ::read(r)?.checked_add(0xffff).ok_or(DecodeError::InvalidValue)?; } Ok(CollectionLength(val)) } @@ -547,7 +561,7 @@ macro_rules! impl_writeable_primitive { #[inline] fn write(&self, writer: &mut W) -> Result<(), io::Error> { // Skip any full leading 0 bytes when writing (in BE): - writer.write_all(&self.0.to_be_bytes()[(self.0.leading_zeros()/8) as usize..$len]) + writer.write_all(&self.0.to_be_bytes()[(self.0.leading_zeros() / 8) as usize..$len]) } } impl Readable for $val_type { @@ -560,12 +574,14 @@ macro_rules! impl_writeable_primitive { } impl Readable for HighZeroBytesDroppedBigSize<$val_type> { #[inline] - fn read(reader: &mut R) -> Result, DecodeError> { + fn read( + reader: &mut R, + ) -> Result, DecodeError> { // We need to accept short reads (read_len == 0) as "EOF" and handle them as simply // the high bytes being dropped. To do so, we start reading into the middle of buf // and then convert the appropriate number of bytes with extra high bytes out of // buf. - let mut buf = [0; $len*2]; + let mut buf = [0; $len * 2]; let mut read_len = reader.read(&mut buf[$len..])?; let mut total_read_len = read_len; while read_len != 0 && total_read_len != $len { @@ -585,9 +601,11 @@ macro_rules! impl_writeable_primitive { } } impl From<$val_type> for HighZeroBytesDroppedBigSize<$val_type> { - fn from(val: $val_type) -> Self { Self(val) } + fn from(val: $val_type) -> Self { + Self(val) + } } - } + }; } impl_writeable_primitive!(u128, 16); @@ -617,7 +635,7 @@ impl Readable for u8 { impl Writeable for bool { #[inline] fn write(&self, writer: &mut W) -> Result<(), io::Error> { - writer.write_all(&[if *self {1} else {0}]) + writer.write_all(&[if *self { 1 } else { 0 }]) } } impl Readable for bool { @@ -633,14 +651,15 @@ impl Readable for bool { } macro_rules! impl_array { - ($size:expr, $ty: ty) => ( + ($size:expr, $ty: ty) => { impl Writeable for [$ty; $size] { #[inline] fn write(&self, w: &mut W) -> Result<(), io::Error> { let mut out = [0; $size * core::mem::size_of::<$ty>()]; for (idx, v) in self.iter().enumerate() { let startpos = idx * core::mem::size_of::<$ty>(); - out[startpos..startpos + core::mem::size_of::<$ty>()].copy_from_slice(&v.to_be_bytes()); + out[startpos..startpos + core::mem::size_of::<$ty>()] + .copy_from_slice(&v.to_be_bytes()); } w.write_all(&out) } @@ -661,7 +680,7 @@ macro_rules! impl_array { Ok(res) } } - ); + }; } impl_array!(3, u8); // for rgb, ISO 4217 code @@ -697,7 +716,9 @@ impl Readable for WithoutLength { } } impl<'a> From<&'a String> for WithoutLength<&'a String> { - fn from(s: &'a String) -> Self { Self(s) } + fn from(s: &'a String) -> Self { + Self(s) + } } impl Writeable for UntrustedString { @@ -716,7 +737,7 @@ impl Readable for UntrustedString { impl Writeable for WithoutLength<&UntrustedString> { #[inline] fn write(&self, w: &mut W) -> Result<(), io::Error> { - WithoutLength(&self.0.0).write(w) + WithoutLength(&self.0 .0).write(w) } } impl Readable for WithoutLength { @@ -734,11 +755,15 @@ trait AsWriteableSlice { impl AsWriteableSlice for &Vec { type Inner = T; - fn as_slice(&self) -> &[T] { &self } + fn as_slice(&self) -> &[T] { + &self + } } impl AsWriteableSlice for &[T] { type Inner = T; - fn as_slice(&self) -> &[T] { &self } + fn as_slice(&self) -> &[T] { + &self + } } impl Writeable for WithoutLength { @@ -758,8 +783,10 @@ impl Readable for WithoutLength> { loop { let mut track_read = ReadTrackingReader::new(reader); match MaybeReadable::read(&mut track_read) { - Ok(Some(v)) => { values.push(v); }, - Ok(None) => { }, + Ok(Some(v)) => { + values.push(v); + }, + Ok(None) => {}, // If we failed to read any bytes at all, we reached the end of our TLV // stream and have simply exhausted all entries. Err(ref e) if e == &DecodeError::ShortRead && !track_read.have_read => break, @@ -770,7 +797,9 @@ impl Readable for WithoutLength> { } } impl<'a, T> From<&'a Vec> for WithoutLength<&'a Vec> { - fn from(v: &'a Vec) -> Self { Self(v) } + fn from(v: &'a Vec) -> Self { + Self(v) + } } impl Writeable for WithoutLength<&ScriptBuf> { @@ -824,7 +853,9 @@ impl + Clone, T: Writeable> Writeable for IterableOwned { impl Writeable for $ty - where K: Writeable + Eq + $keybound, V: Writeable + where + K: Writeable + Eq + $keybound, + V: Writeable, { #[inline] fn write(&self, w: &mut W) -> Result<(), io::Error> { @@ -838,7 +869,9 @@ macro_rules! impl_for_map { } impl Readable for $ty - where K: Readable + Eq + $keybound, V: MaybeReadable + where + K: Readable + Eq + $keybound, + V: MaybeReadable, { #[inline] fn read(r: &mut R) -> Result { @@ -856,7 +889,7 @@ macro_rules! impl_for_map { Ok(ret) } } - } + }; } impl_for_map!(BTreeMap, Ord, |_| BTreeMap::new()); @@ -864,7 +897,8 @@ impl_for_map!(HashMap, Hash, |len| hash_map_with_capacity(len)); // HashSet impl Writeable for HashSet -where T: Writeable + Eq + Hash +where + T: Writeable + Eq + Hash, { #[inline] fn write(&self, w: &mut W) -> Result<(), io::Error> { @@ -877,15 +911,19 @@ where T: Writeable + Eq + Hash } impl Readable for HashSet -where T: Readable + Eq + Hash +where + T: Readable + Eq + Hash, { #[inline] fn read(r: &mut R) -> Result { let len: CollectionLength = Readable::read(r)?; - let mut ret = hash_set_with_capacity(cmp::min(len.0 as usize, MAX_BUF_SIZE / core::mem::size_of::())); + let mut ret = hash_set_with_capacity(cmp::min( + len.0 as usize, + MAX_BUF_SIZE / core::mem::size_of::(), + )); for _ in 0..len.0 { if !ret.insert(T::read(r)?) { - return Err(DecodeError::InvalidValue) + return Err(DecodeError::InvalidValue); } } Ok(ret) @@ -1128,7 +1166,8 @@ impl Writeable for PartialSignatureWithNonce { impl Readable for PartialSignatureWithNonce { fn read(r: &mut R) -> Result { let partial_signature_buf: [u8; SECRET_KEY_SIZE] = Readable::read(r)?; - let partial_signature = musig2::types::PartialSignature::from_slice(&partial_signature_buf).map_err(|_| DecodeError::InvalidValue)?; + let partial_signature = musig2::types::PartialSignature::from_slice(&partial_signature_buf) + .map_err(|_| DecodeError::InvalidValue)?; let public_nonce: musig2::types::PublicNonce = Readable::read(r)?; Ok(PartialSignatureWithNonce(partial_signature, public_nonce)) } @@ -1254,14 +1293,13 @@ impl Writeable for Option { Some(ref data) => { BigSize(data.serialized_length() as u64 + 1).write(w)?; data.write(w)?; - } + }, } Ok(()) } } -impl Readable for Option -{ +impl Readable for Option { fn read(r: &mut R) -> Result { let len: BigSize = Readable::read(r)?; match len.0 { @@ -1269,7 +1307,7 @@ impl Readable for Option len => { let mut reader = FixedLengthReader::new(r, len - 1); Ok(Some(Readable::read(&mut reader)?)) - } + }, } } } @@ -1280,7 +1318,6 @@ impl Writeable for Amount { } } - impl Readable for Amount { fn read(r: &mut R) -> Result { let amount: u64 = Readable::read(r)?; @@ -1343,10 +1380,7 @@ impl Readable for OutPoint { fn read(r: &mut R) -> Result { let txid = Readable::read(r)?; let vout = Readable::read(r)?; - Ok(OutPoint { - txid, - vout, - }) + Ok(OutPoint { txid, vout }) } } @@ -1366,13 +1400,17 @@ macro_rules! impl_consensus_ser { let mut reader = BufReader::<_>::new(r); match consensus::encode::Decodable::consensus_decode(&mut reader) { Ok(t) => Ok(t), - Err(consensus::encode::Error::Io(ref e)) if e.kind() == io::ErrorKind::UnexpectedEof => Err(DecodeError::ShortRead), + Err(consensus::encode::Error::Io(ref e)) + if e.kind() == io::ErrorKind::UnexpectedEof => + { + Err(DecodeError::ShortRead) + }, Err(consensus::encode::Error::Io(e)) => Err(DecodeError::Io(e.kind().into())), Err(_) => Err(DecodeError::InvalidValue), } } } - } + }; } impl_consensus_ser!(Transaction); impl_consensus_ser!(TxOut); @@ -1479,11 +1517,8 @@ impl Hostname { /// Check if the chars in `s` are allowed to be included in a [`Hostname`]. pub(crate) fn str_is_valid_hostname(s: &str) -> bool { - s.len() <= 255 && - s.chars().all(|c| - c.is_ascii_alphanumeric() || - c == '.' || c == '_' || c == '-' - ) + s.len() <= 255 + && s.chars().all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '_' || c == '-') } } @@ -1633,10 +1668,10 @@ impl Readable for ClaimId { #[cfg(test)] mod tests { + use crate::prelude::*; + use crate::util::ser::{Hostname, Readable, Writeable}; use bitcoin::hex::FromHex; use bitcoin::secp256k1::ecdsa; - use crate::util::ser::{Readable, Hostname, Writeable}; - use crate::prelude::*; #[test] fn hostname_conversion() { @@ -1680,7 +1715,7 @@ mod tests { "fe00010000", "feffffffff", "ff0000000100000000", - "ffffffffffffffffff" + "ffffffffffffffffff", ]; for i in 0..=7 { let mut stream = crate::io::Cursor::new(>::from_hex(bytes[i]).unwrap()); @@ -1699,14 +1734,20 @@ mod tests { "fd", "fe", "ff", - "" + "", ]; for i in 0..=9 { let mut stream = crate::io::Cursor::new(>::from_hex(err_bytes[i]).unwrap()); if i < 3 { - assert_eq!(super::BigSize::read(&mut stream).err(), Some(crate::ln::msgs::DecodeError::InvalidValue)); + assert_eq!( + super::BigSize::read(&mut stream).err(), + Some(crate::ln::msgs::DecodeError::InvalidValue) + ); } else { - assert_eq!(super::BigSize::read(&mut stream).err(), Some(crate::ln::msgs::DecodeError::ShortRead)); + assert_eq!( + super::BigSize::read(&mut stream).err(), + Some(crate::ln::msgs::DecodeError::ShortRead) + ); } } } diff --git a/lightning/src/util/ser_macros.rs b/lightning/src/util/ser_macros.rs index 0703aac9e84..d046a28e7d8 100644 --- a/lightning/src/util/ser_macros.rs +++ b/lightning/src/util/ser_macros.rs @@ -82,10 +82,11 @@ macro_rules! _encode_tlv { #[doc(hidden)] #[macro_export] macro_rules! _check_encoded_tlv_order { - ($last_type: expr, $type: expr, (static_value, $value: expr)) => { }; + ($last_type: expr, $type: expr, (static_value, $value: expr)) => {}; ($last_type: expr, $type: expr, $fieldty: tt) => { if let Some(t) = $last_type { - #[allow(unused_comparisons)] // Note that $type may be 0 making the following comparison always false + // Note that $type may be 0 making the following comparison always false + #[allow(unused_comparisons)] (debug_assert!(t < $type)) } $last_type = Some($type); @@ -187,22 +188,28 @@ macro_rules! _get_varint_length_prefixed_tlv_length { ($len: expr, $type: expr, $field: expr, (default_value, $default: expr)) => { $crate::_get_varint_length_prefixed_tlv_length!($len, $type, $field, required) }; - ($len: expr, $type: expr, $field: expr, (static_value, $value: expr)) => { - }; + ($len: expr, $type: expr, $field: expr, (static_value, $value: expr)) => {}; ($len: expr, $type: expr, $field: expr, required) => { BigSize($type).write(&mut $len).expect("No in-memory data may fail to serialize"); let field_len = $field.serialized_length(); - BigSize(field_len as u64).write(&mut $len).expect("No in-memory data may fail to serialize"); + BigSize(field_len as u64) + .write(&mut $len) + .expect("No in-memory data may fail to serialize"); $len.0 += field_len; }; ($len: expr, $type: expr, $field: expr, required_vec) => { - $crate::_get_varint_length_prefixed_tlv_length!($len, $type, $crate::util::ser::WithoutLength(&$field), required); + let field = $crate::util::ser::WithoutLength(&$field); + $crate::_get_varint_length_prefixed_tlv_length!($len, $type, field, required); }; ($len: expr, $optional_type: expr, $optional_field: expr, option) => { if let Some(ref field) = $optional_field { - BigSize($optional_type).write(&mut $len).expect("No in-memory data may fail to serialize"); + BigSize($optional_type) + .write(&mut $len) + .expect("No in-memory data may fail to serialize"); let field_len = field.serialized_length(); - BigSize(field_len as u64).write(&mut $len).expect("No in-memory data may fail to serialize"); + BigSize(field_len as u64) + .write(&mut $len) + .expect("No in-memory data may fail to serialize"); $len.0 += field_len; } }; @@ -215,7 +222,8 @@ macro_rules! _get_varint_length_prefixed_tlv_length { $crate::_get_varint_length_prefixed_tlv_length!($len, $type, $field, option); }; ($len: expr, $type: expr, $field: expr, (option, encoding: ($fieldty: ty, $encoding: ident))) => { - $crate::_get_varint_length_prefixed_tlv_length!($len, $type, $field.map(|f| $encoding(f)), option); + let field = $field.map(|f| $encoding(f)); + $crate::_get_varint_length_prefixed_tlv_length!($len, $type, field, option); }; ($len: expr, $type: expr, $field: expr, upgradable_required) => { $crate::_get_varint_length_prefixed_tlv_length!($len, $type, $field, required); @@ -260,17 +268,20 @@ macro_rules! _encode_varint_length_prefixed_tlv { #[macro_export] macro_rules! _check_decoded_tlv_order { ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, (default_value, $default: expr)) => {{ - #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always false - let invalid_order = ($last_seen_type.is_none() || $last_seen_type.unwrap() < $type) && $typ.0 > $type; + // Note that $type may be 0 making the second comparison always false + #[allow(unused_comparisons)] + let invalid_order = + ($last_seen_type.is_none() || $last_seen_type.unwrap() < $type) && $typ.0 > $type; if invalid_order { $field = $default.into(); } }}; - ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, (static_value, $value: expr)) => { - }; + ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, (static_value, $value: expr)) => {}; ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, required) => {{ - #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always false - let invalid_order = ($last_seen_type.is_none() || $last_seen_type.unwrap() < $type) && $typ.0 > $type; + // Note that $type may be 0 making the second comparison always false + #[allow(unused_comparisons)] + let invalid_order = + ($last_seen_type.is_none() || $last_seen_type.unwrap() < $type) && $typ.0 > $type; if invalid_order { return Err(DecodeError::InvalidValue); } @@ -313,7 +324,8 @@ macro_rules! _check_decoded_tlv_order { #[macro_export] macro_rules! _check_missing_tlv { ($last_seen_type: expr, $type: expr, $field: ident, (default_value, $default: expr)) => {{ - #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always false + // Note that $type may be 0 making the second comparison always false + #[allow(unused_comparisons)] let missing_req_type = $last_seen_type.is_none() || $last_seen_type.unwrap() < $type; if missing_req_type { $field = $default.into(); @@ -323,7 +335,8 @@ macro_rules! _check_missing_tlv { $field = $value; }; ($last_seen_type: expr, $type: expr, $field: ident, required) => {{ - #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always false + // Note that $type may be 0 making the second comparison always false + #[allow(unused_comparisons)] let missing_req_type = $last_seen_type.is_none() || $last_seen_type.unwrap() < $type; if missing_req_type { return Err(DecodeError::InvalidValue); @@ -454,8 +467,12 @@ macro_rules! _decode_tlv { #[doc(hidden)] #[macro_export] macro_rules! _decode_tlv_stream_match_check { - ($val: ident, $type: expr, (static_value, $value: expr)) => { false }; - ($val: ident, $type: expr, $fieldty: tt) => { $val == $type } + ($val: ident, $type: expr, (static_value, $value: expr)) => { + false + }; + ($val: ident, $type: expr, $fieldty: tt) => { + $val == $type + }; } /// Implements the TLVs deserialization part in a [`Readable`] implementation of a struct. @@ -706,7 +723,7 @@ macro_rules! write_ver_prefix { ($stream: expr, $this_version: expr, $min_version_that_can_read_this: expr) => { $stream.write_all(&[$this_version; 1])?; $stream.write_all(&[$min_version_that_can_read_this; 1])?; - } + }; } /// Writes out a suffix to an object as a length-prefixed TLV stream which contains potentially @@ -730,14 +747,14 @@ macro_rules! write_tlv_fields { /// serialization logic for this object. This is compared against the /// `$min_version_that_can_read_this` added by [`write_ver_prefix`]. macro_rules! read_ver_prefix { - ($stream: expr, $this_version: expr) => { { + ($stream: expr, $this_version: expr) => {{ let ver: u8 = Readable::read($stream)?; let min_ver: u8 = Readable::read($stream)?; if min_ver > $this_version { return Err(DecodeError::UnknownVersion); } ver - } } + }}; } /// Reads a suffix added by [`write_tlv_fields`]. @@ -1002,9 +1019,15 @@ macro_rules! tlv_stream { } macro_rules! tlv_record_type { - (($type:ty, $wrapper:ident)) => { $type }; - (($type:ty, $wrapper:ident, $encoder:ty)) => { $type }; - ($type:ty) => { $type }; + (($type:ty, $wrapper:ident)) => { + $type + }; + (($type:ty, $wrapper:ident, $encoder:ty)) => { + $type + }; + ($type:ty) => { + $type + }; } macro_rules! tlv_record_ref_type { @@ -1320,7 +1343,9 @@ mod tests { use crate::io::{self, Cursor}; use crate::ln::msgs::DecodeError; - use crate::util::ser::{MaybeReadable, Readable, Writeable, HighZeroBytesDroppedBigSize, VecWriter}; + use crate::util::ser::{ + HighZeroBytesDroppedBigSize, MaybeReadable, Readable, VecWriter, Writeable, + }; use bitcoin::hex::FromHex; use bitcoin::secp256k1::PublicKey; @@ -1339,54 +1364,66 @@ mod tests { #[test] fn tlv_v_short_read() { // We only expect a u32 for type 3 (which we are given), but the L says its 8 bytes. - if let Err(DecodeError::ShortRead) = tlv_reader(&>::from_hex( - concat!("0100", "0208deadbeef1badbeef", "0308deadbeef") - ).unwrap()[..]) { - } else { panic!(); } + let buf = + >::from_hex(concat!("0100", "0208deadbeef1badbeef", "0308deadbeef")).unwrap(); + if let Err(DecodeError::ShortRead) = tlv_reader(&buf[..]) { + } else { + panic!(); + } } #[test] fn tlv_types_out_of_order() { - if let Err(DecodeError::InvalidValue) = tlv_reader(&>::from_hex( - concat!("0100", "0304deadbeef", "0208deadbeef1badbeef") - ).unwrap()[..]) { - } else { panic!(); } + let buf = + >::from_hex(concat!("0100", "0304deadbeef", "0208deadbeef1badbeef")).unwrap(); + if let Err(DecodeError::InvalidValue) = tlv_reader(&buf[..]) { + } else { + panic!(); + } // ...even if its some field we don't understand - if let Err(DecodeError::InvalidValue) = tlv_reader(&>::from_hex( - concat!("0208deadbeef1badbeef", "0100", "0304deadbeef") - ).unwrap()[..]) { - } else { panic!(); } + let buf = + >::from_hex(concat!("0208deadbeef1badbeef", "0100", "0304deadbeef")).unwrap(); + if let Err(DecodeError::InvalidValue) = tlv_reader(&buf[..]) { + } else { + panic!(); + } } #[test] fn tlv_req_type_missing_or_extra() { // It's also bad if they included even fields we don't understand - if let Err(DecodeError::UnknownRequiredFeature) = tlv_reader(&>::from_hex( - concat!("0100", "0208deadbeef1badbeef", "0304deadbeef", "0600") - ).unwrap()[..]) { - } else { panic!(); } + let buf = + >::from_hex(concat!("0100", "0208deadbeef1badbeef", "0304deadbeef", "0600")) + .unwrap(); + if let Err(DecodeError::UnknownRequiredFeature) = tlv_reader(&buf[..]) { + } else { + panic!(); + } // ... or if they're missing fields we need - if let Err(DecodeError::InvalidValue) = tlv_reader(&>::from_hex( - concat!("0100", "0208deadbeef1badbeef") - ).unwrap()[..]) { - } else { panic!(); } + let buf = >::from_hex(concat!("0100", "0208deadbeef1badbeef")).unwrap(); + if let Err(DecodeError::InvalidValue) = tlv_reader(&buf[..]) { + } else { + panic!(); + } // ... even if that field is even - if let Err(DecodeError::InvalidValue) = tlv_reader(&>::from_hex( - concat!("0304deadbeef", "0500") - ).unwrap()[..]) { - } else { panic!(); } + let buf = >::from_hex(concat!("0304deadbeef", "0500")).unwrap(); + if let Err(DecodeError::InvalidValue) = tlv_reader(&buf[..]) { + } else { + panic!(); + } } #[test] fn tlv_simple_good_cases() { - assert_eq!(tlv_reader(&>::from_hex( - concat!("0208deadbeef1badbeef", "03041bad1dea") - ).unwrap()[..]).unwrap(), - (0xdeadbeef1badbeef, 0x1bad1dea, None)); - assert_eq!(tlv_reader(&>::from_hex( - concat!("0208deadbeef1badbeef", "03041bad1dea", "040401020304") - ).unwrap()[..]).unwrap(), - (0xdeadbeef1badbeef, 0x1bad1dea, Some(0x01020304))); + let buf = >::from_hex(concat!("0208deadbeef1badbeef", "03041bad1dea")).unwrap(); + assert_eq!(tlv_reader(&buf[..]).unwrap(), (0xdeadbeef1badbeef, 0x1bad1dea, None)); + let buf = + >::from_hex(concat!("0208deadbeef1badbeef", "03041bad1dea", "040401020304")) + .unwrap(); + assert_eq!( + tlv_reader(&buf[..]).unwrap(), + (0xdeadbeef1badbeef, 0x1bad1dea, Some(0x01020304)) + ); } #[derive(Debug, PartialEq)] @@ -1402,39 +1439,42 @@ mod tests { let mut b = 0; let mut c: Option = None; decode_tlv_stream!(&mut s, {(2, a, upgradable_required), (3, b, upgradable_required), (4, c, upgradable_option)}); - Ok(Some(TestUpgradable { a, b, c, })) + Ok(Some(TestUpgradable { a, b, c })) } #[test] fn upgradable_tlv_simple_good_cases() { - assert_eq!(upgradable_tlv_reader(&>::from_hex( - concat!("0204deadbeef", "03041bad1dea", "0404deadbeef") - ).unwrap()[..]).unwrap(), - Some(TestUpgradable { a: 0xdeadbeef, b: 0x1bad1dea, c: Some(0xdeadbeef) })); - - assert_eq!(upgradable_tlv_reader(&>::from_hex( - concat!("0204deadbeef", "03041bad1dea") - ).unwrap()[..]).unwrap(), - Some(TestUpgradable { a: 0xdeadbeef, b: 0x1bad1dea, c: None})); + let buf = + >::from_hex(concat!("0204deadbeef", "03041bad1dea", "0404deadbeef")).unwrap(); + assert_eq!( + upgradable_tlv_reader(&buf[..]).unwrap(), + Some(TestUpgradable { a: 0xdeadbeef, b: 0x1bad1dea, c: Some(0xdeadbeef) }) + ); + + let buf = >::from_hex(concat!("0204deadbeef", "03041bad1dea")).unwrap(); + assert_eq!( + upgradable_tlv_reader(&buf[..]).unwrap(), + Some(TestUpgradable { a: 0xdeadbeef, b: 0x1bad1dea, c: None }) + ); } #[test] fn missing_required_upgradable() { - if let Err(DecodeError::InvalidValue) = upgradable_tlv_reader(&>::from_hex( - concat!("0100", "0204deadbeef") - ).unwrap()[..]) { - } else { panic!(); } - if let Err(DecodeError::InvalidValue) = upgradable_tlv_reader(&>::from_hex( - concat!("0100", "03041bad1dea") - ).unwrap()[..]) { - } else { panic!(); } + let buf = >::from_hex(concat!("0100", "0204deadbeef")).unwrap(); + if let Err(DecodeError::InvalidValue) = upgradable_tlv_reader(&buf[..]) { + } else { + panic!(); + } + let buf = >::from_hex(concat!("0100", "03041bad1dea")).unwrap(); + if let Err(DecodeError::InvalidValue) = upgradable_tlv_reader(&buf[..]) { + } else { + panic!(); + } } /// A "V1" enum with only one variant enum InnerEnumV1 { - StructVariantA { - field: u32, - }, + StructVariantA { field: u32 }, } impl_writeable_tlv_based_enum_upgradable!(InnerEnumV1, @@ -1455,12 +1495,8 @@ mod tests { /// An upgraded version of [`InnerEnumV1`] that added a second variant enum InnerEnumV2 { - StructVariantA { - field: u32, - }, - StructVariantB { - field2: u64, - } + StructVariantA { field: u32 }, + StructVariantB { field2: u64 }, } impl_writeable_tlv_based_enum_upgradable!(InnerEnumV2, @@ -1489,7 +1525,8 @@ mod tests { let serialized_bytes = OuterStructOptionalEnumV2 { inner_enum: Some(InnerEnumV2::StructVariantB { field2: 64 }), other_field: 0x1bad1dea, - }.encode(); + } + .encode(); let mut s = Cursor::new(serialized_bytes); let outer_struct: OuterStructOptionalEnumV1 = Readable::read(&mut s).unwrap(); @@ -1509,9 +1546,7 @@ mod tests { read_tlv_fields!(reader, { (0, inner_enum, upgradable_required), }); - Ok(Some(Self { - inner_enum: inner_enum.0.unwrap(), - })) + Ok(Some(Self { inner_enum: inner_enum.0.unwrap() })) } } @@ -1534,7 +1569,6 @@ mod tests { (2, other_field, required), }); - #[test] fn upgradable_enum_required() { // Test downgrading from an `OuterOuterStruct` (i.e. test downgrading an @@ -1547,7 +1581,8 @@ mod tests { let serialized_bytes = OuterOuterStruct { outer_struct: Some(OuterStructRequiredEnum { inner_enum: dummy_inner_enum }), other_field: 0x1bad1dea, - }.encode(); + } + .encode(); let mut s = Cursor::new(serialized_bytes); let outer_outer_struct: OuterOuterStruct = Readable::read(&mut s).unwrap(); @@ -1556,7 +1591,17 @@ mod tests { } // BOLT TLV test cases - fn tlv_reader_n1(s: &[u8]) -> Result<(Option>, Option, Option<(PublicKey, u64, u64)>, Option), DecodeError> { + fn tlv_reader_n1( + s: &[u8], + ) -> Result< + ( + Option>, + Option, + Option<(PublicKey, u64, u64)>, + Option, + ), + DecodeError, + > { let mut s = Cursor::new(s); let mut tlv1: Option> = None; let mut tlv2: Option = None; @@ -1570,9 +1615,13 @@ mod tests { fn bolt_tlv_bogus_stream() { macro_rules! do_test { ($stream: expr, $reason: ident) => { - if let Err(DecodeError::$reason) = tlv_reader_n1(&>::from_hex($stream).unwrap()[..]) { - } else { panic!(); } - } + if let Err(DecodeError::$reason) = + tlv_reader_n1(&>::from_hex($stream).unwrap()[..]) + { + } else { + panic!(); + } + }; } // TLVs from the BOLT test cases which should not decode as either n1 or n2 @@ -1595,9 +1644,13 @@ mod tests { fn bolt_tlv_bogus_n1_stream() { macro_rules! do_test { ($stream: expr, $reason: ident) => { - if let Err(DecodeError::$reason) = tlv_reader_n1(&>::from_hex($stream).unwrap()[..]) { - } else { panic!(); } - } + if let Err(DecodeError::$reason) = + tlv_reader_n1(&>::from_hex($stream).unwrap()[..]) + { + } else { + panic!(); + } + }; } // TLVs from the BOLT test cases which should not decode as n1 @@ -1612,7 +1665,14 @@ mod tests { do_test!(concat!("01", "08", "0001000000000000"), InvalidValue); do_test!(concat!("02", "07", "01010101010101"), ShortRead); do_test!(concat!("02", "09", "010101010101010101"), InvalidValue); - do_test!(concat!("03", "21", "023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb"), ShortRead); + do_test!( + concat!( + "03", + "21", + "023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb" + ), + ShortRead + ); do_test!(concat!("03", "29", "023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb0000000000000001"), ShortRead); do_test!(concat!("03", "30", "023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb000000000000000100000000000001"), ShortRead); do_test!(concat!("03", "31", "043da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb00000000000000010000000000000002"), InvalidValue); @@ -1623,7 +1683,10 @@ mod tests { do_test!(concat!("00", "00"), UnknownRequiredFeature); do_test!(concat!("02", "08", "0000000000000226", "01", "01", "2a"), InvalidValue); - do_test!(concat!("02", "08", "0000000000000231", "02", "08", "0000000000000451"), InvalidValue); + do_test!( + concat!("02", "08", "0000000000000231", "02", "08", "0000000000000451"), + InvalidValue + ); do_test!(concat!("1f", "00", "0f", "01", "2a"), InvalidValue); do_test!(concat!("1f", "00", "1f", "01", "2a"), InvalidValue); @@ -1635,13 +1698,17 @@ mod tests { fn bolt_tlv_valid_n1_stream() { macro_rules! do_test { ($stream: expr, $tlv1: expr, $tlv2: expr, $tlv3: expr, $tlv4: expr) => { - if let Ok((tlv1, tlv2, tlv3, tlv4)) = tlv_reader_n1(&>::from_hex($stream).unwrap()[..]) { + if let Ok((tlv1, tlv2, tlv3, tlv4)) = + tlv_reader_n1(&>::from_hex($stream).unwrap()[..]) + { assert_eq!(tlv1.map(|v| v.0), $tlv1); assert_eq!(tlv2, $tlv2); assert_eq!(tlv3, $tlv3); assert_eq!(tlv4, $tlv4); - } else { panic!(); } - } + } else { + panic!(); + } + }; } do_test!(concat!(""), None, None, None, None); @@ -1660,8 +1727,20 @@ mod tests { do_test!(concat!("01", "05", "0100000000"), Some(4294967296), None, None, None); do_test!(concat!("01", "06", "010000000000"), Some(1099511627776), None, None, None); do_test!(concat!("01", "07", "01000000000000"), Some(281474976710656), None, None, None); - do_test!(concat!("01", "08", "0100000000000000"), Some(72057594037927936), None, None, None); - do_test!(concat!("02", "08", "0000000000000226"), None, Some((0 << 30) | (0 << 5) | (550 << 0)), None, None); + do_test!( + concat!("01", "08", "0100000000000000"), + Some(72057594037927936), + None, + None, + None + ); + do_test!( + concat!("02", "08", "0000000000000226"), + None, + Some((0 << 30) | (0 << 5) | (550 << 0)), + None, + None + ); do_test!(concat!("03", "31", "023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb00000000000000010000000000000002"), None, None, Some(( PublicKey::from_slice(&>::from_hex("023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb").unwrap()[..]).unwrap(), 1, 2)), @@ -1677,7 +1756,7 @@ mod tests { assert_eq!(stream.0, >::from_hex("03010101").unwrap()); stream.0.clear(); - _encode_varint_length_prefixed_tlv!(&mut stream, {(1, Some(1u8), option)}); + _encode_varint_length_prefixed_tlv!(&mut stream, { (1, Some(1u8), option) }); assert_eq!(stream.0, >::from_hex("03010101").unwrap()); stream.0.clear(); diff --git a/lightning/src/util/test_channel_signer.rs b/lightning/src/util/test_channel_signer.rs index f3ef4dc1557..263e054b5cc 100644 --- a/lightning/src/util/test_channel_signer.rs +++ b/lightning/src/util/test_channel_signer.rs @@ -7,41 +7,45 @@ // You may not use this file except in accordance with one or both of these // licenses. +use crate::ln::chan_utils::{ + ChannelPublicKeys, ChannelTransactionParameters, ClosingTransaction, CommitmentTransaction, + HTLCOutputInCommitment, HolderCommitmentTransaction, TrustedCommitmentTransaction, +}; use crate::ln::channel::{ANCHOR_OUTPUT_VALUE_SATOSHI, MIN_CHAN_DUST_LIMIT_SATOSHIS}; -use crate::ln::chan_utils::{HTLCOutputInCommitment, ChannelPublicKeys, HolderCommitmentTransaction, CommitmentTransaction, ChannelTransactionParameters, TrustedCommitmentTransaction, ClosingTransaction}; -use crate::ln::channel_keys::{HtlcKey}; +use crate::ln::channel_keys::HtlcKey; use crate::ln::msgs; -use crate::types::payment::PaymentPreimage; -use crate::sign::{InMemorySigner, ChannelSigner}; use crate::sign::ecdsa::EcdsaChannelSigner; +use crate::sign::{ChannelSigner, InMemorySigner}; +use crate::types::payment::PaymentPreimage; #[allow(unused_imports)] use crate::prelude::*; +#[cfg(test)] +use crate::sync::MutexGuard; +use crate::sync::{Arc, Mutex}; use core::cmp; -use crate::sync::{Mutex, Arc}; -#[cfg(test)] use crate::sync::MutexGuard; -use bitcoin::transaction::Transaction; use bitcoin::hashes::Hash; use bitcoin::sighash; use bitcoin::sighash::EcdsaSighashType; +use bitcoin::transaction::Transaction; -use bitcoin::secp256k1; +use crate::io::Error; #[cfg(taproot)] -use bitcoin::secp256k1::All; -use bitcoin::secp256k1::{SecretKey, PublicKey}; -use bitcoin::secp256k1::{Secp256k1, ecdsa::Signature}; +use crate::ln::msgs::PartialSignatureWithNonce; #[cfg(taproot)] -use musig2::types::{PartialSignature, PublicNonce}; +use crate::sign::taproot::TaprootChannelSigner; use crate::sign::HTLCDescriptor; -use crate::util::ser::{Writeable, Writer}; -use crate::io::Error; use crate::types::features::ChannelTypeFeatures; +use crate::util::ser::{Writeable, Writer}; +use bitcoin::secp256k1; #[cfg(taproot)] -use crate::ln::msgs::PartialSignatureWithNonce; +use bitcoin::secp256k1::All; +use bitcoin::secp256k1::{ecdsa::Signature, Secp256k1}; +use bitcoin::secp256k1::{PublicKey, SecretKey}; #[cfg(taproot)] -use crate::sign::taproot::TaprootChannelSigner; +use musig2::types::{PartialSignature, PublicNonce}; /// Initial value for revoked commitment downward counter pub const INITIAL_REVOKED_COMMITMENT_NUMBER: u64 = 1 << 48; @@ -120,11 +124,7 @@ impl TestChannelSigner { /// Construct an TestChannelSigner pub fn new(inner: InMemorySigner) -> Self { let state = Arc::new(Mutex::new(EnforcementState::new())); - Self { - inner, - state, - disable_revocation_policy_check: false, - } + Self { inner, state, disable_revocation_policy_check: false } } /// Construct an TestChannelSigner with externally managed storage @@ -132,15 +132,16 @@ impl TestChannelSigner { /// Since there are multiple copies of this struct for each channel, some coordination is needed /// so that all copies are aware of enforcement state. A pointer to this state is provided /// here, usually by an implementation of KeysInterface. - pub fn new_with_revoked(inner: InMemorySigner, state: Arc>, disable_revocation_policy_check: bool) -> Self { - Self { - inner, - state, - disable_revocation_policy_check, - } + pub fn new_with_revoked( + inner: InMemorySigner, state: Arc>, + disable_revocation_policy_check: bool, + ) -> Self { + Self { inner, state, disable_revocation_policy_check } } - pub fn channel_type_features(&self) -> &ChannelTypeFeatures { self.inner.channel_type_features().unwrap() } + pub fn channel_type_features(&self) -> &ChannelTypeFeatures { + self.inner.channel_type_features().unwrap() + } #[cfg(test)] pub fn get_enforcement_state(&self) -> MutexGuard { @@ -164,7 +165,9 @@ impl TestChannelSigner { } impl ChannelSigner for TestChannelSigner { - fn get_per_commitment_point(&self, idx: u64, secp_ctx: &Secp256k1) -> Result { + fn get_per_commitment_point( + &self, idx: u64, secp_ctx: &Secp256k1, + ) -> Result { #[cfg(test)] if !self.is_signer_available(SignerOp::GetPerCommitmentPoint) { return Err(()); @@ -186,10 +189,18 @@ impl ChannelSigner for TestChannelSigner { self.inner.release_commitment_secret(idx) } - fn validate_holder_commitment(&self, holder_tx: &HolderCommitmentTransaction, _outbound_htlc_preimages: Vec) -> Result<(), ()> { + fn validate_holder_commitment( + &self, holder_tx: &HolderCommitmentTransaction, + _outbound_htlc_preimages: Vec, + ) -> Result<(), ()> { let mut state = self.state.lock().unwrap(); let idx = holder_tx.commitment_number(); - assert!(idx == state.last_holder_commitment || idx == state.last_holder_commitment - 1, "expecting to validate the current or next holder commitment - trying {}, current {}", idx, state.last_holder_commitment); + assert!( + idx == state.last_holder_commitment || idx == state.last_holder_commitment - 1, + "expecting to validate the current or next holder commitment - trying {}, current {}", + idx, + state.last_holder_commitment + ); state.last_holder_commitment = idx; Ok(()) } @@ -205,9 +216,13 @@ impl ChannelSigner for TestChannelSigner { Ok(()) } - fn pubkeys(&self) -> &ChannelPublicKeys { self.inner.pubkeys() } + fn pubkeys(&self) -> &ChannelPublicKeys { + self.inner.pubkeys() + } - fn channel_keys_id(&self) -> [u8; 32] { self.inner.channel_keys_id() } + fn channel_keys_id(&self) -> [u8; 32] { + self.inner.channel_keys_id() + } fn provide_channel_parameters(&mut self, channel_parameters: &ChannelTransactionParameters) { self.inner.provide_channel_parameters(channel_parameters) @@ -215,7 +230,10 @@ impl ChannelSigner for TestChannelSigner { } impl EcdsaChannelSigner for TestChannelSigner { - fn sign_counterparty_commitment(&self, commitment_tx: &CommitmentTransaction, inbound_htlc_preimages: Vec, outbound_htlc_preimages: Vec, secp_ctx: &Secp256k1) -> Result<(Signature, Vec), ()> { + fn sign_counterparty_commitment( + &self, commitment_tx: &CommitmentTransaction, inbound_htlc_preimages: Vec, + outbound_htlc_preimages: Vec, secp_ctx: &Secp256k1, + ) -> Result<(Signature, Vec), ()> { self.verify_counterparty_commitment_tx(commitment_tx, secp_ctx); { @@ -228,17 +246,39 @@ impl EcdsaChannelSigner for TestChannelSigner { let last_commitment_number = state.last_counterparty_commitment; // These commitment numbers are backwards counting. We expect either the same as the previously encountered, // or the next one. - assert!(last_commitment_number == actual_commitment_number || last_commitment_number - 1 == actual_commitment_number, "{} doesn't come after {}", actual_commitment_number, last_commitment_number); + assert!( + last_commitment_number == actual_commitment_number + || last_commitment_number - 1 == actual_commitment_number, + "{} doesn't come after {}", + actual_commitment_number, + last_commitment_number + ); // Ensure that the counterparty doesn't get more than two broadcastable commitments - // the last and the one we are trying to sign - assert!(actual_commitment_number >= state.last_counterparty_revoked_commitment - 2, "cannot sign a commitment if second to last wasn't revoked - signing {} revoked {}", actual_commitment_number, state.last_counterparty_revoked_commitment); - state.last_counterparty_commitment = cmp::min(last_commitment_number, actual_commitment_number) + assert!( + actual_commitment_number >= state.last_counterparty_revoked_commitment - 2, + "cannot sign a commitment if second to last wasn't revoked - signing {} revoked {}", + actual_commitment_number, + state.last_counterparty_revoked_commitment + ); + state.last_counterparty_commitment = + cmp::min(last_commitment_number, actual_commitment_number) } - Ok(self.inner.sign_counterparty_commitment(commitment_tx, inbound_htlc_preimages, outbound_htlc_preimages, secp_ctx).unwrap()) + Ok(self + .inner + .sign_counterparty_commitment( + commitment_tx, + inbound_htlc_preimages, + outbound_htlc_preimages, + secp_ctx, + ) + .unwrap()) } - fn sign_holder_commitment(&self, commitment_tx: &HolderCommitmentTransaction, secp_ctx: &Secp256k1) -> Result { + fn sign_holder_commitment( + &self, commitment_tx: &HolderCommitmentTransaction, secp_ctx: &Secp256k1, + ) -> Result { #[cfg(test)] if !self.is_signer_available(SignerOp::SignHolderCommitment) { return Err(()); @@ -246,7 +286,9 @@ impl EcdsaChannelSigner for TestChannelSigner { let trusted_tx = self.verify_holder_commitment_tx(commitment_tx, secp_ctx); let state = self.state.lock().unwrap(); let commitment_number = trusted_tx.commitment_number(); - if state.last_holder_revoked_commitment - 1 != commitment_number && state.last_holder_revoked_commitment - 2 != commitment_number { + if state.last_holder_revoked_commitment - 1 != commitment_number + && state.last_holder_revoked_commitment - 2 != commitment_number + { if !self.disable_revocation_policy_check { panic!("can only sign the next two unrevoked commitment numbers, revoked={} vs requested={} for {}", state.last_holder_revoked_commitment, commitment_number, self.inner.commitment_seed[0]) @@ -255,38 +297,63 @@ impl EcdsaChannelSigner for TestChannelSigner { Ok(self.inner.sign_holder_commitment(commitment_tx, secp_ctx).unwrap()) } - #[cfg(any(test,feature = "unsafe_revoked_tx_signing"))] - fn unsafe_sign_holder_commitment(&self, commitment_tx: &HolderCommitmentTransaction, secp_ctx: &Secp256k1) -> Result { + #[cfg(any(test, feature = "unsafe_revoked_tx_signing"))] + fn unsafe_sign_holder_commitment( + &self, commitment_tx: &HolderCommitmentTransaction, secp_ctx: &Secp256k1, + ) -> Result { Ok(self.inner.unsafe_sign_holder_commitment(commitment_tx, secp_ctx).unwrap()) } - fn sign_justice_revoked_output(&self, justice_tx: &Transaction, input: usize, amount: u64, per_commitment_key: &SecretKey, secp_ctx: &Secp256k1) -> Result { + fn sign_justice_revoked_output( + &self, justice_tx: &Transaction, input: usize, amount: u64, per_commitment_key: &SecretKey, + secp_ctx: &Secp256k1, + ) -> Result { #[cfg(test)] if !self.is_signer_available(SignerOp::SignJusticeRevokedOutput) { return Err(()); } - Ok(EcdsaChannelSigner::sign_justice_revoked_output(&self.inner, justice_tx, input, amount, per_commitment_key, secp_ctx).unwrap()) - } - - fn sign_justice_revoked_htlc(&self, justice_tx: &Transaction, input: usize, amount: u64, per_commitment_key: &SecretKey, htlc: &HTLCOutputInCommitment, secp_ctx: &Secp256k1) -> Result { + Ok(EcdsaChannelSigner::sign_justice_revoked_output( + &self.inner, + justice_tx, + input, + amount, + per_commitment_key, + secp_ctx, + ) + .unwrap()) + } + + fn sign_justice_revoked_htlc( + &self, justice_tx: &Transaction, input: usize, amount: u64, per_commitment_key: &SecretKey, + htlc: &HTLCOutputInCommitment, secp_ctx: &Secp256k1, + ) -> Result { #[cfg(test)] if !self.is_signer_available(SignerOp::SignJusticeRevokedHtlc) { return Err(()); } - Ok(EcdsaChannelSigner::sign_justice_revoked_htlc(&self.inner, justice_tx, input, amount, per_commitment_key, htlc, secp_ctx).unwrap()) + Ok(EcdsaChannelSigner::sign_justice_revoked_htlc( + &self.inner, + justice_tx, + input, + amount, + per_commitment_key, + htlc, + secp_ctx, + ) + .unwrap()) } fn sign_holder_htlc_transaction( &self, htlc_tx: &Transaction, input: usize, htlc_descriptor: &HTLCDescriptor, - secp_ctx: &Secp256k1 + secp_ctx: &Secp256k1, ) -> Result { #[cfg(test)] if !self.is_signer_available(SignerOp::SignHolderHtlcTransaction) { return Err(()); } let state = self.state.lock().unwrap(); - if state.last_holder_revoked_commitment - 1 != htlc_descriptor.per_commitment_number && - state.last_holder_revoked_commitment - 2 != htlc_descriptor.per_commitment_number + if state.last_holder_revoked_commitment - 1 != htlc_descriptor.per_commitment_number + && state.last_holder_revoked_commitment - 2 != htlc_descriptor.per_commitment_number { if !self.disable_revocation_policy_check { panic!("can only sign the next two unrevoked commitment numbers, revoked={} vs requested={} for {}", @@ -302,34 +369,60 @@ impl EcdsaChannelSigner for TestChannelSigner { } else { EcdsaSighashType::All }; - let sighash = &sighash::SighashCache::new(&*htlc_tx).p2wsh_signature_hash( - input, &witness_script, htlc_descriptor.htlc.to_bitcoin_amount(), sighash_type - ).unwrap(); + let sighash = &sighash::SighashCache::new(&*htlc_tx) + .p2wsh_signature_hash( + input, + &witness_script, + htlc_descriptor.htlc.to_bitcoin_amount(), + sighash_type, + ) + .unwrap(); let countersignatory_htlc_key = HtlcKey::from_basepoint( - &secp_ctx, &self.inner.counterparty_pubkeys().unwrap().htlc_basepoint, &htlc_descriptor.per_commitment_point, + &secp_ctx, + &self.inner.counterparty_pubkeys().unwrap().htlc_basepoint, + &htlc_descriptor.per_commitment_point, ); - secp_ctx.verify_ecdsa( - &hash_to_message!(sighash.as_byte_array()), &htlc_descriptor.counterparty_sig, &countersignatory_htlc_key.to_public_key() - ).unwrap(); + secp_ctx + .verify_ecdsa( + &hash_to_message!(sighash.as_byte_array()), + &htlc_descriptor.counterparty_sig, + &countersignatory_htlc_key.to_public_key(), + ) + .unwrap(); } - Ok(EcdsaChannelSigner::sign_holder_htlc_transaction(&self.inner, htlc_tx, input, htlc_descriptor, secp_ctx).unwrap()) + self.inner.sign_holder_htlc_transaction(htlc_tx, input, htlc_descriptor, secp_ctx) } - fn sign_counterparty_htlc_transaction(&self, htlc_tx: &Transaction, input: usize, amount: u64, per_commitment_point: &PublicKey, htlc: &HTLCOutputInCommitment, secp_ctx: &Secp256k1) -> Result { + fn sign_counterparty_htlc_transaction( + &self, htlc_tx: &Transaction, input: usize, amount: u64, per_commitment_point: &PublicKey, + htlc: &HTLCOutputInCommitment, secp_ctx: &Secp256k1, + ) -> Result { #[cfg(test)] if !self.is_signer_available(SignerOp::SignCounterpartyHtlcTransaction) { return Err(()); } - Ok(EcdsaChannelSigner::sign_counterparty_htlc_transaction(&self.inner, htlc_tx, input, amount, per_commitment_point, htlc, secp_ctx).unwrap()) - } - - fn sign_closing_transaction(&self, closing_tx: &ClosingTransaction, secp_ctx: &Secp256k1) -> Result { + Ok(EcdsaChannelSigner::sign_counterparty_htlc_transaction( + &self.inner, + htlc_tx, + input, + amount, + per_commitment_point, + htlc, + secp_ctx, + ) + .unwrap()) + } + + fn sign_closing_transaction( + &self, closing_tx: &ClosingTransaction, secp_ctx: &Secp256k1, + ) -> Result { #[cfg(test)] if !self.is_signer_available(SignerOp::SignClosingTransaction) { return Err(()); } - closing_tx.verify(self.inner.funding_outpoint().unwrap().into_bitcoin_outpoint()) + closing_tx + .verify(self.inner.funding_outpoint().unwrap().into_bitcoin_outpoint()) .expect("derived different closing transaction"); Ok(self.inner.sign_closing_transaction(closing_tx, secp_ctx).unwrap()) } @@ -340,7 +433,10 @@ impl EcdsaChannelSigner for TestChannelSigner { debug_assert!(MIN_CHAN_DUST_LIMIT_SATOSHIS > ANCHOR_OUTPUT_VALUE_SATOSHI); // As long as our minimum dust limit is enforced and is greater than our anchor output // value, an anchor output can only have an index within [0, 1]. - assert!(anchor_tx.input[input].previous_output.vout == 0 || anchor_tx.input[input].previous_output.vout == 1); + assert!( + anchor_tx.input[input].previous_output.vout == 0 + || anchor_tx.input[input].previous_output.vout == 1 + ); #[cfg(test)] if !self.is_signer_available(SignerOp::SignHolderAnchorInput) { return Err(()); @@ -349,7 +445,7 @@ impl EcdsaChannelSigner for TestChannelSigner { } fn sign_channel_announcement_with_funding_key( - &self, msg: &msgs::UnsignedChannelAnnouncement, secp_ctx: &Secp256k1 + &self, msg: &msgs::UnsignedChannelAnnouncement, secp_ctx: &Secp256k1, ) -> Result { self.inner.sign_channel_announcement_with_funding_key(msg, secp_ctx) } @@ -365,39 +461,64 @@ impl EcdsaChannelSigner for TestChannelSigner { #[cfg(taproot)] #[allow(unused)] impl TaprootChannelSigner for TestChannelSigner { - fn generate_local_nonce_pair(&self, commitment_number: u64, secp_ctx: &Secp256k1) -> PublicNonce { + fn generate_local_nonce_pair( + &self, commitment_number: u64, secp_ctx: &Secp256k1, + ) -> PublicNonce { todo!() } - fn partially_sign_counterparty_commitment(&self, counterparty_nonce: PublicNonce, commitment_tx: &CommitmentTransaction, inbound_htlc_preimages: Vec, outbound_htlc_preimages: Vec, secp_ctx: &Secp256k1) -> Result<(PartialSignatureWithNonce, Vec), ()> { + fn partially_sign_counterparty_commitment( + &self, counterparty_nonce: PublicNonce, commitment_tx: &CommitmentTransaction, + inbound_htlc_preimages: Vec, + outbound_htlc_preimages: Vec, secp_ctx: &Secp256k1, + ) -> Result<(PartialSignatureWithNonce, Vec), ()> { todo!() } - fn finalize_holder_commitment(&self, commitment_tx: &HolderCommitmentTransaction, counterparty_partial_signature: PartialSignatureWithNonce, secp_ctx: &Secp256k1) -> Result { + fn finalize_holder_commitment( + &self, commitment_tx: &HolderCommitmentTransaction, + counterparty_partial_signature: PartialSignatureWithNonce, secp_ctx: &Secp256k1, + ) -> Result { todo!() } - fn sign_justice_revoked_output(&self, justice_tx: &Transaction, input: usize, amount: u64, per_commitment_key: &SecretKey, secp_ctx: &Secp256k1) -> Result { + fn sign_justice_revoked_output( + &self, justice_tx: &Transaction, input: usize, amount: u64, per_commitment_key: &SecretKey, + secp_ctx: &Secp256k1, + ) -> Result { todo!() } - fn sign_justice_revoked_htlc(&self, justice_tx: &Transaction, input: usize, amount: u64, per_commitment_key: &SecretKey, htlc: &HTLCOutputInCommitment, secp_ctx: &Secp256k1) -> Result { + fn sign_justice_revoked_htlc( + &self, justice_tx: &Transaction, input: usize, amount: u64, per_commitment_key: &SecretKey, + htlc: &HTLCOutputInCommitment, secp_ctx: &Secp256k1, + ) -> Result { todo!() } - fn sign_holder_htlc_transaction(&self, htlc_tx: &Transaction, input: usize, htlc_descriptor: &HTLCDescriptor, secp_ctx: &Secp256k1) -> Result { + fn sign_holder_htlc_transaction( + &self, htlc_tx: &Transaction, input: usize, htlc_descriptor: &HTLCDescriptor, + secp_ctx: &Secp256k1, + ) -> Result { todo!() } - fn sign_counterparty_htlc_transaction(&self, htlc_tx: &Transaction, input: usize, amount: u64, per_commitment_point: &PublicKey, htlc: &HTLCOutputInCommitment, secp_ctx: &Secp256k1) -> Result { + fn sign_counterparty_htlc_transaction( + &self, htlc_tx: &Transaction, input: usize, amount: u64, per_commitment_point: &PublicKey, + htlc: &HTLCOutputInCommitment, secp_ctx: &Secp256k1, + ) -> Result { todo!() } - fn partially_sign_closing_transaction(&self, closing_tx: &ClosingTransaction, secp_ctx: &Secp256k1) -> Result { + fn partially_sign_closing_transaction( + &self, closing_tx: &ClosingTransaction, secp_ctx: &Secp256k1, + ) -> Result { todo!() } - fn sign_holder_anchor_input(&self, anchor_tx: &Transaction, input: usize, secp_ctx: &Secp256k1) -> Result { + fn sign_holder_anchor_input( + &self, anchor_tx: &Transaction, input: usize, secp_ctx: &Secp256k1, + ) -> Result { todo!() } } @@ -414,18 +535,30 @@ impl Writeable for TestChannelSigner { } impl TestChannelSigner { - fn verify_counterparty_commitment_tx<'a, T: secp256k1::Signing + secp256k1::Verification>(&self, commitment_tx: &'a CommitmentTransaction, secp_ctx: &Secp256k1) -> TrustedCommitmentTransaction<'a> { - commitment_tx.verify( - &self.inner.get_channel_parameters().unwrap().as_counterparty_broadcastable(), - self.inner.counterparty_pubkeys().unwrap(), self.inner.pubkeys(), secp_ctx - ).expect("derived different per-tx keys or built transaction") - } - - fn verify_holder_commitment_tx<'a, T: secp256k1::Signing + secp256k1::Verification>(&self, commitment_tx: &'a CommitmentTransaction, secp_ctx: &Secp256k1) -> TrustedCommitmentTransaction<'a> { - commitment_tx.verify( - &self.inner.get_channel_parameters().unwrap().as_holder_broadcastable(), - self.inner.pubkeys(), self.inner.counterparty_pubkeys().unwrap(), secp_ctx - ).expect("derived different per-tx keys or built transaction") + fn verify_counterparty_commitment_tx<'a, T: secp256k1::Signing + secp256k1::Verification>( + &self, commitment_tx: &'a CommitmentTransaction, secp_ctx: &Secp256k1, + ) -> TrustedCommitmentTransaction<'a> { + commitment_tx + .verify( + &self.inner.get_channel_parameters().unwrap().as_counterparty_broadcastable(), + self.inner.counterparty_pubkeys().unwrap(), + self.inner.pubkeys(), + secp_ctx, + ) + .expect("derived different per-tx keys or built transaction") + } + + fn verify_holder_commitment_tx<'a, T: secp256k1::Signing + secp256k1::Verification>( + &self, commitment_tx: &'a CommitmentTransaction, secp_ctx: &Secp256k1, + ) -> TrustedCommitmentTransaction<'a> { + commitment_tx + .verify( + &self.inner.get_channel_parameters().unwrap().as_holder_broadcastable(), + self.inner.pubkeys(), + self.inner.counterparty_pubkeys().unwrap(), + secp_ctx, + ) + .expect("derived different per-tx keys or built transaction") } } diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 07b2b19b0d6..27bce67e862 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -11,7 +11,6 @@ use crate::blinded_path::message::MessageContext; use crate::blinded_path::message::{BlindedMessagePath, MessageForwardNode}; use crate::blinded_path::payment::{BlindedPaymentPath, ReceiveTlvs}; use crate::chain; -use crate::chain::WatchedOutput; use crate::chain::chaininterface; use crate::chain::chaininterface::ConfirmationTarget; #[cfg(test)] @@ -20,60 +19,70 @@ use crate::chain::chainmonitor; use crate::chain::channelmonitor; use crate::chain::channelmonitor::MonitorEvent; use crate::chain::transaction::OutPoint; -use crate::routing::router::{CandidateRouteHop, FirstHopCandidate, PublicHopCandidate, PrivateHopCandidate}; -use crate::sign; +use crate::chain::WatchedOutput; use crate::events; -use crate::events::bump_transaction::{WalletSource, Utxo}; -use crate::ln::types::ChannelId; -use crate::ln::channel_state::ChannelDetails; -use crate::ln::channelmanager; +use crate::events::bump_transaction::{Utxo, WalletSource}; #[cfg(test)] use crate::ln::chan_utils::CommitmentTransaction; -use crate::types::features::{ChannelFeatures, InitFeatures, NodeFeatures}; +use crate::ln::channel_state::ChannelDetails; +use crate::ln::channelmanager; use crate::ln::inbound_payment::ExpandedKey; -use crate::ln::{msgs, wire}; use crate::ln::msgs::LightningError; use crate::ln::script::ShutdownScript; +use crate::ln::types::ChannelId; +use crate::ln::{msgs, wire}; use crate::offers::invoice::UnsignedBolt12Invoice; -use crate::onion_message::messenger::{DefaultMessageRouter, Destination, MessageRouter, OnionMessagePath}; +use crate::onion_message::messenger::{ + DefaultMessageRouter, Destination, MessageRouter, OnionMessagePath, +}; use crate::routing::gossip::{EffectiveCapacity, NetworkGraph, NodeId, RoutingFees}; +use crate::routing::router::{ + CandidateRouteHop, FirstHopCandidate, PrivateHopCandidate, PublicHopCandidate, +}; +use crate::routing::router::{ + DefaultRouter, InFlightHtlcs, Path, Route, RouteHintHop, RouteParameters, Router, + ScorerAccountingForInFlightHtlcs, +}; +use crate::routing::scoring::{ChannelUsage, ScoreLookUp, ScoreUpdate}; use crate::routing::utxo::{UtxoLookup, UtxoLookupError, UtxoResult}; -use crate::routing::router::{DefaultRouter, InFlightHtlcs, Path, Route, RouteParameters, RouteHintHop, Router, ScorerAccountingForInFlightHtlcs}; -use crate::routing::scoring::{ChannelUsage, ScoreUpdate, ScoreLookUp}; +use crate::sign; use crate::sync::RwLock; +use crate::types::features::{ChannelFeatures, InitFeatures, NodeFeatures}; use crate::util::config::UserConfig; -use crate::util::test_channel_signer::{TestChannelSigner, EnforcementState}; use crate::util::logger::{Logger, Record}; -use crate::util::ser::{Readable, ReadableArgs, Writer, Writeable}; use crate::util::persist::KVStore; +use crate::util::ser::{Readable, ReadableArgs, Writeable, Writer}; +use crate::util::test_channel_signer::{EnforcementState, TestChannelSigner}; use bitcoin::amount::Amount; -use bitcoin::constants::ChainHash; -use bitcoin::constants::genesis_block; -use bitcoin::transaction::{Transaction, TxOut}; -use bitcoin::script::{Builder, Script, ScriptBuf}; -use bitcoin::opcodes; use bitcoin::block::Block; -use bitcoin::network::Network; +use bitcoin::constants::genesis_block; +use bitcoin::constants::ChainHash; use bitcoin::hash_types::{BlockHash, Txid}; use bitcoin::hashes::Hash; -use bitcoin::sighash::{SighashCache, EcdsaSighashType}; +use bitcoin::network::Network; +use bitcoin::opcodes; +use bitcoin::script::{Builder, Script, ScriptBuf}; +use bitcoin::sighash::{EcdsaSighashType, SighashCache}; +use bitcoin::transaction::{Transaction, TxOut}; -use bitcoin::secp256k1::{PublicKey, Scalar, Secp256k1, SecretKey, self}; use bitcoin::secp256k1::ecdh::SharedSecret; use bitcoin::secp256k1::ecdsa::{RecoverableSignature, Signature}; use bitcoin::secp256k1::schnorr; +use bitcoin::secp256k1::{self, PublicKey, Scalar, Secp256k1, SecretKey}; use lightning_invoice::RawBolt11Invoice; use crate::io; use crate::prelude::*; +use crate::sign::{ + EntropySource, InMemorySigner, NodeSigner, RandomBytes, Recipient, SignerProvider, +}; +use crate::sync::{Arc, Mutex}; use core::cell::RefCell; -use core::time::Duration; -use crate::sync::{Mutex, Arc}; -use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use core::mem; -use crate::sign::{InMemorySigner, RandomBytes, Recipient, EntropySource, NodeSigner, SignerProvider}; +use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use core::time::Duration; use bitcoin::psbt::Psbt; use bitcoin::Sequence; @@ -103,15 +112,19 @@ pub struct TestFeeEstimator { } impl TestFeeEstimator { pub fn new(sat_per_kw: u32) -> Self { - Self { - sat_per_kw: Mutex::new(sat_per_kw), - target_override: Mutex::new(new_hash_map()), - } + let sat_per_kw = Mutex::new(sat_per_kw); + let target_override = Mutex::new(new_hash_map()); + Self { sat_per_kw, target_override } } } impl chaininterface::FeeEstimator for TestFeeEstimator { fn get_est_sat_per_1000_weight(&self, conf_target: ConfirmationTarget) -> u32 { - *self.target_override.lock().unwrap().get(&conf_target).unwrap_or(&*self.sat_per_kw.lock().unwrap()) + *self + .target_override + .lock() + .unwrap() + .get(&conf_target) + .unwrap_or(&*self.sat_per_kw.lock().unwrap()) } } @@ -136,11 +149,19 @@ impl<'a> TestRouter<'a> { scorer: &'a RwLock, ) -> Self { let entropy_source = Arc::new(RandomBytes::new([42; 32])); + let next_routes = Mutex::new(VecDeque::new()); + let next_blinded_payment_paths = Mutex::new(Vec::new()); Self { - router: DefaultRouter::new(network_graph.clone(), logger, entropy_source, scorer, Default::default()), + router: DefaultRouter::new( + network_graph.clone(), + logger, + entropy_source, + scorer, + Default::default(), + ), network_graph, - next_routes: Mutex::new(VecDeque::new()), - next_blinded_payment_paths: Mutex::new(Vec::new()), + next_routes, + next_blinded_payment_paths, scorer, } } @@ -164,7 +185,7 @@ impl<'a> TestRouter<'a> { impl<'a> Router for TestRouter<'a> { fn find_route( &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>, - inflight_htlcs: InFlightHtlcs + inflight_htlcs: InFlightHtlcs, ) -> Result { let route_res; let next_route_opt = self.next_routes.lock().unwrap().pop_front(); @@ -188,22 +209,31 @@ impl<'a> Router for TestRouter<'a> { if idx == path.hops.len() - 1 { if let Some(first_hops) = first_hops { - if let Some(idx) = first_hops.iter().position(|h| h.get_outbound_payment_scid() == Some(hop.short_channel_id)) { + if let Some(idx) = first_hops.iter().position(|h| { + h.get_outbound_payment_scid() == Some(hop.short_channel_id) + }) { let node_id = NodeId::from_pubkey(payer); - let candidate = CandidateRouteHop::FirstHop(FirstHopCandidate { - details: first_hops[idx], - payer_node_id: &node_id, - payer_node_counter: u32::max_value(), - target_node_counter: u32::max_value(), - }); - scorer.channel_penalty_msat(&candidate, usage, &Default::default()); + let candidate = + CandidateRouteHop::FirstHop(FirstHopCandidate { + details: first_hops[idx], + payer_node_id: &node_id, + payer_node_counter: u32::max_value(), + target_node_counter: u32::max_value(), + }); + scorer.channel_penalty_msat( + &candidate, + usage, + &Default::default(), + ); continue; } } } let network_graph = self.network_graph.read_only(); if let Some(channel) = network_graph.channel(hop.short_channel_id) { - let (directed, _) = channel.as_directed_to(&NodeId::from_pubkey(&hop.pubkey)).unwrap(); + let (directed, _) = channel + .as_directed_to(&NodeId::from_pubkey(&hop.pubkey)) + .unwrap(); let candidate = CandidateRouteHop::PublicHop(PublicHopCandidate { info: directed, short_channel_id: hop.short_channel_id, @@ -219,12 +249,13 @@ impl<'a> Router for TestRouter<'a> { htlc_minimum_msat: None, htlc_maximum_msat: None, }; - let candidate = CandidateRouteHop::PrivateHop(PrivateHopCandidate { - hint: &route_hint, - target_node_id: &target_node_id, - source_node_counter: u32::max_value(), - target_node_counter: u32::max_value(), - }); + let candidate = + CandidateRouteHop::PrivateHop(PrivateHopCandidate { + hint: &route_hint, + target_node_id: &target_node_id, + source_node_counter: u32::max_value(), + target_node_counter: u32::max_value(), + }); scorer.channel_penalty_msat(&candidate, usage, &Default::default()); } prev_hop_node = &hop.pubkey; @@ -248,16 +279,18 @@ impl<'a> Router for TestRouter<'a> { route_res } - fn create_blinded_payment_paths< - T: secp256k1::Signing + secp256k1::Verification - >( + fn create_blinded_payment_paths( &self, recipient: PublicKey, first_hops: Vec, tlvs: ReceiveTlvs, amount_msats: u64, secp_ctx: &Secp256k1, ) -> Result, ()> { let mut expected_paths = self.next_blinded_payment_paths.lock().unwrap(); if expected_paths.is_empty() { self.router.create_blinded_payment_paths( - recipient, first_hops, tlvs, amount_msats, secp_ctx + recipient, + first_hops, + tlvs, + amount_msats, + secp_ctx, ) } else { Ok(core::mem::take(&mut *expected_paths)) @@ -275,32 +308,38 @@ impl<'a> Drop for TestRouter<'a> { } pub struct TestMessageRouter<'a> { - inner: DefaultMessageRouter>, &'a TestLogger, &'a TestKeysInterface>, + inner: DefaultMessageRouter< + Arc>, + &'a TestLogger, + &'a TestKeysInterface, + >, } impl<'a> TestMessageRouter<'a> { - pub fn new(network_graph: Arc>, entropy_source: &'a TestKeysInterface) -> Self { + pub fn new( + network_graph: Arc>, entropy_source: &'a TestKeysInterface, + ) -> Self { Self { inner: DefaultMessageRouter::new(network_graph, entropy_source) } } } impl<'a> MessageRouter for TestMessageRouter<'a> { fn find_path( - &self, sender: PublicKey, peers: Vec, destination: Destination + &self, sender: PublicKey, peers: Vec, destination: Destination, ) -> Result { self.inner.find_path(sender, peers, destination) } fn create_blinded_paths( - &self, recipient: PublicKey, context: MessageContext, - peers: Vec, secp_ctx: &Secp256k1, + &self, recipient: PublicKey, context: MessageContext, peers: Vec, + secp_ctx: &Secp256k1, ) -> Result, ()> { self.inner.create_blinded_paths(recipient, context, peers, secp_ctx) } fn create_compact_blinded_paths( - &self, recipient: PublicKey, context: MessageContext, - peers: Vec, secp_ctx: &Secp256k1, + &self, recipient: PublicKey, context: MessageContext, peers: Vec, + secp_ctx: &Secp256k1, ) -> Result, ()> { self.inner.create_compact_blinded_paths(recipient, context, peers, secp_ctx) } @@ -309,37 +348,55 @@ impl<'a> MessageRouter for TestMessageRouter<'a> { pub struct OnlyReadsKeysInterface {} impl EntropySource for OnlyReadsKeysInterface { - fn get_secure_random_bytes(&self) -> [u8; 32] { [0; 32] }} + fn get_secure_random_bytes(&self) -> [u8; 32] { + [0; 32] + } +} impl SignerProvider for OnlyReadsKeysInterface { type EcdsaSigner = TestChannelSigner; #[cfg(taproot)] type TaprootSigner = TestChannelSigner; - fn generate_channel_keys_id(&self, _inbound: bool, _channel_value_satoshis: u64, _user_channel_id: u128) -> [u8; 32] { unreachable!(); } + fn generate_channel_keys_id( + &self, _inbound: bool, _channel_value_satoshis: u64, _user_channel_id: u128, + ) -> [u8; 32] { + unreachable!(); + } - fn derive_channel_signer(&self, _channel_value_satoshis: u64, _channel_keys_id: [u8; 32]) -> Self::EcdsaSigner { unreachable!(); } + fn derive_channel_signer( + &self, _channel_value_satoshis: u64, _channel_keys_id: [u8; 32], + ) -> Self::EcdsaSigner { + unreachable!(); + } fn read_chan_signer(&self, mut reader: &[u8]) -> Result { let inner: InMemorySigner = ReadableArgs::read(&mut reader, self)?; let state = Arc::new(Mutex::new(EnforcementState::new())); - Ok(TestChannelSigner::new_with_revoked( - inner, - state, - false - )) + Ok(TestChannelSigner::new_with_revoked(inner, state, false)) } - fn get_destination_script(&self, _channel_keys_id: [u8; 32]) -> Result { Err(()) } - fn get_shutdown_scriptpubkey(&self) -> Result { Err(()) } + fn get_destination_script(&self, _channel_keys_id: [u8; 32]) -> Result { + Err(()) + } + fn get_shutdown_scriptpubkey(&self) -> Result { + Err(()) + } } pub struct TestChainMonitor<'a> { pub added_monitors: Mutex)>>, pub monitor_updates: Mutex>>, pub latest_monitor_update_id: Mutex>, - pub chain_monitor: chainmonitor::ChainMonitor>, + pub chain_monitor: chainmonitor::ChainMonitor< + TestChannelSigner, + &'a TestChainSource, + &'a dyn chaininterface::BroadcasterInterface, + &'a TestFeeEstimator, + &'a TestLogger, + &'a dyn chainmonitor::Persist, + >, pub keys_manager: &'a TestKeysInterface, /// If this is set to Some(), the next update_channel call (not watch_channel) must be a /// ChannelForceClosed event for the given channel_id with should_broadcast set to the given @@ -350,58 +407,101 @@ pub struct TestChainMonitor<'a> { pub expect_monitor_round_trip_fail: Mutex>, } impl<'a> TestChainMonitor<'a> { - pub fn new(chain_source: Option<&'a TestChainSource>, broadcaster: &'a dyn chaininterface::BroadcasterInterface, logger: &'a TestLogger, fee_estimator: &'a TestFeeEstimator, persister: &'a dyn chainmonitor::Persist, keys_manager: &'a TestKeysInterface) -> Self { + pub fn new( + chain_source: Option<&'a TestChainSource>, + broadcaster: &'a dyn chaininterface::BroadcasterInterface, logger: &'a TestLogger, + fee_estimator: &'a TestFeeEstimator, + persister: &'a dyn chainmonitor::Persist, + keys_manager: &'a TestKeysInterface, + ) -> Self { + let added_monitors = Mutex::new(Vec::new()); + let monitor_updates = Mutex::new(new_hash_map()); + let latest_monitor_update_id = Mutex::new(new_hash_map()); + let expect_channel_force_closed = Mutex::new(None); + let expect_monitor_round_trip_fail = Mutex::new(None); Self { - added_monitors: Mutex::new(Vec::new()), - monitor_updates: Mutex::new(new_hash_map()), - latest_monitor_update_id: Mutex::new(new_hash_map()), - chain_monitor: chainmonitor::ChainMonitor::new(chain_source, broadcaster, logger, fee_estimator, persister), + added_monitors, + monitor_updates, + latest_monitor_update_id, + chain_monitor: chainmonitor::ChainMonitor::new( + chain_source, + broadcaster, + logger, + fee_estimator, + persister, + ), keys_manager, - expect_channel_force_closed: Mutex::new(None), - expect_monitor_round_trip_fail: Mutex::new(None), + expect_channel_force_closed, + expect_monitor_round_trip_fail, } } pub fn complete_sole_pending_chan_update(&self, channel_id: &ChannelId) { - let (outpoint, _, latest_update) = self.latest_monitor_update_id.lock().unwrap().get(channel_id).unwrap().clone(); + let (outpoint, _, latest_update) = + self.latest_monitor_update_id.lock().unwrap().get(channel_id).unwrap().clone(); self.chain_monitor.channel_monitor_updated(outpoint, latest_update).unwrap(); } } impl<'a> chain::Watch for TestChainMonitor<'a> { - fn watch_channel(&self, funding_txo: OutPoint, monitor: channelmonitor::ChannelMonitor) -> Result { + fn watch_channel( + &self, funding_txo: OutPoint, monitor: channelmonitor::ChannelMonitor, + ) -> Result { // At every point where we get a monitor update, we should be able to send a useful monitor // to a watchtower and disk... let mut w = TestVecWriter(Vec::new()); monitor.write(&mut w).unwrap(); let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( - &mut io::Cursor::new(&w.0), (self.keys_manager, self.keys_manager)).unwrap().1; + &mut io::Cursor::new(&w.0), + (self.keys_manager, self.keys_manager), + ) + .unwrap() + .1; assert!(new_monitor == monitor); - self.latest_monitor_update_id.lock().unwrap().insert(monitor.channel_id(), - (funding_txo, monitor.get_latest_update_id(), monitor.get_latest_update_id())); + self.latest_monitor_update_id.lock().unwrap().insert( + monitor.channel_id(), + (funding_txo, monitor.get_latest_update_id(), monitor.get_latest_update_id()), + ); self.added_monitors.lock().unwrap().push((funding_txo, monitor)); self.chain_monitor.watch_channel(funding_txo, new_monitor) } - fn update_channel(&self, funding_txo: OutPoint, update: &channelmonitor::ChannelMonitorUpdate) -> chain::ChannelMonitorUpdateStatus { + fn update_channel( + &self, funding_txo: OutPoint, update: &channelmonitor::ChannelMonitorUpdate, + ) -> chain::ChannelMonitorUpdateStatus { // Every monitor update should survive roundtrip let mut w = TestVecWriter(Vec::new()); update.write(&mut w).unwrap(); - assert!(channelmonitor::ChannelMonitorUpdate::read( - &mut io::Cursor::new(&w.0)).unwrap() == *update); - let channel_id = update.channel_id.unwrap_or(ChannelId::v1_from_funding_outpoint(funding_txo)); + assert!( + channelmonitor::ChannelMonitorUpdate::read(&mut io::Cursor::new(&w.0)).unwrap() + == *update + ); + let channel_id = + update.channel_id.unwrap_or(ChannelId::v1_from_funding_outpoint(funding_txo)); - self.monitor_updates.lock().unwrap().entry(channel_id).or_insert(Vec::new()).push(update.clone()); + self.monitor_updates + .lock() + .unwrap() + .entry(channel_id) + .or_insert(Vec::new()) + .push(update.clone()); if let Some(exp) = self.expect_channel_force_closed.lock().unwrap().take() { assert_eq!(channel_id, exp.0); assert_eq!(update.updates.len(), 1); - if let channelmonitor::ChannelMonitorUpdateStep::ChannelForceClosed { should_broadcast } = update.updates[0] { + if let channelmonitor::ChannelMonitorUpdateStep::ChannelForceClosed { + should_broadcast, + } = update.updates[0] + { assert_eq!(should_broadcast, exp.1); - } else { panic!(); } + } else { + panic!(); + } } - self.latest_monitor_update_id.lock().unwrap().insert(channel_id, - (funding_txo, update.update_id, update.update_id)); + self.latest_monitor_update_id + .lock() + .unwrap() + .insert(channel_id, (funding_txo, update.update_id, update.update_id)); let update_res = self.chain_monitor.update_channel(funding_txo, update); // At every point where we get a monitor update, we should be able to send a useful monitor // to a watchtower and disk... @@ -409,7 +509,11 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { w.0.clear(); monitor.write(&mut w).unwrap(); let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( - &mut io::Cursor::new(&w.0), (self.keys_manager, self.keys_manager)).unwrap().1; + &mut io::Cursor::new(&w.0), + (self.keys_manager, self.keys_manager), + ) + .unwrap() + .1; if let Some(chan_id) = self.expect_monitor_round_trip_fail.lock().unwrap().take() { assert_eq!(chan_id, channel_id); assert!(new_monitor != *monitor); @@ -420,7 +524,9 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { update_res } - fn release_pending_monitor_events(&self) -> Vec<(OutPoint, ChannelId, Vec, Option)> { + fn release_pending_monitor_events( + &self, + ) -> Vec<(OutPoint, ChannelId, Vec, Option)> { return self.chain_monitor.release_pending_monitor_events(); } } @@ -450,51 +556,79 @@ pub(crate) struct WatchtowerPersister { impl WatchtowerPersister { #[cfg(test)] pub(crate) fn new(destination_script: ScriptBuf) -> Self { + let unsigned_justice_tx_data = Mutex::new(new_hash_map()); + let watchtower_state = Mutex::new(new_hash_map()); WatchtowerPersister { persister: TestPersister::new(), - unsigned_justice_tx_data: Mutex::new(new_hash_map()), - watchtower_state: Mutex::new(new_hash_map()), + unsigned_justice_tx_data, + watchtower_state, destination_script, } } #[cfg(test)] - pub(crate) fn justice_tx(&self, funding_txo: OutPoint, commitment_txid: &Txid) - -> Option { - self.watchtower_state.lock().unwrap().get(&funding_txo).unwrap().get(commitment_txid).cloned() - } - - fn form_justice_data_from_commitment(&self, counterparty_commitment_tx: &CommitmentTransaction) - -> Option { + pub(crate) fn justice_tx( + &self, funding_txo: OutPoint, commitment_txid: &Txid, + ) -> Option { + self.watchtower_state + .lock() + .unwrap() + .get(&funding_txo) + .unwrap() + .get(commitment_txid) + .cloned() + } + + fn form_justice_data_from_commitment( + &self, counterparty_commitment_tx: &CommitmentTransaction, + ) -> Option { let trusted_tx = counterparty_commitment_tx.trust(); let output_idx = trusted_tx.revokeable_output_index()?; let built_tx = trusted_tx.built_transaction(); let value = built_tx.transaction.output[output_idx as usize].value; - let justice_tx = trusted_tx.build_to_local_justice_tx( - FEERATE_FLOOR_SATS_PER_KW as u64, self.destination_script.clone()).ok()?; + let justice_tx = trusted_tx + .build_to_local_justice_tx( + FEERATE_FLOOR_SATS_PER_KW as u64, + self.destination_script.clone(), + ) + .ok()?; let commitment_number = counterparty_commitment_tx.commitment_number(); Some(JusticeTxData { justice_tx, value, commitment_number }) } } #[cfg(test)] -impl chainmonitor::Persist for WatchtowerPersister { - fn persist_new_channel(&self, funding_txo: OutPoint, - data: &channelmonitor::ChannelMonitor +impl chainmonitor::Persist + for WatchtowerPersister +{ + fn persist_new_channel( + &self, funding_txo: OutPoint, data: &channelmonitor::ChannelMonitor, ) -> chain::ChannelMonitorUpdateStatus { let res = self.persister.persist_new_channel(funding_txo, data); - assert!(self.unsigned_justice_tx_data.lock().unwrap() - .insert(funding_txo, VecDeque::new()).is_none()); - assert!(self.watchtower_state.lock().unwrap() - .insert(funding_txo, new_hash_map()).is_none()); - - let initial_counterparty_commitment_tx = data.initial_counterparty_commitment_tx() - .expect("First and only call expects Some"); - if let Some(justice_data) - = self.form_justice_data_from_commitment(&initial_counterparty_commitment_tx) { - self.unsigned_justice_tx_data.lock().unwrap() - .get_mut(&funding_txo).unwrap() + assert!(self + .unsigned_justice_tx_data + .lock() + .unwrap() + .insert(funding_txo, VecDeque::new()) + .is_none()); + assert!(self + .watchtower_state + .lock() + .unwrap() + .insert(funding_txo, new_hash_map()) + .is_none()); + + let initial_counterparty_commitment_tx = + data.initial_counterparty_commitment_tx().expect("First and only call expects Some"); + if let Some(justice_data) = + self.form_justice_data_from_commitment(&initial_counterparty_commitment_tx) + { + self.unsigned_justice_tx_data + .lock() + .unwrap() + .get_mut(&funding_txo) + .unwrap() .push_back(justice_data); } res @@ -502,25 +636,37 @@ impl chainmonitor::Persist for fn update_persisted_channel( &self, funding_txo: OutPoint, update: Option<&channelmonitor::ChannelMonitorUpdate>, - data: &channelmonitor::ChannelMonitor + data: &channelmonitor::ChannelMonitor, ) -> chain::ChannelMonitorUpdateStatus { let res = self.persister.update_persisted_channel(funding_txo, update, data); if let Some(update) = update { let commitment_txs = data.counterparty_commitment_txs_from_update(update); - let justice_datas = commitment_txs.into_iter() + let justice_datas = commitment_txs + .into_iter() .filter_map(|commitment_tx| self.form_justice_data_from_commitment(&commitment_tx)); let mut channels_justice_txs = self.unsigned_justice_tx_data.lock().unwrap(); let channel_state = channels_justice_txs.get_mut(&funding_txo).unwrap(); channel_state.extend(justice_datas); - while let Some(JusticeTxData { justice_tx, value, commitment_number }) = channel_state.front() { + while let Some(JusticeTxData { justice_tx, value, commitment_number }) = + channel_state.front() + { let input_idx = 0; let commitment_txid = justice_tx.input[input_idx].previous_output.txid; - match data.sign_to_local_justice_tx(justice_tx.clone(), input_idx, value.to_sat(), *commitment_number) { + match data.sign_to_local_justice_tx( + justice_tx.clone(), + input_idx, + value.to_sat(), + *commitment_number, + ) { Ok(signed_justice_tx) => { - let dup = self.watchtower_state.lock().unwrap() - .get_mut(&funding_txo).unwrap() + let dup = self + .watchtower_state + .lock() + .unwrap() + .get_mut(&funding_txo) + .unwrap() .insert(commitment_txid, signed_justice_tx); assert!(dup.is_none()); channel_state.pop_front(); @@ -533,7 +679,10 @@ impl chainmonitor::Persist for } fn archive_persisted_channel(&self, funding_txo: OutPoint) { - >::archive_persisted_channel(&self.persister, funding_txo); + >::archive_persisted_channel( + &self.persister, + funding_txo, + ); } } @@ -548,15 +697,14 @@ pub struct TestPersister { pub offchain_monitor_updates: Mutex>>, /// When we get an update_persisted_channel call with no ChannelMonitorUpdate, we insert the /// monitor's funding outpoint here. - pub chain_sync_monitor_persistences: Mutex> + pub chain_sync_monitor_persistences: Mutex>, } impl TestPersister { pub fn new() -> Self { - Self { - update_rets: Mutex::new(VecDeque::new()), - offchain_monitor_updates: Mutex::new(new_hash_map()), - chain_sync_monitor_persistences: Mutex::new(VecDeque::new()) - } + let update_rets = Mutex::new(VecDeque::new()); + let offchain_monitor_updates = Mutex::new(new_hash_map()); + let chain_sync_monitor_persistences = Mutex::new(VecDeque::new()); + Self { update_rets, offchain_monitor_updates, chain_sync_monitor_persistences } } /// Queue an update status to return. @@ -565,21 +713,31 @@ impl TestPersister { } } impl chainmonitor::Persist for TestPersister { - fn persist_new_channel(&self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor) -> chain::ChannelMonitorUpdateStatus { + fn persist_new_channel( + &self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor, + ) -> chain::ChannelMonitorUpdateStatus { if let Some(update_ret) = self.update_rets.lock().unwrap().pop_front() { - return update_ret + return update_ret; } chain::ChannelMonitorUpdateStatus::Completed } - fn update_persisted_channel(&self, funding_txo: OutPoint, update: Option<&channelmonitor::ChannelMonitorUpdate>, _data: &channelmonitor::ChannelMonitor) -> chain::ChannelMonitorUpdateStatus { + fn update_persisted_channel( + &self, funding_txo: OutPoint, update: Option<&channelmonitor::ChannelMonitorUpdate>, + _data: &channelmonitor::ChannelMonitor, + ) -> chain::ChannelMonitorUpdateStatus { let mut ret = chain::ChannelMonitorUpdateStatus::Completed; if let Some(update_ret) = self.update_rets.lock().unwrap().pop_front() { ret = update_ret; } if let Some(update) = update { - self.offchain_monitor_updates.lock().unwrap().entry(funding_txo).or_insert(new_hash_set()).insert(update.update_id); + self.offchain_monitor_updates + .lock() + .unwrap() + .entry(funding_txo) + .or_insert(new_hash_set()) + .insert(update.update_id); } else { self.chain_sync_monitor_persistences.lock().unwrap().push_back(funding_txo); } @@ -606,7 +764,9 @@ impl TestStore { } impl KVStore for TestStore { - fn read(&self, primary_namespace: &str, secondary_namespace: &str, key: &str) -> io::Result> { + fn read( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> io::Result> { let persisted_lock = self.persisted_bytes.lock().unwrap(); let prefixed = if secondary_namespace.is_empty() { primary_namespace.to_string() @@ -626,7 +786,9 @@ impl KVStore for TestStore { } } - fn write(&self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: &[u8]) -> io::Result<()> { + fn write( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: &[u8], + ) -> io::Result<()> { if self.read_only { return Err(io::Error::new( io::ErrorKind::PermissionDenied, @@ -647,7 +809,9 @@ impl KVStore for TestStore { Ok(()) } - fn remove(&self, primary_namespace: &str, secondary_namespace: &str, key: &str, _lazy: bool) -> io::Result<()> { + fn remove( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, _lazy: bool, + ) -> io::Result<()> { if self.read_only { return Err(io::Error::new( io::ErrorKind::PermissionDenied, @@ -663,7 +827,7 @@ impl KVStore for TestStore { format!("{}/{}", primary_namespace, secondary_namespace) }; if let Some(outer_ref) = persisted_lock.get_mut(&prefixed) { - outer_ref.remove(&key.to_string()); + outer_ref.remove(&key.to_string()); } Ok(()) @@ -694,14 +858,14 @@ pub struct TestBroadcaster { impl TestBroadcaster { pub fn new(network: Network) -> Self { - Self { - txn_broadcasted: Mutex::new(Vec::new()), - blocks: Arc::new(Mutex::new(vec![(genesis_block(network), 0)])), - } + let txn_broadcasted = Mutex::new(Vec::new()); + let blocks = Arc::new(Mutex::new(vec![(genesis_block(network), 0)])); + Self { txn_broadcasted, blocks } } pub fn with_blocks(blocks: Arc>>) -> Self { - Self { txn_broadcasted: Mutex::new(Vec::new()), blocks } + let txn_broadcasted = Mutex::new(Vec::new()); + Self { txn_broadcasted, blocks } } pub fn txn_broadcast(&self) -> Vec { @@ -721,10 +885,15 @@ impl chaininterface::BroadcasterInterface for TestBroadcaster { for tx in txs { let lock_time = tx.lock_time.to_consensus_u32(); assert!(lock_time < 1_500_000_000); - if tx.lock_time.is_block_height() && lock_time > self.blocks.lock().unwrap().last().unwrap().1 { + if tx.lock_time.is_block_height() + && lock_time > self.blocks.lock().unwrap().last().unwrap().1 + { for inp in tx.input.iter() { if inp.sequence != Sequence::MAX { - panic!("We should never broadcast a transaction before its locktime ({})!", tx.lock_time); + panic!( + "We should never broadcast a transaction before its locktime ({})!", + tx.lock_time + ); } } } @@ -749,10 +918,13 @@ impl TestChannelMessageHandler { impl TestChannelMessageHandler { pub fn new(chain_hash: ChainHash) -> Self { + let pending_events = Mutex::new(Vec::new()); + let expected_recv_msgs = Mutex::new(None); + let connected_peers = Mutex::new(new_hash_set()); TestChannelMessageHandler { - pending_events: Mutex::new(Vec::new()), - expected_recv_msgs: Mutex::new(None), - connected_peers: Mutex::new(new_hash_set()), + pending_events, + expected_recv_msgs, + connected_peers, chain_hash, } } @@ -760,14 +932,21 @@ impl TestChannelMessageHandler { #[cfg(test)] pub(crate) fn expect_receive_msg(&self, ev: wire::Message<()>) { let mut expected_msgs = self.expected_recv_msgs.lock().unwrap(); - if expected_msgs.is_none() { *expected_msgs = Some(Vec::new()); } + if expected_msgs.is_none() { + *expected_msgs = Some(Vec::new()); + } expected_msgs.as_mut().unwrap().push(ev); } fn received_msg(&self, _ev: wire::Message<()>) { let mut msgs = self.expected_recv_msgs.lock().unwrap(); - if msgs.is_none() { return; } - assert!(!msgs.as_ref().unwrap().is_empty(), "Received message when we weren't expecting one"); + if msgs.is_none() { + return; + } + assert!( + !msgs.as_ref().unwrap().is_empty(), + "Received message when we weren't expecting one" + ); #[cfg(test)] assert_eq!(msgs.as_ref().unwrap()[0], _ev); msgs.as_mut().unwrap().remove(0); @@ -829,7 +1008,9 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler { fn handle_update_fail_htlc(&self, _their_node_id: PublicKey, msg: &msgs::UpdateFailHTLC) { self.received_msg(wire::Message::UpdateFailHTLC(msg.clone())); } - fn handle_update_fail_malformed_htlc(&self, _their_node_id: PublicKey, msg: &msgs::UpdateFailMalformedHTLC) { + fn handle_update_fail_malformed_htlc( + &self, _their_node_id: PublicKey, msg: &msgs::UpdateFailMalformedHTLC, + ) { self.received_msg(wire::Message::UpdateFailMalformedHTLC(msg.clone())); } fn handle_commitment_signed(&self, _their_node_id: PublicKey, msg: &msgs::CommitmentSigned) { @@ -844,16 +1025,22 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler { fn handle_channel_update(&self, _their_node_id: PublicKey, _msg: &msgs::ChannelUpdate) { // Don't call `received_msg` here as `TestRoutingMessageHandler` generates these sometimes } - fn handle_announcement_signatures(&self, _their_node_id: PublicKey, msg: &msgs::AnnouncementSignatures) { + fn handle_announcement_signatures( + &self, _their_node_id: PublicKey, msg: &msgs::AnnouncementSignatures, + ) { self.received_msg(wire::Message::AnnouncementSignatures(msg.clone())); } - fn handle_channel_reestablish(&self, _their_node_id: PublicKey, msg: &msgs::ChannelReestablish) { + fn handle_channel_reestablish( + &self, _their_node_id: PublicKey, msg: &msgs::ChannelReestablish, + ) { self.received_msg(wire::Message::ChannelReestablish(msg.clone())); } fn peer_disconnected(&self, their_node_id: PublicKey) { assert!(self.connected_peers.lock().unwrap().remove(&their_node_id)); } - fn peer_connected(&self, their_node_id: PublicKey, _msg: &msgs::Init, _inbound: bool) -> Result<(), ()> { + fn peer_connected( + &self, their_node_id: PublicKey, _msg: &msgs::Init, _inbound: bool, + ) -> Result<(), ()> { assert!(self.connected_peers.lock().unwrap().insert(their_node_id.clone())); // Don't bother with `received_msg` for Init as its auto-generated and we don't want to // bother re-generating the expected Init message in all tests. @@ -977,7 +1164,7 @@ fn get_dummy_channel_update(short_chan_id: u64) -> msgs::ChannelUpdate { fee_base_msat: 0, fee_proportional_millionths: 0, excess_data: vec![], - } + }, } } @@ -991,28 +1178,38 @@ pub struct TestRoutingMessageHandler { impl TestRoutingMessageHandler { pub fn new() -> Self { + let pending_events = Mutex::new(vec![]); TestRoutingMessageHandler { chan_upds_recvd: AtomicUsize::new(0), chan_anns_recvd: AtomicUsize::new(0), - pending_events: Mutex::new(vec![]), + pending_events, request_full_sync: AtomicBool::new(false), announcement_available_for_sync: AtomicBool::new(false), } } } impl msgs::RoutingMessageHandler for TestRoutingMessageHandler { - fn handle_node_announcement(&self, _their_node_id: Option, _msg: &msgs::NodeAnnouncement) -> Result { + fn handle_node_announcement( + &self, _their_node_id: Option, _msg: &msgs::NodeAnnouncement, + ) -> Result { Ok(true) } - fn handle_channel_announcement(&self, _their_node_id: Option, _msg: &msgs::ChannelAnnouncement) -> Result { + fn handle_channel_announcement( + &self, _their_node_id: Option, _msg: &msgs::ChannelAnnouncement, + ) -> Result { self.chan_anns_recvd.fetch_add(1, Ordering::AcqRel); Ok(true) } - fn handle_channel_update(&self, _their_node_id: Option, _msg: &msgs::ChannelUpdate) -> Result { + fn handle_channel_update( + &self, _their_node_id: Option, _msg: &msgs::ChannelUpdate, + ) -> Result { self.chan_upds_recvd.fetch_add(1, Ordering::AcqRel); Ok(true) } - fn get_next_channel_announcement(&self, starting_point: u64) -> Option<(msgs::ChannelAnnouncement, Option, Option)> { + fn get_next_channel_announcement( + &self, starting_point: u64, + ) -> Option<(msgs::ChannelAnnouncement, Option, Option)> + { if self.announcement_available_for_sync.load(Ordering::Acquire) { let chan_upd_1 = get_dummy_channel_update(starting_point); let chan_upd_2 = get_dummy_channel_update(starting_point); @@ -1024,11 +1221,15 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler { } } - fn get_next_node_announcement(&self, _starting_point: Option<&NodeId>) -> Option { + fn get_next_node_announcement( + &self, _starting_point: Option<&NodeId>, + ) -> Option { None } - fn peer_connected(&self, their_node_id: PublicKey, init_msg: &msgs::Init, _inbound: bool) -> Result<(), ()> { + fn peer_connected( + &self, their_node_id: PublicKey, init_msg: &msgs::Init, _inbound: bool, + ) -> Result<(), ()> { if !init_msg.features.supports_gossip_queries() { return Ok(()); } @@ -1038,7 +1239,10 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler { #[cfg(feature = "std")] { use std::time::{SystemTime, UNIX_EPOCH}; - gossip_start_time = SystemTime::now().duration_since(UNIX_EPOCH).expect("Time must be > 1970").as_secs(); + gossip_start_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time must be > 1970") + .as_secs(); if self.request_full_sync.load(Ordering::Acquire) { gossip_start_time -= 60 * 60 * 24 * 7 * 2; // 2 weeks ago } else { @@ -1058,19 +1262,27 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler { Ok(()) } - fn handle_reply_channel_range(&self, _their_node_id: PublicKey, _msg: msgs::ReplyChannelRange) -> Result<(), msgs::LightningError> { + fn handle_reply_channel_range( + &self, _their_node_id: PublicKey, _msg: msgs::ReplyChannelRange, + ) -> Result<(), msgs::LightningError> { Ok(()) } - fn handle_reply_short_channel_ids_end(&self, _their_node_id: PublicKey, _msg: msgs::ReplyShortChannelIdsEnd) -> Result<(), msgs::LightningError> { + fn handle_reply_short_channel_ids_end( + &self, _their_node_id: PublicKey, _msg: msgs::ReplyShortChannelIdsEnd, + ) -> Result<(), msgs::LightningError> { Ok(()) } - fn handle_query_channel_range(&self, _their_node_id: PublicKey, _msg: msgs::QueryChannelRange) -> Result<(), msgs::LightningError> { + fn handle_query_channel_range( + &self, _their_node_id: PublicKey, _msg: msgs::QueryChannelRange, + ) -> Result<(), msgs::LightningError> { Ok(()) } - fn handle_query_short_channel_ids(&self, _their_node_id: PublicKey, _msg: msgs::QueryShortChannelIds) -> Result<(), msgs::LightningError> { + fn handle_query_short_channel_ids( + &self, _their_node_id: PublicKey, _msg: msgs::QueryShortChannelIds, + ) -> Result<(), msgs::LightningError> { Ok(()) } @@ -1086,7 +1298,9 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler { features } - fn processing_queue_high(&self) -> bool { false } + fn processing_queue_high(&self) -> bool { + false + } } impl events::MessageSendEventsProvider for TestRoutingMessageHandler { @@ -1109,11 +1323,9 @@ impl TestLogger { Self::with_id("".to_owned()) } pub fn with_id(id: String) -> TestLogger { - TestLogger { - id, - lines: Mutex::new(new_hash_map()), - context: Mutex::new(new_hash_map()), - } + let lines = Mutex::new(new_hash_map()); + let context = Mutex::new(new_hash_map()); + TestLogger { id, lines, context } } pub fn assert_log(&self, module: &str, line: String, count: usize) { let log_entries = self.lines.lock().unwrap(); @@ -1126,9 +1338,11 @@ impl TestLogger { /// And asserts if the number of occurrences is the same with the given `count` pub fn assert_log_contains(&self, module: &str, line: &str, count: usize) { let log_entries = self.lines.lock().unwrap(); - let l: usize = log_entries.iter().filter(|&(&(ref m, ref l), _c)| { - *m == module && l.contains(line) - }).map(|(_, c) | { c }).sum(); + let l: usize = log_entries + .iter() + .filter(|&(&(ref m, ref l), _c)| *m == module && l.contains(line)) + .map(|(_, c)| c) + .sum(); assert_eq!(l, count) } @@ -1139,14 +1353,17 @@ impl TestLogger { #[cfg(any(test, feature = "_test_utils"))] pub fn assert_log_regex(&self, module: &str, pattern: regex::Regex, count: usize) { let log_entries = self.lines.lock().unwrap(); - let l: usize = log_entries.iter().filter(|&(&(ref m, ref l), _c)| { - *m == module && pattern.is_match(&l) - }).map(|(_, c) | { c }).sum(); + let l: usize = log_entries + .iter() + .filter(|&(&(ref m, ref l), _c)| *m == module && pattern.is_match(&l)) + .map(|(_, c)| c) + .sum(); assert_eq!(l, count) } pub fn assert_log_context_contains( - &self, module: &str, peer_id: Option, channel_id: Option, count: usize + &self, module: &str, peer_id: Option, channel_id: Option, + count: usize, ) { let context_entries = self.context.lock().unwrap(); let l = context_entries.get(&(module, peer_id, channel_id)).unwrap(); @@ -1156,11 +1373,19 @@ impl TestLogger { impl Logger for TestLogger { fn log(&self, record: Record) { - let s = format!("{:<55} {}", - format_args!("{} {} [{}:{}]", self.id, record.level.to_string(), record.module_path, record.line), + let s = format!( + "{:<55} {}", + format_args!( + "{} {} [{}:{}]", + self.id, + record.level.to_string(), + record.module_path, + record.line + ), record.args ); - #[cfg(ldk_bench)] { + #[cfg(ldk_bench)] + { // When benchmarking, we don't actually want to print logs, but we do want to format // them. To make sure LLVM doesn't skip the above entirely we push it through a // volitile read. This may not be super fast, but it shouldn't be worse than anything a @@ -1170,9 +1395,20 @@ impl Logger for TestLogger { let _ = unsafe { core::ptr::read_volatile(&s_bytes[i]) }; } } - #[cfg(not(ldk_bench))] { - *self.lines.lock().unwrap().entry((record.module_path, format!("{}", record.args))).or_insert(0) += 1; - *self.context.lock().unwrap().entry((record.module_path, record.peer_id, record.channel_id)).or_insert(0) += 1; + #[cfg(not(ldk_bench))] + { + *self + .lines + .lock() + .unwrap() + .entry((record.module_path, format!("{}", record.args))) + .or_insert(0) += 1; + *self + .context + .lock() + .unwrap() + .entry((record.module_path, record.peer_id, record.channel_id)) + .or_insert(0) += 1; println!("{}", s); } } @@ -1196,15 +1432,18 @@ impl NodeSigner for TestNodeSigner { fn get_node_id(&self, recipient: Recipient) -> Result { let node_secret = match recipient { Recipient::Node => Ok(&self.node_secret), - Recipient::PhantomNode => Err(()) + Recipient::PhantomNode => Err(()), }?; Ok(PublicKey::from_secret_key(&Secp256k1::signing_only(), node_secret)) } - fn ecdh(&self, recipient: Recipient, other_key: &PublicKey, tweak: Option<&bitcoin::secp256k1::Scalar>) -> Result { + fn ecdh( + &self, recipient: Recipient, other_key: &PublicKey, + tweak: Option<&bitcoin::secp256k1::Scalar>, + ) -> Result { let mut node_secret = match recipient { Recipient::Node => Ok(self.node_secret.clone()), - Recipient::PhantomNode => Err(()) + Recipient::PhantomNode => Err(()), }?; if let Some(tweak) = tweak { node_secret = node_secret.mul_tweak(tweak).map_err(|_| ())?; @@ -1231,7 +1470,7 @@ pub struct TestKeysInterface { pub backing: sign::PhantomKeysManager, pub override_random_bytes: Mutex>, pub disable_revocation_policy_check: bool, - enforcement_states: Mutex>>>, + enforcement_states: Mutex>>>, expectations: Mutex>>, pub unavailable_signers_ops: Mutex>>, pub next_signer_disabled_ops: Mutex>, @@ -1252,7 +1491,9 @@ impl NodeSigner for TestKeysInterface { self.backing.get_node_id(recipient) } - fn ecdh(&self, recipient: Recipient, other_key: &PublicKey, tweak: Option<&Scalar>) -> Result { + fn ecdh( + &self, recipient: Recipient, other_key: &PublicKey, tweak: Option<&Scalar>, + ) -> Result { self.backing.ecdh(recipient, other_key, tweak) } @@ -1260,7 +1501,9 @@ impl NodeSigner for TestKeysInterface { self.backing.get_inbound_payment_key() } - fn sign_invoice(&self, invoice: &RawBolt11Invoice, recipient: Recipient) -> Result { + fn sign_invoice( + &self, invoice: &RawBolt11Invoice, recipient: Recipient, + ) -> Result { self.backing.sign_invoice(invoice, recipient) } @@ -1280,14 +1523,19 @@ impl SignerProvider for TestKeysInterface { #[cfg(taproot)] type TaprootSigner = TestChannelSigner; - fn generate_channel_keys_id(&self, inbound: bool, channel_value_satoshis: u64, user_channel_id: u128) -> [u8; 32] { + fn generate_channel_keys_id( + &self, inbound: bool, channel_value_satoshis: u64, user_channel_id: u128, + ) -> [u8; 32] { self.backing.generate_channel_keys_id(inbound, channel_value_satoshis, user_channel_id) } - fn derive_channel_signer(&self, channel_value_satoshis: u64, channel_keys_id: [u8; 32]) -> TestChannelSigner { + fn derive_channel_signer( + &self, channel_value_satoshis: u64, channel_keys_id: [u8; 32], + ) -> TestChannelSigner { let keys = self.backing.derive_channel_signer(channel_value_satoshis, channel_keys_id); let state = self.make_enforcement_state_cell(keys.commitment_seed); - let signer = TestChannelSigner::new_with_revoked(keys, state, self.disable_revocation_policy_check); + let signer = + TestChannelSigner::new_with_revoked(keys, state, self.disable_revocation_policy_check); #[cfg(test)] if let Some(ops) = self.unavailable_signers_ops.lock().unwrap().get(&channel_keys_id) { for &op in ops { @@ -1307,14 +1555,12 @@ impl SignerProvider for TestKeysInterface { let inner: InMemorySigner = ReadableArgs::read(&mut reader, self)?; let state = self.make_enforcement_state_cell(inner.commitment_seed); - Ok(TestChannelSigner::new_with_revoked( - inner, - state, - self.disable_revocation_policy_check - )) + Ok(TestChannelSigner::new_with_revoked(inner, state, self.disable_revocation_policy_check)) } - fn get_destination_script(&self, channel_keys_id: [u8; 32]) -> Result { self.backing.get_destination_script(channel_keys_id) } + fn get_destination_script(&self, channel_keys_id: [u8; 32]) -> Result { + self.backing.get_destination_script(channel_keys_id) + } fn get_shutdown_scriptpubkey(&self) -> Result { match &mut *self.expectations.lock().unwrap() { @@ -1330,31 +1576,42 @@ impl SignerProvider for TestKeysInterface { impl TestKeysInterface { pub fn new(seed: &[u8; 32], network: Network) -> Self { let now = Duration::from_secs(genesis_block(network).header.time as u64); + let override_random_bytes = Mutex::new(None); + let enforcement_states = Mutex::new(new_hash_map()); + let expectations = Mutex::new(None); + let unavailable_signers_ops = Mutex::new(new_hash_map()); + let next_signer_disabled_ops = Mutex::new(new_hash_set()); Self { backing: sign::PhantomKeysManager::new(seed, now.as_secs(), now.subsec_nanos(), seed), - override_random_bytes: Mutex::new(None), + override_random_bytes, disable_revocation_policy_check: false, - enforcement_states: Mutex::new(new_hash_map()), - expectations: Mutex::new(None), - unavailable_signers_ops: Mutex::new(new_hash_map()), - next_signer_disabled_ops: Mutex::new(new_hash_set()), + enforcement_states, + expectations, + unavailable_signers_ops, + next_signer_disabled_ops, } } /// Sets an expectation that [`sign::SignerProvider::get_shutdown_scriptpubkey`] is /// called. pub fn expect(&self, expectation: OnGetShutdownScriptpubkey) -> &Self { - self.expectations.lock().unwrap() + self.expectations + .lock() + .unwrap() .get_or_insert_with(|| VecDeque::new()) .push_back(expectation); self } - pub fn derive_channel_keys(&self, channel_value_satoshis: u64, id: &[u8; 32]) -> TestChannelSigner { + pub fn derive_channel_keys( + &self, channel_value_satoshis: u64, id: &[u8; 32], + ) -> TestChannelSigner { self.derive_channel_signer(channel_value_satoshis, *id) } - fn make_enforcement_state_cell(&self, commitment_seed: [u8; 32]) -> Arc> { + fn make_enforcement_state_cell( + &self, commitment_seed: [u8; 32], + ) -> Arc> { let mut states = self.enforcement_states.lock().unwrap(); if !states.contains_key(&commitment_seed) { let state = EnforcementState::new(); @@ -1403,12 +1660,16 @@ pub struct TestChainSource { impl TestChainSource { pub fn new(network: Network) -> Self { let script_pubkey = Builder::new().push_opcode(opcodes::OP_TRUE).into_script(); + let utxo_ret = + Mutex::new(UtxoResult::Sync(Ok(TxOut { value: Amount::MAX, script_pubkey }))); + let watched_txn = Mutex::new(new_hash_set()); + let watched_outputs = Mutex::new(new_hash_set()); Self { chain_hash: ChainHash::using_genesis_block(network), - utxo_ret: Mutex::new(UtxoResult::Sync(Ok(TxOut { value: Amount::MAX, script_pubkey }))), + utxo_ret, get_utxo_call_count: AtomicUsize::new(0), - watched_txn: Mutex::new(new_hash_set()), - watched_outputs: Mutex::new(new_hash_set()), + watched_txn, + watched_outputs, } } pub fn remove_watched_txn_and_outputs(&self, outpoint: OutPoint, script_pubkey: ScriptBuf) { @@ -1453,25 +1714,29 @@ pub struct TestScorer { impl TestScorer { pub fn new() -> Self { - Self { - scorer_expectations: RefCell::new(None), - } + Self { scorer_expectations: RefCell::new(None) } } pub fn expect_usage(&self, scid: u64, expectation: ChannelUsage) { - self.scorer_expectations.borrow_mut().get_or_insert_with(|| VecDeque::new()).push_back((scid, expectation)); + self.scorer_expectations + .borrow_mut() + .get_or_insert_with(|| VecDeque::new()) + .push_back((scid, expectation)); } } #[cfg(c_bindings)] impl crate::util::ser::Writeable for TestScorer { - fn write(&self, _: &mut W) -> Result<(), crate::io::Error> { unreachable!(); } + fn write(&self, _: &mut W) -> Result<(), crate::io::Error> { + unreachable!(); + } } impl ScoreLookUp for TestScorer { type ScoreParams = (); fn channel_penalty_msat( - &self, candidate: &CandidateRouteHop, usage: ChannelUsage, _score_params: &Self::ScoreParams + &self, candidate: &CandidateRouteHop, usage: ChannelUsage, + _score_params: &Self::ScoreParams, ) -> u64 { let short_channel_id = match candidate.globally_unique_short_channel_id() { Some(scid) => scid, @@ -1491,7 +1756,11 @@ impl ScoreLookUp for TestScorer { } impl ScoreUpdate for TestScorer { - fn payment_path_failed(&mut self, _actual_path: &Path, _actual_short_channel_id: u64, _duration_since_epoch: Duration) {} + fn payment_path_failed( + &mut self, _actual_path: &Path, _actual_short_channel_id: u64, + _duration_since_epoch: Duration, + ) { + } fn payment_path_successful(&mut self, _actual_path: &Path, _duration_since_epoch: Duration) {} @@ -1527,11 +1796,7 @@ pub struct TestWalletSource { impl TestWalletSource { pub fn new(secret_key: SecretKey) -> Self { - Self { - secret_key, - utxos: RefCell::new(Vec::new()), - secp: Secp256k1::new(), - } + Self { secret_key, utxos: RefCell::new(Vec::new()), secp: Secp256k1::new() } } pub fn add_utxo(&self, outpoint: bitcoin::OutPoint, value: Amount) -> TxOut { @@ -1566,12 +1831,22 @@ impl WalletSource for TestWalletSource { let mut tx = psbt.extract_tx_unchecked_fee_rate(); let utxos = self.utxos.borrow(); for i in 0..tx.input.len() { - if let Some(utxo) = utxos.iter().find(|utxo| utxo.outpoint == tx.input[i].previous_output) { + if let Some(utxo) = + utxos.iter().find(|utxo| utxo.outpoint == tx.input[i].previous_output) + { let sighash = SighashCache::new(&tx) - .legacy_signature_hash(i, &utxo.output.script_pubkey, EcdsaSighashType::All as u32) + .legacy_signature_hash( + i, + &utxo.output.script_pubkey, + EcdsaSighashType::All as u32, + ) .map_err(|_| ())?; - let signature = self.secp.sign_ecdsa(&secp256k1::Message::from_digest(sighash.to_byte_array()), &self.secret_key); - let bitcoin_sig = bitcoin::ecdsa::Signature { signature, sighash_type: EcdsaSighashType::All }; + let signature = self.secp.sign_ecdsa( + &secp256k1::Message::from_digest(sighash.to_byte_array()), + &self.secret_key, + ); + let bitcoin_sig = + bitcoin::ecdsa::Signature { signature, sighash_type: EcdsaSighashType::All }; tx.input[i].script_sig = Builder::new() .push_slice(&bitcoin_sig.serialize()) .push_slice(&self.secret_key.public_key(&self.secp).serialize()) diff --git a/lightning/src/util/time.rs b/lightning/src/util/time.rs index 106a4ce4e17..d79dddc6950 100644 --- a/lightning/src/util/time.rs +++ b/lightning/src/util/time.rs @@ -7,16 +7,16 @@ //! A simple module which either re-exports [`std::time::Instant`] or a mocked version of it for //! tests. -#[cfg(test)] -pub use test::Instant; #[cfg(not(test))] pub use std::time::Instant; +#[cfg(test)] +pub use test::Instant; #[cfg(test)] mod test { - use core::time::Duration; - use core::ops::Sub; use core::cell::Cell; + use core::ops::Sub; + use core::time::Duration; /// Time that can be advanced manually in tests. #[derive(Clone, Copy, Debug, PartialEq, Eq)] diff --git a/lightning/src/util/transaction_utils.rs b/lightning/src/util/transaction_utils.rs index da8ff434301..09cc7be62c3 100644 --- a/lightning/src/util/transaction_utils.rs +++ b/lightning/src/util/transaction_utils.rs @@ -8,10 +8,10 @@ // licenses. use bitcoin::amount::Amount; -use bitcoin::transaction::{Transaction, TxOut}; -use bitcoin::script::ScriptBuf; -use bitcoin::consensus::Encodable; use bitcoin::consensus::encode::VarInt; +use bitcoin::consensus::Encodable; +use bitcoin::script::ScriptBuf; +use bitcoin::transaction::{Transaction, TxOut}; #[allow(unused_imports)] use crate::prelude::*; @@ -19,12 +19,10 @@ use crate::prelude::*; use crate::io_extras::sink; use core::cmp::Ordering; -pub fn sort_outputs Ordering>(outputs: &mut Vec<(TxOut, T)>, tie_breaker: C) { +pub fn sort_outputs Ordering>(outputs: &mut Vec<(TxOut, T)>, tie_breaker: C) { outputs.sort_unstable_by(|a, b| { a.0.value.cmp(&b.0.value).then_with(|| { - a.0.script_pubkey[..].cmp(&b.0.script_pubkey[..]).then_with(|| { - tie_breaker(&a.1, &b.1) - }) + a.0.script_pubkey[..].cmp(&b.0.script_pubkey[..]).then_with(|| tie_breaker(&a.1, &b.1)) }) }); } @@ -34,34 +32,41 @@ pub fn sort_outputs Ordering>(outputs: &mut Vec<(TxOut, T)> /// Assumes at least one input will have a witness (ie spends a segwit output). /// Returns an Err(()) if the requested feerate cannot be met. /// Returns the expected maximum weight of the fully signed transaction on success. -pub(crate) fn maybe_add_change_output(tx: &mut Transaction, input_value: Amount, witness_max_weight: u64, feerate_sat_per_1000_weight: u32, change_destination_script: ScriptBuf) -> Result { - if input_value > Amount::MAX_MONEY { return Err(()); } +pub(crate) fn maybe_add_change_output( + tx: &mut Transaction, input_value: Amount, witness_max_weight: u64, + feerate_sat_per_1000_weight: u32, change_destination_script: ScriptBuf, +) -> Result { + if input_value > Amount::MAX_MONEY { + return Err(()); + } const WITNESS_FLAG_BYTES: u64 = 2; let mut output_value = Amount::ZERO; for output in tx.output.iter() { output_value += output.value; - if output_value >= input_value { return Err(()); } + if output_value >= input_value { + return Err(()); + } } let dust_value = change_destination_script.minimal_non_dust(); - let mut change_output = TxOut { - script_pubkey: change_destination_script, - value: Amount::ZERO, - }; + let mut change_output = TxOut { script_pubkey: change_destination_script, value: Amount::ZERO }; let change_len = change_output.consensus_encode(&mut sink()).unwrap(); let starting_weight = tx.weight().to_wu() + WITNESS_FLAG_BYTES + witness_max_weight as u64; + let starting_fees = (starting_weight as i64) * feerate_sat_per_1000_weight as i64 / 1000; let mut weight_with_change: i64 = starting_weight as i64 + change_len as i64 * 4; // Include any extra bytes required to push an extra output. - weight_with_change += (VarInt(tx.output.len() as u64 + 1).size() - VarInt(tx.output.len() as u64).size()) as i64 * 4; + let num_outputs = tx.output.len() as u64; + weight_with_change += (VarInt(num_outputs + 1).size() - VarInt(num_outputs).size()) as i64 * 4; // When calculating weight, add two for the flag bytes - let change_value: i64 = (input_value - output_value).to_sat() as i64 - weight_with_change * feerate_sat_per_1000_weight as i64 / 1000; + let fees_with_change = weight_with_change * feerate_sat_per_1000_weight as i64 / 1000; + let change_value: i64 = (input_value - output_value).to_sat() as i64 - fees_with_change; if change_value >= dust_value.to_sat() as i64 { change_output.value = Amount::from_sat(change_value as u64); tx.output.push(change_output); Ok(weight_with_change as u64) - } else if (input_value - output_value).to_sat() as i64 - (starting_weight as i64) * feerate_sat_per_1000_weight as i64 / 1000 < 0 { + } else if (input_value - output_value).to_sat() as i64 - starting_fees < 0 { Err(()) } else { Ok(starting_weight) @@ -73,12 +78,12 @@ mod tests { use super::*; use bitcoin::amount::Amount; - use bitcoin::locktime::absolute::LockTime; - use bitcoin::transaction::{TxIn, OutPoint, Version}; - use bitcoin::script::Builder; use bitcoin::hash_types::Txid; use bitcoin::hashes::Hash; use bitcoin::hex::FromHex; + use bitcoin::locktime::absolute::LockTime; + use bitcoin::script::Builder; + use bitcoin::transaction::{OutPoint, TxIn, Version}; use bitcoin::{PubkeyHash, Sequence, Witness}; use alloc::vec; @@ -86,47 +91,45 @@ mod tests { #[test] fn sort_output_by_value() { let txout1 = TxOut { - value: Amount::from_sat(100), - script_pubkey: Builder::new().push_int(0).into_script() + value: Amount::from_sat(100), + script_pubkey: Builder::new().push_int(0).into_script(), }; let txout1_ = txout1.clone(); let txout2 = TxOut { value: Amount::from_sat(99), - script_pubkey: Builder::new().push_int(0).into_script() + script_pubkey: Builder::new().push_int(0).into_script(), }; let txout2_ = txout2.clone(); let mut outputs = vec![(txout1, "ignore"), (txout2, "ignore")]; - sort_outputs(&mut outputs, |_, _| { unreachable!(); }); + sort_outputs(&mut outputs, |_, _| { + unreachable!(); + }); - assert_eq!( - &outputs, - &vec![(txout2_, "ignore"), (txout1_, "ignore")] - ); + assert_eq!(&outputs, &vec![(txout2_, "ignore"), (txout1_, "ignore")]); } #[test] fn sort_output_by_script_pubkey() { let txout1 = TxOut { - value: Amount::from_sat(100), + value: Amount::from_sat(100), script_pubkey: Builder::new().push_int(3).into_script(), }; let txout1_ = txout1.clone(); let txout2 = TxOut { value: Amount::from_sat(100), - script_pubkey: Builder::new().push_int(1).push_int(2).into_script() + script_pubkey: Builder::new().push_int(1).push_int(2).into_script(), }; let txout2_ = txout2.clone(); let mut outputs = vec![(txout1, "ignore"), (txout2, "ignore")]; - sort_outputs(&mut outputs, |_, _| { unreachable!(); }); + sort_outputs(&mut outputs, |_, _| { + unreachable!(); + }); - assert_eq!( - &outputs, - &vec![(txout2_, "ignore"), (txout1_, "ignore")] - ); + assert_eq!(&outputs, &vec![(txout2_, "ignore"), (txout1_, "ignore")]); } #[test] @@ -145,7 +148,9 @@ mod tests { let txout2_ = txout2.clone(); let mut outputs = vec![(txout1, "ignore"), (txout2, "ignore")]; - sort_outputs(&mut outputs, |_, _| { unreachable!(); }); + sort_outputs(&mut outputs, |_, _| { + unreachable!(); + }); assert_eq!(&outputs, &vec![(txout1_, "ignore"), (txout2_, "ignore")]); } @@ -153,8 +158,8 @@ mod tests { #[test] fn sort_output_tie_breaker_test() { let txout1 = TxOut { - value: Amount::from_sat(100), - script_pubkey: Builder::new().push_int(1).push_int(2).into_script() + value: Amount::from_sat(100), + script_pubkey: Builder::new().push_int(1).push_int(2).into_script(), }; let txout1_ = txout1.clone(); @@ -162,12 +167,9 @@ mod tests { let txout2_ = txout1.clone(); let mut outputs = vec![(txout1, 420), (txout2, 69)]; - sort_outputs(&mut outputs, |a, b| { a.cmp(b) }); + sort_outputs(&mut outputs, |a, b| a.cmp(b)); - assert_eq!( - &outputs, - &vec![(txout2_, 69), (txout1_, 420)] - ); + assert_eq!(&outputs, &vec![(txout2_, 69), (txout1_, 420)]); } fn script_from_hex(hex_str: &str) -> ScriptBuf { @@ -215,18 +217,28 @@ mod tests { #[test] fn test_tx_value_overrun() { // If we have a bogus input amount or outputs valued more than inputs, we should fail - let mut tx = Transaction { version: Version::TWO, lock_time: LockTime::ZERO, input: Vec::new(), output: vec![TxOut { - script_pubkey: ScriptBuf::new(), value: Amount::from_sat(1000) - }] }; - assert!(maybe_add_change_output(&mut tx, Amount::from_sat(21_000_000_0000_0001), 0, 253, ScriptBuf::new()).is_err()); - assert!(maybe_add_change_output(&mut tx, Amount::from_sat(400), 0, 253, ScriptBuf::new()).is_err()); - assert!(maybe_add_change_output(&mut tx, Amount::from_sat(4000), 0, 253, ScriptBuf::new()).is_ok()); + let version = Version::TWO; + let lock_time = LockTime::ZERO; + let input = Vec::new(); + let tx_out = TxOut { script_pubkey: ScriptBuf::new(), value: Amount::from_sat(1000) }; + let output = vec![tx_out]; + let mut tx = Transaction { version, lock_time, input, output }; + let amount = Amount::from_sat(21_000_000_0000_0001); + assert!(maybe_add_change_output(&mut tx, amount, 0, 253, ScriptBuf::new()).is_err()); + let amount = Amount::from_sat(400); + assert!(maybe_add_change_output(&mut tx, amount, 0, 253, ScriptBuf::new()).is_err()); + let amount = Amount::from_sat(4000); + assert!(maybe_add_change_output(&mut tx, amount, 0, 253, ScriptBuf::new()).is_ok()); } #[test] fn test_tx_change_edge() { // Check that we never add dust outputs - let mut tx = Transaction { version: Version::TWO, lock_time: LockTime::ZERO, input: Vec::new(), output: Vec::new() }; + let version = Version::TWO; + let lock_time = LockTime::ZERO; + let input = Vec::new(); + let output = Vec::new(); + let mut tx = Transaction { version, lock_time, input, output }; let orig_wtxid = tx.compute_wtxid(); let output_spk = ScriptBuf::new_p2pkh(&PubkeyHash::hash(&[0; 0])); assert_eq!(output_spk.minimal_non_dust().to_sat(), 546); @@ -235,38 +247,56 @@ mod tests { // weight = 3 * base size + total size = 3 * (4 + 1 + 0 + 1 + 0 + 4) + (4 + 1 + 1 + 1 + 0 + 1 + 0 + 4) = 3 * 10 + 12 = 42 assert_eq!(tx.weight().to_wu(), 42); // 10 sats isn't enough to pay fee on a dummy transaction... - assert!(maybe_add_change_output(&mut tx, Amount::from_sat(10), 0, 250, output_spk.clone()).is_err()); - assert_eq!(tx.compute_wtxid(), orig_wtxid); // Failure doesn't change the transaction + let amount = Amount::from_sat(10); + assert!(maybe_add_change_output(&mut tx, amount, 0, 250, output_spk.clone()).is_err()); + // Failure doesn't change the transaction + assert_eq!(tx.compute_wtxid(), orig_wtxid); // but 11 (= ceil(42 * 250 / 1000)) is, just not enough to add a change output... - assert!(maybe_add_change_output(&mut tx, Amount::from_sat(11), 0, 250, output_spk.clone()).is_ok()); + let amount = Amount::from_sat(11); + assert!(maybe_add_change_output(&mut tx, amount, 0, 250, output_spk.clone()).is_ok()); assert_eq!(tx.output.len(), 0); - assert_eq!(tx.compute_wtxid(), orig_wtxid); // If we don't add an output, we don't change the transaction - assert!(maybe_add_change_output(&mut tx, Amount::from_sat(549), 0, 250, output_spk.clone()).is_ok()); + // If we don't add an output, we don't change the transaction + assert_eq!(tx.compute_wtxid(), orig_wtxid); + let amount = Amount::from_sat(549); + assert!(maybe_add_change_output(&mut tx, amount, 0, 250, output_spk.clone()).is_ok()); assert_eq!(tx.output.len(), 0); - assert_eq!(tx.compute_wtxid(), orig_wtxid); // If we don't add an output, we don't change the transaction + // If we don't add an output, we don't change the transaction + assert_eq!(tx.compute_wtxid(), orig_wtxid); // 590 is also not enough - assert!(maybe_add_change_output(&mut tx, Amount::from_sat(590), 0, 250, output_spk.clone()).is_ok()); + let amount = Amount::from_sat(590); + assert!(maybe_add_change_output(&mut tx, amount, 0, 250, output_spk.clone()).is_ok()); assert_eq!(tx.output.len(), 0); - assert_eq!(tx.compute_wtxid(), orig_wtxid); // If we don't add an output, we don't change the transaction + // If we don't add an output, we don't change the transaction + assert_eq!(tx.compute_wtxid(), orig_wtxid); // at 591 we can afford the change output at the dust limit (546) - assert!(maybe_add_change_output(&mut tx, Amount::from_sat(591), 0, 250, output_spk.clone()).is_ok()); + let amount = Amount::from_sat(591); + assert!(maybe_add_change_output(&mut tx, amount, 0, 250, output_spk.clone()).is_ok()); assert_eq!(tx.output.len(), 1); assert_eq!(tx.output[0].value.to_sat(), 546); assert_eq!(tx.output[0].script_pubkey, output_spk); - assert_eq!(tx.weight().to_wu() / 4, 590-546); // New weight is exactly the fee we wanted. + // New weight is exactly the fee we wanted. + assert_eq!(tx.weight().to_wu() / 4, 590 - 546); tx.output.pop(); - assert_eq!(tx.compute_wtxid(), orig_wtxid); // The only change is the addition of one output. + // The only change is the addition of one output. + assert_eq!(tx.compute_wtxid(), orig_wtxid); } #[test] fn test_tx_extra_outputs() { // Check that we correctly handle existing outputs - let mut tx = Transaction { version: Version::TWO, lock_time: LockTime::ZERO, input: vec![TxIn { - previous_output: OutPoint::new(Txid::all_zeros(), 0), script_sig: ScriptBuf::new(), witness: Witness::new(), sequence: Sequence::ZERO, - }], output: vec![TxOut { - script_pubkey: Builder::new().push_int(1).into_script(), value: Amount::from_sat(1000) - }] }; + let script_pubkey = Builder::new().push_int(1).into_script(); + let tx_out = TxOut { script_pubkey, value: Amount::from_sat(1000) }; + let previous_output = OutPoint::new(Txid::all_zeros(), 0); + let script_sig = ScriptBuf::new(); + let witness = Witness::new(); + let sequence = Sequence::ZERO; + let tx_in = TxIn { previous_output, script_sig, witness, sequence }; + let version = Version::TWO; + let lock_time = LockTime::ZERO; + let input = vec![tx_in]; + let output = vec![tx_out]; + let mut tx = Transaction { version, lock_time, input, output }; let orig_wtxid = tx.compute_wtxid(); let orig_weight = tx.weight().to_wu(); assert_eq!(orig_weight / 4, 61); @@ -274,21 +304,34 @@ mod tests { assert_eq!(Builder::new().push_int(2).into_script().minimal_non_dust().to_sat(), 474); // Input value of the output value + fee - 1 should fail: - assert!(maybe_add_change_output(&mut tx, Amount::from_sat(1000 + 61 + 100 - 1), 400, 250, Builder::new().push_int(2).into_script()).is_err()); - assert_eq!(tx.compute_wtxid(), orig_wtxid); // Failure doesn't change the transaction + let amount = Amount::from_sat(1000 + 61 + 100 - 1); + let script = Builder::new().push_int(2).into_script(); + assert!(maybe_add_change_output(&mut tx, amount, 400, 250, script).is_err()); + // Failure doesn't change the transaction + assert_eq!(tx.compute_wtxid(), orig_wtxid); // but one more input sat should succeed, without changing the transaction - assert!(maybe_add_change_output(&mut tx, Amount::from_sat(1000 + 61 + 100), 400, 250, Builder::new().push_int(2).into_script()).is_ok()); - assert_eq!(tx.compute_wtxid(), orig_wtxid); // If we don't add an output, we don't change the transaction + let amount = Amount::from_sat(1000 + 61 + 100); + let script = Builder::new().push_int(2).into_script(); + assert!(maybe_add_change_output(&mut tx, amount, 400, 250, script).is_ok()); + // If we don't add an output, we don't change the transaction + assert_eq!(tx.compute_wtxid(), orig_wtxid); // In order to get a change output, we need to add 474 plus the output's weight / 4 (10)... - assert!(maybe_add_change_output(&mut tx, Amount::from_sat(1000 + 61 + 100 + 474 + 9), 400, 250, Builder::new().push_int(2).into_script()).is_ok()); - assert_eq!(tx.compute_wtxid(), orig_wtxid); // If we don't add an output, we don't change the transaction - - assert!(maybe_add_change_output(&mut tx, Amount::from_sat(1000 + 61 + 100 + 474 + 10), 400, 250, Builder::new().push_int(2).into_script()).is_ok()); + let amount = Amount::from_sat(1000 + 61 + 100 + 474 + 9); + let script = Builder::new().push_int(2).into_script(); + assert!(maybe_add_change_output(&mut tx, amount, 400, 250, script).is_ok()); + // If we don't add an output, we don't change the transaction + assert_eq!(tx.compute_wtxid(), orig_wtxid); + + let amount = Amount::from_sat(1000 + 61 + 100 + 474 + 10); + let script = Builder::new().push_int(2).into_script(); + assert!(maybe_add_change_output(&mut tx, amount, 400, 250, script).is_ok()); assert_eq!(tx.output.len(), 2); assert_eq!(tx.output[1].value.to_sat(), 474); assert_eq!(tx.output[1].script_pubkey, Builder::new().push_int(2).into_script()); - assert_eq!(tx.weight().to_wu() - orig_weight, 40); // Weight difference matches what we had to add above + // Weight difference matches what we had to add above + assert_eq!(tx.weight().to_wu() - orig_weight, 40); tx.output.pop(); - assert_eq!(tx.compute_wtxid(), orig_wtxid); // The only change is the addition of one output. + // The only change is the addition of one output. + assert_eq!(tx.compute_wtxid(), orig_wtxid); } } diff --git a/lightning/src/util/wakers.rs b/lightning/src/util/wakers.rs index a01948f3ea1..a23e866ec18 100644 --- a/lightning/src/util/wakers.rs +++ b/lightning/src/util/wakers.rs @@ -13,9 +13,9 @@ //! //! [`ChannelManager`]: crate::ln::channelmanager::ChannelManager +use crate::sync::Mutex; use alloc::sync::Arc; use core::mem; -use crate::sync::Mutex; #[allow(unused_imports)] use crate::prelude::*; @@ -26,9 +26,8 @@ use crate::sync::Condvar; use std::time::Duration; use core::future::Future as StdFuture; -use core::task::{Context, Poll}; use core::pin::Pin; - +use core::task::{Context, Poll}; /// Used to signal to one of many waiters that the condition they're waiting on has happened. pub(crate) struct Notifier { @@ -37,9 +36,7 @@ pub(crate) struct Notifier { impl Notifier { pub(crate) fn new() -> Self { - Self { - notify_pending: Mutex::new((false, None)), - } + Self { notify_pending: Mutex::new((false, None)) } } /// Wake waiters, tracking that wake needs to occur even if there are currently no waiters. @@ -198,7 +195,9 @@ impl Future { if state.complete { state.callbacks_made = true; true - } else { false } + } else { + false + } } } @@ -251,11 +250,8 @@ impl Sleeper { // Note that this is the common case - a ChannelManager, a ChainMonitor, and an // OnionMessenger. pub fn from_three_futures(fut_a: &Future, fut_b: &Future, fut_c: &Future) -> Self { - let notifiers = vec![ - Arc::clone(&fut_a.state), - Arc::clone(&fut_b.state), - Arc::clone(&fut_c.state) - ]; + let notifiers = + vec![Arc::clone(&fut_a.state), Arc::clone(&fut_b.state), Arc::clone(&fut_c.state)]; Self { notifiers } } /// Constructs a new sleeper on many futures, allowing blocking on all at once. @@ -289,8 +285,11 @@ impl Sleeper { /// Wait until one of the [`Future`]s registered with this [`Sleeper`] has completed. pub fn wait(&self) { let (cv, notified_fut_mtx) = self.setup_wait(); - let notified_fut = cv.wait_while(notified_fut_mtx.lock().unwrap(), |fut_opt| fut_opt.is_none()) - .unwrap().take().expect("CV wait shouldn't have returned until the notifying future was set"); + let notified_fut = cv + .wait_while(notified_fut_mtx.lock().unwrap(), |fut_opt| fut_opt.is_none()) + .unwrap() + .take() + .expect("CV wait shouldn't have returned until the notifying future was set"); notified_fut.lock().unwrap().callbacks_made = true; } @@ -300,10 +299,13 @@ impl Sleeper { pub fn wait_timeout(&self, max_wait: Duration) -> bool { let (cv, notified_fut_mtx) = self.setup_wait(); let notified_fut = - match cv.wait_timeout_while(notified_fut_mtx.lock().unwrap(), max_wait, |fut_opt| fut_opt.is_none()) { + match cv.wait_timeout_while(notified_fut_mtx.lock().unwrap(), max_wait, |fut_opt| { + fut_opt.is_none() + }) { Ok((_, e)) if e.timed_out() => return false, - Ok((mut notified_fut, _)) => - notified_fut.take().expect("CV wait shouldn't have returned until the notifying future was set"), + Ok((mut notified_fut, _)) => notified_fut + .take() + .expect("CV wait shouldn't have returned until the notifying future was set"), Err(_) => panic!("Previous panic while a lock was held led to a lock panic"), }; notified_fut.lock().unwrap().callbacks_made = true; @@ -314,8 +316,8 @@ impl Sleeper { #[cfg(test)] mod tests { use super::*; - use core::sync::atomic::{AtomicBool, Ordering}; use core::future::Future as FutureTrait; + use core::sync::atomic::{AtomicBool, Ordering}; use core::task::{RawWaker, RawWakerVTable}; #[test] @@ -328,7 +330,9 @@ mod tests { let callback = Arc::new(AtomicBool::new(false)); let callback_ref = Arc::clone(&callback); - notifier.get_future().register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)))); + notifier.get_future().register_callback(Box::new(move || { + assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)) + })); assert!(callback.load(Ordering::SeqCst)); } @@ -343,7 +347,9 @@ mod tests { // a second `notify`. let callback = Arc::new(AtomicBool::new(false)); let callback_ref = Arc::clone(&callback); - notifier.get_future().register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)))); + notifier.get_future().register_callback(Box::new(move || { + assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)) + })); assert!(!callback.load(Ordering::SeqCst)); notifier.notify(); @@ -351,7 +357,9 @@ mod tests { let callback = Arc::new(AtomicBool::new(false)); let callback_ref = Arc::clone(&callback); - notifier.get_future().register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)))); + notifier.get_future().register_callback(Box::new(move || { + assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)) + })); assert!(!callback.load(Ordering::SeqCst)); notifier.notify(); @@ -365,12 +373,16 @@ mod tests { let callback = Arc::new(AtomicBool::new(false)); let callback_ref = Arc::clone(&callback); - future.register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)))); + future.register_callback(Box::new(move || { + assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)) + })); assert!(callback.load(Ordering::SeqCst)); let callback = Arc::new(AtomicBool::new(false)); let callback_ref = Arc::clone(&callback); - notifier.get_future().register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)))); + notifier.get_future().register_callback(Box::new(move || { + assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)) + })); assert!(!callback.load(Ordering::SeqCst)); } @@ -384,12 +396,16 @@ mod tests { let callback = Arc::new(AtomicBool::new(false)); let callback_ref = Arc::clone(&callback); - notifier.get_future().register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)))); + notifier.get_future().register_callback(Box::new(move || { + assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)) + })); assert!(callback.load(Ordering::SeqCst)); let callback = Arc::new(AtomicBool::new(false)); let callback_ref = Arc::clone(&callback); - notifier.get_future().register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)))); + notifier.get_future().register_callback(Box::new(move || { + assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)) + })); assert!(!callback.load(Ordering::SeqCst)); notifier.notify(); @@ -407,12 +423,10 @@ mod tests { let exit_thread = Arc::new(AtomicBool::new(false)); let exit_thread_clone = exit_thread.clone(); - thread::spawn(move || { - loop { - thread_notifier.notify(); - if exit_thread_clone.load(Ordering::SeqCst) { - break - } + thread::spawn(move || loop { + thread_notifier.notify(); + if exit_thread_clone.load(Ordering::SeqCst) { + break; } }); @@ -423,7 +437,7 @@ mod tests { // available. loop { if persistence_notifier.get_future().wait_timeout(Duration::from_millis(100)) { - break + break; } } @@ -433,7 +447,7 @@ mod tests { // are available. loop { if !persistence_notifier.get_future().wait_timeout(Duration::from_millis(100)) { - break + break; } } } @@ -493,7 +507,9 @@ mod tests { }; let callback = Arc::new(AtomicBool::new(false)); let callback_ref = Arc::clone(&callback); - future.register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)))); + future.register_callback(Box::new(move || { + assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)) + })); assert!(!callback.load(Ordering::SeqCst)); complete_future(&future.state); @@ -518,7 +534,9 @@ mod tests { let callback = Arc::new(AtomicBool::new(false)); let callback_ref = Arc::clone(&callback); - future.register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)))); + future.register_callback(Box::new(move || { + assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)) + })); assert!(callback.load(Ordering::SeqCst)); assert!(future.state.lock().unwrap().callbacks.is_empty()); @@ -529,9 +547,18 @@ mod tests { // compared to a raw VTable). Instead, we have to write out a lot of boilerplate to build a // waker, which we do here with a trivial Arc data element to track woke-ness. const WAKER_V_TABLE: RawWakerVTable = RawWakerVTable::new(waker_clone, wake, wake_by_ref, drop); - unsafe fn wake_by_ref(ptr: *const ()) { let p = ptr as *const Arc; assert!(!(*p).fetch_or(true, Ordering::SeqCst)); } - unsafe fn drop(ptr: *const ()) { let p = ptr as *mut Arc; let _freed = Box::from_raw(p); } - unsafe fn wake(ptr: *const ()) { wake_by_ref(ptr); drop(ptr); } + unsafe fn wake_by_ref(ptr: *const ()) { + let p = ptr as *const Arc; + assert!(!(*p).fetch_or(true, Ordering::SeqCst)); + } + unsafe fn drop(ptr: *const ()) { + let p = ptr as *mut Arc; + let _freed = Box::from_raw(p); + } + unsafe fn wake(ptr: *const ()) { + wake_by_ref(ptr); + drop(ptr); + } unsafe fn waker_clone(ptr: *const ()) -> RawWaker { let p = ptr as *const Arc; RawWaker::new(Box::into_raw(Box::new(Arc::clone(&*p))) as *const (), &WAKER_V_TABLE) @@ -539,7 +566,8 @@ mod tests { fn create_waker() -> (Arc, Waker) { let a = Arc::new(AtomicBool::new(false)); - let waker = unsafe { Waker::from_raw(waker_clone((&a as *const Arc) as *const ())) }; + let waker = + unsafe { Waker::from_raw(waker_clone((&a as *const Arc) as *const ())) }; (a, waker) } @@ -563,14 +591,20 @@ mod tests { assert!(!woken.load(Ordering::SeqCst)); let (second_woken, second_waker) = create_waker(); - assert_eq!(Pin::new(&mut second_future).poll(&mut Context::from_waker(&second_waker)), Poll::Pending); + assert_eq!( + Pin::new(&mut second_future).poll(&mut Context::from_waker(&second_waker)), + Poll::Pending + ); assert!(!second_woken.load(Ordering::SeqCst)); complete_future(&future.state); assert!(woken.load(Ordering::SeqCst)); assert!(second_woken.load(Ordering::SeqCst)); assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(())); - assert_eq!(Pin::new(&mut second_future).poll(&mut Context::from_waker(&second_waker)), Poll::Ready(())); + assert_eq!( + Pin::new(&mut second_future).poll(&mut Context::from_waker(&second_waker)), + Poll::Ready(()) + ); } #[test] @@ -713,8 +747,12 @@ mod tests { let callback_b = Arc::new(AtomicBool::new(false)); let callback_a_ref = Arc::clone(&callback_a); let callback_b_ref = Arc::clone(&callback_b); - notifier_a.get_future().register_callback(Box::new(move || assert!(!callback_a_ref.fetch_or(true, Ordering::SeqCst)))); - notifier_b.get_future().register_callback(Box::new(move || assert!(!callback_b_ref.fetch_or(true, Ordering::SeqCst)))); + notifier_a.get_future().register_callback(Box::new(move || { + assert!(!callback_a_ref.fetch_or(true, Ordering::SeqCst)) + })); + notifier_b.get_future().register_callback(Box::new(move || { + assert!(!callback_b_ref.fetch_or(true, Ordering::SeqCst)) + })); assert!(callback_a.load(Ordering::SeqCst) ^ callback_b.load(Ordering::SeqCst)); // If we now notify both notifiers again, the other callback will fire, completing the @@ -739,14 +777,23 @@ mod tests { // Test that simply polling a future twice doesn't result in two pending `Waker`s. let mut future_a = notifier.get_future(); - assert_eq!(Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), Poll::Pending); + assert_eq!( + Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), + Poll::Pending + ); assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 1); - assert_eq!(Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), Poll::Pending); + assert_eq!( + Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), + Poll::Pending + ); assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 1); // If we poll a second future, however, that will store a second `Waker`. let mut future_b = notifier.get_future(); - assert_eq!(Pin::new(&mut future_b).poll(&mut Context::from_waker(&create_waker().1)), Poll::Pending); + assert_eq!( + Pin::new(&mut future_b).poll(&mut Context::from_waker(&create_waker().1)), + Poll::Pending + ); assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 2); // but when we drop the `Future`s, the pending Wakers will also be dropped. @@ -757,13 +804,22 @@ mod tests { // Further, after polling a future twice, if the notifier is woken all Wakers are dropped. let mut future_a = notifier.get_future(); - assert_eq!(Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), Poll::Pending); + assert_eq!( + Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), + Poll::Pending + ); assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 1); - assert_eq!(Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), Poll::Pending); + assert_eq!( + Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), + Poll::Pending + ); assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 1); notifier.notify(); assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 0); - assert_eq!(Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), Poll::Ready(())); + assert_eq!( + Pin::new(&mut future_a).poll(&mut Context::from_waker(&create_waker().1)), + Poll::Ready(()) + ); assert_eq!(future_state.lock().unwrap().std_future_callbacks.len(), 0); } } diff --git a/rustfmt_excluded_files b/rustfmt_excluded_files index a71a16f6e8d..973ecff3392 100644 --- a/rustfmt_excluded_files +++ b/rustfmt_excluded_files @@ -63,15 +63,3 @@ lightning/src/routing/router.rs lightning/src/routing/scoring.rs lightning/src/routing/test_utils.rs lightning/src/routing/utxo.rs -lightning/src/util/invoice.rs -lightning/src/util/message_signing.rs -lightning/src/util/mod.rs -lightning/src/util/scid_utils.rs -lightning/src/util/ser.rs -lightning/src/util/ser_macros.rs -lightning/src/util/string.rs -lightning/src/util/test_channel_signer.rs -lightning/src/util/test_utils.rs -lightning/src/util/time.rs -lightning/src/util/transaction_utils.rs -lightning/src/util/wakers.rs