Skip to content

Commit

Permalink
Remove panics for cleaner error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
tanyav2 committed Dec 14, 2020
1 parent 503ae82 commit c970eaf
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 54 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ bincode2 = "2.0.1"
hkdf = "0.9"
hpke = "0.4.1"
rand = "0.7"
serde = { version = "1.0.115", features = [ "derive" ] }
serde_json = "1.0.57"
serde = { version = "1.0", features = [ "derive" ] }
serde_repr = "0.1"

[dev-dependencies]
hex = "0.4"
serde_json = "1.0"
tokio = { version = "0.2", features = [ "full" ] }
116 changes: 64 additions & 52 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use aes_gcm::aead::{generic_array::GenericArray, AeadInPlace, NewAead};
use aes_gcm::Aes128Gcm;
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context, Error, Result};
use bincode2::LengthOption;
use hkdf::Hkdf;
use hpke::{
Expand Down Expand Up @@ -94,12 +94,10 @@ pub struct ObliviousDoHConfigs {
#[doc(hidden)]
impl ObliviousDoHConfigs {
pub fn to_bytes(&self) -> Result<Vec<u8>> {
let mut serialized: Vec<u8> = self
.configs
.iter()
.map(|c| c.to_bytes().unwrap())
.flatten()
.collect();
let mut serialized = Vec::new();
for temp in self.configs.iter().map(|c| c.to_bytes()) {
serialized.extend(temp?);
}
let length = (serialized.len() as u16).to_be_bytes();
serialized.splice(0..0, length.iter().cloned());
Ok(serialized)
Expand Down Expand Up @@ -161,19 +159,21 @@ impl_custom_serde!(ObliviousDoHConfigContents);

impl ObliviousDoHConfigContents {
/// Creates a KeyID for an `ObliviousDoHConfigContents` struct
pub fn identifier(&self) -> Vec<u8> {
let serialized = self.to_bytes().unwrap();
pub fn identifier(&self) -> Result<Vec<u8>> {
let serialized = self
.to_bytes()
.context("could not serialize ObliviousDoHConfigContents")?;
Self::identifier_from_bytes(&serialized)
}

/// Creates a KeyID from a serialized `ObliviousDoHConfigContents`
/// Use this when you already have a serialized `ObliviousDoHConfigContents`
pub fn identifier_from_bytes(key: &[u8]) -> Vec<u8> {
pub fn identifier_from_bytes(key: &[u8]) -> Result<Vec<u8>> {
let key_id_info = LABEL_KEY_ID.to_vec();
let prk = Hkdf::<<Kdf as KdfTrait>::HashImpl>::new(None, key);
let mut key_id = [0; KDF_OUTPUT_SIZE];
prk.expand(&key_id_info, &mut key_id).unwrap();
key_id.to_vec()
prk.expand(&key_id_info, &mut key_id).map_err(Error::msg)?;
Ok(key_id.to_vec())
}

/// Asserts that the HPKE suite corresponds to the supported HPKE suite
Expand Down Expand Up @@ -227,18 +227,17 @@ impl ObliviousDoHMessage {
msg_type: ObliviousDoHMessageType,
key: Option<ObliviousDoHConfigContents>,
msg: Vec<u8>,
) -> Self {
let key_id;
if let Some(k) = key {
key_id = k.identifier();
) -> Result<Self> {
let key_id = if let Some(k) = key {
k.identifier()?
} else {
key_id = vec![];
}
Self {
vec![]
};
Ok(Self {
msg_type,
key_id,
encrypted_msg: msg,
}
})
}
}

Expand Down Expand Up @@ -306,36 +305,39 @@ impl ObliviousDoHMessagePlaintext for ObliviousDoHResponseBody {
fn derive_secrets(odoh_secret: &[u8], query: &ObliviousDoHQueryBody) -> Result<(Vec<u8>, Vec<u8>)> {
let key_info = LABEL_KEY.to_vec();
let nonce_info = LABEL_NONCE.to_vec();
let query_bytes = query.to_bytes().unwrap();
let query_bytes = query.to_bytes()?;

let h_key = Hkdf::<<Kdf as KdfTrait>::HashImpl>::new(Some(&query_bytes), &odoh_secret);
let mut key = vec![0; AEAD_KEY_SIZE];
h_key.expand(&key_info, &mut key).unwrap();
h_key.expand(&key_info, &mut key).map_err(Error::msg)?;

let h_nonce = Hkdf::<<Kdf as KdfTrait>::HashImpl>::new(Some(&query_bytes), &odoh_secret);
let mut nonce = vec![0; AEAD_NONCE_SIZE];
h_nonce.expand(&nonce_info, &mut nonce).unwrap();
h_nonce
.expand(&nonce_info, &mut nonce)
.map_err(Error::msg)?;

Ok((key, nonce))
}

fn build_query_aad(server_config: &ObliviousDoHConfigContents) -> Vec<u8> {
let key_id = server_config.identifier();
fn build_query_aad(server_config: &ObliviousDoHConfigContents) -> Result<Vec<u8>> {
let key_id = server_config.identifier()?;
let key_id_len = key_id.len();
let key_size_as_u16 = u16::try_from(key_id_len).unwrap().to_be_bytes();
let key_size_as_u16 = u16::try_from(key_id_len)?.to_be_bytes();

let mut aad = vec![ObliviousDoHMessageType::Query as u8];
aad.extend(&key_size_as_u16);
aad.extend(key_id);
aad
Ok(aad)
}

fn encrypt_query_helper(
server_config: &ObliviousDoHConfigContents,
query_body: &[u8],
) -> Result<(Vec<u8>, Vec<u8>)> {
let server_pk = <Kex as KeyExchange>::PublicKey::from_bytes(&server_config.public_key)
.expect("could not deserialize server public key");
.map_err(Error::msg)
.context("could not deserialize server public key")?;

let mut csprng = StdRng::from_entropy();

Expand All @@ -345,15 +347,19 @@ fn encrypt_query_helper(
LABEL_QUERY,
&mut csprng,
)
.expect("invalid server pubkey");
.map_err(Error::msg)
.context("invalid server pubkey")?;

let mut msg_copy = query_body.to_vec();
let query_aad = build_query_aad(server_config);
let query_aad = build_query_aad(server_config)?;
let tag = client_ctx
.seal(&mut msg_copy, &query_aad)
.expect("encryption failed");
.map_err(Error::msg)
.context("encryption failed")?;
let mut odoh_secret = [0; ODOH_SECRET_LEN];
client_ctx.export(LABEL_SECRET, &mut odoh_secret).unwrap();
client_ctx
.export(LABEL_SECRET, &mut odoh_secret)
.map_err(Error::msg)?;

let ciphertext = msg_copy.to_vec();
let result = [
Expand All @@ -372,18 +378,21 @@ async fn decrypt_query_helper(
server_config: &ObliviousDoHConfigContents,
query_ciphertext: Vec<u8>,
) -> Result<(Vec<u8>, Vec<u8>)> {
let aad = build_query_aad(server_config);
let aad = build_query_aad(server_config)?;
let (ciphertext, tag_bytes) = query_ciphertext.split_at(query_ciphertext.len() - AEAD_TAG_SIZE);
let mut ciphertext_copy = ciphertext.to_vec();

let tag = AeadTag::<Aead>::from_bytes(tag_bytes).unwrap();
let tag = AeadTag::<Aead>::from_bytes(tag_bytes).map_err(Error::msg)?;

server_ctx
.open(&mut ciphertext_copy, &aad, &tag)
.expect("invalid ciphertext");
.map_err(Error::msg)
.context("invalid ciphertext")?;

let mut odoh_secret = [0; ODOH_SECRET_LEN];
server_ctx.export(LABEL_SECRET, &mut odoh_secret).unwrap();
server_ctx
.export(LABEL_SECRET, &mut odoh_secret)
.map_err(Error::msg)?;

let plaintext = ciphertext_copy.to_vec();
Ok((plaintext, odoh_secret.to_vec()))
Expand All @@ -394,24 +403,25 @@ async fn decrypt_query_helper(
fn setup_query_context(
key_pair: &ObliviousDoHKeyPair,
encrypted_query_msg: Vec<u8>,
) -> (Vec<u8>, AeadCtxR<Aead, Kdf, Kem>) {
) -> Result<(Vec<u8>, AeadCtxR<Aead, Kdf, Kem>)> {
let server_sk = &key_pair.private_key;

let key_size = <Kex as KeyExchange>::PublicKey::size();
let (enc, ct) = encrypted_query_msg.split_at(key_size);

let encapped_key =
EncappedKey::<Kex>::from_bytes(enc).expect("could not deserialize the encapsulated pubkey");
let encapped_key = EncappedKey::<Kex>::from_bytes(enc)
.map_err(Error::msg)
.context("could not deserialize the encapsulated pubkey")?;

let recv_ctx = hpke::setup_receiver::<Aead, Kdf, Kem>(
&OpModeR::Base,
&server_sk,
&encapped_key,
LABEL_QUERY,
)
.expect("failed to setup receiver");
.map_err(Error::msg)?;

(ct.to_vec(), recv_ctx)
Ok((ct.to_vec(), recv_ctx))
}

/// Encrypts a message `msg` using the symmetric key derived from query
Expand All @@ -422,12 +432,12 @@ async fn encrypt_response_helper(
plaintext_resp_body: &[u8],
query: &ObliviousDoHQueryBody,
) -> Result<Vec<u8>> {
let (key, nonce) = derive_secrets(odoh_secret, query).unwrap();
let (key, nonce) = derive_secrets(odoh_secret, query)?;
let cipher = Aes128Gcm::new(GenericArray::from_slice(&key));
let mut data = plaintext_resp_body.to_owned();
cipher
.encrypt_in_place(GenericArray::from_slice(&nonce), RESPONSE_AAD, &mut data)
.unwrap();
.map_err(Error::msg)?;
Ok(data)
}

Expand All @@ -439,12 +449,12 @@ fn decrypt_response_helper(
encrypted_resp_body: &[u8],
query: &ObliviousDoHQueryBody,
) -> Result<Vec<u8>> {
let (key, nonce) = derive_secrets(odoh_secret, query).unwrap();
let (key, nonce) = derive_secrets(odoh_secret, query)?;
let cipher = Aes128Gcm::new(GenericArray::from_slice(&key));
let mut data = encrypted_resp_body.to_owned();
cipher
.decrypt_in_place(GenericArray::from_slice(&nonce), RESPONSE_AAD, &mut data)
.unwrap();
.map_err(Error::msg)?;
Ok(data)
}

Expand Down Expand Up @@ -473,11 +483,10 @@ pub fn create_query_msg(
Ok((
ObliviousDoHMessage {
msg_type: ObliviousDoHMessageType::Query,
key_id: server_config.identifier(),
key_id: server_config.identifier()?,
encrypted_msg,
}
.to_bytes()
.unwrap(),
.to_bytes()?,
client_secret,
))
}
Expand Down Expand Up @@ -524,15 +533,15 @@ pub async fn parse_received_query(
return Err(anyhow!("ObliviousDoHMessageType is wrong"));
}

let key_id = key_pair.public_key.identifier();
let key_id = key_pair.public_key.identifier()?;
let key_id_recv = de_query.key_id;

if !key_id.eq(&key_id_recv) {
return Err(anyhow!("KeyId of query differs from expected KeyID"));
}

let encrypted_query_msg = de_query.encrypted_msg;
let (ciphertext, mut server_ctx) = setup_query_context(key_pair, encrypted_query_msg);
let (ciphertext, mut server_ctx) = setup_query_context(key_pair, encrypted_query_msg)?;
let (decrypted_msg, server_secret) =
decrypt_query_helper(&mut server_ctx, &key_pair.public_key, ciphertext).await?;
let query = ObliviousDoHQueryBody::from_bytes(&decrypted_msg)?;
Expand Down Expand Up @@ -565,7 +574,7 @@ pub async fn create_response_msg(
.to_bytes()?;
let encrypted_resp = encrypt_response_helper(server_secret, &response_body, query).await?;

ObliviousDoHMessage::new(ObliviousDoHMessageType::Response, None, encrypted_resp).to_bytes()
ObliviousDoHMessage::new(ObliviousDoHMessageType::Response, None, encrypted_resp)?.to_bytes()
}

#[cfg(test)]
Expand Down Expand Up @@ -675,7 +684,10 @@ mod tests {
let odoh_public_key =
get_supported_config(&hex::decode(tv.odohconfigs).unwrap()).unwrap();

assert_eq!(odoh_public_key.identifier(), expected_public_key_id);
assert_eq!(
odoh_public_key.identifier().unwrap(),
expected_public_key_id
);

let key_pair = ObliviousDoHKeyPair {
private_key: secret_key,
Expand Down

0 comments on commit c970eaf

Please sign in to comment.