Skip to content

Commit

Permalink
Persist send sessions for async pj
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould committed Dec 13, 2023
1 parent 8ac1570 commit 1001583
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
target
*config.toml
*seen_inputs.json
*session_store.json
*_store.json
64 changes: 59 additions & 5 deletions payjoin-cli/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,16 @@ const LOCAL_CERT_FILE: &str = "localhost.der";
pub(crate) struct App {
config: AppConfig,
receive_store: Arc<Mutex<ReceiveStore>>,
send_store: Arc<Mutex<SendStore>>,
seen_inputs: Arc<Mutex<SeenInputs>>,
}

impl App {
pub fn new(config: AppConfig) -> Result<Self> {
let seen_inputs = Arc::new(Mutex::new(SeenInputs::new()?));
let receive_store = Arc::new(Mutex::new(ReceiveStore::new()?));
Ok(Self { config, receive_store, seen_inputs })
let send_store = Arc::new(Mutex::new(SendStore::new()?));
Ok(Self { config, receive_store, send_store, seen_inputs })
}

pub fn bitcoind(&self) -> Result<bitcoincore_rpc::Client> {
Expand All @@ -62,8 +64,18 @@ impl App {
}

#[cfg(feature = "v2")]
pub async fn send_payjoin(&self, bip21: &str, fee_rate: &f32) -> Result<()> {
let req_ctx = self.create_pj_request(bip21, fee_rate)?;
pub async fn send_payjoin(&self, bip21: &str, fee_rate: &f32, is_retry: bool) -> Result<()> {
let mut session = self.send_store.lock().expect("mutex lock failed");
let req_ctx = if is_retry {
log::debug!("Resuming session");
// Get a reference to RequestContext
session.req_ctx.as_ref().expect("RequestContext is missing")
} else {
let req_ctx = self.create_pj_request(bip21, fee_rate)?;
session.write(req_ctx)?;
log::debug!("Writing req_ctx");
session.req_ctx.as_ref().expect("RequestContext is missing")
};
log::debug!("Awaiting response");
let res = self.long_poll_post(req_ctx).await?;
self.process_pj_response(res)?;
Expand Down Expand Up @@ -173,7 +185,7 @@ impl App {
}

#[cfg(feature = "v2")]
async fn long_poll_post(&self, req_ctx: payjoin::send::RequestContext<'_>) -> Result<Psbt> {
async fn long_poll_post(&self, req_ctx: &payjoin::send::RequestContext) -> Result<Psbt> {
loop {
let (req, ctx, ohttp) = req_ctx.extract_v2(&self.config.ohttp_proxy)?;
println!("Sending fallback request to {}", &req.url);
Expand Down Expand Up @@ -221,7 +233,7 @@ impl App {
}
}

fn create_pj_request<'a>(&self, bip21: &'a str, fee_rate: &f32) -> Result<RequestContext<'a>> {
fn create_pj_request<'a>(&self, bip21: &'a str, fee_rate: &f32) -> Result<RequestContext> {
let uri = payjoin::Uri::try_from(bip21)
.map_err(|e| anyhow!("Failed to create URI from BIP21: {}", e))?;

Expand Down Expand Up @@ -288,6 +300,7 @@ impl App {
.bitcoind()?
.send_raw_transaction(&tx)
.with_context(|| "Failed to send raw transaction")?;
self.send_store.lock().expect("mutex lock failed").clear()?;
println!("Payjoin sent: {}", txid);
Ok(txid)
}
Expand Down Expand Up @@ -631,12 +644,53 @@ impl App {
}
}

#[cfg(feature = "v2")]
struct SendStore {
req_ctx: Option<payjoin::send::RequestContext>,
file: std::fs::File,
}

impl SendStore {
fn new() -> Result<Self> {
let mut file =
OpenOptions::new().write(true).read(true).create(true).open("send_store.json")?;
let session = match serde_json::from_reader(&mut file) {
Ok(session) => Some(session),
Err(e) => {
log::debug!("error reading send session store: {}", e);
None
}
};

Ok(Self { req_ctx: session, file })
}

fn write(
&mut self,
session: payjoin::send::RequestContext,
) -> Result<&mut payjoin::send::RequestContext> {
use std::io::Write;

let session = self.req_ctx.insert(session);
let serialized = serde_json::to_string(session)?;
self.file.write_all(serialized.as_bytes())?;
Ok(session)
}

fn clear(&mut self) -> Result<()> {
let file = OpenOptions::new().write(true).open("send_store.json")?;
file.set_len(0)?;
Ok(())
}
}

#[cfg(feature = "v2")]
struct ReceiveStore {
session: Option<payjoin::receive::v2::Enrolled>,
file: std::fs::File,
}

#[cfg(feature = "v2")]
impl ReceiveStore {
fn new() -> Result<Self> {
let mut file =
Expand Down
18 changes: 11 additions & 7 deletions payjoin-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@ async fn main() -> Result<()> {
let bip21 = sub_matches.get_one::<String>("BIP21").context("Missing BIP21 argument")?;
let fee_rate_sat_per_vb =
sub_matches.get_one::<f32>("fee_rate").context("Missing --fee-rate argument")?;
#[cfg(feature = "v2")]
let is_retry = matches.get_one::<bool>("retry").context("Could not read --retry")?;
#[cfg(feature = "v2")]
app.send_payjoin(bip21, fee_rate_sat_per_vb, *is_retry).await?;
#[cfg(not(feature = "v2"))]
app.send_payjoin(bip21, fee_rate_sat_per_vb).await?;
}
Some(("receive", sub_matches)) => {
let amount =
sub_matches.get_one::<String>("AMOUNT").context("Missing AMOUNT argument")?;
#[cfg(feature = "v2")]
let is_retry =
sub_matches.get_one::<bool>("retry").context("Could not read --retry")?;
let is_retry = matches.get_one::<bool>("retry").context("Could not read --retry")?;
#[cfg(feature = "v2")]
app.receive_payjoin(amount, *is_retry).await?;
#[cfg(not(feature = "v2"))]
Expand Down Expand Up @@ -64,6 +68,11 @@ fn cli() -> ArgMatches {
.arg(Arg::new("ohttp_proxy")
.long("ohttp-proxy")
.help("The ohttp proxy url"))
.arg(Arg::new("retry")
.long("retry")
.short('e')
.action(clap::ArgAction::SetTrue)
.help("Retry the asynchronous payjoin request if it did not yet complete"))
.subcommand(
Command::new("send")
.arg_required_else_help(true)
Expand Down Expand Up @@ -91,11 +100,6 @@ fn cli() -> ArgMatches {
.short('e')
.takes_value(true)
.help("The `pj=` endpoint to receive the payjoin request"))
.arg(Arg::new("retry")
.long("retry")
.short('r')
.action(clap::ArgAction::SetTrue)
.help("Retry the asynchronous payjoin request if it did not yet complete"))
.arg(Arg::new("sub_only")
.long("sub-only")
.short('s')
Expand Down
56 changes: 56 additions & 0 deletions payjoin/src/input_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,62 @@ pub(crate) enum InputType {
Taproot,
}

#[cfg(feature = "v2")]
impl serde::Serialize for InputType {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
use InputType::*;

match self {
P2Pk => serializer.serialize_str("P2PK"),
P2Pkh => serializer.serialize_str("P2PKH"),
P2Sh => serializer.serialize_str("P2SH"),
SegWitV0 { ty, nested } =>
serializer.serialize_str(&format!("SegWitV0: type={}, nested={}", ty, nested)),
Taproot => serializer.serialize_str("Taproot"),
}
}
}

impl<'de> serde::Deserialize<'de> for InputType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use InputType::*;

let s = String::deserialize(deserializer)?;
if s.starts_with("SegWitV0: ") {
let rest = &s["SegWitV0: ".len()..];
let parts: Vec<&str> = rest.split(", ").collect();
if parts.len() != 2 {
return Err(serde::de::Error::custom("invalid format for SegWitV0"));
}
log::debug!("parts: {:?}", parts);
let ty = match parts[0].strip_prefix("type=") {
Some("pubkey") => SegWitV0Type::Pubkey,
Some("script") => SegWitV0Type::Script,
_ => return Err(serde::de::Error::custom("invalid SegWitV0 type")),
};

let nested = match parts[1].strip_prefix("nested=") {
Some("true") => true,
Some("false") => false,
_ => return Err(serde::de::Error::custom("invalid SegWitV0 nested value")),
};

Ok(SegWitV0 { ty, nested })
} else {
match s.as_str() {
"P2PK" => Ok(P2Pk),
"P2PKH" => Ok(P2Pkh),
"P2SH" => Ok(P2Sh),
"Taproot" => Ok(Taproot),
_ => Err(serde::de::Error::custom("invalid type")),
}
}
}
}

impl InputType {
pub(crate) fn from_spent_input(
txout: &TxOut,
Expand Down
Loading

0 comments on commit 1001583

Please sign in to comment.