diff --git a/payjoin-cli/src/app/v2.rs b/payjoin-cli/src/app/v2.rs index 3dce6a98..9e4d23ea 100644 --- a/payjoin-cli/src/app/v2.rs +++ b/payjoin-cli/src/app/v2.rs @@ -137,14 +137,7 @@ impl App { .extract_v2_req() .map_err(|e| anyhow!("v2 req extraction failed {}", e))?; println!("Got a request from the sender. Responding with a Payjoin proposal."); - let http = http_agent()?; - let res = http - .post(req.url) - .header("Content-Type", req.content_type) - .body(req.body) - .send() - .await - .map_err(map_reqwest_err)?; + let res = post_request(req).await?; payjoin_proposal .process_res(res.bytes().await?.to_vec(), ohttp_ctx) .map_err(|e| anyhow!("Failed to deserialize response {}", e))?; @@ -197,31 +190,17 @@ impl App { } async fn long_poll_post(&self, req_ctx: &mut payjoin::send::Sender) -> Result { - let (req, ctx) = req_ctx.extract_highest_version(self.config.ohttp_relay.clone())?; - println!("Posting Original PSBT Payload request..."); - let http = http_agent()?; - let response = http - .post(req.url) - .header("Content-Type", req.content_type) - .body(req.body) - .send() - .await - .map_err(map_reqwest_err)?; - println!("Sent fallback transaction"); - match ctx { - payjoin::send::Context::V2(ctx) => { + match req_ctx.extract_v2(self.config.ohttp_relay.clone()) { + Ok((req, ctx)) => { + println!("Posting Original PSBT Payload request..."); + let response = post_request(req).await?; + println!("Sent fallback transaction"); let v2_ctx = Arc::new( ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?, ); loop { let (req, ohttp_ctx) = v2_ctx.extract_req(self.config.ohttp_relay.clone())?; - let response = http - .post(req.url) - .header("Content-Type", req.content_type) - .body(req.body) - .send() - .await - .map_err(map_reqwest_err)?; + let response = post_request(req).await?; match v2_ctx.process_response( &mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx, @@ -239,8 +218,12 @@ impl App { } } } - payjoin::send::Context::V1(ctx) => { - match ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) { + Err(_) => { + let (req, v1_ctx) = req_ctx.extract_v1()?; + println!("Posting Original PSBT Payload request..."); + let response = post_request(req).await?; + println!("Sent fallback transaction"); + match v1_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) { Ok(psbt) => Ok(psbt), Err(re) => { println!("{}", re); @@ -259,15 +242,7 @@ impl App { loop { let (req, context) = session.extract_req()?; println!("Polling receive request..."); - let http = http_agent()?; - let ohttp_response = http - .post(req.url) - .header("Content-Type", req.content_type) - .body(req.body) - .send() - .await - .map_err(map_reqwest_err)?; - + let ohttp_response = post_request(req).await?; let proposal = session .process_res(ohttp_response.bytes().await?.to_vec().as_slice(), context) .map_err(|_| anyhow!("GET fallback failed"))?; @@ -407,6 +382,16 @@ async fn handle_interrupt(tx: watch::Sender<()>) { let _ = tx.send(()); } +async fn post_request(req: payjoin::Request) -> Result { + let http = http_agent()?; + http.post(req.url) + .header("Content-Type", req.content_type) + .body(req.body) + .send() + .await + .map_err(map_reqwest_err) +} + fn map_reqwest_err(e: reqwest::Error) -> anyhow::Error { match e.status() { Some(status_code) => anyhow!("HTTP request failed: {} {}", status_code, e), diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 28e8b148..8400d90b 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -268,46 +268,22 @@ impl Sender { )) } - /// Extract serialized Request and Context from a Payjoin Proposal. Automatically selects the correct version. - /// - /// In order to support polling, this may need to be called many times to be encrypted with - /// new unique nonces to make independent OHTTP requests. + /// Extract serialized Request and Context from a Payjoin Proposal. /// - /// The `ohttp_relay` merely passes the encrypted payload to the ohttp gateway of the receiver + /// This method requires the `rs` pubkey to be extracted from the endpoint + /// and has no fallback to v1. #[cfg(feature = "v2")] - pub fn extract_highest_version( - &mut self, + pub fn extract_v2( + &self, ohttp_relay: Url, - ) -> Result<(Request, Context), CreateRequestError> { + ) -> Result<(Request, V2PostContext), CreateRequestError> { use crate::uri::UrlExt; - if let Some(expiry) = self.endpoint.exp() { if std::time::SystemTime::now() > expiry { return Err(InternalCreateRequestError::Expired(expiry).into()); } } - - match self.extract_rs_pubkey() { - Ok(rs) => self.extract_v2(ohttp_relay, rs), - Err(e) => { - log::warn!("Failed to extract `rs` pubkey, falling back to v1: {}", e); - let (req, context_v1) = self.extract_v1()?; - Ok((req, Context::V1(context_v1))) - } - } - } - - /// Extract serialized Request and Context from a Payjoin Proposal. - /// - /// This method requires the `rs` pubkey to be extracted from the endpoint - /// and has no fallback to v1. - #[cfg(feature = "v2")] - fn extract_v2( - &mut self, - ohttp_relay: Url, - rs: HpkePublicKey, - ) -> Result<(Request, Context), CreateRequestError> { - use crate::uri::UrlExt; + let rs = self.extract_rs_pubkey()?; let url = self.endpoint.clone(); let body = serialize_v2_body( &self.psbt, @@ -329,7 +305,7 @@ impl Sender { log::debug!("ohttp_relay_url: {:?}", ohttp_relay); Ok(( Request::new_v2(ohttp_relay, body), - Context::V2(V2PostContext { + V2PostContext { endpoint: self.endpoint.clone(), psbt_ctx: PsbtContext { original_psbt: self.psbt.clone(), @@ -341,7 +317,7 @@ impl Sender { }, hpke_ctx, ohttp_ctx, - }), + }, )) } @@ -366,12 +342,6 @@ impl Sender { pub fn endpoint(&self) -> &Url { &self.endpoint } } -pub enum Context { - V1(V1Context), - #[cfg(feature = "v2")] - V2(V2PostContext), -} - #[derive(Debug, Clone)] pub struct V1Context { psbt_context: PsbtContext, diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index e4701792..82eb4339 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -180,7 +180,6 @@ mod integration { use bitcoin::Address; use http::StatusCode; use payjoin::receive::v2::{PayjoinProposal, Receiver, UncheckedProposal}; - use payjoin::send::Context; use payjoin::{OhttpKeys, PjUri, UriExt}; use reqwest::{Client, ClientBuilder, Error, Response}; use testcontainers_modules::redis::Redis; @@ -285,9 +284,9 @@ mod integration { Some(std::time::SystemTime::now()), ) .build(); - let mut expired_req_ctx = SenderBuilder::from_psbt_and_uri(psbt, expired_pj_uri)? + let expired_req_ctx = SenderBuilder::from_psbt_and_uri(psbt, expired_pj_uri)? .build_non_incentivizing(FeeRate::BROADCAST_MIN)?; - match expired_req_ctx.extract_highest_version(directory.to_owned()) { + match expired_req_ctx.extract_v2(directory.to_owned()) { // Internal error types are private, so check against a string Err(err) => assert!(err.to_string().contains("expired")), _ => assert!(false, "Expired send session should error"), @@ -355,14 +354,10 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_sweep_psbt(&sender, &pj_uri)?; - let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? + let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(FeeRate::BROADCAST_MIN)?; let (Request { url, body, content_type, .. }, send_ctx) = - req_ctx.extract_highest_version(directory.to_owned())?; - let send_ctx = match send_ctx { - Context::V2(ctx) => ctx, - _ => panic!("V2 context expected"), - }; + req_ctx.extract_v2(directory.to_owned())?; let response = agent .post(url.clone()) .header("Content-Type", content_type) @@ -521,10 +516,10 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_sweep_psbt(&sender, &pj_uri)?; - let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? + let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(FeeRate::BROADCAST_MIN)?; let (Request { url, body, content_type, .. }, post_ctx) = - req_ctx.extract_highest_version(directory.to_owned())?; + req_ctx.extract_v2(directory.to_owned())?; let response = agent .post(url.clone()) .header("Content-Type", content_type) @@ -534,11 +529,8 @@ mod integration { .unwrap(); log::info!("Response: {:#?}", &response); assert!(response.status().is_success()); - let get_ctx = match post_ctx { - Context::V2(ctx) => - ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?, - _ => panic!("V2 context expected"), - }; + let get_ctx = + post_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?; let (Request { url, body, content_type, .. }, ohttp_ctx) = get_ctx.extract_req(directory.to_owned())?; let response = agent @@ -622,9 +614,9 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_original_psbt(&sender, &pj_uri)?; - let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? + let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(FeeRate::BROADCAST_MIN)?; - let (req, ctx) = req_ctx.extract_highest_version(EXAMPLE_URL.to_owned())?; + let (req, ctx) = req_ctx.extract_v1()?; let headers = HeaderMock::new(&req.body, req.content_type); // ********************** @@ -636,10 +628,6 @@ mod integration { // ********************** // Inside the Sender: // Sender checks, signs, finalizes, extracts, and broadcasts - let ctx = match ctx { - Context::V1(ctx) => ctx, - _ => panic!("V1 context expected"), - }; let checked_payjoin_proposal_psbt = ctx.process_response(&mut response.as_bytes())?; let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?; sender.send_raw_transaction(&payjoin_tx)?;