From 1970a59644302a374b5a48f54332a6325a4f4f4c Mon Sep 17 00:00:00 2001 From: DanGould Date: Tue, 5 Nov 2024 23:31:20 -0500 Subject: [PATCH 1/3] Expose Sender::extract_v2 for bindings The send::Context typestate enum is not simple to bind to in UniFFI and would require abstracting distinct extract_v1 extract_v2 functions in order to cross the FFI boundary. Exposing this method is a simple fix to make such abstraction unnecessary. --- payjoin/src/send/mod.rs | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 28e8b148..8863ea86 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -288,7 +288,10 @@ impl Sender { } match self.extract_rs_pubkey() { - Ok(rs) => self.extract_v2(ohttp_relay, rs), + Ok(_rs) => { + let (req, context_v2) = self.extract_v2(ohttp_relay)?; + Ok((req, Context::V2(context_v2))) + } Err(e) => { log::warn!("Failed to extract `rs` pubkey, falling back to v1: {}", e); let (req, context_v1) = self.extract_v1()?; @@ -302,12 +305,12 @@ impl Sender { /// This method requires the `rs` pubkey to be extracted from the endpoint /// and has no fallback to v1. #[cfg(feature = "v2")] - fn extract_v2( - &mut self, + pub fn extract_v2( + &self, ohttp_relay: Url, - rs: HpkePublicKey, - ) -> Result<(Request, Context), CreateRequestError> { + ) -> Result<(Request, V2PostContext), CreateRequestError> { use crate::uri::UrlExt; + let rs = self.extract_rs_pubkey()?; let url = self.endpoint.clone(); let body = serialize_v2_body( &self.psbt, @@ -329,7 +332,7 @@ impl Sender { log::debug!("ohttp_relay_url: {:?}", ohttp_relay); Ok(( Request::new_v2(ohttp_relay, body), - Context::V2(V2PostContext { + V2PostContext { endpoint: self.endpoint.clone(), psbt_ctx: PsbtContext { original_psbt: self.psbt.clone(), @@ -341,7 +344,7 @@ impl Sender { }, hpke_ctx, ohttp_ctx, - }), + }, )) } From 6ba1e530628e40b1f87b541b53227f8c44dd1801 Mon Sep 17 00:00:00 2001 From: spacebear Date: Thu, 7 Nov 2024 13:47:48 -0500 Subject: [PATCH 2/3] Remove Context wrapper and extract_highest_version The send::Context typestate enum is not simple to bind to in UniFFI, and the extract_highest_version function is not very useful because it still requires the caller to match on the resulting Context. --- payjoin-cli/src/app/v2.rs | 40 ++++++++++++++++++++------------- payjoin/src/send/mod.rs | 43 +++++------------------------------- payjoin/tests/integration.rs | 32 +++++++++------------------ 3 files changed, 40 insertions(+), 75 deletions(-) diff --git a/payjoin-cli/src/app/v2.rs b/payjoin-cli/src/app/v2.rs index 3dce6a98..5f6dc11d 100644 --- a/payjoin-cli/src/app/v2.rs +++ b/payjoin-cli/src/app/v2.rs @@ -197,19 +197,18 @@ impl App { } async fn long_poll_post(&self, req_ctx: &mut payjoin::send::Sender) -> Result { - let (req, ctx) = req_ctx.extract_highest_version(self.config.ohttp_relay.clone())?; - println!("Posting Original PSBT Payload request..."); - let http = http_agent()?; - let response = http - .post(req.url) - .header("Content-Type", req.content_type) - .body(req.body) - .send() - .await - .map_err(map_reqwest_err)?; - println!("Sent fallback transaction"); - match ctx { - payjoin::send::Context::V2(ctx) => { + match req_ctx.extract_v2(self.config.ohttp_relay.clone()) { + Ok((req, ctx)) => { + println!("Posting Original PSBT Payload request..."); + let http = http_agent()?; + let response = http + .post(req.url) + .header("Content-Type", req.content_type) + .body(req.body) + .send() + .await + .map_err(map_reqwest_err)?; + println!("Sent fallback transaction"); let v2_ctx = Arc::new( ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?, ); @@ -239,8 +238,19 @@ impl App { } } } - payjoin::send::Context::V1(ctx) => { - match ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) { + Err(_) => { + let (req, v1_ctx) = req_ctx.extract_v1()?; + println!("Posting Original PSBT Payload request..."); + let http = http_agent()?; + let response = http + .post(req.url) + .header("Content-Type", req.content_type) + .body(req.body) + .send() + .await + .map_err(map_reqwest_err)?; + println!("Sent fallback transaction"); + match v1_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) { Ok(psbt) => Ok(psbt), Err(re) => { println!("{}", re); diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 8863ea86..8400d90b 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -268,38 +268,6 @@ impl Sender { )) } - /// Extract serialized Request and Context from a Payjoin Proposal. Automatically selects the correct version. - /// - /// In order to support polling, this may need to be called many times to be encrypted with - /// new unique nonces to make independent OHTTP requests. - /// - /// The `ohttp_relay` merely passes the encrypted payload to the ohttp gateway of the receiver - #[cfg(feature = "v2")] - pub fn extract_highest_version( - &mut self, - ohttp_relay: Url, - ) -> Result<(Request, Context), CreateRequestError> { - use crate::uri::UrlExt; - - if let Some(expiry) = self.endpoint.exp() { - if std::time::SystemTime::now() > expiry { - return Err(InternalCreateRequestError::Expired(expiry).into()); - } - } - - match self.extract_rs_pubkey() { - Ok(_rs) => { - let (req, context_v2) = self.extract_v2(ohttp_relay)?; - Ok((req, Context::V2(context_v2))) - } - Err(e) => { - log::warn!("Failed to extract `rs` pubkey, falling back to v1: {}", e); - let (req, context_v1) = self.extract_v1()?; - Ok((req, Context::V1(context_v1))) - } - } - } - /// Extract serialized Request and Context from a Payjoin Proposal. /// /// This method requires the `rs` pubkey to be extracted from the endpoint @@ -310,6 +278,11 @@ impl Sender { ohttp_relay: Url, ) -> Result<(Request, V2PostContext), CreateRequestError> { use crate::uri::UrlExt; + if let Some(expiry) = self.endpoint.exp() { + if std::time::SystemTime::now() > expiry { + return Err(InternalCreateRequestError::Expired(expiry).into()); + } + } let rs = self.extract_rs_pubkey()?; let url = self.endpoint.clone(); let body = serialize_v2_body( @@ -369,12 +342,6 @@ impl Sender { pub fn endpoint(&self) -> &Url { &self.endpoint } } -pub enum Context { - V1(V1Context), - #[cfg(feature = "v2")] - V2(V2PostContext), -} - #[derive(Debug, Clone)] pub struct V1Context { psbt_context: PsbtContext, diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index e4701792..82eb4339 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -180,7 +180,6 @@ mod integration { use bitcoin::Address; use http::StatusCode; use payjoin::receive::v2::{PayjoinProposal, Receiver, UncheckedProposal}; - use payjoin::send::Context; use payjoin::{OhttpKeys, PjUri, UriExt}; use reqwest::{Client, ClientBuilder, Error, Response}; use testcontainers_modules::redis::Redis; @@ -285,9 +284,9 @@ mod integration { Some(std::time::SystemTime::now()), ) .build(); - let mut expired_req_ctx = SenderBuilder::from_psbt_and_uri(psbt, expired_pj_uri)? + let expired_req_ctx = SenderBuilder::from_psbt_and_uri(psbt, expired_pj_uri)? .build_non_incentivizing(FeeRate::BROADCAST_MIN)?; - match expired_req_ctx.extract_highest_version(directory.to_owned()) { + match expired_req_ctx.extract_v2(directory.to_owned()) { // Internal error types are private, so check against a string Err(err) => assert!(err.to_string().contains("expired")), _ => assert!(false, "Expired send session should error"), @@ -355,14 +354,10 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_sweep_psbt(&sender, &pj_uri)?; - let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? + let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(FeeRate::BROADCAST_MIN)?; let (Request { url, body, content_type, .. }, send_ctx) = - req_ctx.extract_highest_version(directory.to_owned())?; - let send_ctx = match send_ctx { - Context::V2(ctx) => ctx, - _ => panic!("V2 context expected"), - }; + req_ctx.extract_v2(directory.to_owned())?; let response = agent .post(url.clone()) .header("Content-Type", content_type) @@ -521,10 +516,10 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_sweep_psbt(&sender, &pj_uri)?; - let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? + let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(FeeRate::BROADCAST_MIN)?; let (Request { url, body, content_type, .. }, post_ctx) = - req_ctx.extract_highest_version(directory.to_owned())?; + req_ctx.extract_v2(directory.to_owned())?; let response = agent .post(url.clone()) .header("Content-Type", content_type) @@ -534,11 +529,8 @@ mod integration { .unwrap(); log::info!("Response: {:#?}", &response); assert!(response.status().is_success()); - let get_ctx = match post_ctx { - Context::V2(ctx) => - ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?, - _ => panic!("V2 context expected"), - }; + let get_ctx = + post_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?; let (Request { url, body, content_type, .. }, ohttp_ctx) = get_ctx.extract_req(directory.to_owned())?; let response = agent @@ -622,9 +614,9 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_original_psbt(&sender, &pj_uri)?; - let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? + let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(FeeRate::BROADCAST_MIN)?; - let (req, ctx) = req_ctx.extract_highest_version(EXAMPLE_URL.to_owned())?; + let (req, ctx) = req_ctx.extract_v1()?; let headers = HeaderMock::new(&req.body, req.content_type); // ********************** @@ -636,10 +628,6 @@ mod integration { // ********************** // Inside the Sender: // Sender checks, signs, finalizes, extracts, and broadcasts - let ctx = match ctx { - Context::V1(ctx) => ctx, - _ => panic!("V1 context expected"), - }; let checked_payjoin_proposal_psbt = ctx.process_response(&mut response.as_bytes())?; let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?; sender.send_raw_transaction(&payjoin_tx)?; From f6247ff69f3c41f87d5bf68bea04500865157ac5 Mon Sep 17 00:00:00 2001 From: spacebear Date: Thu, 7 Nov 2024 14:00:07 -0500 Subject: [PATCH 3/3] extract post_request --- payjoin-cli/src/app/v2.rs | 55 +++++++++++---------------------------- 1 file changed, 15 insertions(+), 40 deletions(-) diff --git a/payjoin-cli/src/app/v2.rs b/payjoin-cli/src/app/v2.rs index 5f6dc11d..9e4d23ea 100644 --- a/payjoin-cli/src/app/v2.rs +++ b/payjoin-cli/src/app/v2.rs @@ -137,14 +137,7 @@ impl App { .extract_v2_req() .map_err(|e| anyhow!("v2 req extraction failed {}", e))?; println!("Got a request from the sender. Responding with a Payjoin proposal."); - let http = http_agent()?; - let res = http - .post(req.url) - .header("Content-Type", req.content_type) - .body(req.body) - .send() - .await - .map_err(map_reqwest_err)?; + let res = post_request(req).await?; payjoin_proposal .process_res(res.bytes().await?.to_vec(), ohttp_ctx) .map_err(|e| anyhow!("Failed to deserialize response {}", e))?; @@ -200,27 +193,14 @@ impl App { match req_ctx.extract_v2(self.config.ohttp_relay.clone()) { Ok((req, ctx)) => { println!("Posting Original PSBT Payload request..."); - let http = http_agent()?; - let response = http - .post(req.url) - .header("Content-Type", req.content_type) - .body(req.body) - .send() - .await - .map_err(map_reqwest_err)?; + let response = post_request(req).await?; println!("Sent fallback transaction"); let v2_ctx = Arc::new( ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?, ); loop { let (req, ohttp_ctx) = v2_ctx.extract_req(self.config.ohttp_relay.clone())?; - let response = http - .post(req.url) - .header("Content-Type", req.content_type) - .body(req.body) - .send() - .await - .map_err(map_reqwest_err)?; + let response = post_request(req).await?; match v2_ctx.process_response( &mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx, @@ -241,14 +221,7 @@ impl App { Err(_) => { let (req, v1_ctx) = req_ctx.extract_v1()?; println!("Posting Original PSBT Payload request..."); - let http = http_agent()?; - let response = http - .post(req.url) - .header("Content-Type", req.content_type) - .body(req.body) - .send() - .await - .map_err(map_reqwest_err)?; + let response = post_request(req).await?; println!("Sent fallback transaction"); match v1_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) { Ok(psbt) => Ok(psbt), @@ -269,15 +242,7 @@ impl App { loop { let (req, context) = session.extract_req()?; println!("Polling receive request..."); - let http = http_agent()?; - let ohttp_response = http - .post(req.url) - .header("Content-Type", req.content_type) - .body(req.body) - .send() - .await - .map_err(map_reqwest_err)?; - + let ohttp_response = post_request(req).await?; let proposal = session .process_res(ohttp_response.bytes().await?.to_vec().as_slice(), context) .map_err(|_| anyhow!("GET fallback failed"))?; @@ -417,6 +382,16 @@ async fn handle_interrupt(tx: watch::Sender<()>) { let _ = tx.send(()); } +async fn post_request(req: payjoin::Request) -> Result { + let http = http_agent()?; + http.post(req.url) + .header("Content-Type", req.content_type) + .body(req.body) + .send() + .await + .map_err(map_reqwest_err) +} + fn map_reqwest_err(e: reqwest::Error) -> anyhow::Error { match e.status() { Some(status_code) => anyhow!("HTTP request failed: {} {}", status_code, e),