diff --git a/payjoin/src/psbt.rs b/payjoin/src/psbt.rs index c8fa043e..371e4e6d 100644 --- a/payjoin/src/psbt.rs +++ b/payjoin/src/psbt.rs @@ -3,8 +3,10 @@ use std::collections::BTreeMap; use std::fmt; +use bitcoin::blockdata::script::Instruction; use bitcoin::psbt::Psbt; -use bitcoin::{bip32, psbt, TxIn, TxOut}; +use bitcoin::transaction::InputWeightPrediction; +use bitcoin::{bip32, psbt, Address, AddressType, Network, Script, TxIn, TxOut, Weight}; #[derive(Debug)] pub(crate) enum InconsistentPsbt { @@ -107,6 +109,14 @@ impl PsbtExt for Psbt { } } +/// Gets redeemScript from the script_sig following BIP16 rules regarding P2SH spending. +fn redeem_script(script_sig: &Script) -> Option<&Script> { + match script_sig.instructions().last()?.ok()? { + Instruction::PushBytes(bytes) => Some(Script::from_bytes(bytes.as_bytes())), + Instruction::Op(_) => None, + } +} + pub(crate) struct InputPair<'a> { pub txin: &'a TxIn, pub psbtin: &'a psbt::Input, @@ -180,6 +190,44 @@ impl<'a> InputPair<'a> { (Some(_), Some(_)) => Err(PsbtInputError::UnequalTxid), } } + + pub fn address_type(&self) -> AddressType { + let txo = self.previous_txout().expect("PrevTxoutError"); + // HACK: Network doesn't matter for our use case of only getting the address type + // but is required in the `from_script` interface. Hardcoded to mainnet. + Address::from_script(&txo.script_pubkey, Network::Bitcoin) + .expect("Unrecognized script") + .address_type() + .expect("UnknownAddressType") + } + + pub fn expected_input_weight(&self) -> Weight { + use bitcoin::AddressType::*; + + // Get the input weight prediction corresponding to spending an output of this address type + let iwp = match self.address_type() { + P2pkh => InputWeightPrediction::P2PKH_COMPRESSED_MAX, + P2sh => + match self.psbtin.final_script_sig.as_ref().and_then(|s| redeem_script(s.as_ref())) + { + Some(script) if script.is_witness_program() && script.is_p2wpkh() => + // input script: 0x160014{20-byte-key-hash} = 23 bytes + // witness: = 72, 33 bytes + // https://github.com/bitcoin/bips/blob/master/bip-0141.mediawiki#p2wpkh-nested-in-bip16-p2sh + InputWeightPrediction::new(23, &[72, 33]), + Some(_) => unimplemented!(), + None => panic!("Input not finalized!"), + }, + P2wpkh => InputWeightPrediction::P2WPKH_MAX, + P2wsh => unimplemented!(), + P2tr => InputWeightPrediction::P2TR_KEY_NON_DEFAULT_SIGHASH, + _ => panic!("Unknown address type!"), + }; + + // Lengths of txid, index and sequence: (32, 4, 4). + let input_weight = iwp.weight() + Weight::from_non_witness_data_size(32 + 4 + 4); + input_weight + } } #[derive(Debug)] diff --git a/payjoin/src/receive/error.rs b/payjoin/src/receive/error.rs index 2056bcf8..a897fe3f 100644 --- a/payjoin/src/receive/error.rs +++ b/payjoin/src/receive/error.rs @@ -73,7 +73,7 @@ pub(crate) enum InternalRequestError { /// The sender is trying to spend the receiver input InputOwned(bitcoin::ScriptBuf), /// The original psbt has mixed input address types that could harm privacy - MixedInputScripts(crate::input_type::InputType, crate::input_type::InputType), + MixedInputScripts(bitcoin::AddressType, bitcoin::AddressType), /// Unrecognized input type InputType(crate::input_type::InputTypeError), /// Original PSBT input has been seen before. Only automatic receivers, aka "interactive" in the spec diff --git a/payjoin/src/receive/mod.rs b/payjoin/src/receive/mod.rs index 0f668040..4bc5682b 100644 --- a/payjoin/src/receive/mod.rs +++ b/payjoin/src/receive/mod.rs @@ -46,7 +46,6 @@ use error::{ }; use optional_parameters::Params; -use crate::input_type::InputType; use crate::psbt::PsbtExt; pub trait Headers { @@ -229,14 +228,8 @@ impl MaybeMixedInputScripts { let input_scripts = self .psbt .input_pairs() - .scan(&mut err, |err, input| match input.previous_txout() { - Ok(txout) => match InputType::from_spent_input(txout, input.psbtin) { - Ok(input_script) => Some(input_script), - Err(e) => { - **err = Err(RequestError::from(InternalRequestError::InputType(e))); - None - } - }, + .scan(&mut err, |err, input| match Ok(input.address_type()) { + Ok(address_type) => Some(address_type), Err(e) => { **err = Err(RequestError::from(InternalRequestError::PrevTxOut(e))); None @@ -755,12 +748,9 @@ impl ProvisionalProposal { .next() .ok_or(InternalRequestError::OriginalPsbtNotBroadcastable)?; // Calculate the additional weight contribution - let txo = input_pair.previous_txout().map_err(InternalRequestError::PrevTxOut)?; - let input_type = InputType::from_spent_input(txo, &self.payjoin_psbt.inputs[0]) - .map_err(InternalRequestError::InputType)?; let input_count = self.payjoin_psbt.inputs.len() - self.original_psbt.inputs.len(); log::trace!("input_count : {}", input_count); - let weight_per_input = input_type.expected_input_weight(); + let weight_per_input = input_pair.expected_input_weight(); log::trace!("weight_per_input : {}", weight_per_input); let contribution_weight = weight_per_input * input_count as u64; log::trace!("contribution_weight: {}", contribution_weight); diff --git a/payjoin/src/send/error.rs b/payjoin/src/send/error.rs index 3be59131..f9e86a9f 100644 --- a/payjoin/src/send/error.rs +++ b/payjoin/src/send/error.rs @@ -2,7 +2,7 @@ use std::fmt::{self, Display}; use bitcoin::locktime::absolute::LockTime; use bitcoin::transaction::Version; -use bitcoin::Sequence; +use bitcoin::{AddressType, Sequence}; use crate::input_type::{InputType, InputTypeError}; @@ -43,8 +43,8 @@ pub(crate) enum InternalValidationError { ReceiverTxinMissingUtxoInfo, MixedSequence, MixedInputTypes { - proposed: InputType, - original: InputType, + proposed: AddressType, + original: AddressType, }, MissingOrShuffledInputs, TxOutContainsKeyPaths, diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index c0c60ffe..545b9da4 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -31,7 +31,7 @@ use bitcoin::psbt::Psbt; use bitcoin::secp256k1::rand; #[cfg(feature = "v2")] use bitcoin::secp256k1::PublicKey; -use bitcoin::{Amount, FeeRate, Script, ScriptBuf, Sequence, TxOut, Weight}; +use bitcoin::{AddressType, Amount, FeeRate, Script, ScriptBuf, Sequence, TxOut, Weight}; pub use error::{CreateRequestError, ResponseError, ValidationError}; pub(crate) use error::{InternalCreateRequestError, InternalValidationError}; #[cfg(feature = "v2")] @@ -39,7 +39,7 @@ use serde::{Deserialize, Serialize}; use url::Url; use crate::input_type::InputType; -use crate::psbt::PsbtExt; +use crate::psbt::{InputPair, PsbtExt}; use crate::request::Request; #[cfg(feature = "v2")] use crate::v2::{HpkePublicKey, HpkeSecretKey}; @@ -129,23 +129,18 @@ impl<'a> RequestBuilder<'a> { .find(|(_, txo)| payout_scripts.all(|script| script != txo.script_pubkey)) .map(|(i, txo)| (i, txo.value)) { - let input_types = self - .psbt - .input_pairs() - .map(|input| { - let txo = - input.previous_txout().map_err(InternalCreateRequestError::PrevTxOut)?; - InputType::from_spent_input(txo, input.psbtin) - .map_err(InternalCreateRequestError::InputType) - }) - .collect::, InternalCreateRequestError>>()?; - - let first_type = input_types.first().ok_or(InternalCreateRequestError::NoInputs)?; + let input_pairs = self.psbt.input_pairs().collect::>(); + + let first_input_pair = + input_pairs.first().ok_or(InternalCreateRequestError::NoInputs)?; // use cheapest default if mixed input types let mut input_vsize = InputType::Taproot.expected_input_weight(); // Check if all inputs are the same type - if input_types.iter().all(|input_type| input_type == first_type) { - input_vsize = first_type.expected_input_weight(); + if input_pairs + .iter() + .all(|input_pair| input_pair.address_type() == first_input_pair.address_type()) + { + input_vsize = first_input_pair.expected_input_weight(); } let recommended_additional_fee = min_fee_rate * input_vsize; @@ -232,9 +227,7 @@ impl<'a> RequestBuilder<'a> { let zeroth_input = psbt.input_pairs().next().ok_or(InternalCreateRequestError::NoInputs)?; let sequence = zeroth_input.txin.sequence; - let txout = zeroth_input.previous_txout().map_err(InternalCreateRequestError::PrevTxOut)?; - let input_type = InputType::from_spent_input(txout, zeroth_input.psbtin) - .map_err(InternalCreateRequestError::InputType)?; + let input_type = zeroth_input.address_type().to_string(); Ok(RequestContext { psbt, @@ -259,7 +252,7 @@ pub struct RequestContext { disable_output_substitution: bool, fee_contribution: Option<(bitcoin::Amount, usize)>, min_fee_rate: FeeRate, - input_type: InputType, + input_type: String, sequence: Sequence, payee: ScriptBuf, #[cfg(feature = "v2")] @@ -285,7 +278,7 @@ impl RequestContext { disable_output_substitution: self.disable_output_substitution, fee_contribution: self.fee_contribution, payee: self.payee.clone(), - input_type: self.input_type, + input_type: AddressType::from_str(&self.input_type).expect("Unknown address type"), sequence: self.sequence, min_fee_rate: self.min_fee_rate, }, @@ -355,7 +348,8 @@ impl RequestContext { disable_output_substitution: self.disable_output_substitution, fee_contribution: self.fee_contribution, payee: self.payee.clone(), - input_type: self.input_type, + input_type: AddressType::from_str(&self.input_type) + .expect("Unknown address type"), sequence: self.sequence, min_fee_rate: self.min_fee_rate, }, @@ -399,7 +393,7 @@ pub struct ContextV1 { disable_output_substitution: bool, fee_contribution: Option<(bitcoin::Amount, usize)>, min_fee_rate: FeeRate, - input_type: InputType, + input_type: AddressType, sequence: Sequence, payee: ScriptBuf, } @@ -503,7 +497,7 @@ impl ContextV1 { ensure!( contributed_fee <= original_fee_rate - * self.input_type.expected_input_weight() + * self.original_psbt.input_pairs().next().unwrap().expected_input_weight() * (proposal.inputs.len() - self.original_psbt.inputs.len()) as u64, FeeContributionPaysOutputSizeIncrease ); @@ -575,14 +569,7 @@ impl ContextV1 { ReceiverTxinMissingUtxoInfo ); ensure!(proposed.txin.sequence == self.sequence, MixedSequence); - let txout = proposed - .previous_txout() - .map_err(InternalValidationError::InvalidProposedInput)?; - check_eq!( - InputType::from_spent_input(txout, proposed.psbtin)?, - self.input_type, - MixedInputTypes - ); + check_eq!(proposed.address_type(), self.input_type, MixedInputTypes); } } } @@ -860,7 +847,7 @@ mod test { fee_contribution: Some((bitcoin::Amount::from_sat(182), 0)), min_fee_rate: FeeRate::ZERO, payee, - input_type: InputType::SegWitV0 { ty: SegWitV0Type::Pubkey, nested: true }, + input_type: bitcoin::AddressType::P2sh, sequence, }; ctx @@ -916,10 +903,7 @@ mod test { disable_output_substitution: false, fee_contribution: None, min_fee_rate: FeeRate::ZERO, - input_type: InputType::SegWitV0 { - ty: crate::input_type::SegWitV0Type::Pubkey, - nested: true, - }, + input_type: bitcoin::AddressType::P2sh.to_string(), sequence: Sequence::MAX, payee: ScriptBuf::from(vec![0x00]), e: HpkeSecretKey(