From 20c12734555c97e283dce2974ca9d5e63a8badbe Mon Sep 17 00:00:00 2001 From: spacebear Date: Wed, 2 Oct 2024 12:08:38 -0400 Subject: [PATCH] Add proper error handling --- payjoin/src/psbt.rs | 104 +++++++++++++++++++++++++++++------ payjoin/src/receive/error.rs | 10 ++++ payjoin/src/receive/mod.rs | 7 ++- payjoin/src/send/error.rs | 21 +++++-- payjoin/src/send/mod.rs | 42 +++++++++----- 5 files changed, 142 insertions(+), 42 deletions(-) diff --git a/payjoin/src/psbt.rs b/payjoin/src/psbt.rs index c7510dd7..aaa4ccf7 100644 --- a/payjoin/src/psbt.rs +++ b/payjoin/src/psbt.rs @@ -3,6 +3,7 @@ use std::collections::BTreeMap; use std::fmt; +use bitcoin::address::FromScriptError; use bitcoin::blockdata::script::Instruction; use bitcoin::psbt::Psbt; use bitcoin::transaction::InputWeightPrediction; @@ -174,42 +175,44 @@ impl<'a> InputPair<'a> { } } - pub fn address_type(&self) -> AddressType { - let txo = self.previous_txout().expect("PrevTxoutError"); + pub fn address_type(&self) -> Result { + let txo = self.previous_txout()?; // HACK: Network doesn't matter for our use case of only getting the address type // but is required in the `from_script` interface. Hardcoded to mainnet. - Address::from_script(&txo.script_pubkey, Network::Bitcoin) - .expect("Unrecognized script") + Address::from_script(&txo.script_pubkey, Network::Bitcoin)? .address_type() - .expect("UnknownAddressType") + .ok_or(AddressTypeError::UnknownAddressType) } - pub fn expected_input_weight(&self) -> Weight { + pub fn expected_input_weight(&self) -> Result { use bitcoin::AddressType::*; // Get the input weight prediction corresponding to spending an output of this address type - let iwp = match self.address_type() { - P2pkh => InputWeightPrediction::P2PKH_COMPRESSED_MAX, + let iwp = match self.address_type()? { + P2pkh => Ok(InputWeightPrediction::P2PKH_COMPRESSED_MAX), P2sh => match self.psbtin.final_script_sig.as_ref().and_then(|s| redeem_script(s.as_ref())) { - Some(script) if script.is_witness_program() && script.is_p2wpkh() => + // Nested segwit p2wpkh. // input script: 0x160014{20-byte-key-hash} = 23 bytes // witness: = 72, 33 bytes // https://github.com/bitcoin/bips/blob/master/bip-0141.mediawiki#p2wpkh-nested-in-bip16-p2sh - InputWeightPrediction::new(23, &[72, 33]), - Some(_) => unimplemented!(), - None => panic!("Input not finalized!"), + Some(script) if script.is_witness_program() && script.is_p2wpkh() => + Ok(InputWeightPrediction::new(23, &[72, 33])), + // Other script or witness program. + Some(_) => Err(InputWeightError::NotSupported), + // No redeem script provided. Cannot determine the script type. + None => Err(InputWeightError::NotFinalized), }, - P2wpkh => InputWeightPrediction::P2WPKH_MAX, - P2wsh => unimplemented!(), - P2tr => InputWeightPrediction::P2TR_KEY_NON_DEFAULT_SIGHASH, - _ => panic!("Unknown address type!"), - }; + P2wpkh => Ok(InputWeightPrediction::P2WPKH_MAX), + P2wsh => Err(InputWeightError::NotSupported), + P2tr => Ok(InputWeightPrediction::P2TR_KEY_NON_DEFAULT_SIGHASH), + _ => Err(AddressTypeError::UnknownAddressType.into()), + }?; // Lengths of txid, index and sequence: (32, 4, 4). let input_weight = iwp.weight() + Weight::from_non_witness_data_size(32 + 4 + 4); - input_weight + Ok(input_weight) } } @@ -279,3 +282,68 @@ impl fmt::Display for PsbtInputsError { impl std::error::Error for PsbtInputsError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { Some(&self.error) } } + +#[derive(Debug)] +pub(crate) enum AddressTypeError { + PrevTxOut(PrevTxOutError), + InvalidScript(FromScriptError), + UnknownAddressType, +} + +impl fmt::Display for AddressTypeError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + AddressTypeError::PrevTxOut(_) => write!(f, "invalid previous transaction output"), + AddressTypeError::InvalidScript(_) => write!(f, "invalid script"), + AddressTypeError::UnknownAddressType => write!(f, "unknown address type"), + } + } +} + +impl std::error::Error for AddressTypeError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + AddressTypeError::PrevTxOut(error) => Some(error), + AddressTypeError::InvalidScript(error) => Some(error), + AddressTypeError::UnknownAddressType => None, + } + } +} + +impl From for AddressTypeError { + fn from(value: PrevTxOutError) -> Self { AddressTypeError::PrevTxOut(value) } +} + +impl From for AddressTypeError { + fn from(value: FromScriptError) -> Self { AddressTypeError::InvalidScript(value) } +} + +#[derive(Debug)] +pub(crate) enum InputWeightError { + AddressType(AddressTypeError), + NotFinalized, + NotSupported, +} + +impl fmt::Display for InputWeightError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + InputWeightError::AddressType(_) => write!(f, "invalid address type"), + InputWeightError::NotFinalized => write!(f, "input not finalized"), + InputWeightError::NotSupported => write!(f, "weight prediction not supported"), + } + } +} + +impl std::error::Error for InputWeightError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + InputWeightError::AddressType(error) => Some(error), + InputWeightError::NotFinalized => None, + InputWeightError::NotSupported => None, + } + } +} +impl From for InputWeightError { + fn from(value: AddressTypeError) -> Self { InputWeightError::AddressType(value) } +} diff --git a/payjoin/src/receive/error.rs b/payjoin/src/receive/error.rs index 95aa68ec..f1ba34d4 100644 --- a/payjoin/src/receive/error.rs +++ b/payjoin/src/receive/error.rs @@ -74,6 +74,10 @@ pub(crate) enum InternalRequestError { InputOwned(bitcoin::ScriptBuf), /// The original psbt has mixed input address types that could harm privacy MixedInputScripts(bitcoin::AddressType, bitcoin::AddressType), + /// The address type could not be determined + AddressType(crate::psbt::AddressTypeError), + /// The expected input weight cannot be determined + InputWeight(crate::psbt::InputWeightError), /// Original PSBT input has been seen before. Only automatic receivers, aka "interactive" in the spec /// look out for these to prevent probing attacks. InputSeen(bitcoin::OutPoint), @@ -153,6 +157,10 @@ impl fmt::Display for RequestError { "original-psbt-rejected", &format!("Mixed input scripts: {}; {}.", type_a, type_b), ), + InternalRequestError::AddressType(e) => + write_error(f, "original-psbt-rejected", &format!("AddressType Error: {}", e)), + InternalRequestError::InputWeight(e) => + write_error(f, "original-psbt-rejected", &format!("InputWeight Error: {}", e)), InternalRequestError::InputSeen(_) => write_error(f, "original-psbt-rejected", "The receiver rejected the original PSBT."), #[cfg(feature = "v2")] @@ -192,6 +200,8 @@ impl std::error::Error for RequestError { InternalRequestError::SenderParams(e) => Some(e), InternalRequestError::InconsistentPsbt(e) => Some(e), InternalRequestError::PrevTxOut(e) => Some(e), + InternalRequestError::AddressType(e) => Some(e), + InternalRequestError::InputWeight(e) => Some(e), #[cfg(feature = "v2")] InternalRequestError::ParsePsbt(e) => Some(e), #[cfg(feature = "v2")] diff --git a/payjoin/src/receive/mod.rs b/payjoin/src/receive/mod.rs index 4bc5682b..0f31a8c4 100644 --- a/payjoin/src/receive/mod.rs +++ b/payjoin/src/receive/mod.rs @@ -228,10 +228,10 @@ impl MaybeMixedInputScripts { let input_scripts = self .psbt .input_pairs() - .scan(&mut err, |err, input| match Ok(input.address_type()) { + .scan(&mut err, |err, input| match input.address_type() { Ok(address_type) => Some(address_type), Err(e) => { - **err = Err(RequestError::from(InternalRequestError::PrevTxOut(e))); + **err = Err(RequestError::from(InternalRequestError::AddressType(e))); None } }) @@ -750,7 +750,8 @@ impl ProvisionalProposal { // Calculate the additional weight contribution let input_count = self.payjoin_psbt.inputs.len() - self.original_psbt.inputs.len(); log::trace!("input_count : {}", input_count); - let weight_per_input = input_pair.expected_input_weight(); + let weight_per_input = + input_pair.expected_input_weight().map_err(InternalRequestError::InputWeight)?; log::trace!("weight_per_input : {}", weight_per_input); let contribution_weight = weight_per_input * input_count as u64; log::trace!("contribution_weight: {}", contribution_weight); diff --git a/payjoin/src/send/error.rs b/payjoin/src/send/error.rs index ee510f34..8f240e3f 100644 --- a/payjoin/src/send/error.rs +++ b/payjoin/src/send/error.rs @@ -17,7 +17,7 @@ pub struct ValidationError { pub(crate) enum InternalValidationError { Parse, Io(std::io::Error), - InvalidProposedInput(crate::psbt::PrevTxOutError), + InvalidAddressType(crate::psbt::AddressTypeError), VersionsDontMatch { proposed: Version, original: Version, @@ -66,6 +66,12 @@ impl From for ValidationError { fn from(value: InternalValidationError) -> Self { ValidationError { internal: value } } } +impl From for InternalValidationError { + fn from(value: crate::psbt::AddressTypeError) -> Self { + InternalValidationError::InvalidAddressType(value) + } +} + impl fmt::Display for ValidationError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { use InternalValidationError::*; @@ -73,7 +79,7 @@ impl fmt::Display for ValidationError { match &self.internal { Parse => write!(f, "couldn't decode as PSBT or JSON",), Io(e) => write!(f, "couldn't read PSBT: {}", e), - InvalidProposedInput(e) => write!(f, "invalid proposed transaction input: {}", e), + InvalidAddressType(e) => write!(f, "invalid input address type: {}", e), VersionsDontMatch { proposed, original, } => write!(f, "proposed transaction version {} doesn't match the original {}", proposed, original), LockTimesDontMatch { proposed, original, } => write!(f, "proposed transaction lock time {} doesn't match the original {}", proposed, original), SenderTxinSequenceChanged { proposed, original, } => write!(f, "proposed transaction sequence number {} doesn't match the original {}", proposed, original), @@ -115,7 +121,7 @@ impl std::error::Error for ValidationError { match &self.internal { Parse => None, Io(error) => Some(error), - InvalidProposedInput(error) => Some(error), + InvalidAddressType(error) => Some(error), VersionsDontMatch { proposed: _, original: _ } => None, LockTimesDontMatch { proposed: _, original: _ } => None, SenderTxinSequenceChanged { proposed: _, original: _ } => None, @@ -172,7 +178,8 @@ pub(crate) enum InternalCreateRequestError { ChangeIndexOutOfBounds, ChangeIndexPointsAtPayee, Url(url::ParseError), - PrevTxOut(crate::psbt::PrevTxOutError), + AddressType(crate::psbt::AddressTypeError), + InputWeight(crate::psbt::InputWeightError), #[cfg(feature = "v2")] Hpke(crate::v2::HpkeError), #[cfg(feature = "v2")] @@ -202,7 +209,8 @@ impl fmt::Display for CreateRequestError { ChangeIndexOutOfBounds => write!(f, "fee output index is points out of bounds"), ChangeIndexPointsAtPayee => write!(f, "fee output index is points at output belonging to the payee"), Url(e) => write!(f, "cannot parse url: {:#?}", e), - PrevTxOut(e) => write!(f, "invalid previous transaction output: {}", e), + AddressType(e) => write!(f, "can not determine input address type: {}", e), + InputWeight(e) => write!(f, "can not determine expected input weight: {}", e), #[cfg(feature = "v2")] Hpke(e) => write!(f, "v2 error: {}", e), #[cfg(feature = "v2")] @@ -234,7 +242,8 @@ impl std::error::Error for CreateRequestError { ChangeIndexOutOfBounds => None, ChangeIndexPointsAtPayee => None, Url(error) => Some(error), - PrevTxOut(error) => Some(error), + AddressType(error) => Some(error), + InputWeight(error) => Some(error), #[cfg(feature = "v2")] Hpke(error) => Some(error), #[cfg(feature = "v2")] diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index a36e6945..bc2bc75a 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -133,18 +133,22 @@ impl<'a> RequestBuilder<'a> { let first_input_pair = input_pairs.first().ok_or(InternalCreateRequestError::NoInputs)?; - // use cheapest default if mixed input types - let mut input_weight = - bitcoin::transaction::InputWeightPrediction::P2TR_KEY_NON_DEFAULT_SIGHASH.weight() - // Lengths of txid, index and sequence: (32, 4, 4). - + Weight::from_non_witness_data_size(32 + 4 + 4); - // Check if all inputs are the same type - if input_pairs + let input_weight = if input_pairs .iter() - .all(|input_pair| input_pair.address_type() == first_input_pair.address_type()) + .try_fold(true, |_, input_pair| -> Result { + Ok(input_pair.address_type()? == first_input_pair.address_type()?) + }) + .map_err(InternalCreateRequestError::AddressType)? { - input_weight = first_input_pair.expected_input_weight(); - } + first_input_pair + .expected_input_weight() + .map_err(InternalCreateRequestError::InputWeight)? + } else { + // use cheapest default if mixed input types + bitcoin::transaction::InputWeightPrediction::P2TR_KEY_NON_DEFAULT_SIGHASH.weight() + // Lengths of txid, index and sequence: (32, 4, 4). + + Weight::from_non_witness_data_size(32 + 4 + 4) + }; let recommended_additional_fee = min_fee_rate * input_weight; if fee_available < recommended_additional_fee { @@ -230,7 +234,10 @@ impl<'a> RequestBuilder<'a> { let zeroth_input = psbt.input_pairs().next().ok_or(InternalCreateRequestError::NoInputs)?; let sequence = zeroth_input.txin.sequence; - let input_type = zeroth_input.address_type().to_string(); + let input_type = zeroth_input + .address_type() + .map_err(InternalCreateRequestError::AddressType)? + .to_string(); #[cfg(feature = "v2")] let e = { @@ -620,12 +627,17 @@ impl ContextV1 { ensure!(contributed_fee <= proposed_fee - original_fee, PayeeTookContributedFee); let original_weight = self.original_psbt.clone().extract_tx_unchecked_fee_rate().weight(); let original_fee_rate = original_fee / original_weight; - // TODO: Refactor this to be support mixed input types, preferably share method with - // `ProvisionalProposal::additional_input_weight()` + // TODO: This should support mixed input types ensure!( contributed_fee <= original_fee_rate - * self.original_psbt.input_pairs().next().unwrap().expected_input_weight() + * self + .original_psbt + .input_pairs() + .next() + .expect("This shouldn't happen. Failed to get an original input.") + .expected_input_weight() + .expect("This shouldn't happen. Weight should have been calculated successfully before.") * (proposal.inputs.len() - self.original_psbt.inputs.len()) as u64, FeeContributionPaysOutputSizeIncrease ); @@ -697,7 +709,7 @@ impl ContextV1 { ReceiverTxinMissingUtxoInfo ); ensure!(proposed.txin.sequence == self.sequence, MixedSequence); - check_eq!(proposed.address_type(), self.input_type, MixedInputTypes); + check_eq!(proposed.address_type()?, self.input_type, MixedInputTypes); } } }