Skip to content

Commit

Permalink
refactor(dpapi): code refactoring after review: improve error handling
Browse files Browse the repository at this point in the history
and naming;
  • Loading branch information
TheBestTvarynka committed Jan 16, 2025
1 parent ae80799 commit c3cb4d1
Show file tree
Hide file tree
Showing 10 changed files with 305 additions and 182 deletions.
134 changes: 82 additions & 52 deletions crates/dpapi/src/blob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,40 @@ use picky_asn1_x509::enveloped_data::{
KeyEncryptionAlgorithmIdentifier, OtherKeyAttribute, ProtectionDescriptor, RecipientInfo, RecipientInfos,
};
use picky_asn1_x509::oids;
use thiserror::Error;
use uuid::Uuid;

use crate::rpc::{read_buf, read_to_end, read_vec, write_buf, Decode, Encode, EncodeExt};
use crate::sid_utils::{ace_to_bytes, sd_to_bytes};
use crate::utils::{encode_utf16_le, utf16_bytes_to_utf8_string};
use crate::sid::{ace_to_bytes, sd_to_bytes};
use crate::str::{encode_utf16_le, from_utf16_le};
use crate::{DpapiResult, Error};

#[derive(Debug, Error)]
pub enum BlobError {
#[error("unsupported protection descriptor: {0}")]
UnsupportedProtectionDescriptor(String),

#[error("invalid {name}: expected {expected} but got {actual}")]
InvalidOid {
name: &'static str,
expected: String,
actual: String,
},

#[error("invalid {name} version: expected {expected:?} but got {actual:?}")]
InvalidCmsVersion {
name: &'static str,
expected: CmsVersion,
actual: CmsVersion,
},

#[error("bad recipient infos amount: expected {expected} but got {actual}")]
RecipientInfosAmount { expected: usize, actual: usize },

#[error("missing {0} value")]
MissingValue(&'static str),
}

/// Key Identifier
///
/// This contains the key identifier info that can be used by MS-GKDI GetKey to retrieve the group key seed values.
Expand Down Expand Up @@ -89,11 +116,11 @@ impl Decode for KeyIdentifier {
read_buf(&mut reader, &mut magic)?;

if magic != Self::MAGIC {
return Err(Error::InvalidMagicBytes(
"KeyIdentifier",
Self::MAGIC.as_slice(),
magic.to_vec(),
));
return Err(Error::InvalidMagic {
name: "KeyIdentifier",
expected: Self::MAGIC.as_slice(),
actual: magic.to_vec(),
});
}

let flags = reader.read_u32::<LittleEndian>()?;
Expand All @@ -105,31 +132,33 @@ impl Decode for KeyIdentifier {

let key_info_len = reader.read_u32::<LittleEndian>()?;

let domain_len = reader.read_u32::<LittleEndian>()?;
let domain_len = reader.read_u32::<LittleEndian>()?.try_into()?;
if domain_len <= 2 {
return Err(Error::InvalidValue(
"KeyIdentifier domain name length",
format!("expected more than 2 bytes, but got {}", domain_len),
));
return Err(Error::InvalidLength {
name: "KeyIdentifier domain name",
expected: 2,
actual: domain_len,
});
}
let domain_len = domain_len - 2 /* UTF16 null terminator */;

let forest_len = reader.read_u32::<LittleEndian>()?;
let forest_len = reader.read_u32::<LittleEndian>()?.try_into()?;
if forest_len <= 2 {
return Err(Error::InvalidValue(
"KeyIdentifier forest name length",
format!("expected more than 2 bytes, but got {}", forest_len),
));
return Err(Error::InvalidLength {
name: "KeyIdentifier forest name",
expected: 2,
actual: forest_len,
});
}
let forest_len = forest_len - 2 /* UTF16 null terminator */;

let key_info = read_vec(key_info_len.try_into()?, &mut reader)?;

let domain_name = read_vec(domain_len.try_into()?, &mut reader)?;
let domain_name = read_vec(domain_len, &mut reader)?;
// Read UTF16 null terminator.
reader.read_u16::<LittleEndian>()?;

let forest_name = read_vec(forest_len.try_into()?, &mut reader)?;
let forest_name = read_vec(forest_len, &mut reader)?;
// Read UTF16 null terminator.
reader.read_u16::<LittleEndian>()?;

Expand All @@ -141,8 +170,8 @@ impl Decode for KeyIdentifier {
l2,
root_key_identifier,
key_info,
domain_name: utf16_bytes_to_utf8_string(&domain_name)?,
forest_name: utf16_bytes_to_utf8_string(&forest_name)?,
domain_name: from_utf16_le(&domain_name)?,
forest_name: from_utf16_le(&forest_name)?,
})
}
}
Expand Down Expand Up @@ -190,9 +219,9 @@ impl SidProtectionDescriptor {
let general_protection_descriptor: GeneralProtectionDescriptor = picky_asn1_der::from_bytes(data)?;

if general_protection_descriptor.descriptor_type.0 != oids::sid_protection_descriptor() {
return Err(Error::UnsupportedProtectionDescriptor(
Err(BlobError::UnsupportedProtectionDescriptor(
general_protection_descriptor.descriptor_type.0.into(),
));
))?;
}

let ProtectionDescriptor {
Expand All @@ -202,15 +231,15 @@ impl SidProtectionDescriptor {
.descriptors
.0
.first()
.ok_or_else(|| Error::InvalidProtectionDescriptor("missing ASN1 sequence".into()))?
.ok_or(BlobError::MissingValue("protection descriptor"))?
.0
.first()
.ok_or_else(|| Error::InvalidProtectionDescriptor("missing ASN1 sequence".into()))?;
.ok_or(BlobError::MissingValue("protection descriptor"))?;

if descriptor_type.0.as_utf8() != "SID" {
return Err(Error::UnsupportedProtectionDescriptor(
Err(BlobError::UnsupportedProtectionDescriptor(
descriptor_type.0.as_utf8().to_owned(),
));
))?;
}

Ok(Self {
Expand Down Expand Up @@ -293,38 +322,38 @@ impl Decode for DpapiBlob {
let expected_content_type: String = oids::enveloped_data().into();
let actual_content_type: String = content_info.content_type.0.into();

return Err(Error::InvalidValue(
"content info type",
format!("expected {expected_content_type} but got {actual_content_type}"),
));
Err(BlobError::InvalidOid {
name: "blob content type",
expected: expected_content_type,
actual: actual_content_type,
})?;
}

let enveloped_data: EnvelopedData = picky_asn1_der::from_bytes(&content_info.content.0 .0)?;

if enveloped_data.version != CmsVersion::V2 {
return Err(Error::InvalidValue(
"enveloped data CMS version",
format!("expected {:?} but got {:?}", CmsVersion::V2, enveloped_data.version,),
));
Err(BlobError::InvalidCmsVersion {
name: "enveloped data",
expected: CmsVersion::V2,
actual: enveloped_data.version,
})?;
}

if enveloped_data.recipient_infos.0.len() != 1 {
return Err(Error::InvalidValue(
"recipient infos",
format!(
"expected exactly 1 recipient info but got {}",
enveloped_data.recipient_infos.0.len(),
),
));
Err(BlobError::RecipientInfosAmount {
expected: 1,
actual: enveloped_data.recipient_infos.0.len(),
})?;
}

let RecipientInfo::Kek(kek_info) = enveloped_data.recipient_infos.0.first().unwrap();

if kek_info.version != CmsVersion::V4 {
return Err(Error::InvalidValue(
"KEK info CMS version",
format!("expected {:?} but got {:?}", CmsVersion::V4, enveloped_data.version,),
));
Err(BlobError::InvalidCmsVersion {
name: "KEK info",
expected: CmsVersion::V4,
actual: kek_info.version,
})?;
}

let key_identifier = KeyIdentifier::decode(&kek_info.kek_id.key_identifier.0 as &[u8])?;
Expand All @@ -335,19 +364,20 @@ impl Decode for DpapiBlob {
let expected_descriptor: String = oids::protection_descriptor_type().into();
let actual_descriptor: String = (&key_attr_id.0).into();

return Err(Error::InvalidValue(
"KEK recipient info OtherAttribute OID",
format!("expected {expected_descriptor} but got {actual_descriptor}"),
));
Err(BlobError::InvalidOid {
name: "KEK recipient info OtherAttribute OID",
expected: expected_descriptor,
actual: actual_descriptor,
})?;
}

if let Some(encoded_protection_descriptor) = key_attr {
SidProtectionDescriptor::decode_asn1(&encoded_protection_descriptor.0)?
} else {
return Err(Error::MissingValue("KEK recipient info OtherAttribute"));
Err(BlobError::MissingValue("KEK recipient info OtherAttribute"))?
}
} else {
return Err(Error::MissingValue("KEK recipient info protection descriptor"));
Err(BlobError::MissingValue("KEK recipient info protection descriptor"))?
};

let enc_content = if let Some(enc_content) = enveloped_data.encrypted_content_info.encrypted_content.0 {
Expand Down
87 changes: 34 additions & 53 deletions crates/dpapi/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,32 @@ use thiserror::Error;

#[derive(Debug, Error)]
pub enum Error {
#[error("invalid {name} magic bytes")]
InvalidMagic {
name: &'static str,
expected: &'static [u8],
actual: Vec<u8>,
},

#[error("invalid {name} length: expected at least {expected} bytes but got {actual}")]
InvalidLength {
name: &'static str,
expected: usize,
actual: usize,
},

#[error(transparent)]
Gkdi(#[from] crate::gkdi::GkdiError),

#[error(transparent)]
Blob(#[from] crate::blob::BlobError),

#[error(transparent)]
Rpc(#[from] crate::rpc::RpcError),

#[error(transparent)]
Sid(#[from] crate::sid::SidError),

#[error("IO error")]
Io(#[from] std::io::Error),

Expand All @@ -14,68 +40,23 @@ pub enum Error {
#[error("provided buf contains invalid UTF-8 data")]
Utf8(#[from] std::string::FromUtf8Error),

#[error("invalid context result code value: {0}")]
InvalidContextResultCode(u16),

#[error("invalid integer representation value: {0}")]
InvalidIntRepr(u8),

#[error("invalid character representation value: {0}")]
InvalidCharacterRepr(u8),

#[error("invalid floating point representation value: {0}")]
InvalidFloatingPointRepr(u8),

#[error("invalid packet type value: {0}")]
InvalidPacketType(u8),

#[error("invalid packet flags value: {0}")]
InvalidPacketFlags(u8),

#[error("invalid security provider value: {0}")]
InvalidSecurityProvider(u8),

#[error("invalid authentication level value: {0}")]
InvalidAuthenticationLevel(u8),

#[error("invalid fault flags value: {0}")]
InvalidFaultFlags(u8),

#[error("{0:?} PDU is not supported")]
PduNotSupported(crate::rpc::pdu::PacketType),

#[error("invalid fragment (PDU) length: {0}")]
InvalidFragLength(u16),

#[error("invalid {0} magic bytes")]
InvalidMagicBytes(&'static str, &'static [u8], Vec<u8>),

#[error("unsupported protection descriptor: {0}")]
UnsupportedProtectionDescriptor(String),

#[error("invalid protection descriptor: {0}")]
InvalidProtectionDescriptor(std::borrow::Cow<'static, str>),

#[error("invalid {0} value: {0}")]
InvalidValue(&'static str, String),

#[error("missing {0} value")]
MissingValue(&'static str),

#[error(transparent)]
ParseInt(#[from] std::num::ParseIntError),

#[error("this error should never occur: {0}")]
Infallible(#[from] std::convert::Infallible),

#[error(transparent)]
Asn1(#[from] picky_asn1_der::Asn1DerError),

#[error(transparent)]
CharSet(#[from] picky_asn1::restricted_string::CharSetError),

#[error(transparent)]
FromUtf16(#[from] std::string::FromUtf16Error),
#[error("{0}")]
FromUtf16(String),
}

impl From<std::string::FromUtf16Error> for Error {
fn from(err: std::string::FromUtf16Error) -> Self {
Self::FromUtf16(err.to_string())
}
}

pub type DpapiResult<T> = Result<T, Error>;
Loading

0 comments on commit c3cb4d1

Please sign in to comment.