From eb39808fddc85947140e6997d1e162c323220540 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E5=AE=87?= Date: Fri, 27 Oct 2023 14:14:25 +0800 Subject: [PATCH] feat(wallet)!: add `new_or_load` methods These methods try to load wallet from persistence and initializes the wallet instead if non-existant. An internal helper method `create_signers` is added to reuse code. Documentation is also improved. --- crates/bdk/src/wallet/mod.rs | 239 +++++++++++++++++++++++++++-------- 1 file changed, 189 insertions(+), 50 deletions(-) diff --git a/crates/bdk/src/wallet/mod.rs b/crates/bdk/src/wallet/mod.rs index 38166a52f5..8b3a6e5748 100644 --- a/crates/bdk/src/wallet/mod.rs +++ b/crates/bdk/src/wallet/mod.rs @@ -28,7 +28,7 @@ use bdk_chain::{ Append, BlockId, ChainPosition, ConfirmationTime, ConfirmationTimeAnchor, FullTxOut, IndexedTxGraph, Persist, PersistBackend, }; -use bitcoin::secp256k1::Secp256k1; +use bitcoin::secp256k1::{All, Secp256k1}; use bitcoin::sighash::{EcdsaSighashType, TapSighashType}; use bitcoin::{ absolute, Address, Network, OutPoint, Script, ScriptBuf, Sequence, Transaction, TxOut, Txid, @@ -246,8 +246,13 @@ impl Wallet { } } +/// The error type when constructing a fresh [`Wallet`]. +/// +/// Methods [`new`] and [`new_with_genesis_hash`] may return this error. +/// +/// [`new`]: Wallet::new +/// [`new_with_genesis_hash`]: Wallet::new_with_genesis_hash #[derive(Debug)] -/// Error returned from [`Wallet::new`] pub enum NewError { /// There was problem with the passed-in descriptor(s). Descriptor(crate::descriptor::DescriptorError), @@ -270,7 +275,11 @@ where #[cfg(feature = "std")] impl std::error::Error for NewError where W: core::fmt::Display + core::fmt::Debug {} -/// An error that may occur when loading a [`Wallet`] from persistence. +/// The error type when loading a [`Wallet`] from persistence. +/// +/// Method [`load`] may return this error. +/// +/// [`load`]: Wallet::load #[derive(Debug)] pub enum LoadError { /// There was a problem with the passed-in descriptor(s). @@ -300,6 +309,64 @@ where #[cfg(feature = "std")] impl std::error::Error for LoadError where L: core::fmt::Display + core::fmt::Debug {} +/// Error type for when we try load a [`Wallet`] from persistence and creating it if non-existant. +/// +/// Methods [`new_or_load`] and [`new_or_load_with_genesis_hash`] may return this error. +/// +/// [`new_or_load`]: Wallet::new_or_load +/// [`new_or_load_with_genesis_hash`]: Wallet::new_or_load_with_genesis_hash +#[derive(Debug)] +pub enum NewOrLoadError { + /// There is a problem with the passed-in descriptor. + Descriptor(crate::descriptor::DescriptorError), + /// Writing to the persistence backend failed. + Write(W), + /// Loading from the persistence backend failed. + Load(L), + /// The loaded genesis hash does not match what was provided. + LoadedGenesisDoesNotMatch { + /// The expected genesis block hash. + expected: BlockHash, + /// The block hash loaded from persistence. + got: Option, + }, + /// The loaded network type does not match what was provided. + LoadedNetworkDoesNotMatch { + /// The expected network type. + expected: Network, + /// The network type loaded from persistence. + got: Option, + }, +} + +impl fmt::Display for NewOrLoadError +where + W: fmt::Display, + L: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + NewOrLoadError::Descriptor(e) => e.fmt(f), + NewOrLoadError::Write(e) => write!(f, "failed to write to persistence: {}", e), + NewOrLoadError::Load(e) => write!(f, "failed to load from persistence: {}", e), + NewOrLoadError::LoadedGenesisDoesNotMatch { expected, got } => { + write!(f, "loaded genesis hash is not {}, got {:?}", expected, got) + } + NewOrLoadError::LoadedNetworkDoesNotMatch { expected, got } => { + write!(f, "loaded network type is not {}, got {:?}", expected, got) + } + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for NewOrLoadError +where + W: core::fmt::Display + core::fmt::Debug, + L: core::fmt::Display + core::fmt::Debug, +{ +} + /// An error that may occur when inserting a transaction into [`Wallet`]. #[derive(Debug)] pub enum InsertTxError { @@ -314,8 +381,7 @@ pub enum InsertTxError { } impl Wallet { - /// Create a wallet from a `descriptor` (and an optional `change_descriptor`) and load related - /// transaction data from `db`. + /// Initialize an empty [`Wallet`]. pub fn new( descriptor: E, change_descriptor: Option, @@ -329,9 +395,10 @@ impl Wallet { Self::new_with_genesis_hash(descriptor, change_descriptor, db, network, genesis_hash) } - /// Create a new [`Wallet`] with a custom genesis hash. + /// Initialize an empty [`Wallet`] with a custom genesis hash. /// - /// This is like [`Wallet::new`] with an additional `custom_genesis_hash` parameter. + /// This is like [`Wallet::new`] with an additional `genesis_hash` parameter. This is useful + /// for syncing from alternative networks. pub fn new_with_genesis_hash( descriptor: E, change_descriptor: Option, @@ -343,33 +410,18 @@ impl Wallet { D: PersistBackend, { let secp = Secp256k1::new(); - let (chain, _) = LocalChain::from_genesis_hash(genesis_hash); - let mut indexed_graph = - IndexedTxGraph::>::default(); + let (chain, chain_changeset) = LocalChain::from_genesis_hash(genesis_hash); + let mut index = KeychainTxOutIndex::::default(); - let (descriptor, keymap) = into_wallet_descriptor_checked(descriptor, &secp, network) - .map_err(NewError::Descriptor)?; - indexed_graph - .index - .add_keychain(KeychainKind::External, descriptor.clone()); - let signers = Arc::new(SignersContainer::build(keymap, &descriptor, &secp)); - - let change_signers = Arc::new(match change_descriptor { - Some(desc) => { - let (descriptor, keymap) = into_wallet_descriptor_checked(desc, &secp, network) - .map_err(NewError::Descriptor)?; - let signers = SignersContainer::build(keymap, &descriptor, &secp); - indexed_graph - .index - .add_keychain(KeychainKind::Internal, descriptor); - signers - } - None => SignersContainer::new(), - }); + let (signers, change_signers) = + create_signers(&mut index, &secp, descriptor, change_descriptor, network) + .map_err(NewError::Descriptor)?; + + let indexed_graph = IndexedTxGraph::new(index); let mut persist = Persist::new(db); persist.stage(ChangeSet { - chain: chain.initial_changeset(), + chain: chain_changeset, indexed_tx_graph: indexed_graph.initial_changeset(), network: Some(network), }); @@ -386,7 +438,7 @@ impl Wallet { }) } - /// Load [`Wallet`] from persistence. + /// Load [`Wallet`] from the given persistence backend. pub fn load( descriptor: E, change_descriptor: Option, @@ -396,31 +448,15 @@ impl Wallet { D: PersistBackend, { let secp = Secp256k1::new(); - let changeset = db.load_from_persistence().map_err(LoadError::Load)?; let network = changeset.network.ok_or(LoadError::MissingNetwork)?; - let chain = LocalChain::from_changeset(changeset.chain).map_err(|_| LoadError::MissingGenesis)?; - let mut index = KeychainTxOutIndex::::default(); - let (descriptor, keymap) = into_wallet_descriptor_checked(descriptor, &secp, network) - .map_err(LoadError::Descriptor)?; - let signers = Arc::new(SignersContainer::build(keymap, &descriptor, &secp)); - index.add_keychain(KeychainKind::External, descriptor); - - let change_signers = Arc::new(match change_descriptor { - Some(descriptor) => { - let (descriptor, keymap) = - into_wallet_descriptor_checked(descriptor, &secp, network) - .map_err(LoadError::Descriptor)?; - let signers = SignersContainer::build(keymap, &descriptor, &secp); - index.add_keychain(KeychainKind::Internal, descriptor); - signers - } - None => SignersContainer::new(), - }); + let (signers, change_signers) = + create_signers(&mut index, &secp, descriptor, change_descriptor, network) + .map_err(LoadError::Descriptor)?; let indexed_graph = IndexedTxGraph::new(index); let persist = Persist::new(db); @@ -436,6 +472,85 @@ impl Wallet { }) } + /// Either loads [`Wallet`] from persistence, or initializes it if it does not exist. + /// + /// This method will fail if the loaded [`Wallet`] has different parameters to those provided. + pub fn new_or_load( + descriptor: E, + change_descriptor: Option, + db: D, + network: Network, + ) -> Result> + where + D: PersistBackend, + { + let genesis_hash = genesis_block(network).block_hash(); + Self::new_or_load_with_genesis_hash( + descriptor, + change_descriptor, + db, + network, + genesis_hash, + ) + } + + /// Either loads [`Wallet`] from persistence, or initializes it if it does not exist (with a + /// custom genesis hash). + /// + /// This method will fail if the loaded [`Wallet`] has different parameters to those provided. + /// This is like [`Wallet::new_or_load`] with an additional `genesis_hash` parameter. This is + /// useful for syncing from alternative networks. + pub fn new_or_load_with_genesis_hash( + descriptor: E, + change_descriptor: Option, + mut db: D, + network: Network, + genesis_hash: BlockHash, + ) -> Result> + where + D: PersistBackend, + { + if db.is_empty().map_err(NewOrLoadError::Load)? { + return Self::new_with_genesis_hash( + descriptor, + change_descriptor, + db, + network, + genesis_hash, + ) + .map_err(|e| match e { + NewError::Descriptor(e) => NewOrLoadError::Descriptor(e), + NewError::Write(e) => NewOrLoadError::Write(e), + }); + } + + let wallet = Self::load(descriptor, change_descriptor, db).map_err(|e| match e { + LoadError::Descriptor(e) => NewOrLoadError::Descriptor(e), + LoadError::Load(e) => NewOrLoadError::Load(e), + LoadError::MissingNetwork => NewOrLoadError::LoadedNetworkDoesNotMatch { + expected: network, + got: None, + }, + LoadError::MissingGenesis => NewOrLoadError::LoadedGenesisDoesNotMatch { + expected: genesis_hash, + got: None, + }, + })?; + if wallet.chain.genesis_hash() != genesis_hash { + return Err(NewOrLoadError::LoadedGenesisDoesNotMatch { + expected: genesis_hash, + got: Some(wallet.chain.genesis_hash()), + }); + } + if wallet.network != network { + return Err(NewOrLoadError::LoadedNetworkDoesNotMatch { + expected: network, + got: Some(wallet.network), + }); + } + Ok(wallet) + } + /// Get the Bitcoin network the wallet is using. pub fn network(&self) -> Network { self.network @@ -2149,6 +2264,30 @@ fn new_local_utxo( } } +fn create_signers( + index: &mut KeychainTxOutIndex, + secp: &Secp256k1, + descriptor: E, + change_descriptor: Option, + network: Network, +) -> Result<(Arc, Arc), crate::descriptor::error::Error> { + let (descriptor, keymap) = into_wallet_descriptor_checked(descriptor, secp, network)?; + let signers = Arc::new(SignersContainer::build(keymap, &descriptor, secp)); + index.add_keychain(KeychainKind::External, descriptor); + + let change_signers = match change_descriptor { + Some(descriptor) => { + let (descriptor, keymap) = into_wallet_descriptor_checked(descriptor, secp, network)?; + let signers = Arc::new(SignersContainer::build(keymap, &descriptor, secp)); + index.add_keychain(KeychainKind::Internal, descriptor); + signers + } + None => Arc::new(SignersContainer::new()), + }; + + Ok((signers, change_signers)) +} + #[macro_export] #[doc(hidden)] /// Macro for getting a wallet for use in a doctest