diff --git a/Cargo.lock b/Cargo.lock index 317658ca..049a99ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1716,6 +1716,7 @@ dependencies = [ "ohttp", "rand", "rustls", + "serde", "testcontainers", "testcontainers-modules", "tokio", diff --git a/payjoin/Cargo.toml b/payjoin/Cargo.toml index a6e36f6c..9e73fa22 100644 --- a/payjoin/Cargo.toml +++ b/payjoin/Cargo.toml @@ -18,7 +18,7 @@ exclude = ["tests"] send = [] receive = ["rand"] base64 = ["bitcoin/base64"] -v2 = ["bitcoin/rand-std", "chacha20poly1305", "ohttp", "bhttp"] +v2 = ["bitcoin/rand-std", "chacha20poly1305", "ohttp", "bhttp", "serde"] [dependencies] bitcoin = { version = "0.30.0", features = ["base64"] } @@ -28,6 +28,7 @@ log = { version = "0.4.14"} ohttp = { version = "0.4.0", optional = true } bhttp = { version = "0.4.0", optional = true } rand = { version = "0.8.4", optional = true } +serde = { version = "1.0.186", default-features = false, optional = true } url = "2.2.2" [dev-dependencies] diff --git a/payjoin/src/receive/v2.rs b/payjoin/src/receive/v2.rs index 204358eb..eccab925 100644 --- a/payjoin/src/receive/v2.rs +++ b/payjoin/src/receive/v2.rs @@ -2,6 +2,8 @@ use std::collections::HashMap; use bitcoin::psbt::Psbt; use bitcoin::{base64, Amount, FeeRate, OutPoint, Script, TxOut}; +use serde::ser::SerializeStruct; +use serde::{Deserialize, Serialize, Serializer}; use super::{Error, InternalRequestError, RequestError, SelectionError}; use crate::psbt::PsbtExt; @@ -107,7 +109,7 @@ fn subdirectory(pubkey: &bitcoin::secp256k1::PublicKey) -> String { pubkey_base64 } -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq)] pub struct Enrolled { relay_url: url::Url, ohttp_config: Vec, @@ -115,6 +117,138 @@ pub struct Enrolled { s: bitcoin::secp256k1::KeyPair, } +impl Serialize for Enrolled { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut state = serializer.serialize_struct("Enrolled", 4)?; + state.serialize_field("relay_url", &self.relay_url.to_string())?; + state.serialize_field("ohttp_config", &self.ohttp_config)?; + state.serialize_field("ohttp_proxy", &self.ohttp_proxy.to_string())?; + state.serialize_field("s", &self.s.secret_key().secret_bytes())?; + + state.end() + } +} + +use std::fmt; +use std::str::FromStr; + +use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor}; + +impl<'de> Deserialize<'de> for Enrolled { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + enum Field { + RelayUrl, + OhttpConfig, + OhttpProxy, + S, + } + + impl<'de> Deserialize<'de> for Field { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct FieldVisitor; + + impl<'de> Visitor<'de> for FieldVisitor { + type Value = Field; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("`relay_url`, `ohttp_config`, `ohttp_proxy`, or `s`") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match value { + "relay_url" => Ok(Field::RelayUrl), + "ohttp_config" => Ok(Field::OhttpConfig), + "ohttp_proxy" => Ok(Field::OhttpProxy), + "s" => Ok(Field::S), + _ => Err(de::Error::unknown_field(value, FIELDS)), + } + } + } + + deserializer.deserialize_identifier(FieldVisitor) + } + } + + struct EnrolledVisitor; + + impl<'de> Visitor<'de> for EnrolledVisitor { + type Value = Enrolled; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct Enrolled") + } + + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + let mut relay_url = None; + let mut ohttp_config = None; + let mut ohttp_proxy = None; + let mut s = None; + while let Some(key) = map.next_key()? { + match key { + Field::RelayUrl => { + if relay_url.is_some() { + return Err(de::Error::duplicate_field("relay_url")); + } + let url_str: String = map.next_value()?; + relay_url = Some(url::Url::parse(&url_str).map_err(de::Error::custom)?); + } + Field::OhttpConfig => { + if ohttp_config.is_some() { + return Err(de::Error::duplicate_field("ohttp_config")); + } + ohttp_config = Some(map.next_value()?); + } + Field::OhttpProxy => { + if ohttp_proxy.is_some() { + return Err(de::Error::duplicate_field("ohttp_proxy")); + } + let proxy_str: String = map.next_value()?; + ohttp_proxy = + Some(url::Url::parse(&proxy_str).map_err(de::Error::custom)?); + } + Field::S => { + if s.is_some() { + return Err(de::Error::duplicate_field("s")); + } + let s_bytes: Vec = map.next_value()?; + let secp = bitcoin::secp256k1::Secp256k1::new(); + s = Some( + bitcoin::secp256k1::KeyPair::from_seckey_slice(&secp, &s_bytes) + .map_err(de::Error::custom)?, + ); + } + } + } + let relay_url = relay_url.ok_or_else(|| de::Error::missing_field("relay_url"))?; + let ohttp_config = + ohttp_config.ok_or_else(|| de::Error::missing_field("ohttp_config"))?; + let ohttp_proxy = + ohttp_proxy.ok_or_else(|| de::Error::missing_field("ohttp_proxy"))?; + let s = s.ok_or_else(|| de::Error::missing_field("s"))?; + Ok(Enrolled { relay_url, ohttp_config, ohttp_proxy, s }) + } + } + + const FIELDS: &'static [&'static str] = &["relay_url", "ohttp_config", "ohttp_proxy", "s"]; + deserializer.deserialize_struct("Enrolled", FIELDS, EnrolledVisitor) + } +} + impl Enrolled { pub fn extract_req(&self) -> Result<(Request, ohttp::ClientResponse), Error> { let (body, ohttp_ctx) = self.fallback_req_body()?; @@ -174,6 +308,8 @@ impl Enrolled { crate::v2::ohttp_encapsulate(&self.ohttp_config, "GET", &self.fallback_target(), None) } + pub fn pubkey(&self) -> [u8; 33] { self.s.public_key().serialize() } + pub fn fallback_target(&self) -> String { let pubkey = &self.s.public_key().serialize(); let b64_config = base64::Config::new(base64::CharacterSet::UrlSafe, false);