Skip to content

Commit

Permalink
Remove Context wrapper and extract_highest_version
Browse files Browse the repository at this point in the history
The send::Context typestate enum is not simple to bind to in UniFFI, and
the extract_highest_version function is not very useful because it still
requires the caller to match on the resulting Context.
  • Loading branch information
spacebear21 committed Nov 7, 2024
1 parent 1970a59 commit 6ba1e53
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 75 deletions.
40 changes: 25 additions & 15 deletions payjoin-cli/src/app/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,19 +197,18 @@ impl App {
}

async fn long_poll_post(&self, req_ctx: &mut payjoin::send::Sender) -> Result<Psbt> {
let (req, ctx) = req_ctx.extract_highest_version(self.config.ohttp_relay.clone())?;
println!("Posting Original PSBT Payload request...");
let http = http_agent()?;
let response = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;
println!("Sent fallback transaction");
match ctx {
payjoin::send::Context::V2(ctx) => {
match req_ctx.extract_v2(self.config.ohttp_relay.clone()) {
Ok((req, ctx)) => {
println!("Posting Original PSBT Payload request...");
let http = http_agent()?;
let response = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;
println!("Sent fallback transaction");
let v2_ctx = Arc::new(
ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?,
);
Expand Down Expand Up @@ -239,8 +238,19 @@ impl App {
}
}
}
payjoin::send::Context::V1(ctx) => {
match ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) {
Err(_) => {
let (req, v1_ctx) = req_ctx.extract_v1()?;
println!("Posting Original PSBT Payload request...");
let http = http_agent()?;
let response = http
.post(req.url)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
.map_err(map_reqwest_err)?;
println!("Sent fallback transaction");
match v1_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) {
Ok(psbt) => Ok(psbt),
Err(re) => {
println!("{}", re);
Expand Down
43 changes: 5 additions & 38 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,38 +268,6 @@ impl Sender {
))
}

/// Extract serialized Request and Context from a Payjoin Proposal. Automatically selects the correct version.
///
/// In order to support polling, this may need to be called many times to be encrypted with
/// new unique nonces to make independent OHTTP requests.
///
/// The `ohttp_relay` merely passes the encrypted payload to the ohttp gateway of the receiver
#[cfg(feature = "v2")]
pub fn extract_highest_version(
&mut self,
ohttp_relay: Url,
) -> Result<(Request, Context), CreateRequestError> {
use crate::uri::UrlExt;

if let Some(expiry) = self.endpoint.exp() {
if std::time::SystemTime::now() > expiry {
return Err(InternalCreateRequestError::Expired(expiry).into());
}
}

match self.extract_rs_pubkey() {
Ok(_rs) => {
let (req, context_v2) = self.extract_v2(ohttp_relay)?;
Ok((req, Context::V2(context_v2)))
}
Err(e) => {
log::warn!("Failed to extract `rs` pubkey, falling back to v1: {}", e);
let (req, context_v1) = self.extract_v1()?;
Ok((req, Context::V1(context_v1)))
}
}
}

/// Extract serialized Request and Context from a Payjoin Proposal.
///
/// This method requires the `rs` pubkey to be extracted from the endpoint
Expand All @@ -310,6 +278,11 @@ impl Sender {
ohttp_relay: Url,
) -> Result<(Request, V2PostContext), CreateRequestError> {
use crate::uri::UrlExt;
if let Some(expiry) = self.endpoint.exp() {
if std::time::SystemTime::now() > expiry {
return Err(InternalCreateRequestError::Expired(expiry).into());
}
}
let rs = self.extract_rs_pubkey()?;
let url = self.endpoint.clone();
let body = serialize_v2_body(
Expand Down Expand Up @@ -369,12 +342,6 @@ impl Sender {
pub fn endpoint(&self) -> &Url { &self.endpoint }
}

pub enum Context {
V1(V1Context),
#[cfg(feature = "v2")]
V2(V2PostContext),
}

#[derive(Debug, Clone)]
pub struct V1Context {
psbt_context: PsbtContext,
Expand Down
32 changes: 10 additions & 22 deletions payjoin/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ mod integration {
use bitcoin::Address;
use http::StatusCode;
use payjoin::receive::v2::{PayjoinProposal, Receiver, UncheckedProposal};
use payjoin::send::Context;
use payjoin::{OhttpKeys, PjUri, UriExt};
use reqwest::{Client, ClientBuilder, Error, Response};
use testcontainers_modules::redis::Redis;
Expand Down Expand Up @@ -285,9 +284,9 @@ mod integration {
Some(std::time::SystemTime::now()),
)
.build();
let mut expired_req_ctx = SenderBuilder::from_psbt_and_uri(psbt, expired_pj_uri)?
let expired_req_ctx = SenderBuilder::from_psbt_and_uri(psbt, expired_pj_uri)?
.build_non_incentivizing(FeeRate::BROADCAST_MIN)?;
match expired_req_ctx.extract_highest_version(directory.to_owned()) {
match expired_req_ctx.extract_v2(directory.to_owned()) {
// Internal error types are private, so check against a string
Err(err) => assert!(err.to_string().contains("expired")),
_ => assert!(false, "Expired send session should error"),
Expand Down Expand Up @@ -355,14 +354,10 @@ mod integration {
.check_pj_supported()
.unwrap();
let psbt = build_sweep_psbt(&sender, &pj_uri)?;
let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
.build_recommended(FeeRate::BROADCAST_MIN)?;
let (Request { url, body, content_type, .. }, send_ctx) =
req_ctx.extract_highest_version(directory.to_owned())?;
let send_ctx = match send_ctx {
Context::V2(ctx) => ctx,
_ => panic!("V2 context expected"),
};
req_ctx.extract_v2(directory.to_owned())?;
let response = agent
.post(url.clone())
.header("Content-Type", content_type)
Expand Down Expand Up @@ -521,10 +516,10 @@ mod integration {
.check_pj_supported()
.unwrap();
let psbt = build_sweep_psbt(&sender, &pj_uri)?;
let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
.build_recommended(FeeRate::BROADCAST_MIN)?;
let (Request { url, body, content_type, .. }, post_ctx) =
req_ctx.extract_highest_version(directory.to_owned())?;
req_ctx.extract_v2(directory.to_owned())?;
let response = agent
.post(url.clone())
.header("Content-Type", content_type)
Expand All @@ -534,11 +529,8 @@ mod integration {
.unwrap();
log::info!("Response: {:#?}", &response);
assert!(response.status().is_success());
let get_ctx = match post_ctx {
Context::V2(ctx) =>
ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?,
_ => panic!("V2 context expected"),
};
let get_ctx =
post_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?;
let (Request { url, body, content_type, .. }, ohttp_ctx) =
get_ctx.extract_req(directory.to_owned())?;
let response = agent
Expand Down Expand Up @@ -622,9 +614,9 @@ mod integration {
.check_pj_supported()
.unwrap();
let psbt = build_original_psbt(&sender, &pj_uri)?;
let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
let req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())?
.build_recommended(FeeRate::BROADCAST_MIN)?;
let (req, ctx) = req_ctx.extract_highest_version(EXAMPLE_URL.to_owned())?;
let (req, ctx) = req_ctx.extract_v1()?;
let headers = HeaderMock::new(&req.body, req.content_type);

// **********************
Expand All @@ -636,10 +628,6 @@ mod integration {
// **********************
// Inside the Sender:
// Sender checks, signs, finalizes, extracts, and broadcasts
let ctx = match ctx {
Context::V1(ctx) => ctx,
_ => panic!("V1 context expected"),
};
let checked_payjoin_proposal_psbt = ctx.process_response(&mut response.as_bytes())?;
let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?;
sender.send_raw_transaction(&payjoin_tx)?;
Expand Down

0 comments on commit 6ba1e53

Please sign in to comment.