From 28cef51ed8d8a5f3538f31441f51fbf8fccd09b2 Mon Sep 17 00:00:00 2001 From: DanGould Date: Thu, 30 Nov 2023 15:09:38 -0500 Subject: [PATCH] Impl traits for session persistence --- Cargo.lock | 1 + payjoin/Cargo.toml | 3 +- payjoin/src/receive/v2.rs | 129 +++++++++++++++++++++++++++++++++++++- 3 files changed, 131 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2ef922f7..b706121c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1811,6 +1811,7 @@ dependencies = [ "log", "ohttp", "rand", + "serde", "url", ] diff --git a/payjoin/Cargo.toml b/payjoin/Cargo.toml index 3757b2a7..f703eca3 100644 --- a/payjoin/Cargo.toml +++ b/payjoin/Cargo.toml @@ -18,7 +18,7 @@ exclude = ["tests"] send = [] receive = ["rand"] base64 = ["bitcoin/base64"] -v2 = ["bitcoin/rand-std", "chacha20poly1305", "ohttp", "bhttp"] +v2 = ["bitcoin/rand-std", "chacha20poly1305", "ohttp", "bhttp", "serde"] [dependencies] bitcoin = { version = "0.30.0", features = ["base64"] } @@ -28,6 +28,7 @@ log = { version = "0.4.14"} ohttp = { version = "0.4.0", optional = true } bhttp = { version = "0.4.0", optional = true } rand = { version = "0.8.4", optional = true } +serde = { version = "1.0.186", default-features = false, optional = true } url = "2.2.2" [dev-dependencies] diff --git a/payjoin/src/receive/v2.rs b/payjoin/src/receive/v2.rs index 11a38a82..6c536000 100644 --- a/payjoin/src/receive/v2.rs +++ b/payjoin/src/receive/v2.rs @@ -2,6 +2,8 @@ use std::collections::HashMap; use bitcoin::psbt::Psbt; use bitcoin::{base64, Amount, FeeRate, OutPoint, Script, TxOut}; +use serde::ser::SerializeStruct; +use serde::{Serialize, Deserialize, Serializer}; use super::{Error, InternalRequestError, RequestError, SelectionError}; use crate::psbt::PsbtExt; @@ -107,7 +109,7 @@ fn subdirectory(pubkey: &bitcoin::secp256k1::PublicKey) -> String { pubkey_base64 } -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq)] pub struct Enrolled { relay_url: url::Url, ohttp_config: Vec, @@ -115,6 +117,127 @@ pub struct Enrolled { s: bitcoin::secp256k1::KeyPair, } +impl Serialize for Enrolled { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut state = serializer.serialize_struct("Enrolled", 4)?; + state.serialize_field("relay_url", &self.relay_url.to_string())?; + state.serialize_field("ohttp_config", &self.ohttp_config)?; + state.serialize_field("ohttp_proxy", &self.ohttp_proxy.to_string())?; + state.serialize_field("s", &self.s.secret_key().secret_bytes())?; + + state.end() + } +} + +use serde::de::{self, Deserializer, Visitor, SeqAccess, MapAccess}; +use std::fmt; +use std::str::FromStr; + +impl<'de> Deserialize<'de> for Enrolled { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + enum Field { RelayUrl, OhttpConfig, OhttpProxy, S } + + impl<'de> Deserialize<'de> for Field { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct FieldVisitor; + + impl<'de> Visitor<'de> for FieldVisitor { + type Value = Field; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("`relay_url`, `ohttp_config`, `ohttp_proxy`, or `s`") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match value { + "relay_url" => Ok(Field::RelayUrl), + "ohttp_config" => Ok(Field::OhttpConfig), + "ohttp_proxy" => Ok(Field::OhttpProxy), + "s" => Ok(Field::S), + _ => Err(de::Error::unknown_field(value, FIELDS)), + } + } + } + + deserializer.deserialize_identifier(FieldVisitor) + } + } + + struct EnrolledVisitor; + + impl<'de> Visitor<'de> for EnrolledVisitor { + type Value = Enrolled; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct Enrolled") + } + + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + let mut relay_url = None; + let mut ohttp_config = None; + let mut ohttp_proxy = None; + let mut s = None; + while let Some(key) = map.next_key()? { + match key { + Field::RelayUrl => { + if relay_url.is_some() { + return Err(de::Error::duplicate_field("relay_url")); + } + let url_str: String = map.next_value()?; + relay_url = Some(url::Url::parse(&url_str).map_err(de::Error::custom)?); + }, + Field::OhttpConfig => { + if ohttp_config.is_some() { + return Err(de::Error::duplicate_field("ohttp_config")); + } + ohttp_config = Some(map.next_value()?); + }, + Field::OhttpProxy => { + if ohttp_proxy.is_some() { + return Err(de::Error::duplicate_field("ohttp_proxy")); + } + let proxy_str: String = map.next_value()?; + ohttp_proxy = Some(url::Url::parse(&proxy_str).map_err(de::Error::custom)?); + }, + Field::S => { + if s.is_some() { + return Err(de::Error::duplicate_field("s")); + } + let s_bytes: Vec = map.next_value()?; + let secp = bitcoin::secp256k1::Secp256k1::new(); + s = Some(bitcoin::secp256k1::KeyPair::from_seckey_slice(&secp, &s_bytes) + .map_err(de::Error::custom)?); + } + } + } + let relay_url = relay_url.ok_or_else(|| de::Error::missing_field("relay_url"))?; + let ohttp_config = ohttp_config.ok_or_else(|| de::Error::missing_field("ohttp_config"))?; + let ohttp_proxy = ohttp_proxy.ok_or_else(|| de::Error::missing_field("ohttp_proxy"))?; + let s = s.ok_or_else(|| de::Error::missing_field("s"))?; + Ok(Enrolled { relay_url, ohttp_config, ohttp_proxy, s }) + } + } + + const FIELDS: &'static [&'static str] = &["relay_url", "ohttp_config", "ohttp_proxy", "s"]; + deserializer.deserialize_struct("Enrolled", FIELDS, EnrolledVisitor) + } +} + impl Enrolled { pub fn extract_req(&self) -> Result<(Request, ohttp::ClientResponse), Error> { let (body, ohttp_ctx) = self.fallback_req_body()?; @@ -174,6 +297,10 @@ impl Enrolled { crate::v2::ohttp_encapsulate(&self.ohttp_config, "GET", &self.fallback_target(), None) } + pub fn pubkey(&self) -> [u8; 33] { + self.s.public_key().serialize() + } + pub fn fallback_target(&self) -> String { let pubkey = &self.s.public_key().serialize(); let b64_config = base64::Config::new(base64::CharacterSet::UrlSafe, false);