diff --git a/payjoin-cli/src/app/v2.rs b/payjoin-cli/src/app/v2.rs index 3dce6a98..5f6dc11d 100644 --- a/payjoin-cli/src/app/v2.rs +++ b/payjoin-cli/src/app/v2.rs @@ -197,19 +197,18 @@ impl App { } async fn long_poll_post(&self, req_ctx: &mut payjoin::send::Sender) -> Result { - let (req, ctx) = req_ctx.extract_highest_version(self.config.ohttp_relay.clone())?; - println!("Posting Original PSBT Payload request..."); - let http = http_agent()?; - let response = http - .post(req.url) - .header("Content-Type", req.content_type) - .body(req.body) - .send() - .await - .map_err(map_reqwest_err)?; - println!("Sent fallback transaction"); - match ctx { - payjoin::send::Context::V2(ctx) => { + match req_ctx.extract_v2(self.config.ohttp_relay.clone()) { + Ok((req, ctx)) => { + println!("Posting Original PSBT Payload request..."); + let http = http_agent()?; + let response = http + .post(req.url) + .header("Content-Type", req.content_type) + .body(req.body) + .send() + .await + .map_err(map_reqwest_err)?; + println!("Sent fallback transaction"); let v2_ctx = Arc::new( ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?, ); @@ -239,8 +238,19 @@ impl App { } } } - payjoin::send::Context::V1(ctx) => { - match ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) { + Err(_) => { + let (req, v1_ctx) = req_ctx.extract_v1()?; + println!("Posting Original PSBT Payload request..."); + let http = http_agent()?; + let response = http + .post(req.url) + .header("Content-Type", req.content_type) + .body(req.body) + .send() + .await + .map_err(map_reqwest_err)?; + println!("Sent fallback transaction"); + match v1_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) { Ok(psbt) => Ok(psbt), Err(re) => { println!("{}", re); diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 8863ea86..8400d90b 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -268,38 +268,6 @@ impl Sender { )) } - /// Extract serialized Request and Context from a Payjoin Proposal. Automatically selects the correct version. - /// - /// In order to support polling, this may need to be called many times to be encrypted with - /// new unique nonces to make independent OHTTP requests. - /// - /// The `ohttp_relay` merely passes the encrypted payload to the ohttp gateway of the receiver - #[cfg(feature = "v2")] - pub fn extract_highest_version( - &mut self, - ohttp_relay: Url, - ) -> Result<(Request, Context), CreateRequestError> { - use crate::uri::UrlExt; - - if let Some(expiry) = self.endpoint.exp() { - if std::time::SystemTime::now() > expiry { - return Err(InternalCreateRequestError::Expired(expiry).into()); - } - } - - match self.extract_rs_pubkey() { - Ok(_rs) => { - let (req, context_v2) = self.extract_v2(ohttp_relay)?; - Ok((req, Context::V2(context_v2))) - } - Err(e) => { - log::warn!("Failed to extract `rs` pubkey, falling back to v1: {}", e); - let (req, context_v1) = self.extract_v1()?; - Ok((req, Context::V1(context_v1))) - } - } - } - /// Extract serialized Request and Context from a Payjoin Proposal. /// /// This method requires the `rs` pubkey to be extracted from the endpoint @@ -310,6 +278,11 @@ impl Sender { ohttp_relay: Url, ) -> Result<(Request, V2PostContext), CreateRequestError> { use crate::uri::UrlExt; + if let Some(expiry) = self.endpoint.exp() { + if std::time::SystemTime::now() > expiry { + return Err(InternalCreateRequestError::Expired(expiry).into()); + } + } let rs = self.extract_rs_pubkey()?; let url = self.endpoint.clone(); let body = serialize_v2_body( @@ -369,12 +342,6 @@ impl Sender { pub fn endpoint(&self) -> &Url { &self.endpoint } } -pub enum Context { - V1(V1Context), - #[cfg(feature = "v2")] - V2(V2PostContext), -} - #[derive(Debug, Clone)] pub struct V1Context { psbt_context: PsbtContext, diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index e4701792..82eb4339 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -180,7 +180,6 @@ mod integration { use bitcoin::Address; use http::StatusCode; use payjoin::receive::v2::{PayjoinProposal, Receiver, UncheckedProposal}; - use payjoin::send::Context; use payjoin::{OhttpKeys, PjUri, UriExt}; use reqwest::{Client, ClientBuilder, Error, Response}; use testcontainers_modules::redis::Redis; @@ -285,9 +284,9 @@ mod integration { Some(std::time::SystemTime::now()), ) .build(); - let mut expired_req_ctx = SenderBuilder::from_psbt_and_uri(psbt, expired_pj_uri)? + let expired_req_ctx = SenderBuilder::from_psbt_and_uri(psbt, expired_pj_uri)? .build_non_incentivizing(FeeRate::BROADCAST_MIN)?; - match expired_req_ctx.extract_highest_version(directory.to_owned()) { + match expired_req_ctx.extract_v2(directory.to_owned()) { // Internal error types are private, so check against a string Err(err) => assert!(err.to_string().contains("expired")), _ => assert!(false, "Expired send session should error"), @@ -355,14 +354,10 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_sweep_psbt(&sender, &pj_uri)?; - let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? + let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(FeeRate::BROADCAST_MIN)?; let (Request { url, body, content_type, .. }, send_ctx) = - req_ctx.extract_highest_version(directory.to_owned())?; - let send_ctx = match send_ctx { - Context::V2(ctx) => ctx, - _ => panic!("V2 context expected"), - }; + req_ctx.extract_v2(directory.to_owned())?; let response = agent .post(url.clone()) .header("Content-Type", content_type) @@ -521,10 +516,10 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_sweep_psbt(&sender, &pj_uri)?; - let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? + let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(FeeRate::BROADCAST_MIN)?; let (Request { url, body, content_type, .. }, post_ctx) = - req_ctx.extract_highest_version(directory.to_owned())?; + req_ctx.extract_v2(directory.to_owned())?; let response = agent .post(url.clone()) .header("Content-Type", content_type) @@ -534,11 +529,8 @@ mod integration { .unwrap(); log::info!("Response: {:#?}", &response); assert!(response.status().is_success()); - let get_ctx = match post_ctx { - Context::V2(ctx) => - ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?, - _ => panic!("V2 context expected"), - }; + let get_ctx = + post_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?; let (Request { url, body, content_type, .. }, ohttp_ctx) = get_ctx.extract_req(directory.to_owned())?; let response = agent @@ -622,9 +614,9 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_original_psbt(&sender, &pj_uri)?; - let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? + let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(FeeRate::BROADCAST_MIN)?; - let (req, ctx) = req_ctx.extract_highest_version(EXAMPLE_URL.to_owned())?; + let (req, ctx) = req_ctx.extract_v1()?; let headers = HeaderMock::new(&req.body, req.content_type); // ********************** @@ -636,10 +628,6 @@ mod integration { // ********************** // Inside the Sender: // Sender checks, signs, finalizes, extracts, and broadcasts - let ctx = match ctx { - Context::V1(ctx) => ctx, - _ => panic!("V1 context expected"), - }; let checked_payjoin_proposal_psbt = ctx.process_response(&mut response.as_bytes())?; let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?; sender.send_raw_transaction(&payjoin_tx)?;