Skip to content

Commit

Permalink
Tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould committed Dec 13, 2023
1 parent 1001583 commit 06b0243
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 23 deletions.
38 changes: 24 additions & 14 deletions payjoin-cli/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ use payjoin::receive::{PayjoinProposal, UncheckedProposal};
use payjoin::send::RequestContext;
use serde::{Deserialize, Serialize};
#[cfg(feature = "v2")]
use tokio::sync::Mutex as AsyncMutex;
#[cfg(feature = "v2")]
use tokio::task::spawn_blocking;

#[cfg(feature = "danger-local-https")]
Expand All @@ -33,17 +35,25 @@ const LOCAL_CERT_FILE: &str = "localhost.der";
#[derive(Clone)]
pub(crate) struct App {
config: AppConfig,
receive_store: Arc<Mutex<ReceiveStore>>,
send_store: Arc<Mutex<SendStore>>,
#[cfg(feature = "v2")]
receive_store: Arc<AsyncMutex<ReceiveStore>>,
#[cfg(feature = "v2")]
send_store: Arc<AsyncMutex<SendStore>>,
seen_inputs: Arc<Mutex<SeenInputs>>,
}

impl App {
pub fn new(config: AppConfig) -> Result<Self> {
let seen_inputs = Arc::new(Mutex::new(SeenInputs::new()?));
let receive_store = Arc::new(Mutex::new(ReceiveStore::new()?));
let send_store = Arc::new(Mutex::new(SendStore::new()?));
Ok(Self { config, receive_store, send_store, seen_inputs })
#[cfg(feature = "v2")]
let receive_store = Arc::new(AsyncMutex::new(ReceiveStore::new()?));
#[cfg(feature = "v2")]
let send_store = Arc::new(AsyncMutex::new(SendStore::new()?));
#[cfg(feature = "v2")]
let app = Self { config, receive_store, send_store, seen_inputs };
#[cfg(not(feature = "v2"))]
let app = Self { config, seen_inputs };
Ok(app)
}

pub fn bitcoind(&self) -> Result<bitcoincore_rpc::Client> {
Expand All @@ -65,7 +75,7 @@ impl App {

#[cfg(feature = "v2")]
pub async fn send_payjoin(&self, bip21: &str, fee_rate: &f32, is_retry: bool) -> Result<()> {
let mut session = self.send_store.lock().expect("mutex lock failed");
let mut session = self.send_store.lock().await;
let req_ctx = if is_retry {
log::debug!("Resuming session");
// Get a reference to RequestContext
Expand All @@ -79,6 +89,7 @@ impl App {
log::debug!("Awaiting response");
let res = self.long_poll_post(req_ctx).await?;
self.process_pj_response(res)?;
self.send_store.lock().await.clear()?;
Ok(())
}

Expand Down Expand Up @@ -132,13 +143,12 @@ impl App {
let enrolled = enroller
.process_res(ohttp_response.into_reader(), ctx)
.map_err(|_| anyhow!("Enrollment failed"))?;
self.receive_store.lock().expect("mutex lock failed").write(enrolled.clone())?;
self.receive_store.lock().await.write(enrolled.clone())?;
enrolled
} else {
let session = self.receive_store.lock().expect("mutex lock failed");
let session = self.receive_store.lock().await;
log::debug!("Resuming session");
let session = session.session.clone().unwrap();
session
session.session.clone().ok_or(anyhow!("No session found"))?
};

log::debug!("Enrolled receiver");
Expand Down Expand Up @@ -167,7 +177,7 @@ impl App {
let _ = res.into_reader().read_to_end(&mut buf)?;
let res = payjoin_proposal.deserialize_res(buf, ohttp_ctx);
log::debug!("Received response {:?}", res);
self.receive_store.lock().expect("mutex lock failed").clear()?;
self.receive_store.lock().await.clear()?;
Ok(())
}

Expand Down Expand Up @@ -233,7 +243,7 @@ impl App {
}
}

fn create_pj_request<'a>(&self, bip21: &'a str, fee_rate: &f32) -> Result<RequestContext> {
fn create_pj_request(&self, bip21: &str, fee_rate: &f32) -> Result<RequestContext> {
let uri = payjoin::Uri::try_from(bip21)
.map_err(|e| anyhow!("Failed to create URI from BIP21: {}", e))?;

Expand Down Expand Up @@ -300,7 +310,6 @@ impl App {
.bitcoind()?
.send_raw_transaction(&tx)
.with_context(|| "Failed to send raw transaction")?;
self.send_store.lock().expect("mutex lock failed").clear()?;
println!("Payjoin sent: {}", txid);
Ok(txid)
}
Expand Down Expand Up @@ -650,6 +659,7 @@ struct SendStore {
file: std::fs::File,
}

#[cfg(feature = "v2")]
impl SendStore {
fn new() -> Result<Self> {
let mut file =
Expand Down Expand Up @@ -703,7 +713,7 @@ impl ReceiveStore {
}
};

Ok(Self { session: session, file })
Ok(Self { session, file })
}

fn write(
Expand Down
4 changes: 2 additions & 2 deletions payjoin/src/input_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ impl serde::Serialize for InputType {
}
}

#[cfg(feature = "v2")]
impl<'de> serde::Deserialize<'de> for InputType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
Expand All @@ -46,8 +47,7 @@ impl<'de> serde::Deserialize<'de> for InputType {
use InputType::*;

let s = String::deserialize(deserializer)?;
if s.starts_with("SegWitV0: ") {
let rest = &s["SegWitV0: ".len()..];
if let Some(rest) = s.strip_prefix("SegWitV0: ") {
let parts: Vec<&str> = rest.split(", ").collect();
if parts.len() != 2 {
return Err(serde::de::Error::custom("invalid format for SegWitV0"));
Expand Down
18 changes: 11 additions & 7 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,12 @@ use bitcoin::psbt::Psbt;
use bitcoin::{FeeRate, Script, ScriptBuf, Sequence, TxOut, Weight};
pub use error::{CreateRequestError, ValidationError};
pub(crate) use error::{InternalCreateRequestError, InternalValidationError};
use serde::ser::SerializeStruct;
#[cfg(feature = "v2")]
use serde::{
de::{self, MapAccess, Visitor},
ser::SerializeStruct,
Deserialize, Deserializer, Serialize, Serializer,
};
use url::Url;

use crate::input_type::InputType;
Expand Down Expand Up @@ -310,6 +315,7 @@ impl<'a> RequestBuilder<'a> {
psbt.validate_input_utxos(true)
.map_err(InternalCreateRequestError::InvalidOriginalInput)?;
let endpoint = self.uri.extras._endpoint.clone();
#[cfg(feature = "v2")]
let ohttp_config = self.uri.extras.ohttp_config;
let disable_output_substitution =
self.uri.extras.disable_output_substitution || self.disable_output_substitution;
Expand Down Expand Up @@ -452,10 +458,10 @@ impl RequestContext {
}

#[cfg(feature = "v2")]
impl serde::Serialize for RequestContext {
impl Serialize for RequestContext {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
S: Serializer,
{
let mut state = serializer.serialize_struct("RequestContext", 8)?;
state.serialize_field("psbt", &self.psbt.to_string())?;
Expand All @@ -479,17 +485,15 @@ impl serde::Serialize for RequestContext {
}
}

use serde::de::{self, MapAccess, Visitor};
use serde::{Deserialize, Deserializer};

#[cfg(feature = "v2")]
impl<'de> Deserialize<'de> for RequestContext {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct RequestContextVisitor;

const FIELDS: &'static [&'static str] = &[
const FIELDS: &[&str] = &[
"psbt",
"endpoint",
"ohttp_config",
Expand Down

0 comments on commit 06b0243

Please sign in to comment.