Skip to content

Commit

Permalink
Add proper error handling to new InputPair functions
Browse files Browse the repository at this point in the history
  • Loading branch information
spacebear21 committed Oct 3, 2024
1 parent e8ca5ab commit f9ff6ad
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 42 deletions.
104 changes: 86 additions & 18 deletions payjoin/src/psbt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::collections::BTreeMap;
use std::fmt;

use bitcoin::address::FromScriptError;
use bitcoin::blockdata::script::Instruction;
use bitcoin::psbt::Psbt;
use bitcoin::transaction::InputWeightPrediction;
Expand Down Expand Up @@ -174,42 +175,44 @@ impl<'a> InputPair<'a> {
}
}

pub fn address_type(&self) -> AddressType {
let txo = self.previous_txout().expect("PrevTxoutError");
pub fn address_type(&self) -> Result<AddressType, AddressTypeError> {
let txo = self.previous_txout()?;
// HACK: Network doesn't matter for our use case of only getting the address type
// but is required in the `from_script` interface. Hardcoded to mainnet.
Address::from_script(&txo.script_pubkey, Network::Bitcoin)
.expect("Unrecognized script")
Address::from_script(&txo.script_pubkey, Network::Bitcoin)?
.address_type()
.expect("UnknownAddressType")
.ok_or(AddressTypeError::UnknownAddressType)
}

pub fn expected_input_weight(&self) -> Weight {
pub fn expected_input_weight(&self) -> Result<Weight, InputWeightError> {
use bitcoin::AddressType::*;

// Get the input weight prediction corresponding to spending an output of this address type
let iwp = match self.address_type() {
P2pkh => InputWeightPrediction::P2PKH_COMPRESSED_MAX,
let iwp = match self.address_type()? {
P2pkh => Ok(InputWeightPrediction::P2PKH_COMPRESSED_MAX),
P2sh =>
match self.psbtin.final_script_sig.as_ref().and_then(|s| redeem_script(s.as_ref()))
{
Some(script) if script.is_witness_program() && script.is_p2wpkh() =>
// Nested segwit p2wpkh.
// input script: 0x160014{20-byte-key-hash} = 23 bytes
// witness: <signature> <pubkey> = 72, 33 bytes
// https://github.com/bitcoin/bips/blob/master/bip-0141.mediawiki#p2wpkh-nested-in-bip16-p2sh
InputWeightPrediction::new(23, &[72, 33]),
Some(_) => unimplemented!(),
None => panic!("Input not finalized!"),
Some(script) if script.is_witness_program() && script.is_p2wpkh() =>
Ok(InputWeightPrediction::new(23, &[72, 33])),
// Other script or witness program.
Some(_) => Err(InputWeightError::NotSupported),
// No redeem script provided. Cannot determine the script type.
None => Err(InputWeightError::NotFinalized),
},
P2wpkh => InputWeightPrediction::P2WPKH_MAX,
P2wsh => unimplemented!(),
P2tr => InputWeightPrediction::P2TR_KEY_NON_DEFAULT_SIGHASH,
_ => panic!("Unknown address type!"),
};
P2wpkh => Ok(InputWeightPrediction::P2WPKH_MAX),
P2wsh => Err(InputWeightError::NotSupported),
P2tr => Ok(InputWeightPrediction::P2TR_KEY_NON_DEFAULT_SIGHASH),
_ => Err(AddressTypeError::UnknownAddressType.into()),
}?;

// Lengths of txid, index and sequence: (32, 4, 4).
let input_weight = iwp.weight() + Weight::from_non_witness_data_size(32 + 4 + 4);
input_weight
Ok(input_weight)
}
}

Expand Down Expand Up @@ -279,3 +282,68 @@ impl fmt::Display for PsbtInputsError {
impl std::error::Error for PsbtInputsError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { Some(&self.error) }
}

#[derive(Debug)]
pub(crate) enum AddressTypeError {
PrevTxOut(PrevTxOutError),
InvalidScript(FromScriptError),
UnknownAddressType,
}

impl fmt::Display for AddressTypeError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
AddressTypeError::PrevTxOut(_) => write!(f, "invalid previous transaction output"),
AddressTypeError::InvalidScript(_) => write!(f, "invalid script"),
AddressTypeError::UnknownAddressType => write!(f, "unknown address type"),
}
}
}

impl std::error::Error for AddressTypeError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
AddressTypeError::PrevTxOut(error) => Some(error),
AddressTypeError::InvalidScript(error) => Some(error),
AddressTypeError::UnknownAddressType => None,
}
}
}

impl From<PrevTxOutError> for AddressTypeError {
fn from(value: PrevTxOutError) -> Self { AddressTypeError::PrevTxOut(value) }
}

impl From<FromScriptError> for AddressTypeError {
fn from(value: FromScriptError) -> Self { AddressTypeError::InvalidScript(value) }
}

#[derive(Debug)]
pub(crate) enum InputWeightError {
AddressType(AddressTypeError),
NotFinalized,
NotSupported,
}

impl fmt::Display for InputWeightError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
InputWeightError::AddressType(_) => write!(f, "invalid address type"),
InputWeightError::NotFinalized => write!(f, "input not finalized"),
InputWeightError::NotSupported => write!(f, "weight prediction not supported"),
}
}
}

impl std::error::Error for InputWeightError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
InputWeightError::AddressType(error) => Some(error),
InputWeightError::NotFinalized => None,
InputWeightError::NotSupported => None,
}
}
}
impl From<AddressTypeError> for InputWeightError {
fn from(value: AddressTypeError) -> Self { InputWeightError::AddressType(value) }
}
10 changes: 10 additions & 0 deletions payjoin/src/receive/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ pub(crate) enum InternalRequestError {
InputOwned(bitcoin::ScriptBuf),
/// The original psbt has mixed input address types that could harm privacy
MixedInputScripts(bitcoin::AddressType, bitcoin::AddressType),
/// The address type could not be determined
AddressType(crate::psbt::AddressTypeError),
/// The expected input weight cannot be determined
InputWeight(crate::psbt::InputWeightError),
/// Original PSBT input has been seen before. Only automatic receivers, aka "interactive" in the spec
/// look out for these to prevent probing attacks.
InputSeen(bitcoin::OutPoint),
Expand Down Expand Up @@ -153,6 +157,10 @@ impl fmt::Display for RequestError {
"original-psbt-rejected",
&format!("Mixed input scripts: {}; {}.", type_a, type_b),
),
InternalRequestError::AddressType(e) =>
write_error(f, "original-psbt-rejected", &format!("AddressType Error: {}", e)),
InternalRequestError::InputWeight(e) =>
write_error(f, "original-psbt-rejected", &format!("InputWeight Error: {}", e)),
InternalRequestError::InputSeen(_) =>
write_error(f, "original-psbt-rejected", "The receiver rejected the original PSBT."),
#[cfg(feature = "v2")]
Expand Down Expand Up @@ -192,6 +200,8 @@ impl std::error::Error for RequestError {
InternalRequestError::SenderParams(e) => Some(e),
InternalRequestError::InconsistentPsbt(e) => Some(e),
InternalRequestError::PrevTxOut(e) => Some(e),
InternalRequestError::AddressType(e) => Some(e),
InternalRequestError::InputWeight(e) => Some(e),
#[cfg(feature = "v2")]
InternalRequestError::ParsePsbt(e) => Some(e),
#[cfg(feature = "v2")]
Expand Down
7 changes: 4 additions & 3 deletions payjoin/src/receive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,10 @@ impl MaybeMixedInputScripts {
let input_scripts = self
.psbt
.input_pairs()
.scan(&mut err, |err, input| match Ok(input.address_type()) {
.scan(&mut err, |err, input| match input.address_type() {
Ok(address_type) => Some(address_type),
Err(e) => {
**err = Err(RequestError::from(InternalRequestError::PrevTxOut(e)));
**err = Err(RequestError::from(InternalRequestError::AddressType(e)));
None
}
})
Expand Down Expand Up @@ -750,7 +750,8 @@ impl ProvisionalProposal {
// Calculate the additional weight contribution
let input_count = self.payjoin_psbt.inputs.len() - self.original_psbt.inputs.len();
log::trace!("input_count : {}", input_count);
let weight_per_input = input_pair.expected_input_weight();
let weight_per_input =
input_pair.expected_input_weight().map_err(InternalRequestError::InputWeight)?;
log::trace!("weight_per_input : {}", weight_per_input);
let contribution_weight = weight_per_input * input_count as u64;
log::trace!("contribution_weight: {}", contribution_weight);
Expand Down
21 changes: 15 additions & 6 deletions payjoin/src/send/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub struct ValidationError {
pub(crate) enum InternalValidationError {
Parse,
Io(std::io::Error),
InvalidProposedInput(crate::psbt::PrevTxOutError),
InvalidAddressType(crate::psbt::AddressTypeError),
VersionsDontMatch {
proposed: Version,
original: Version,
Expand Down Expand Up @@ -66,14 +66,20 @@ impl From<InternalValidationError> for ValidationError {
fn from(value: InternalValidationError) -> Self { ValidationError { internal: value } }
}

impl From<crate::psbt::AddressTypeError> for InternalValidationError {
fn from(value: crate::psbt::AddressTypeError) -> Self {
InternalValidationError::InvalidAddressType(value)
}
}

impl fmt::Display for ValidationError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use InternalValidationError::*;

match &self.internal {
Parse => write!(f, "couldn't decode as PSBT or JSON",),
Io(e) => write!(f, "couldn't read PSBT: {}", e),
InvalidProposedInput(e) => write!(f, "invalid proposed transaction input: {}", e),
InvalidAddressType(e) => write!(f, "invalid input address type: {}", e),
VersionsDontMatch { proposed, original, } => write!(f, "proposed transaction version {} doesn't match the original {}", proposed, original),
LockTimesDontMatch { proposed, original, } => write!(f, "proposed transaction lock time {} doesn't match the original {}", proposed, original),
SenderTxinSequenceChanged { proposed, original, } => write!(f, "proposed transaction sequence number {} doesn't match the original {}", proposed, original),
Expand Down Expand Up @@ -115,7 +121,7 @@ impl std::error::Error for ValidationError {
match &self.internal {
Parse => None,
Io(error) => Some(error),
InvalidProposedInput(error) => Some(error),
InvalidAddressType(error) => Some(error),
VersionsDontMatch { proposed: _, original: _ } => None,
LockTimesDontMatch { proposed: _, original: _ } => None,
SenderTxinSequenceChanged { proposed: _, original: _ } => None,
Expand Down Expand Up @@ -172,7 +178,8 @@ pub(crate) enum InternalCreateRequestError {
ChangeIndexOutOfBounds,
ChangeIndexPointsAtPayee,
Url(url::ParseError),
PrevTxOut(crate::psbt::PrevTxOutError),
AddressType(crate::psbt::AddressTypeError),
InputWeight(crate::psbt::InputWeightError),
#[cfg(feature = "v2")]
Hpke(crate::v2::HpkeError),
#[cfg(feature = "v2")]
Expand Down Expand Up @@ -202,7 +209,8 @@ impl fmt::Display for CreateRequestError {
ChangeIndexOutOfBounds => write!(f, "fee output index is points out of bounds"),
ChangeIndexPointsAtPayee => write!(f, "fee output index is points at output belonging to the payee"),
Url(e) => write!(f, "cannot parse url: {:#?}", e),
PrevTxOut(e) => write!(f, "invalid previous transaction output: {}", e),
AddressType(e) => write!(f, "can not determine input address type: {}", e),
InputWeight(e) => write!(f, "can not determine expected input weight: {}", e),
#[cfg(feature = "v2")]
Hpke(e) => write!(f, "v2 error: {}", e),
#[cfg(feature = "v2")]
Expand Down Expand Up @@ -234,7 +242,8 @@ impl std::error::Error for CreateRequestError {
ChangeIndexOutOfBounds => None,
ChangeIndexPointsAtPayee => None,
Url(error) => Some(error),
PrevTxOut(error) => Some(error),
AddressType(error) => Some(error),
InputWeight(error) => Some(error),
#[cfg(feature = "v2")]
Hpke(error) => Some(error),
#[cfg(feature = "v2")]
Expand Down
42 changes: 27 additions & 15 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,22 @@ impl<'a> RequestBuilder<'a> {

let first_input_pair =
input_pairs.first().ok_or(InternalCreateRequestError::NoInputs)?;
// use cheapest default if mixed input types
let mut input_weight =
bitcoin::transaction::InputWeightPrediction::P2TR_KEY_NON_DEFAULT_SIGHASH.weight()
// Lengths of txid, index and sequence: (32, 4, 4).
+ Weight::from_non_witness_data_size(32 + 4 + 4);
// Check if all inputs are the same type
if input_pairs
let input_weight = if input_pairs
.iter()
.all(|input_pair| input_pair.address_type() == first_input_pair.address_type())
.try_fold(true, |_, input_pair| -> Result<bool, crate::psbt::AddressTypeError> {
Ok(input_pair.address_type()? == first_input_pair.address_type()?)
})
.map_err(InternalCreateRequestError::AddressType)?
{
input_weight = first_input_pair.expected_input_weight();
}
first_input_pair
.expected_input_weight()
.map_err(InternalCreateRequestError::InputWeight)?
} else {
// use cheapest default if mixed input types
bitcoin::transaction::InputWeightPrediction::P2TR_KEY_NON_DEFAULT_SIGHASH.weight()
// Lengths of txid, index and sequence: (32, 4, 4).
+ Weight::from_non_witness_data_size(32 + 4 + 4)
};

let recommended_additional_fee = min_fee_rate * input_weight;
if fee_available < recommended_additional_fee {
Expand Down Expand Up @@ -228,7 +232,10 @@ impl<'a> RequestBuilder<'a> {
let zeroth_input = psbt.input_pairs().next().ok_or(InternalCreateRequestError::NoInputs)?;

let sequence = zeroth_input.txin.sequence;
let input_type = zeroth_input.address_type().to_string();
let input_type = zeroth_input
.address_type()
.map_err(InternalCreateRequestError::AddressType)?
.to_string();

Ok(RequestContext {
psbt,
Expand Down Expand Up @@ -493,12 +500,17 @@ impl ContextV1 {
ensure!(contributed_fee <= proposed_fee - original_fee, PayeeTookContributedFee);
let original_weight = self.original_psbt.clone().extract_tx_unchecked_fee_rate().weight();
let original_fee_rate = original_fee / original_weight;
// TODO: Refactor this to be support mixed input types, preferably share method with
// `ProvisionalProposal::additional_input_weight()`
// TODO: This should support mixed input types
ensure!(
contributed_fee
<= original_fee_rate
* self.original_psbt.input_pairs().next().unwrap().expected_input_weight()
* self
.original_psbt
.input_pairs()
.next()
.expect("This shouldn't happen. Failed to get an original input.")
.expected_input_weight()
.expect("This shouldn't happen. Weight should have been calculated successfully before.")
* (proposal.inputs.len() - self.original_psbt.inputs.len()) as u64,
FeeContributionPaysOutputSizeIncrease
);
Expand Down Expand Up @@ -570,7 +582,7 @@ impl ContextV1 {
ReceiverTxinMissingUtxoInfo
);
ensure!(proposed.txin.sequence == self.sequence, MixedSequence);
check_eq!(proposed.address_type(), self.input_type, MixedInputTypes);
check_eq!(proposed.address_type()?, self.input_type, MixedInputTypes);
}
}
}
Expand Down

0 comments on commit f9ff6ad

Please sign in to comment.