From 06b024377c6c36c90a753341deb97d55911a962c Mon Sep 17 00:00:00 2001 From: DanGould Date: Wed, 13 Dec 2023 11:59:05 -0500 Subject: [PATCH] Tidy up --- payjoin-cli/src/app.rs | 38 ++++++++++++++++++++++++-------------- payjoin/src/input_type.rs | 4 ++-- payjoin/src/send/mod.rs | 18 +++++++++++------- 3 files changed, 37 insertions(+), 23 deletions(-) diff --git a/payjoin-cli/src/app.rs b/payjoin-cli/src/app.rs index 5a3f662c..7903ff8a 100644 --- a/payjoin-cli/src/app.rs +++ b/payjoin-cli/src/app.rs @@ -25,6 +25,8 @@ use payjoin::receive::{PayjoinProposal, UncheckedProposal}; use payjoin::send::RequestContext; use serde::{Deserialize, Serialize}; #[cfg(feature = "v2")] +use tokio::sync::Mutex as AsyncMutex; +#[cfg(feature = "v2")] use tokio::task::spawn_blocking; #[cfg(feature = "danger-local-https")] @@ -33,17 +35,25 @@ const LOCAL_CERT_FILE: &str = "localhost.der"; #[derive(Clone)] pub(crate) struct App { config: AppConfig, - receive_store: Arc>, - send_store: Arc>, + #[cfg(feature = "v2")] + receive_store: Arc>, + #[cfg(feature = "v2")] + send_store: Arc>, seen_inputs: Arc>, } impl App { pub fn new(config: AppConfig) -> Result { let seen_inputs = Arc::new(Mutex::new(SeenInputs::new()?)); - let receive_store = Arc::new(Mutex::new(ReceiveStore::new()?)); - let send_store = Arc::new(Mutex::new(SendStore::new()?)); - Ok(Self { config, receive_store, send_store, seen_inputs }) + #[cfg(feature = "v2")] + let receive_store = Arc::new(AsyncMutex::new(ReceiveStore::new()?)); + #[cfg(feature = "v2")] + let send_store = Arc::new(AsyncMutex::new(SendStore::new()?)); + #[cfg(feature = "v2")] + let app = Self { config, receive_store, send_store, seen_inputs }; + #[cfg(not(feature = "v2"))] + let app = Self { config, seen_inputs }; + Ok(app) } pub fn bitcoind(&self) -> Result { @@ -65,7 +75,7 @@ impl App { #[cfg(feature = "v2")] pub async fn send_payjoin(&self, bip21: &str, fee_rate: &f32, is_retry: bool) -> Result<()> { - let mut session = self.send_store.lock().expect("mutex lock failed"); + let mut session = self.send_store.lock().await; let req_ctx = if is_retry { log::debug!("Resuming session"); // Get a reference to RequestContext @@ -79,6 +89,7 @@ impl App { log::debug!("Awaiting response"); let res = self.long_poll_post(req_ctx).await?; self.process_pj_response(res)?; + self.send_store.lock().await.clear()?; Ok(()) } @@ -132,13 +143,12 @@ impl App { let enrolled = enroller .process_res(ohttp_response.into_reader(), ctx) .map_err(|_| anyhow!("Enrollment failed"))?; - self.receive_store.lock().expect("mutex lock failed").write(enrolled.clone())?; + self.receive_store.lock().await.write(enrolled.clone())?; enrolled } else { - let session = self.receive_store.lock().expect("mutex lock failed"); + let session = self.receive_store.lock().await; log::debug!("Resuming session"); - let session = session.session.clone().unwrap(); - session + session.session.clone().ok_or(anyhow!("No session found"))? }; log::debug!("Enrolled receiver"); @@ -167,7 +177,7 @@ impl App { let _ = res.into_reader().read_to_end(&mut buf)?; let res = payjoin_proposal.deserialize_res(buf, ohttp_ctx); log::debug!("Received response {:?}", res); - self.receive_store.lock().expect("mutex lock failed").clear()?; + self.receive_store.lock().await.clear()?; Ok(()) } @@ -233,7 +243,7 @@ impl App { } } - fn create_pj_request<'a>(&self, bip21: &'a str, fee_rate: &f32) -> Result { + fn create_pj_request(&self, bip21: &str, fee_rate: &f32) -> Result { let uri = payjoin::Uri::try_from(bip21) .map_err(|e| anyhow!("Failed to create URI from BIP21: {}", e))?; @@ -300,7 +310,6 @@ impl App { .bitcoind()? .send_raw_transaction(&tx) .with_context(|| "Failed to send raw transaction")?; - self.send_store.lock().expect("mutex lock failed").clear()?; println!("Payjoin sent: {}", txid); Ok(txid) } @@ -650,6 +659,7 @@ struct SendStore { file: std::fs::File, } +#[cfg(feature = "v2")] impl SendStore { fn new() -> Result { let mut file = @@ -703,7 +713,7 @@ impl ReceiveStore { } }; - Ok(Self { session: session, file }) + Ok(Self { session, file }) } fn write( diff --git a/payjoin/src/input_type.rs b/payjoin/src/input_type.rs index 8601c05c..48a51390 100644 --- a/payjoin/src/input_type.rs +++ b/payjoin/src/input_type.rs @@ -38,6 +38,7 @@ impl serde::Serialize for InputType { } } +#[cfg(feature = "v2")] impl<'de> serde::Deserialize<'de> for InputType { fn deserialize(deserializer: D) -> Result where @@ -46,8 +47,7 @@ impl<'de> serde::Deserialize<'de> for InputType { use InputType::*; let s = String::deserialize(deserializer)?; - if s.starts_with("SegWitV0: ") { - let rest = &s["SegWitV0: ".len()..]; + if let Some(rest) = s.strip_prefix("SegWitV0: ") { let parts: Vec<&str> = rest.split(", ").collect(); if parts.len() != 2 { return Err(serde::de::Error::custom("invalid format for SegWitV0")); diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 61511555..4ca75ec6 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -142,7 +142,12 @@ use bitcoin::psbt::Psbt; use bitcoin::{FeeRate, Script, ScriptBuf, Sequence, TxOut, Weight}; pub use error::{CreateRequestError, ValidationError}; pub(crate) use error::{InternalCreateRequestError, InternalValidationError}; -use serde::ser::SerializeStruct; +#[cfg(feature = "v2")] +use serde::{ + de::{self, MapAccess, Visitor}, + ser::SerializeStruct, + Deserialize, Deserializer, Serialize, Serializer, +}; use url::Url; use crate::input_type::InputType; @@ -310,6 +315,7 @@ impl<'a> RequestBuilder<'a> { psbt.validate_input_utxos(true) .map_err(InternalCreateRequestError::InvalidOriginalInput)?; let endpoint = self.uri.extras._endpoint.clone(); + #[cfg(feature = "v2")] let ohttp_config = self.uri.extras.ohttp_config; let disable_output_substitution = self.uri.extras.disable_output_substitution || self.disable_output_substitution; @@ -452,10 +458,10 @@ impl RequestContext { } #[cfg(feature = "v2")] -impl serde::Serialize for RequestContext { +impl Serialize for RequestContext { fn serialize(&self, serializer: S) -> Result where - S: serde::Serializer, + S: Serializer, { let mut state = serializer.serialize_struct("RequestContext", 8)?; state.serialize_field("psbt", &self.psbt.to_string())?; @@ -479,9 +485,7 @@ impl serde::Serialize for RequestContext { } } -use serde::de::{self, MapAccess, Visitor}; -use serde::{Deserialize, Deserializer}; - +#[cfg(feature = "v2")] impl<'de> Deserialize<'de> for RequestContext { fn deserialize(deserializer: D) -> Result where @@ -489,7 +493,7 @@ impl<'de> Deserialize<'de> for RequestContext { { struct RequestContextVisitor; - const FIELDS: &'static [&'static str] = &[ + const FIELDS: &[&str] = &[ "psbt", "endpoint", "ohttp_config",