diff --git a/crates/dpapi/src/blob.rs b/crates/dpapi/src/blob.rs index cf760cf2..c3106aeb 100644 --- a/crates/dpapi/src/blob.rs +++ b/crates/dpapi/src/blob.rs @@ -14,13 +14,40 @@ use picky_asn1_x509::enveloped_data::{ KeyEncryptionAlgorithmIdentifier, OtherKeyAttribute, ProtectionDescriptor, RecipientInfo, RecipientInfos, }; use picky_asn1_x509::oids; +use thiserror::Error; use uuid::Uuid; use crate::rpc::{read_buf, read_to_end, read_vec, write_buf, Decode, Encode, EncodeExt}; -use crate::sid_utils::{ace_to_bytes, sd_to_bytes}; -use crate::utils::{encode_utf16_le, utf16_bytes_to_utf8_string}; +use crate::sid::{ace_to_bytes, sd_to_bytes}; +use crate::str::{encode_utf16_le, from_utf16_le}; use crate::{DpapiResult, Error}; +#[derive(Debug, Error)] +pub enum BlobError { + #[error("unsupported protection descriptor: {0}")] + UnsupportedProtectionDescriptor(String), + + #[error("invalid {name}: expected {expected} but got {actual}")] + InvalidOid { + name: &'static str, + expected: String, + actual: String, + }, + + #[error("invalid {name} version: expected {expected:?} but got {actual:?}")] + InvalidCmsVersion { + name: &'static str, + expected: CmsVersion, + actual: CmsVersion, + }, + + #[error("bad recipient infos amount: expected {expected} but got {actual}")] + RecipientInfosAmount { expected: usize, actual: usize }, + + #[error("missing {0} value")] + MissingValue(&'static str), +} + /// Key Identifier /// /// This contains the key identifier info that can be used by MS-GKDI GetKey to retrieve the group key seed values. @@ -89,11 +116,11 @@ impl Decode for KeyIdentifier { read_buf(&mut reader, &mut magic)?; if magic != Self::MAGIC { - return Err(Error::InvalidMagicBytes( - "KeyIdentifier", - Self::MAGIC.as_slice(), - magic.to_vec(), - )); + return Err(Error::InvalidMagic { + name: "KeyIdentifier", + expected: Self::MAGIC.as_slice(), + actual: magic.to_vec(), + }); } let flags = reader.read_u32::()?; @@ -105,31 +132,33 @@ impl Decode for KeyIdentifier { let key_info_len = reader.read_u32::()?; - let domain_len = reader.read_u32::()?; + let domain_len = reader.read_u32::()?.try_into()?; if domain_len <= 2 { - return Err(Error::InvalidValue( - "KeyIdentifier domain name length", - format!("expected more than 2 bytes, but got {}", domain_len), - )); + return Err(Error::InvalidLength { + name: "KeyIdentifier domain name", + expected: 2, + actual: domain_len, + }); } let domain_len = domain_len - 2 /* UTF16 null terminator */; - let forest_len = reader.read_u32::()?; + let forest_len = reader.read_u32::()?.try_into()?; if forest_len <= 2 { - return Err(Error::InvalidValue( - "KeyIdentifier forest name length", - format!("expected more than 2 bytes, but got {}", forest_len), - )); + return Err(Error::InvalidLength { + name: "KeyIdentifier forest name", + expected: 2, + actual: forest_len, + }); } let forest_len = forest_len - 2 /* UTF16 null terminator */; let key_info = read_vec(key_info_len.try_into()?, &mut reader)?; - let domain_name = read_vec(domain_len.try_into()?, &mut reader)?; + let domain_name = read_vec(domain_len, &mut reader)?; // Read UTF16 null terminator. reader.read_u16::()?; - let forest_name = read_vec(forest_len.try_into()?, &mut reader)?; + let forest_name = read_vec(forest_len, &mut reader)?; // Read UTF16 null terminator. reader.read_u16::()?; @@ -141,8 +170,8 @@ impl Decode for KeyIdentifier { l2, root_key_identifier, key_info, - domain_name: utf16_bytes_to_utf8_string(&domain_name)?, - forest_name: utf16_bytes_to_utf8_string(&forest_name)?, + domain_name: from_utf16_le(&domain_name)?, + forest_name: from_utf16_le(&forest_name)?, }) } } @@ -190,9 +219,9 @@ impl SidProtectionDescriptor { let general_protection_descriptor: GeneralProtectionDescriptor = picky_asn1_der::from_bytes(data)?; if general_protection_descriptor.descriptor_type.0 != oids::sid_protection_descriptor() { - return Err(Error::UnsupportedProtectionDescriptor( + Err(BlobError::UnsupportedProtectionDescriptor( general_protection_descriptor.descriptor_type.0.into(), - )); + ))?; } let ProtectionDescriptor { @@ -202,15 +231,15 @@ impl SidProtectionDescriptor { .descriptors .0 .first() - .ok_or_else(|| Error::InvalidProtectionDescriptor("missing ASN1 sequence".into()))? + .ok_or(BlobError::MissingValue("protection descriptor"))? .0 .first() - .ok_or_else(|| Error::InvalidProtectionDescriptor("missing ASN1 sequence".into()))?; + .ok_or(BlobError::MissingValue("protection descriptor"))?; if descriptor_type.0.as_utf8() != "SID" { - return Err(Error::UnsupportedProtectionDescriptor( + Err(BlobError::UnsupportedProtectionDescriptor( descriptor_type.0.as_utf8().to_owned(), - )); + ))?; } Ok(Self { @@ -293,38 +322,38 @@ impl Decode for DpapiBlob { let expected_content_type: String = oids::enveloped_data().into(); let actual_content_type: String = content_info.content_type.0.into(); - return Err(Error::InvalidValue( - "content info type", - format!("expected {expected_content_type} but got {actual_content_type}"), - )); + Err(BlobError::InvalidOid { + name: "blob content type", + expected: expected_content_type, + actual: actual_content_type, + })?; } let enveloped_data: EnvelopedData = picky_asn1_der::from_bytes(&content_info.content.0 .0)?; if enveloped_data.version != CmsVersion::V2 { - return Err(Error::InvalidValue( - "enveloped data CMS version", - format!("expected {:?} but got {:?}", CmsVersion::V2, enveloped_data.version,), - )); + Err(BlobError::InvalidCmsVersion { + name: "enveloped data", + expected: CmsVersion::V2, + actual: enveloped_data.version, + })?; } if enveloped_data.recipient_infos.0.len() != 1 { - return Err(Error::InvalidValue( - "recipient infos", - format!( - "expected exactly 1 recipient info but got {}", - enveloped_data.recipient_infos.0.len(), - ), - )); + Err(BlobError::RecipientInfosAmount { + expected: 1, + actual: enveloped_data.recipient_infos.0.len(), + })?; } let RecipientInfo::Kek(kek_info) = enveloped_data.recipient_infos.0.first().unwrap(); if kek_info.version != CmsVersion::V4 { - return Err(Error::InvalidValue( - "KEK info CMS version", - format!("expected {:?} but got {:?}", CmsVersion::V4, enveloped_data.version,), - )); + Err(BlobError::InvalidCmsVersion { + name: "KEK info", + expected: CmsVersion::V4, + actual: kek_info.version, + })?; } let key_identifier = KeyIdentifier::decode(&kek_info.kek_id.key_identifier.0 as &[u8])?; @@ -335,19 +364,20 @@ impl Decode for DpapiBlob { let expected_descriptor: String = oids::protection_descriptor_type().into(); let actual_descriptor: String = (&key_attr_id.0).into(); - return Err(Error::InvalidValue( - "KEK recipient info OtherAttribute OID", - format!("expected {expected_descriptor} but got {actual_descriptor}"), - )); + Err(BlobError::InvalidOid { + name: "KEK recipient info OtherAttribute OID", + expected: expected_descriptor, + actual: actual_descriptor, + })?; } if let Some(encoded_protection_descriptor) = key_attr { SidProtectionDescriptor::decode_asn1(&encoded_protection_descriptor.0)? } else { - return Err(Error::MissingValue("KEK recipient info OtherAttribute")); + Err(BlobError::MissingValue("KEK recipient info OtherAttribute"))? } } else { - return Err(Error::MissingValue("KEK recipient info protection descriptor")); + Err(BlobError::MissingValue("KEK recipient info protection descriptor"))? }; let enc_content = if let Some(enc_content) = enveloped_data.encrypted_content_info.encrypted_content.0 { diff --git a/crates/dpapi/src/error.rs b/crates/dpapi/src/error.rs index aec0e8b9..72d6193c 100644 --- a/crates/dpapi/src/error.rs +++ b/crates/dpapi/src/error.rs @@ -2,6 +2,32 @@ use thiserror::Error; #[derive(Debug, Error)] pub enum Error { + #[error("invalid {name} magic bytes")] + InvalidMagic { + name: &'static str, + expected: &'static [u8], + actual: Vec, + }, + + #[error("invalid {name} length: expected at least {expected} bytes but got {actual}")] + InvalidLength { + name: &'static str, + expected: usize, + actual: usize, + }, + + #[error(transparent)] + Gkdi(#[from] crate::gkdi::GkdiError), + + #[error(transparent)] + Blob(#[from] crate::blob::BlobError), + + #[error(transparent)] + Rpc(#[from] crate::rpc::RpcError), + + #[error(transparent)] + Sid(#[from] crate::sid::SidError), + #[error("IO error")] Io(#[from] std::io::Error), @@ -14,68 +40,23 @@ pub enum Error { #[error("provided buf contains invalid UTF-8 data")] Utf8(#[from] std::string::FromUtf8Error), - #[error("invalid context result code value: {0}")] - InvalidContextResultCode(u16), - - #[error("invalid integer representation value: {0}")] - InvalidIntRepr(u8), - - #[error("invalid character representation value: {0}")] - InvalidCharacterRepr(u8), - - #[error("invalid floating point representation value: {0}")] - InvalidFloatingPointRepr(u8), - - #[error("invalid packet type value: {0}")] - InvalidPacketType(u8), - - #[error("invalid packet flags value: {0}")] - InvalidPacketFlags(u8), - - #[error("invalid security provider value: {0}")] - InvalidSecurityProvider(u8), - - #[error("invalid authentication level value: {0}")] - InvalidAuthenticationLevel(u8), - - #[error("invalid fault flags value: {0}")] - InvalidFaultFlags(u8), - - #[error("{0:?} PDU is not supported")] - PduNotSupported(crate::rpc::pdu::PacketType), - - #[error("invalid fragment (PDU) length: {0}")] - InvalidFragLength(u16), - - #[error("invalid {0} magic bytes")] - InvalidMagicBytes(&'static str, &'static [u8], Vec), - - #[error("unsupported protection descriptor: {0}")] - UnsupportedProtectionDescriptor(String), - - #[error("invalid protection descriptor: {0}")] - InvalidProtectionDescriptor(std::borrow::Cow<'static, str>), - - #[error("invalid {0} value: {0}")] - InvalidValue(&'static str, String), - - #[error("missing {0} value")] - MissingValue(&'static str), - #[error(transparent)] ParseInt(#[from] std::num::ParseIntError), - #[error("this error should never occur: {0}")] - Infallible(#[from] std::convert::Infallible), - #[error(transparent)] Asn1(#[from] picky_asn1_der::Asn1DerError), #[error(transparent)] CharSet(#[from] picky_asn1::restricted_string::CharSetError), - #[error(transparent)] - FromUtf16(#[from] std::string::FromUtf16Error), + #[error("{0}")] + FromUtf16(String), +} + +impl From for Error { + fn from(err: std::string::FromUtf16Error) -> Self { + Self::FromUtf16(err.to_string()) + } } pub type DpapiResult = Result; diff --git a/crates/dpapi/src/gkdi.rs b/crates/dpapi/src/gkdi.rs index 60cb5191..6ed120e9 100644 --- a/crates/dpapi/src/gkdi.rs +++ b/crates/dpapi/src/gkdi.rs @@ -3,12 +3,29 @@ use std::io::{Read, Write}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use num_bigint_dig::BigUint; +use thiserror::Error; use uuid::Uuid; use crate::rpc::{read_buf, read_c_str_utf16_le, read_padding, read_vec, write_buf, write_padding, Decode, Encode}; -use crate::utils::{encode_utf16_le, utf16_bytes_to_utf8_string}; +use crate::str::{encode_utf16_le, from_utf16_le}; use crate::{DpapiResult, Error}; +#[derive(Debug, Error)] +pub enum GkdiError { + #[error("invalid hash algorithm name: {0}")] + InvalidHashName(String), + + #[error("invalid {name} version: expected {expected} but got {actual}")] + InvalidVersion { + name: &'static str, + expected: u32, + actual: u32, + }, + + #[error("invalid elliptic curve id")] + InvalidEllipticCurveId(Vec), +} + /// GetKey RPC Request /// /// This can be used to build the stub data for the GetKey RPC request. @@ -110,7 +127,7 @@ impl fmt::Display for HashAlg { } impl TryFrom<&str> for HashAlg { - type Error = Error; + type Error = GkdiError; fn try_from(data: &str) -> Result { match data { @@ -118,7 +135,7 @@ impl TryFrom<&str> for HashAlg { "SHA256" => Ok(HashAlg::Sha256), "SHA384" => Ok(HashAlg::Sha384), "SHA512" => Ok(HashAlg::Sha512), - _ => Err(Error::InvalidValue("hash algorithm", data.to_owned())), + _ => Err(GkdiError::InvalidHashName(data.to_owned())), } } } @@ -157,11 +174,11 @@ impl Decode for KdfParameters { read_buf(&mut reader, &mut magic_identifier_1)?; if magic_identifier_1 != Self::MAGIC_IDENTIFIER_1 { - return Err(Error::InvalidMagicBytes( - "KdfParameters::MAGIC_IDENTIFIER_1", - Self::MAGIC_IDENTIFIER_1, - magic_identifier_1.to_vec(), - )); + return Err(Error::InvalidMagic { + name: "KdfParameters::MAGIC_IDENTIFIER_1", + expected: Self::MAGIC_IDENTIFIER_1, + actual: magic_identifier_1.to_vec(), + }); } let hash_name_len: usize = reader.read_u32::()?.try_into()?; @@ -170,19 +187,20 @@ impl Decode for KdfParameters { read_buf(&mut reader, &mut magic_identifier_2)?; if magic_identifier_2 != Self::MAGIC_IDENTIFIER_2 { - return Err(Error::InvalidMagicBytes( - "KdfParameters::MAGIC_IDENTIFIER_1", - Self::MAGIC_IDENTIFIER_2, - magic_identifier_2.to_vec(), - )); + return Err(Error::InvalidMagic { + name: "KdfParameters::MAGIC_IDENTIFIER_1", + expected: Self::MAGIC_IDENTIFIER_2, + actual: magic_identifier_2.to_vec(), + }); } // The smallest possible hash algorithm name is "SHA1\0", 10 bytes long in UTF-16 encoding. if hash_name_len < 10 { - return Err(Error::InvalidValue( - "KdfParameters hash algorithm name length", - format!("expected at least 10 bytes but got {:?}", hash_name_len), - )); + Err(Error::InvalidLength { + name: "KdfParameters hash id", + expected: 10, + actual: hash_name_len, + })?; } let buf = read_vec(hash_name_len - 2 /* UTF-16 null terminator char */, &mut reader)?; @@ -190,7 +208,7 @@ impl Decode for KdfParameters { reader.read_u16::()?; Ok(Self { - hash_alg: utf16_bytes_to_utf8_string(&buf)?.as_str().try_into()?, + hash_alg: from_utf16_le(&buf)?.as_str().try_into()?, }) } } @@ -254,7 +272,11 @@ impl Decode for FfcdhParameters { read_buf(&mut reader, &mut magic)?; if magic != Self::MAGIC { - return Err(Error::InvalidMagicBytes("FfcdhParameters", Self::MAGIC, magic.to_vec())); + return Err(Error::InvalidMagic { + name: "FfcdhParameters", + expected: Self::MAGIC, + actual: magic.to_vec(), + }); } let key_length = reader.read_u32::()?; @@ -329,7 +351,11 @@ impl Decode for FfcdhKey { read_buf(&mut reader, &mut magic)?; if magic != FfcdhKey::MAGIC { - return Err(Error::InvalidMagicBytes("FfcdhKey", Self::MAGIC, magic.to_vec())); + return Err(Error::InvalidMagic { + name: "FfcdhKey", + expected: Self::MAGIC, + actual: magic.to_vec(), + }); } let key_length = reader.read_u32::()?; @@ -365,14 +391,14 @@ impl From for &[u8] { } impl TryFrom<&[u8]> for EllipticCurve { - type Error = Error; + type Error = GkdiError; fn try_from(value: &[u8]) -> Result { match value { b"ECK1" => Ok(EllipticCurve::P256), b"ECK3" => Ok(EllipticCurve::P384), b"ECK5" => Ok(EllipticCurve::P521), - _ => Err(Error::InvalidValue("elliptic curve id", format!("{:?}", value))), + _ => Err(GkdiError::InvalidEllipticCurveId(value.to_vec())), } } } @@ -535,21 +561,22 @@ impl Decode for GroupKeyEnvelope { let version = reader.read_u32::()?; if version != Self::VERSION { - return Err(Error::InvalidValue( - "GroupKeyEnvelope version", - format!("expected {} but got {}", Self::VERSION, version), - )); + Err(GkdiError::InvalidVersion { + name: "GroupKeyEnvelope", + expected: Self::VERSION, + actual: version, + })?; } let mut magic = [0; 4]; read_buf(&mut reader, &mut magic)?; if magic != Self::MAGIC { - return Err(Error::InvalidMagicBytes( - "GroupKeyEnvelope", - Self::MAGIC, - magic.to_vec(), - )); + return Err(Error::InvalidMagic { + name: "GroupKeyEnvelope", + expected: Self::MAGIC, + actual: magic.to_vec(), + }); } let flags = reader.read_u32::()?; diff --git a/crates/dpapi/src/lib.rs b/crates/dpapi/src/lib.rs index f35697d8..692f2fec 100644 --- a/crates/dpapi/src/lib.rs +++ b/crates/dpapi/src/lib.rs @@ -1,4 +1,3 @@ -// #![warn(missing_docs)] #![doc = include_str!("../README.md")] #![allow(dead_code)] @@ -6,7 +5,7 @@ pub mod blob; pub mod error; pub mod gkdi; pub mod rpc; -pub(crate) mod sid_utils; -pub(crate) mod utils; +pub(crate) mod sid; +pub(crate) mod str; pub use error::*; diff --git a/crates/dpapi/src/rpc/bind.rs b/crates/dpapi/src/rpc/bind.rs index aa4cd3de..ce49da26 100644 --- a/crates/dpapi/src/rpc/bind.rs +++ b/crates/dpapi/src/rpc/bind.rs @@ -1,10 +1,17 @@ use std::io::{Read, Write}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use thiserror::Error; use uuid::Uuid; use crate::rpc::{read_padding, read_vec, write_buf, write_padding, Decode, Encode}; -use crate::{DpapiResult, Error}; +use crate::DpapiResult; + +#[derive(Debug, Error)] +pub enum BindError { + #[error("invalid context result code value: {0}")] + InvalidContextResultCode(u16), +} /// [BindTimeFeatureNegotiationBitmask](https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-rpce/cef529cc-77b5-4794-85dc-91e1467e80f0) /// @@ -113,15 +120,15 @@ impl ContextResultCode { } impl TryFrom for ContextResultCode { - type Error = Error; + type Error = BindError; - fn try_from(v: u16) -> DpapiResult { + fn try_from(v: u16) -> Result { match v { 0 => Ok(Self::Acceptance), 1 => Ok(Self::UserRejection), 2 => Ok(Self::ProviderRejection), 3 => Ok(Self::NegotiateAck), - v => Err(Error::InvalidContextResultCode(v)), + v => Err(BindError::InvalidContextResultCode(v)), } } } diff --git a/crates/dpapi/src/rpc/mod.rs b/crates/dpapi/src/rpc/mod.rs index 63d470fd..f03dcd78 100644 --- a/crates/dpapi/src/rpc/mod.rs +++ b/crates/dpapi/src/rpc/mod.rs @@ -4,10 +4,34 @@ pub mod request; use std::io::{ErrorKind as IoErrorKind, Read, Write}; +use thiserror::Error; use uuid::Uuid; +use self::bind::BindError; +use self::pdu::PduError; use crate::{DpapiResult, Error}; +#[derive(Debug, Error)] +pub enum RpcError { + #[error(transparent)] + Bind(BindError), + + #[error(transparent)] + Pdu(PduError), +} + +impl From for Error { + fn from(err: PduError) -> Self { + Error::from(RpcError::Pdu(err)) + } +} + +impl From for Error { + fn from(err: BindError) -> Self { + Error::from(RpcError::Bind(err)) + } +} + pub trait Encode { fn encode(&self, writer: impl Write) -> DpapiResult<()>; } @@ -107,13 +131,14 @@ pub fn read_vec(len: usize, reader: impl Read) -> DpapiResult> { pub fn read_c_str_utf16_le(len: usize, mut reader: impl Read) -> DpapiResult { use byteorder::{LittleEndian, ReadBytesExt}; - use crate::utils::utf16_bytes_to_utf8_string; + use crate::str::from_utf16_le; if len < 2 { - return Err(Error::InvalidValue( - "invalid UTF-17 string length", - format!("expected more than 2 bytes, but got {}", len), - )); + return Err(Error::InvalidLength { + name: "UTF-16 string", + expected: 2, + actual: len, + }); } let buf = read_vec(len - 2 /* UTF16 null terminator */, &mut reader)?; @@ -121,5 +146,5 @@ pub fn read_c_str_utf16_le(len: usize, mut reader: impl Read) -> DpapiResult()?; - utf16_bytes_to_utf8_string(&buf) + from_utf16_le(&buf) } diff --git a/crates/dpapi/src/rpc/pdu.rs b/crates/dpapi/src/rpc/pdu.rs index b88ff771..918a7ad7 100644 --- a/crates/dpapi/src/rpc/pdu.rs +++ b/crates/dpapi/src/rpc/pdu.rs @@ -3,11 +3,45 @@ use std::io::{Read, Write}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use num_derive::FromPrimitive; use num_traits::FromPrimitive; +use thiserror::Error; use super::{read_to_end, read_vec, write_buf, Decode, Encode}; use crate::rpc::bind::{AlterContext, AlterContextResponse, Bind, BindAck, BindNak}; use crate::rpc::request::{Request, Response}; -use crate::{DpapiResult, Error}; +use crate::DpapiResult; + +#[derive(Error, Debug)] +pub enum PduError { + #[error("invalid integer representation value: {0}")] + InvalidIntRepr(u8), + + #[error("invalid character representation value: {0}")] + InvalidCharacterRepr(u8), + + #[error("invalid floating point representation value: {0}")] + InvalidFloatingPointRepr(u8), + + #[error("invalid packet type value: {0}")] + InvalidPacketType(u8), + + #[error("invalid packet flags value: {0}")] + InvalidPacketFlags(u8), + + #[error("invalid security provider value: {0}")] + InvalidSecurityProvider(u8), + + #[error("invalid authentication level value: {0}")] + InvalidAuthenticationLevel(u8), + + #[error("invalid fault flags value: {0}")] + InvalidFaultFlags(u8), + + #[error("{0:?} PDU is not supported")] + PduNotSupported(crate::rpc::pdu::PacketType), + + #[error("invalid fragment (PDU) length: {0}")] + InvalidFragLength(u16), +} #[derive(Debug, Clone, Copy, PartialEq, Eq, Default, FromPrimitive)] #[repr(u8)] @@ -129,11 +163,11 @@ impl Decode for DataRepr { let data_representation = Self { byte_order: IntRepr::from_u8(integer_representation) - .ok_or_else(|| Error::InvalidIntRepr(integer_representation))?, + .ok_or(PduError::InvalidIntRepr(integer_representation))?, character: CharacterRepr::from_u8(character_representation) - .ok_or_else(|| Error::InvalidCharacterRepr(character_representation))?, + .ok_or(PduError::InvalidCharacterRepr(character_representation))?, floating_point: FloatingPointRepr::from_u8(floating_representation) - .ok_or_else(|| Error::InvalidFloatingPointRepr(floating_representation))?, + .ok_or(PduError::InvalidFloatingPointRepr(floating_representation))?, }; // Padding. @@ -177,11 +211,11 @@ impl Decode for PduHeader { version_minor: reader.read_u8()?, packet_type: { let packet_type = reader.read_u8()?; - PacketType::from_u8(packet_type).ok_or_else(|| Error::InvalidPacketType(packet_type))? + PacketType::from_u8(packet_type).ok_or(PduError::InvalidPacketType(packet_type))? }, packet_flags: { let packet_flags = reader.read_u8()?; - PacketFlags::from_bits(packet_flags).ok_or_else(|| Error::InvalidPacketFlags(packet_flags))? + PacketFlags::from_bits(packet_flags).ok_or(PduError::InvalidPacketFlags(packet_flags))? }, data_rep: DataRepr::decode(&mut reader)?, frag_len: reader.read_u16::()?, @@ -256,9 +290,9 @@ impl Decode for SecurityTrailer { Ok(Self { security_type: SecurityProvider::from_u8(security_provider) - .ok_or_else(|| Error::InvalidSecurityProvider(security_provider))?, + .ok_or(PduError::InvalidSecurityProvider(security_provider))?, level: AuthenticationLevel::from_u8(authentication_level) - .ok_or_else(|| Error::InvalidAuthenticationLevel(authentication_level))?, + .ok_or(PduError::InvalidAuthenticationLevel(authentication_level))?, pad_length: reader.read_u8()?, context_id: { // Skip Auth-Rsrvd. @@ -313,7 +347,7 @@ impl Decode for Fault { cancel_count: reader.read_u8()?, flags: { let fault_flags = reader.read_u8()?; - FaultFlags::from_bits(fault_flags).ok_or_else(|| Error::InvalidFaultFlags(fault_flags))? + FaultFlags::from_bits(fault_flags).ok_or(PduError::InvalidFaultFlags(fault_flags))? }, status: reader.read_u32::()?, stub_data: read_to_end(reader)?, @@ -348,7 +382,7 @@ impl PduData { PacketType::Request => Ok(PduData::Request(Request::decode(pdu_header, buf.as_slice())?)), PacketType::Response => Ok(PduData::Response(Response::decode(buf.as_slice())?)), PacketType::Fault => Ok(PduData::Fault(Fault::decode(buf.as_slice())?)), - packet_type => Err(Error::PduNotSupported(packet_type)), + packet_type => Err(PduError::PduNotSupported(packet_type))?, } } } @@ -396,7 +430,7 @@ impl Decode for Pdu { .checked_sub( header.auth_len + 8 /* security trailer header */ + 16, /* PDU header len */ ) - .ok_or_else(|| Error::InvalidFragLength(header.frag_len))? + .ok_or(PduError::InvalidFragLength(header.frag_len))? .into(), &mut reader, )?; diff --git a/crates/dpapi/src/sid_utils.rs b/crates/dpapi/src/sid.rs similarity index 93% rename from crates/dpapi/src/sid_utils.rs rename to crates/dpapi/src/sid.rs index 2fccd242..9e8f2b9d 100644 --- a/crates/dpapi/src/sid_utils.rs +++ b/crates/dpapi/src/sid.rs @@ -1,21 +1,28 @@ use std::sync::LazyLock; use regex::Regex; +use thiserror::Error; -use crate::{DpapiResult, Error}; +use crate::DpapiResult; + +#[derive(Debug, Error)] +pub enum SidError { + #[error("invalid sid value: {0}")] + InvalidSid(String), +} static SID_PATTERN: LazyLock = LazyLock::new(|| Regex::new(r"^S-(\d)-(\d+)(?:-\d+){1,15}$").expect("valid SID regex")); pub fn sid_to_bytes(sid: &str) -> DpapiResult> { if !SID_PATTERN.is_match(sid) { - return Err(Error::InvalidValue("SID", sid.to_owned())); + Err(SidError::InvalidSid(sid.to_owned()))?; } let parts = sid.split('-').collect::>(); if parts.len() < 3 { - return Err(Error::InvalidValue("SID", sid.to_owned())); + Err(SidError::InvalidSid(sid.to_owned()))?; } let revision = parts[1].parse::()?; diff --git a/crates/dpapi/src/str.rs b/crates/dpapi/src/str.rs new file mode 100644 index 00000000..0343b0cb --- /dev/null +++ b/crates/dpapi/src/str.rs @@ -0,0 +1,32 @@ +use crate::{DpapiResult, Error}; + +/// Decodes a UTF-16–encoded byte slice into a [String]. +/// +/// The input `data` slice should has the size multiple of two (`data.len() % 2 == 0`). +/// Otherwise, the function will return an error. +/// +/// *Note*: this function does not expect a NULL-char at the end of the byte slice. +pub fn from_utf16_le(data: &[u8]) -> DpapiResult { + if data.len() % 2 != 0 { + return Err(Error::FromUtf16( + "invalid UTF-16: byte slice should has the size multiple of two".into(), + )); + } + + Ok(String::from_utf16( + &data + .chunks(2) + .map(|c| u16::from_le_bytes(c.try_into().unwrap())) + .collect::>(), + )?) +} + +/// Encodes str into a UTF-16 encoded byte array. +/// +/// *Note*: this function automatically appends a NULL-char. +pub fn encode_utf16_le(data: &str) -> Vec { + data.encode_utf16() + .chain(std::iter::once(0)) + .flat_map(|v| v.to_le_bytes()) + .collect::>() +} diff --git a/crates/dpapi/src/utils.rs b/crates/dpapi/src/utils.rs deleted file mode 100644 index 75ab8404..00000000 --- a/crates/dpapi/src/utils.rs +++ /dev/null @@ -1,19 +0,0 @@ -use crate::DpapiResult; - -pub fn utf16_bytes_to_utf8_string(data: &[u8]) -> DpapiResult { - debug_assert_eq!(data.len() % 2, 0); - - Ok(String::from_utf16( - &data - .chunks(2) - .map(|c| u16::from_le_bytes(c.try_into().unwrap())) - .collect::>(), - )?) -} - -pub fn encode_utf16_le(data: &str) -> Vec { - data.encode_utf16() - .chain(std::iter::once(0)) - .flat_map(|v| v.to_le_bytes()) - .collect::>() -}