Skip to content

Commit

Permalink
Expose Sender::extract_v2 for bindings (#382)
Browse files Browse the repository at this point in the history
The send::Context typestate enum is not simple to bind to in UniFFI and
would require abstracting distinct extract_v1 extract_v2 functions in
order to cross the FFI boundary. Exposing this method is a simple fix to
make such abstraction unnecessary.

As a consequence, the `enum Context` has also been removed.
  • Loading branch information
DanGould authored Nov 11, 2024
2 parents e821a41 + f6247ff commit ef2ce55
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 100 deletions.
63 changes: 24 additions & 39 deletions payjoin-cli/src/app/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,7 @@ impl App {
.extract_v2_req()
.map_err(|e| anyhow!("v2 req extraction failed {}", e))?;
println!("Got a request from the sender. Responding with a Payjoin proposal.");
let http = http_agent()?;
let res = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;
let res = post_request(req).await?;
payjoin_proposal
.process_res(res.bytes().await?.to_vec(), ohttp_ctx)
.map_err(|e| anyhow!("Failed to deserialize response {}", e))?;
Expand Down Expand Up @@ -197,31 +190,17 @@ impl App {
}

async fn long_poll_post(&self, req_ctx: &mut payjoin::send::Sender) -> Result<Psbt> {
let (req, ctx) = req_ctx.extract_highest_version(self.config.ohttp_relay.clone())?;
println!("Posting Original PSBT Payload request...");
let http = http_agent()?;
let response = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;
println!("Sent fallback transaction");
match ctx {
payjoin::send::Context::V2(ctx) => {
match req_ctx.extract_v2(self.config.ohttp_relay.clone()) {
Ok((req, ctx)) => {
println!("Posting Original PSBT Payload request...");
let response = post_request(req).await?;
println!("Sent fallback transaction");
let v2_ctx = Arc::new(
ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?,
);
loop {
let (req, ohttp_ctx) = v2_ctx.extract_req(self.config.ohttp_relay.clone())?;
let response = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;
let response = post_request(req).await?;
match v2_ctx.process_response(
&mut response.bytes().await?.to_vec().as_slice(),
ohttp_ctx,
Expand All @@ -239,8 +218,12 @@ impl App {
}
}
}
payjoin::send::Context::V1(ctx) => {
match ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) {
Err(_) => {
let (req, v1_ctx) = req_ctx.extract_v1()?;
println!("Posting Original PSBT Payload request...");
let response = post_request(req).await?;
println!("Sent fallback transaction");
match v1_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) {
Ok(psbt) => Ok(psbt),
Err(re) => {
println!("{}", re);
Expand All @@ -259,15 +242,7 @@ impl App {
loop {
let (req, context) = session.extract_req()?;
println!("Polling receive request...");
let http = http_agent()?;
let ohttp_response = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;

let ohttp_response = post_request(req).await?;
let proposal = session
.process_res(ohttp_response.bytes().await?.to_vec().as_slice(), context)
.map_err(|_| anyhow!("GET fallback failed"))?;
Expand Down Expand Up @@ -407,6 +382,16 @@ async fn handle_interrupt(tx: watch::Sender<()>) {
let _ = tx.send(());
}

async fn post_request(req: payjoin::Request) -> Result<reqwest::Response> {
let http = http_agent()?;
http.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)
}

fn map_reqwest_err(e: reqwest::Error) -> anyhow::Error {
match e.status() {
Some(status_code) => anyhow!("HTTP request failed: {} {}", status_code, e),
Expand Down
48 changes: 9 additions & 39 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,46 +268,22 @@ impl Sender {
))
}

/// Extract serialized Request and Context from a Payjoin Proposal. Automatically selects the correct version.
///
/// In order to support polling, this may need to be called many times to be encrypted with
/// new unique nonces to make independent OHTTP requests.
/// Extract serialized Request and Context from a Payjoin Proposal.
///
/// The `ohttp_relay` merely passes the encrypted payload to the ohttp gateway of the receiver
/// This method requires the `rs` pubkey to be extracted from the endpoint
/// and has no fallback to v1.
#[cfg(feature = "v2")]
pub fn extract_highest_version(
&mut self,
pub fn extract_v2(
&self,
ohttp_relay: Url,
) -> Result<(Request, Context), CreateRequestError> {
) -> Result<(Request, V2PostContext), CreateRequestError> {
use crate::uri::UrlExt;

if let Some(expiry) = self.endpoint.exp() {
if std::time::SystemTime::now() > expiry {
return Err(InternalCreateRequestError::Expired(expiry).into());
}
}

match self.extract_rs_pubkey() {
Ok(rs) => self.extract_v2(ohttp_relay, rs),
Err(e) => {
log::warn!("Failed to extract `rs` pubkey, falling back to v1: {}", e);
let (req, context_v1) = self.extract_v1()?;
Ok((req, Context::V1(context_v1)))
}
}
}

/// Extract serialized Request and Context from a Payjoin Proposal.
///
/// This method requires the `rs` pubkey to be extracted from the endpoint
/// and has no fallback to v1.
#[cfg(feature = "v2")]
fn extract_v2(
&mut self,
ohttp_relay: Url,
rs: HpkePublicKey,
) -> Result<(Request, Context), CreateRequestError> {
use crate::uri::UrlExt;
let rs = self.extract_rs_pubkey()?;
let url = self.endpoint.clone();
let body = serialize_v2_body(
&self.psbt,
Expand All @@ -329,7 +305,7 @@ impl Sender {
log::debug!("ohttp_relay_url: {:?}", ohttp_relay);
Ok((
Request::new_v2(ohttp_relay, body),
Context::V2(V2PostContext {
V2PostContext {
endpoint: self.endpoint.clone(),
psbt_ctx: PsbtContext {
original_psbt: self.psbt.clone(),
Expand All @@ -341,7 +317,7 @@ impl Sender {
},
hpke_ctx,
ohttp_ctx,
}),
},
))
}

Expand All @@ -366,12 +342,6 @@ impl Sender {
pub fn endpoint(&self) -> &Url { &self.endpoint }
}

pub enum Context {
V1(V1Context),
#[cfg(feature = "v2")]
V2(V2PostContext),
}

#[derive(Debug, Clone)]
pub struct V1Context {
psbt_context: PsbtContext,
Expand Down
32 changes: 10 additions & 22 deletions payjoin/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ mod integration {
use bitcoin::Address;
use http::StatusCode;
use payjoin::receive::v2::{PayjoinProposal, Receiver, UncheckedProposal};
use payjoin::send::Context;
use payjoin::{OhttpKeys, PjUri, UriExt};
use reqwest::{Client, ClientBuilder, Error, Response};
use testcontainers_modules::redis::Redis;
Expand Down Expand Up @@ -285,9 +284,9 @@ mod integration {
Some(std::time::SystemTime::now()),
)
.build();
let mut expired_req_ctx = SenderBuilder::from_psbt_and_uri(psbt, expired_pj_uri)?
let expired_req_ctx = SenderBuilder::from_psbt_and_uri(psbt, expired_pj_uri)?
.build_non_incentivizing(FeeRate::BROADCAST_MIN)?;
match expired_req_ctx.extract_highest_version(directory.to_owned()) {
match expired_req_ctx.extract_v2(directory.to_owned()) {
// Internal error types are private, so check against a string
Err(err) => assert!(err.to_string().contains("expired")),
_ => assert!(false, "Expired send session should error"),
Expand Down Expand Up @@ -355,14 +354,10 @@ mod integration {
.check_pj_supported()
.unwrap();
let psbt = build_sweep_psbt(&sender, &pj_uri)?;
let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
.build_recommended(FeeRate::BROADCAST_MIN)?;
let (Request { url, body, content_type, .. }, send_ctx) =
req_ctx.extract_highest_version(directory.to_owned())?;
let send_ctx = match send_ctx {
Context::V2(ctx) => ctx,
_ => panic!("V2 context expected"),
};
req_ctx.extract_v2(directory.to_owned())?;
let response = agent
.post(url.clone())
.header("Content-Type", content_type)
Expand Down Expand Up @@ -521,10 +516,10 @@ mod integration {
.check_pj_supported()
.unwrap();
let psbt = build_sweep_psbt(&sender, &pj_uri)?;
let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
.build_recommended(FeeRate::BROADCAST_MIN)?;
let (Request { url, body, content_type, .. }, post_ctx) =
req_ctx.extract_highest_version(directory.to_owned())?;
req_ctx.extract_v2(directory.to_owned())?;
let response = agent
.post(url.clone())
.header("Content-Type", content_type)
Expand All @@ -534,11 +529,8 @@ mod integration {
.unwrap();
log::info!("Response: {:#?}", &response);
assert!(response.status().is_success());
let get_ctx = match post_ctx {
Context::V2(ctx) =>
ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?,
_ => panic!("V2 context expected"),
};
let get_ctx =
post_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?;
let (Request { url, body, content_type, .. }, ohttp_ctx) =
get_ctx.extract_req(directory.to_owned())?;
let response = agent
Expand Down Expand Up @@ -622,9 +614,9 @@ mod integration {
.check_pj_supported()
.unwrap();
let psbt = build_original_psbt(&sender, &pj_uri)?;
let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
.build_recommended(FeeRate::BROADCAST_MIN)?;
let (req, ctx) = req_ctx.extract_highest_version(EXAMPLE_URL.to_owned())?;
let (req, ctx) = req_ctx.extract_v1()?;
let headers = HeaderMock::new(&req.body, req.content_type);

// **********************
Expand All @@ -636,10 +628,6 @@ mod integration {
// **********************
// Inside the Sender:
// Sender checks, signs, finalizes, extracts, and broadcasts
let ctx = match ctx {
Context::V1(ctx) => ctx,
_ => panic!("V1 context expected"),
};
let checked_payjoin_proposal_psbt = ctx.process_response(&mut response.as_bytes())?;
let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?;
sender.send_raw_transaction(&payjoin_tx)?;
Expand Down

0 comments on commit ef2ce55

Please sign in to comment.