Skip to content

Commit

Permalink
Compute input type and weight on InputPair
Browse files Browse the repository at this point in the history
bitcoin::AddressType can act as a substitute for InputType, and we can
use that to compute expected input weights. InputPair contains all the
context necessary to derive those properties, so input_type.rs is
obsolete.
  • Loading branch information
spacebear21 committed Oct 3, 2024
1 parent 2aab905 commit 84f121e
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 55 deletions.
50 changes: 49 additions & 1 deletion payjoin/src/psbt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
use std::collections::BTreeMap;
use std::fmt;

use bitcoin::blockdata::script::Instruction;
use bitcoin::psbt::Psbt;
use bitcoin::{bip32, psbt, TxIn, TxOut};
use bitcoin::transaction::InputWeightPrediction;
use bitcoin::{bip32, psbt, Address, AddressType, Network, Script, TxIn, TxOut, Weight};

#[derive(Debug)]
pub(crate) enum InconsistentPsbt {
Expand Down Expand Up @@ -107,6 +109,14 @@ impl PsbtExt for Psbt {
}
}

/// Gets redeemScript from the script_sig following BIP16 rules regarding P2SH spending.
fn redeem_script(script_sig: &Script) -> Option<&Script> {
match script_sig.instructions().last()?.ok()? {
Instruction::PushBytes(bytes) => Some(Script::from_bytes(bytes.as_bytes())),
Instruction::Op(_) => None,
}
}

pub(crate) struct InputPair<'a> {
pub txin: &'a TxIn,
pub psbtin: &'a psbt::Input,
Expand Down Expand Up @@ -180,6 +190,44 @@ impl<'a> InputPair<'a> {
(Some(_), Some(_)) => Err(PsbtInputError::UnequalTxid),
}
}

pub fn address_type(&self) -> AddressType {
let txo = self.previous_txout().expect("PrevTxoutError");
// HACK: Network doesn't matter for our use case of only getting the address type
// but is required in the `from_script` interface. Hardcoded to mainnet.
Address::from_script(&txo.script_pubkey, Network::Bitcoin)
.expect("Unrecognized script")
.address_type()
.expect("UnknownAddressType")
}

pub fn expected_input_weight(&self) -> Weight {
use bitcoin::AddressType::*;

// Get the input weight prediction corresponding to spending an output of this address type
let iwp = match self.address_type() {
P2pkh => InputWeightPrediction::P2PKH_COMPRESSED_MAX,
P2sh =>
match self.psbtin.final_script_sig.as_ref().and_then(|s| redeem_script(s.as_ref()))
{
Some(script) if script.is_witness_program() && script.is_p2wpkh() =>
// input script: 0x160014{20-byte-key-hash} = 23 bytes
// witness: <signature> <pubkey> = 72, 33 bytes
// https://github.com/bitcoin/bips/blob/master/bip-0141.mediawiki#p2wpkh-nested-in-bip16-p2sh
InputWeightPrediction::new(23, &[72, 33]),
Some(_) => unimplemented!(),
None => panic!("Input not finalized!"),
},
P2wpkh => InputWeightPrediction::P2WPKH_MAX,
P2wsh => unimplemented!(),
P2tr => InputWeightPrediction::P2TR_KEY_NON_DEFAULT_SIGHASH,
_ => panic!("Unknown address type!"),
};

// Lengths of txid, index and sequence: (32, 4, 4).
let input_weight = iwp.weight() + Weight::from_non_witness_data_size(32 + 4 + 4);
input_weight
}
}

#[derive(Debug)]
Expand Down
2 changes: 1 addition & 1 deletion payjoin/src/receive/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub(crate) enum InternalRequestError {
/// The sender is trying to spend the receiver input
InputOwned(bitcoin::ScriptBuf),
/// The original psbt has mixed input address types that could harm privacy
MixedInputScripts(crate::input_type::InputType, crate::input_type::InputType),
MixedInputScripts(bitcoin::AddressType, bitcoin::AddressType),
/// Unrecognized input type
InputType(crate::input_type::InputTypeError),
/// Original PSBT input has been seen before. Only automatic receivers, aka "interactive" in the spec
Expand Down
16 changes: 3 additions & 13 deletions payjoin/src/receive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ use error::{
};
use optional_parameters::Params;

use crate::input_type::InputType;
use crate::psbt::PsbtExt;

pub trait Headers {
Expand Down Expand Up @@ -229,14 +228,8 @@ impl MaybeMixedInputScripts {
let input_scripts = self
.psbt
.input_pairs()
.scan(&mut err, |err, input| match input.previous_txout() {
Ok(txout) => match InputType::from_spent_input(txout, input.psbtin) {
Ok(input_script) => Some(input_script),
Err(e) => {
**err = Err(RequestError::from(InternalRequestError::InputType(e)));
None
}
},
.scan(&mut err, |err, input| match Ok(input.address_type()) {
Ok(address_type) => Some(address_type),
Err(e) => {
**err = Err(RequestError::from(InternalRequestError::PrevTxOut(e)));
None
Expand Down Expand Up @@ -755,12 +748,9 @@ impl ProvisionalProposal {
.next()
.ok_or(InternalRequestError::OriginalPsbtNotBroadcastable)?;
// Calculate the additional weight contribution
let txo = input_pair.previous_txout().map_err(InternalRequestError::PrevTxOut)?;
let input_type = InputType::from_spent_input(txo, &self.payjoin_psbt.inputs[0])
.map_err(InternalRequestError::InputType)?;
let input_count = self.payjoin_psbt.inputs.len() - self.original_psbt.inputs.len();
log::trace!("input_count : {}", input_count);
let weight_per_input = input_type.expected_input_weight();
let weight_per_input = input_pair.expected_input_weight();
log::trace!("weight_per_input : {}", weight_per_input);
let contribution_weight = weight_per_input * input_count as u64;
log::trace!("contribution_weight: {}", contribution_weight);
Expand Down
6 changes: 3 additions & 3 deletions payjoin/src/send/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::fmt::{self, Display};

use bitcoin::locktime::absolute::LockTime;
use bitcoin::transaction::Version;
use bitcoin::Sequence;
use bitcoin::{AddressType, Sequence};

use crate::input_type::{InputType, InputTypeError};

Expand Down Expand Up @@ -43,8 +43,8 @@ pub(crate) enum InternalValidationError {
ReceiverTxinMissingUtxoInfo,
MixedSequence,
MixedInputTypes {
proposed: InputType,
original: InputType,
proposed: AddressType,
original: AddressType,
},
MissingOrShuffledInputs,
TxOutContainsKeyPaths,
Expand Down
58 changes: 21 additions & 37 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ use bitcoin::psbt::Psbt;
use bitcoin::secp256k1::rand;
#[cfg(feature = "v2")]
use bitcoin::secp256k1::PublicKey;
use bitcoin::{Amount, FeeRate, Script, ScriptBuf, Sequence, TxOut, Weight};
use bitcoin::{AddressType, Amount, FeeRate, Script, ScriptBuf, Sequence, TxOut, Weight};
pub use error::{CreateRequestError, ResponseError, ValidationError};
pub(crate) use error::{InternalCreateRequestError, InternalValidationError};
#[cfg(feature = "v2")]
use serde::{Deserialize, Serialize};
use url::Url;

use crate::input_type::InputType;
use crate::psbt::PsbtExt;
use crate::psbt::{InputPair, PsbtExt};
use crate::request::Request;
#[cfg(feature = "v2")]
use crate::v2::{HpkePublicKey, HpkeSecretKey};
Expand Down Expand Up @@ -129,23 +129,18 @@ impl<'a> RequestBuilder<'a> {
.find(|(_, txo)| payout_scripts.all(|script| script != txo.script_pubkey))
.map(|(i, txo)| (i, txo.value))
{
let input_types = self
.psbt
.input_pairs()
.map(|input| {
let txo =
input.previous_txout().map_err(InternalCreateRequestError::PrevTxOut)?;
InputType::from_spent_input(txo, input.psbtin)
.map_err(InternalCreateRequestError::InputType)
})
.collect::<Result<Vec<InputType>, InternalCreateRequestError>>()?;

let first_type = input_types.first().ok_or(InternalCreateRequestError::NoInputs)?;
let input_pairs = self.psbt.input_pairs().collect::<Vec<InputPair>>();

let first_input_pair =
input_pairs.first().ok_or(InternalCreateRequestError::NoInputs)?;
// use cheapest default if mixed input types
let mut input_vsize = InputType::Taproot.expected_input_weight();
// Check if all inputs are the same type
if input_types.iter().all(|input_type| input_type == first_type) {
input_vsize = first_type.expected_input_weight();
if input_pairs
.iter()
.all(|input_pair| input_pair.address_type() == first_input_pair.address_type())
{
input_vsize = first_input_pair.expected_input_weight();
}

let recommended_additional_fee = min_fee_rate * input_vsize;
Expand Down Expand Up @@ -232,9 +227,7 @@ impl<'a> RequestBuilder<'a> {
let zeroth_input = psbt.input_pairs().next().ok_or(InternalCreateRequestError::NoInputs)?;

let sequence = zeroth_input.txin.sequence;
let txout = zeroth_input.previous_txout().map_err(InternalCreateRequestError::PrevTxOut)?;
let input_type = InputType::from_spent_input(txout, zeroth_input.psbtin)
.map_err(InternalCreateRequestError::InputType)?;
let input_type = zeroth_input.address_type().to_string();

Ok(RequestContext {
psbt,
Expand All @@ -259,7 +252,7 @@ pub struct RequestContext {
disable_output_substitution: bool,
fee_contribution: Option<(bitcoin::Amount, usize)>,
min_fee_rate: FeeRate,
input_type: InputType,
input_type: String,
sequence: Sequence,
payee: ScriptBuf,
#[cfg(feature = "v2")]
Expand All @@ -285,7 +278,7 @@ impl RequestContext {
disable_output_substitution: self.disable_output_substitution,
fee_contribution: self.fee_contribution,
payee: self.payee.clone(),
input_type: self.input_type,
input_type: AddressType::from_str(&self.input_type).expect("Unknown address type"),
sequence: self.sequence,
min_fee_rate: self.min_fee_rate,
},
Expand Down Expand Up @@ -355,7 +348,8 @@ impl RequestContext {
disable_output_substitution: self.disable_output_substitution,
fee_contribution: self.fee_contribution,
payee: self.payee.clone(),
input_type: self.input_type,
input_type: AddressType::from_str(&self.input_type)
.expect("Unknown address type"),
sequence: self.sequence,
min_fee_rate: self.min_fee_rate,
},
Expand Down Expand Up @@ -399,7 +393,7 @@ pub struct ContextV1 {
disable_output_substitution: bool,
fee_contribution: Option<(bitcoin::Amount, usize)>,
min_fee_rate: FeeRate,
input_type: InputType,
input_type: AddressType,
sequence: Sequence,
payee: ScriptBuf,
}
Expand Down Expand Up @@ -503,7 +497,7 @@ impl ContextV1 {
ensure!(
contributed_fee
<= original_fee_rate
* self.input_type.expected_input_weight()
* self.original_psbt.input_pairs().next().unwrap().expected_input_weight()
* (proposal.inputs.len() - self.original_psbt.inputs.len()) as u64,
FeeContributionPaysOutputSizeIncrease
);
Expand Down Expand Up @@ -575,14 +569,7 @@ impl ContextV1 {
ReceiverTxinMissingUtxoInfo
);
ensure!(proposed.txin.sequence == self.sequence, MixedSequence);
let txout = proposed
.previous_txout()
.map_err(InternalValidationError::InvalidProposedInput)?;
check_eq!(
InputType::from_spent_input(txout, proposed.psbtin)?,
self.input_type,
MixedInputTypes
);
check_eq!(proposed.address_type(), self.input_type, MixedInputTypes);
}
}
}
Expand Down Expand Up @@ -860,7 +847,7 @@ mod test {
fee_contribution: Some((bitcoin::Amount::from_sat(182), 0)),
min_fee_rate: FeeRate::ZERO,
payee,
input_type: InputType::SegWitV0 { ty: SegWitV0Type::Pubkey, nested: true },
input_type: bitcoin::AddressType::P2sh,
sequence,
};
ctx
Expand Down Expand Up @@ -916,10 +903,7 @@ mod test {
disable_output_substitution: false,
fee_contribution: None,
min_fee_rate: FeeRate::ZERO,
input_type: InputType::SegWitV0 {
ty: crate::input_type::SegWitV0Type::Pubkey,
nested: true,
},
input_type: bitcoin::AddressType::P2sh.to_string(),
sequence: Sequence::MAX,
payee: ScriptBuf::from(vec![0x00]),
e: HpkeSecretKey(
Expand Down

0 comments on commit 84f121e

Please sign in to comment.