Skip to content

Commit

Permalink
Impl traits for session persistence
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould committed Dec 11, 2023
1 parent 5dca484 commit e386f4f
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 2 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion payjoin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ exclude = ["tests"]
send = []
receive = ["rand"]
base64 = ["bitcoin/base64"]
v2 = ["bitcoin/rand-std", "chacha20poly1305", "ohttp", "bhttp"]
v2 = ["bitcoin/rand-std", "chacha20poly1305", "ohttp", "bhttp", "serde"]

[dependencies]
bitcoin = { version = "0.30.0", features = ["base64"] }
Expand All @@ -28,6 +28,7 @@ log = { version = "0.4.14"}
ohttp = { version = "0.4.0", optional = true }
bhttp = { version = "0.4.0", optional = true }
rand = { version = "0.8.4", optional = true }
serde = { version = "1.0.186", default-features = false, optional = true }
url = "2.2.2"

[dev-dependencies]
Expand Down
138 changes: 137 additions & 1 deletion payjoin/src/receive/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::collections::HashMap;

use bitcoin::psbt::Psbt;
use bitcoin::{base64, Amount, FeeRate, OutPoint, Script, TxOut};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize, Serializer};

use super::{Error, InternalRequestError, RequestError, SelectionError};
use crate::psbt::PsbtExt;
Expand Down Expand Up @@ -107,14 +109,146 @@ fn subdirectory(pubkey: &bitcoin::secp256k1::PublicKey) -> String {
pubkey_base64
}

#[derive(Debug)]
#[derive(Debug, Clone, PartialEq)]
pub struct Enrolled {
relay_url: url::Url,
ohttp_config: Vec<u8>,
ohttp_proxy: url::Url,
s: bitcoin::secp256k1::KeyPair,
}

impl Serialize for Enrolled {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("Enrolled", 4)?;
state.serialize_field("relay_url", &self.relay_url.to_string())?;
state.serialize_field("ohttp_config", &self.ohttp_config)?;
state.serialize_field("ohttp_proxy", &self.ohttp_proxy.to_string())?;
state.serialize_field("s", &self.s.secret_key().secret_bytes())?;

state.end()
}
}

use std::fmt;
use std::str::FromStr;

use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};

impl<'de> Deserialize<'de> for Enrolled {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
enum Field {
RelayUrl,
OhttpConfig,
OhttpProxy,
S,
}

impl<'de> Deserialize<'de> for Field {
fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
where
D: Deserializer<'de>,
{
struct FieldVisitor;

impl<'de> Visitor<'de> for FieldVisitor {
type Value = Field;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("`relay_url`, `ohttp_config`, `ohttp_proxy`, or `s`")
}

fn visit_str<E>(self, value: &str) -> Result<Field, E>
where
E: de::Error,
{
match value {
"relay_url" => Ok(Field::RelayUrl),
"ohttp_config" => Ok(Field::OhttpConfig),
"ohttp_proxy" => Ok(Field::OhttpProxy),
"s" => Ok(Field::S),
_ => Err(de::Error::unknown_field(value, FIELDS)),
}
}
}

deserializer.deserialize_identifier(FieldVisitor)
}
}

struct EnrolledVisitor;

impl<'de> Visitor<'de> for EnrolledVisitor {
type Value = Enrolled;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct Enrolled")
}

fn visit_map<V>(self, mut map: V) -> Result<Enrolled, V::Error>
where
V: MapAccess<'de>,
{
let mut relay_url = None;
let mut ohttp_config = None;
let mut ohttp_proxy = None;
let mut s = None;
while let Some(key) = map.next_key()? {
match key {
Field::RelayUrl => {
if relay_url.is_some() {
return Err(de::Error::duplicate_field("relay_url"));
}
let url_str: String = map.next_value()?;
relay_url = Some(url::Url::parse(&url_str).map_err(de::Error::custom)?);
}
Field::OhttpConfig => {
if ohttp_config.is_some() {
return Err(de::Error::duplicate_field("ohttp_config"));
}
ohttp_config = Some(map.next_value()?);
}
Field::OhttpProxy => {
if ohttp_proxy.is_some() {
return Err(de::Error::duplicate_field("ohttp_proxy"));
}
let proxy_str: String = map.next_value()?;
ohttp_proxy =
Some(url::Url::parse(&proxy_str).map_err(de::Error::custom)?);
}
Field::S => {
if s.is_some() {
return Err(de::Error::duplicate_field("s"));
}
let s_bytes: Vec<u8> = map.next_value()?;
let secp = bitcoin::secp256k1::Secp256k1::new();
s = Some(
bitcoin::secp256k1::KeyPair::from_seckey_slice(&secp, &s_bytes)
.map_err(de::Error::custom)?,
);
}
}
}
let relay_url = relay_url.ok_or_else(|| de::Error::missing_field("relay_url"))?;
let ohttp_config =
ohttp_config.ok_or_else(|| de::Error::missing_field("ohttp_config"))?;
let ohttp_proxy =
ohttp_proxy.ok_or_else(|| de::Error::missing_field("ohttp_proxy"))?;
let s = s.ok_or_else(|| de::Error::missing_field("s"))?;
Ok(Enrolled { relay_url, ohttp_config, ohttp_proxy, s })
}
}

const FIELDS: &'static [&'static str] = &["relay_url", "ohttp_config", "ohttp_proxy", "s"];
deserializer.deserialize_struct("Enrolled", FIELDS, EnrolledVisitor)
}
}

impl Enrolled {
pub fn extract_req(&self) -> Result<(Request, ohttp::ClientResponse), Error> {
let (body, ohttp_ctx) = self.fallback_req_body()?;
Expand Down Expand Up @@ -174,6 +308,8 @@ impl Enrolled {
crate::v2::ohttp_encapsulate(&self.ohttp_config, "GET", &self.fallback_target(), None)
}

pub fn pubkey(&self) -> [u8; 33] { self.s.public_key().serialize() }

pub fn fallback_target(&self) -> String {
let pubkey = &self.s.public_key().serialize();
let b64_config = base64::Config::new(base64::CharacterSet::UrlSafe, false);
Expand Down

0 comments on commit e386f4f

Please sign in to comment.