diff --git a/payjoin/src/hpke.rs b/payjoin/src/hpke.rs index d002125d..cdbbf097 100644 --- a/payjoin/src/hpke.rs +++ b/payjoin/src/hpke.rs @@ -129,27 +129,21 @@ impl<'de> serde::Deserialize<'de> for HpkePublicKey { #[cfg(feature = "send")] pub fn encrypt_message_a( body: Vec, - encapsulation_pair: &HpkeKeyPair, reply_pk: &HpkePublicKey, receiver_pk: &HpkePublicKey, ) -> Result, HpkeError> { let (encapsulated_key, mut encryption_context) = hpke::setup_sender::( - &OpModeS::Auth(( - encapsulation_pair.secret_key().0.clone(), - encapsulation_pair.public_key().0.clone(), - )), + &OpModeS::Base, &receiver_pk.0, INFO_A, &mut OsRng, )?; - let aad = encapsulation_pair.public_key().to_bytes().to_vec(); let mut plaintext = reply_pk.to_bytes().to_vec(); plaintext.extend(body); let plaintext = pad_plaintext(&mut plaintext, PADDED_PLAINTEXT_A_LENGTH)?; - let ciphertext = encryption_context.seal(plaintext, &aad)?; + let ciphertext = encryption_context.seal(plaintext, &[])?; let mut message_a = encapsulated_key.to_bytes().to_vec(); - message_a.extend(&aad); message_a.extend(&ciphertext); Ok(message_a.to_vec()) } @@ -167,21 +161,15 @@ pub fn decrypt_message_a( cursor.read_exact(&mut enc).map_err(|_| HpkeError::PayloadTooShort)?; let enc = EncappedKey::from_bytes(&enc)?; - let mut aad = [0u8; UNCOMPRESSED_PUBLIC_KEY_SIZE]; - cursor.read_exact(&mut aad).map_err(|_| HpkeError::PayloadTooShort)?; - let encapsulation_pk = PublicKey::from_bytes(&aad)?; - - let mut decryption_ctx = - hpke::setup_receiver::( - &OpModeR::Auth(encapsulation_pk.clone()), - &receiver_sk.0, - &enc, - INFO_A, - )?; + let mut decryption_ctx = hpke::setup_receiver::< + ChaCha20Poly1305, + HkdfSha256, + SecpK256HkdfSha256, + >(&OpModeR::Base, &receiver_sk.0, &enc, INFO_A)?; let mut ciphertext = Vec::new(); cursor.read_to_end(&mut ciphertext).map_err(|_| HpkeError::PayloadTooShort)?; - let plaintext = decryption_ctx.open(&ciphertext, &aad)?; + let plaintext = decryption_ctx.open(&ciphertext, &[])?; let reply_pk_bytes = &plaintext[..UNCOMPRESSED_PUBLIC_KEY_SIZE]; let reply_pk = HpkePublicKey(PublicKey::from_bytes(reply_pk_bytes)?); diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 6821d940..aa3e8b55 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -320,7 +320,6 @@ impl Sender { let hpke_ctx = HpkeContext::new(rs); let body = encrypt_message_a( body, - &hpke_ctx.encapsulation_pair.clone(), &hpke_ctx.reply_pair.public_key().clone(), &hpke_ctx.receiver.clone(), ) @@ -440,7 +439,6 @@ impl V2GetContext { url.set_path(&subdir); let body = encrypt_message_a( Vec::new(), - &self.hpke_ctx.encapsulation_pair.clone(), &self.hpke_ctx.reply_pair.public_key().clone(), &self.hpke_ctx.receiver.clone(), ) @@ -496,18 +494,13 @@ pub struct PsbtContext { #[cfg(feature = "v2")] struct HpkeContext { receiver: HpkePublicKey, - encapsulation_pair: HpkeKeyPair, reply_pair: HpkeKeyPair, } #[cfg(feature = "v2")] impl HpkeContext { pub fn new(receiver: HpkePublicKey) -> Self { - Self { - receiver, - encapsulation_pair: HpkeKeyPair::gen_keypair(), - reply_pair: HpkeKeyPair::gen_keypair(), - } + Self { receiver, reply_pair: HpkeKeyPair::gen_keypair() } } }