From 1001583517c9a106dddc9e08774288e8dcb52d60 Mon Sep 17 00:00:00 2001 From: DanGould Date: Tue, 12 Dec 2023 21:06:08 -0500 Subject: [PATCH] Persist send sessions for async pj --- .gitignore | 2 +- payjoin-cli/src/app.rs | 64 ++++++++++++- payjoin-cli/src/main.rs | 18 ++-- payjoin/src/input_type.rs | 56 +++++++++++ payjoin/src/send/mod.rs | 189 +++++++++++++++++++++++++++++++++++--- payjoin/src/v2.rs | 7 +- 6 files changed, 305 insertions(+), 31 deletions(-) diff --git a/.gitignore b/.gitignore index 72648211..9fb534de 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ target *config.toml *seen_inputs.json -*session_store.json +*_store.json diff --git a/payjoin-cli/src/app.rs b/payjoin-cli/src/app.rs index 20c366f0..5a3f662c 100644 --- a/payjoin-cli/src/app.rs +++ b/payjoin-cli/src/app.rs @@ -34,6 +34,7 @@ const LOCAL_CERT_FILE: &str = "localhost.der"; pub(crate) struct App { config: AppConfig, receive_store: Arc>, + send_store: Arc>, seen_inputs: Arc>, } @@ -41,7 +42,8 @@ impl App { pub fn new(config: AppConfig) -> Result { let seen_inputs = Arc::new(Mutex::new(SeenInputs::new()?)); let receive_store = Arc::new(Mutex::new(ReceiveStore::new()?)); - Ok(Self { config, receive_store, seen_inputs }) + let send_store = Arc::new(Mutex::new(SendStore::new()?)); + Ok(Self { config, receive_store, send_store, seen_inputs }) } pub fn bitcoind(&self) -> Result { @@ -62,8 +64,18 @@ impl App { } #[cfg(feature = "v2")] - pub async fn send_payjoin(&self, bip21: &str, fee_rate: &f32) -> Result<()> { - let req_ctx = self.create_pj_request(bip21, fee_rate)?; + pub async fn send_payjoin(&self, bip21: &str, fee_rate: &f32, is_retry: bool) -> Result<()> { + let mut session = self.send_store.lock().expect("mutex lock failed"); + let req_ctx = if is_retry { + log::debug!("Resuming session"); + // Get a reference to RequestContext + session.req_ctx.as_ref().expect("RequestContext is missing") + } else { + let req_ctx = self.create_pj_request(bip21, fee_rate)?; + session.write(req_ctx)?; + log::debug!("Writing req_ctx"); + session.req_ctx.as_ref().expect("RequestContext is missing") + }; log::debug!("Awaiting response"); let res = self.long_poll_post(req_ctx).await?; self.process_pj_response(res)?; @@ -173,7 +185,7 @@ impl App { } #[cfg(feature = "v2")] - async fn long_poll_post(&self, req_ctx: payjoin::send::RequestContext<'_>) -> Result { + async fn long_poll_post(&self, req_ctx: &payjoin::send::RequestContext) -> Result { loop { let (req, ctx, ohttp) = req_ctx.extract_v2(&self.config.ohttp_proxy)?; println!("Sending fallback request to {}", &req.url); @@ -221,7 +233,7 @@ impl App { } } - fn create_pj_request<'a>(&self, bip21: &'a str, fee_rate: &f32) -> Result> { + fn create_pj_request<'a>(&self, bip21: &'a str, fee_rate: &f32) -> Result { let uri = payjoin::Uri::try_from(bip21) .map_err(|e| anyhow!("Failed to create URI from BIP21: {}", e))?; @@ -288,6 +300,7 @@ impl App { .bitcoind()? .send_raw_transaction(&tx) .with_context(|| "Failed to send raw transaction")?; + self.send_store.lock().expect("mutex lock failed").clear()?; println!("Payjoin sent: {}", txid); Ok(txid) } @@ -631,12 +644,53 @@ impl App { } } +#[cfg(feature = "v2")] +struct SendStore { + req_ctx: Option, + file: std::fs::File, +} + +impl SendStore { + fn new() -> Result { + let mut file = + OpenOptions::new().write(true).read(true).create(true).open("send_store.json")?; + let session = match serde_json::from_reader(&mut file) { + Ok(session) => Some(session), + Err(e) => { + log::debug!("error reading send session store: {}", e); + None + } + }; + + Ok(Self { req_ctx: session, file }) + } + + fn write( + &mut self, + session: payjoin::send::RequestContext, + ) -> Result<&mut payjoin::send::RequestContext> { + use std::io::Write; + + let session = self.req_ctx.insert(session); + let serialized = serde_json::to_string(session)?; + self.file.write_all(serialized.as_bytes())?; + Ok(session) + } + + fn clear(&mut self) -> Result<()> { + let file = OpenOptions::new().write(true).open("send_store.json")?; + file.set_len(0)?; + Ok(()) + } +} + #[cfg(feature = "v2")] struct ReceiveStore { session: Option, file: std::fs::File, } +#[cfg(feature = "v2")] impl ReceiveStore { fn new() -> Result { let mut file = diff --git a/payjoin-cli/src/main.rs b/payjoin-cli/src/main.rs index 46db3996..0cb72fba 100644 --- a/payjoin-cli/src/main.rs +++ b/payjoin-cli/src/main.rs @@ -17,14 +17,18 @@ async fn main() -> Result<()> { let bip21 = sub_matches.get_one::("BIP21").context("Missing BIP21 argument")?; let fee_rate_sat_per_vb = sub_matches.get_one::("fee_rate").context("Missing --fee-rate argument")?; + #[cfg(feature = "v2")] + let is_retry = matches.get_one::("retry").context("Could not read --retry")?; + #[cfg(feature = "v2")] + app.send_payjoin(bip21, fee_rate_sat_per_vb, *is_retry).await?; + #[cfg(not(feature = "v2"))] app.send_payjoin(bip21, fee_rate_sat_per_vb).await?; } Some(("receive", sub_matches)) => { let amount = sub_matches.get_one::("AMOUNT").context("Missing AMOUNT argument")?; #[cfg(feature = "v2")] - let is_retry = - sub_matches.get_one::("retry").context("Could not read --retry")?; + let is_retry = matches.get_one::("retry").context("Could not read --retry")?; #[cfg(feature = "v2")] app.receive_payjoin(amount, *is_retry).await?; #[cfg(not(feature = "v2"))] @@ -64,6 +68,11 @@ fn cli() -> ArgMatches { .arg(Arg::new("ohttp_proxy") .long("ohttp-proxy") .help("The ohttp proxy url")) + .arg(Arg::new("retry") + .long("retry") + .short('e') + .action(clap::ArgAction::SetTrue) + .help("Retry the asynchronous payjoin request if it did not yet complete")) .subcommand( Command::new("send") .arg_required_else_help(true) @@ -91,11 +100,6 @@ fn cli() -> ArgMatches { .short('e') .takes_value(true) .help("The `pj=` endpoint to receive the payjoin request")) - .arg(Arg::new("retry") - .long("retry") - .short('r') - .action(clap::ArgAction::SetTrue) - .help("Retry the asynchronous payjoin request if it did not yet complete")) .arg(Arg::new("sub_only") .long("sub-only") .short('s') diff --git a/payjoin/src/input_type.rs b/payjoin/src/input_type.rs index e85b5b8a..8601c05c 100644 --- a/payjoin/src/input_type.rs +++ b/payjoin/src/input_type.rs @@ -22,6 +22,62 @@ pub(crate) enum InputType { Taproot, } +#[cfg(feature = "v2")] +impl serde::Serialize for InputType { + fn serialize(&self, serializer: S) -> Result { + use InputType::*; + + match self { + P2Pk => serializer.serialize_str("P2PK"), + P2Pkh => serializer.serialize_str("P2PKH"), + P2Sh => serializer.serialize_str("P2SH"), + SegWitV0 { ty, nested } => + serializer.serialize_str(&format!("SegWitV0: type={}, nested={}", ty, nested)), + Taproot => serializer.serialize_str("Taproot"), + } + } +} + +impl<'de> serde::Deserialize<'de> for InputType { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use InputType::*; + + let s = String::deserialize(deserializer)?; + if s.starts_with("SegWitV0: ") { + let rest = &s["SegWitV0: ".len()..]; + let parts: Vec<&str> = rest.split(", ").collect(); + if parts.len() != 2 { + return Err(serde::de::Error::custom("invalid format for SegWitV0")); + } + log::debug!("parts: {:?}", parts); + let ty = match parts[0].strip_prefix("type=") { + Some("pubkey") => SegWitV0Type::Pubkey, + Some("script") => SegWitV0Type::Script, + _ => return Err(serde::de::Error::custom("invalid SegWitV0 type")), + }; + + let nested = match parts[1].strip_prefix("nested=") { + Some("true") => true, + Some("false") => false, + _ => return Err(serde::de::Error::custom("invalid SegWitV0 nested value")), + }; + + Ok(SegWitV0 { ty, nested }) + } else { + match s.as_str() { + "P2PK" => Ok(P2Pk), + "P2PKH" => Ok(P2Pkh), + "P2SH" => Ok(P2Sh), + "Taproot" => Ok(Taproot), + _ => Err(serde::de::Error::custom("invalid type")), + } + } + } +} + impl InputType { pub(crate) fn from_spent_input( txout: &TxOut, diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index f3e7bac3..61511555 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -142,6 +142,7 @@ use bitcoin::psbt::Psbt; use bitcoin::{FeeRate, Script, ScriptBuf, Sequence, TxOut, Weight}; pub use error::{CreateRequestError, ValidationError}; pub(crate) use error::{InternalCreateRequestError, InternalValidationError}; +use serde::ser::SerializeStruct; use url::Url; use crate::input_type::InputType; @@ -211,7 +212,7 @@ impl<'a> RequestBuilder<'a> { pub fn build_recommended( self, min_fee_rate: FeeRate, - ) -> Result, CreateRequestError> { + ) -> Result { // TODO support optional batched payout scripts. This would require a change to // build() which now checks for a single payee. let mut payout_scripts = std::iter::once(self.uri.address.script_pubkey()); @@ -283,7 +284,7 @@ impl<'a> RequestBuilder<'a> { change_index: Option, min_fee_rate: FeeRate, clamp_fee_contribution: bool, - ) -> Result, CreateRequestError> { + ) -> Result { self.fee_contribution = Some((max_fee_contribution, change_index)); self.clamp_fee_contribution = clamp_fee_contribution; self.min_fee_rate = min_fee_rate; @@ -294,7 +295,7 @@ impl<'a> RequestBuilder<'a> { /// /// While it's generally better to offer some contribution some users may wish not to. /// This function disables contribution. - pub fn build_non_incentivizing(mut self) -> Result, CreateRequestError> { + pub fn build_non_incentivizing(mut self) -> Result { // since this is a builder, these should already be cleared // but we'll reset them to be sure self.fee_contribution = None; @@ -303,11 +304,13 @@ impl<'a> RequestBuilder<'a> { self.build() } - fn build(self) -> Result, CreateRequestError> { + fn build(self) -> Result { let mut psbt = self.psbt.validate().map_err(InternalCreateRequestError::InconsistentOriginalPsbt)?; psbt.validate_input_utxos(true) .map_err(InternalCreateRequestError::InvalidOriginalInput)?; + let endpoint = self.uri.extras._endpoint.clone(); + let ohttp_config = self.uri.extras.ohttp_config; let disable_output_substitution = self.uri.extras.disable_output_substitution || self.disable_output_substitution; let payee = self.uri.address.script_pubkey(); @@ -327,35 +330,50 @@ impl<'a> RequestBuilder<'a> { let txout = zeroth_input.previous_txout().expect("We already checked this above"); let input_type = InputType::from_spent_input(txout, zeroth_input.psbtin).unwrap(); + #[cfg(feature = "v2")] + let e = { + let secp = bitcoin::secp256k1::Secp256k1::new(); + let (e_sec, _) = secp.generate_keypair(&mut rand::rngs::OsRng); + e_sec + }; + Ok(RequestContext { psbt, - uri: self.uri, + endpoint, + #[cfg(feature = "v2")] + ohttp_config, disable_output_substitution, fee_contribution, payee, input_type, sequence, min_fee_rate: self.min_fee_rate, + #[cfg(feature = "v2")] + e, }) } } -pub struct RequestContext<'a> { +pub struct RequestContext { psbt: Psbt, - uri: PjUri<'a>, + endpoint: Url, + #[cfg(feature = "v2")] + ohttp_config: Option, disable_output_substitution: bool, fee_contribution: Option<(bitcoin::Amount, usize)>, min_fee_rate: FeeRate, input_type: InputType, sequence: Sequence, payee: ScriptBuf, + #[cfg(feature = "v2")] + e: bitcoin::secp256k1::SecretKey, } -impl<'a> RequestContext<'a> { +impl RequestContext { /// Extract serialized V1 Request and Context froma Payjoin Proposal pub fn extract_v1(self) -> Result<(Request, ContextV1), CreateRequestError> { let url = serialize_url( - self.uri.extras._endpoint.into(), + self.endpoint.into(), self.disable_output_substitution, self.fee_contribution, self.min_fee_rate, @@ -387,7 +405,7 @@ impl<'a> RequestContext<'a> { &self, ohttp_proxy_url: &str, ) -> Result<(Request, ContextV2, ohttp::ClientResponse), CreateRequestError> { - let rs_base64 = crate::v2::subdir(self.uri.extras._endpoint.as_str()).to_string(); + let rs_base64 = crate::v2::subdir(self.endpoint.as_str()).to_string(); log::debug!("rs_base64: {:?}", rs_base64); let b64_config = bitcoin::base64::Config::new(bitcoin::base64::CharacterSet::UrlSafe, false); @@ -395,17 +413,17 @@ impl<'a> RequestContext<'a> { log::debug!("rs: {:?}", rs.len()); let rs = bitcoin::secp256k1::PublicKey::from_slice(&rs).unwrap(); - let url = self.uri.extras._endpoint.clone(); + let url = self.endpoint.clone(); let body = serialize_v2_body( &self.psbt, self.disable_output_substitution, self.fee_contribution, self.min_fee_rate, )?; - let (body, e) = - crate::v2::encrypt_message_a(body, rs).map_err(InternalCreateRequestError::V2)?; + let body = crate::v2::encrypt_message_a(body, self.e, rs) + .map_err(InternalCreateRequestError::V2)?; let (body, ohttp_res) = crate::v2::ohttp_encapsulate( - &self.uri.extras.ohttp_config.as_ref().unwrap().encode().unwrap(), + &self.ohttp_config.as_ref().unwrap().encode().unwrap(), "POST", url.as_str(), Some(&body), @@ -426,13 +444,154 @@ impl<'a> RequestContext<'a> { sequence: self.sequence, min_fee_rate: self.min_fee_rate, }, - e, + e: self.e, }, ohttp_res, )) } } +#[cfg(feature = "v2")] +impl serde::Serialize for RequestContext { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut state = serializer.serialize_struct("RequestContext", 8)?; + state.serialize_field("psbt", &self.psbt.to_string())?; + state.serialize_field("endpoint", &self.endpoint.as_str())?; + let ohttp_string = self + .ohttp_config + .as_ref() + .map_or("".to_string(), |config| bitcoin::base64::encode(config.encode().unwrap())); + state.serialize_field("ohttp_config", &ohttp_string)?; + state.serialize_field("disable_output_substitution", &self.disable_output_substitution)?; + state.serialize_field( + "fee_contribution", + &self.fee_contribution.as_ref().map(|(amount, index)| (amount.to_sat(), *index)), + )?; + state.serialize_field("min_fee_rate", &self.min_fee_rate)?; + state.serialize_field("input_type", &self.input_type)?; + state.serialize_field("sequence", &self.sequence)?; + state.serialize_field("payee", &self.payee)?; + state.serialize_field("e", &self.e.secret_bytes())?; + state.end() + } +} + +use serde::de::{self, MapAccess, Visitor}; +use serde::{Deserialize, Deserializer}; + +impl<'de> Deserialize<'de> for RequestContext { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct RequestContextVisitor; + + const FIELDS: &'static [&'static str] = &[ + "psbt", + "endpoint", + "ohttp_config", + "disable_output_substitution", + "fee_contribution", + "min_fee_rate", + "input_type", + "sequence", + "payee", + "e", + ]; + + impl<'de> Visitor<'de> for RequestContextVisitor { + type Value = RequestContext; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("struct RequestContext") + } + + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + let mut psbt = None; + let mut endpoint = None; + let mut ohttp_config = None; + let mut disable_output_substitution = None; + let mut fee_contribution = None; + let mut min_fee_rate = None; + let mut input_type = None; + let mut sequence = None; + let mut payee = None; + let mut e = None; + + while let Some(key) = map.next_key::()? { + match key.as_str() { + "psbt" => { + let buf: String = map.next_value::()?; + psbt = Some(Psbt::from_str(&buf).map_err(de::Error::custom)?); + } + "endpoint" => + endpoint = Some( + url::Url::from_str(&map.next_value::()?) + .map_err(de::Error::custom)?, + ), + "ohttp_config" => { + let ohttp_base64: String = map.next_value()?; + ohttp_config = if ohttp_base64.is_empty() { + None + } else { + Some( + ohttp::KeyConfig::decode( + bitcoin::base64::decode(&ohttp_base64) + .map_err(de::Error::custom)? + .as_slice(), + ) + .map_err(de::Error::custom)?, + ) + }; + } + "disable_output_substitution" => + disable_output_substitution = Some(map.next_value()?), + "fee_contribution" => { + let fc: Option<(u64, usize)> = map.next_value()?; + fee_contribution = fc + .map(|(amount, index)| (bitcoin::Amount::from_sat(amount), index)); + } + "min_fee_rate" => min_fee_rate = Some(map.next_value()?), + "input_type" => input_type = Some(map.next_value()?), + "sequence" => sequence = Some(map.next_value()?), + "payee" => payee = Some(map.next_value()?), + "e" => { + let secret_bytes: Vec = map.next_value()?; + e = Some( + bitcoin::secp256k1::SecretKey::from_slice(&secret_bytes) + .map_err(de::Error::custom)?, + ); + } + _ => return Err(de::Error::unknown_field(key.as_str(), FIELDS)), + } + } + + Ok(RequestContext { + psbt: psbt.ok_or_else(|| de::Error::missing_field("psbt"))?, + endpoint: endpoint.ok_or_else(|| de::Error::missing_field("endpoint"))?, + ohttp_config, + disable_output_substitution: disable_output_substitution + .ok_or_else(|| de::Error::missing_field("disable_output_substitution"))?, + fee_contribution, + min_fee_rate: min_fee_rate + .ok_or_else(|| de::Error::missing_field("min_fee_rate"))?, + input_type: input_type.ok_or_else(|| de::Error::missing_field("input_type"))?, + sequence: sequence.ok_or_else(|| de::Error::missing_field("sequence"))?, + payee: payee.ok_or_else(|| de::Error::missing_field("payee"))?, + e: e.ok_or_else(|| de::Error::missing_field("e"))?, + }) + } + } + + deserializer.deserialize_struct("RequestContext", FIELDS, RequestContextVisitor) + } +} /// Represents data that needs to be transmitted to the receiver. /// /// You need to send this request over HTTP(S) to the receiver. diff --git a/payjoin/src/v2.rs b/payjoin/src/v2.rs index 5f4fe54e..0401750c 100644 --- a/payjoin/src/v2.rs +++ b/payjoin/src/v2.rs @@ -34,10 +34,11 @@ use chacha20poly1305::{AeadCore, ChaCha20Poly1305, Nonce}; /// <- Receiver E, EE(payload), payload protected by knowledge of sender & receiver key pub fn encrypt_message_a( mut raw_msg: Vec, + e_sec: SecretKey, s: PublicKey, -) -> Result<(Vec, SecretKey), Error> { +) -> Result, Error> { let secp = Secp256k1::new(); - let (e_sec, e_pub) = secp.generate_keypair(&mut OsRng); + let e_pub = e_sec.public_key(&secp); let es = SharedSecret::new(&s, &e_sec); let cipher = ChaCha20Poly1305::new_from_slice(&es.secret_bytes()) .map_err(|_| InternalError::InvalidKeyLength)?; @@ -49,7 +50,7 @@ pub fn encrypt_message_a( let mut message_a = e_pub.serialize().to_vec(); message_a.extend(&nonce[..]); message_a.extend(&c_t[..]); - Ok((message_a, e_sec)) + Ok(message_a) } pub fn decrypt_message_a(message_a: &[u8], s: SecretKey) -> Result<(Vec, PublicKey), Error> {