From e559641cd47101691923c9380446c6b575a11c75 Mon Sep 17 00:00:00 2001 From: xiangjinwu <17769960+xiangjinwu@users.noreply.github.com> Date: Tue, 30 Jan 2024 17:19:54 +0800 Subject: [PATCH] refactor(expr): DRY encrypt/decrypt implementation (#14869) --- src/expr/impl/src/scalar/encrypt.rs | 211 +++++++++++++--------------- 1 file changed, 95 insertions(+), 116 deletions(-) diff --git a/src/expr/impl/src/scalar/encrypt.rs b/src/expr/impl/src/scalar/encrypt.rs index 0603733c47cc1..4754246cacda8 100644 --- a/src/expr/impl/src/scalar/encrypt.rs +++ b/src/expr/impl/src/scalar/encrypt.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::fmt::Debug; +use std::sync::LazyLock; use openssl::error::ErrorStack; use openssl::symm::{Cipher, Crypter, Mode as CipherMode}; @@ -36,116 +37,128 @@ enum Padding { None, } -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct CipherConfig { algorithm: Algorithm, mode: Mode, + cipher: Cipher, padding: Padding, crypt_key: Vec, } +/// Because `Cipher` is not `Debug`, we include algorithm, key length and mode manually. +impl std::fmt::Debug for CipherConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CipherConfig") + .field("algorithm", &self.algorithm) + .field("key_len", &self.crypt_key.len()) + .field("mode", &self.mode) + .field("padding", &self.padding) + .finish() + } +} + +static CIPHER_CONFIG_RE: LazyLock = + LazyLock::new(|| Regex::new(r"^(aes|bf)(?:-(cbc|ecb))?(?:/pad:(pkcs|none))?$").unwrap()); + impl CipherConfig { fn parse_cipher_config(key: &[u8], input: &str) -> Result { - let re = Regex::new(r"^(aes|bf)(?:-(cbc|ecb))?(?:/pad:(pkcs|none))?$").unwrap(); - let (algorithm, mode, padding) = { - if let Some(caps) = re.captures(input) { - let algorithm = match caps.get(1).map(|s| s.as_str()) { - Some("bf") => Algorithm::Blowfish, - Some("aes") => Algorithm::Aes, - algo => { - return Err(ExprError::InvalidParam { - name: "mode", - reason: format!("expect bf or aes for algorithm, but got: {:?}", algo) - .into(), - }) - } - }; + let Some(caps) = CIPHER_CONFIG_RE.captures(input) else { + return Err(ExprError::InvalidParam { + name: "mode", + reason: format!( + "invalid mode: {}, expect pattern algorithm[-mode][/pad:padding]", + input + ) + .into(), + }); + }; - let mode = match caps.get(2).map(|m| m.as_str()) { - Some("cbc") | None => Mode::Cbc, - Some("ecb") => Mode::Ecb, // Default to Ecb if not specified - mode => { - return Err(ExprError::InvalidParam { - name: "mode", - reason: format!( - "expect cbc or ecb for mode, but got: {}", - mode.unwrap() - ) - .into(), - }) - } - }; + let algorithm = match caps.get(1).map(|s| s.as_str()) { + Some("bf") => Algorithm::Blowfish, + Some("aes") => Algorithm::Aes, + algo => { + return Err(ExprError::InvalidParam { + name: "mode", + reason: format!("expect bf or aes for algorithm, but got: {:?}", algo).into(), + }) + } + }; - let padding = match caps.get(3).map(|m| m.as_str()) { - Some("pkcs") | None => Padding::Pkcs, // Default to Pkcs if not specified - Some("none") => Padding::None, - padding => { - return Err(ExprError::InvalidParam { - name: "mode", - reason: format!( - "expect cbc or ecb for padding, but got: {}", - padding.unwrap() - ) - .into(), - }) - } - }; + let mode = match caps.get(2).map(|m| m.as_str()) { + Some("cbc") | None => Mode::Cbc, // Default to Cbc if not specified + Some("ecb") => Mode::Ecb, + Some(mode) => { + return Err(ExprError::InvalidParam { + name: "mode", + reason: format!("expect cbc or ecb for mode, but got: {}", mode).into(), + }) + } + }; - (algorithm, mode, padding) - } else { + let padding = match caps.get(3).map(|m| m.as_str()) { + Some("pkcs") | None => Padding::Pkcs, // Default to Pkcs if not specified + Some("none") => Padding::None, + Some(padding) => { return Err(ExprError::InvalidParam { name: "mode", - reason: format!( - "invalid mode: {}, expect pattern algorithm[-mode][/pad:padding]", - input - ) - .into(), - }); + reason: format!("expect pkcs or none for padding, but got: {}", padding).into(), + }) } }; - if algorithm == Algorithm::Aes && key.len() != 16 && key.len() != 24 && key.len() != 32 { - return Err(ExprError::InvalidParam { - name: "key", - reason: format!("invalid key length: {}, expect 16, 24 or 32", key.len()).into(), - }); - } + let cipher = match (&algorithm, key.len(), &mode) { + (Algorithm::Blowfish, _, Mode::Cbc) => Cipher::bf_cbc(), + (Algorithm::Blowfish, _, Mode::Ecb) => Cipher::bf_ecb(), + (Algorithm::Aes, 16, Mode::Cbc) => Cipher::aes_128_cbc(), + (Algorithm::Aes, 16, Mode::Ecb) => Cipher::aes_128_ecb(), + (Algorithm::Aes, 24, Mode::Cbc) => Cipher::aes_192_cbc(), + (Algorithm::Aes, 24, Mode::Ecb) => Cipher::aes_192_ecb(), + (Algorithm::Aes, 32, Mode::Cbc) => Cipher::aes_256_cbc(), + (Algorithm::Aes, 32, Mode::Ecb) => Cipher::aes_256_ecb(), + (Algorithm::Aes, n, Mode::Cbc | Mode::Ecb) => { + return Err(ExprError::InvalidParam { + name: "key", + reason: format!("invalid key length: {}, expect 16, 24 or 32", n).into(), + }) + } + }; Ok(CipherConfig { algorithm, mode, + cipher, padding, crypt_key: key.to_vec(), }) } - fn build_cipher(&self) -> Result { - // match config's algorithm, mode, padding to openssl's cipher - match (&self.algorithm, self.crypt_key.len(), &self.mode) { - (Algorithm::Blowfish, _, Mode::Cbc) => Ok(Cipher::bf_cbc()), - (Algorithm::Blowfish, _, Mode::Ecb) => Ok(Cipher::bf_ecb()), - (Algorithm::Aes, 16, Mode::Cbc) => Ok(Cipher::aes_128_cbc()), - (Algorithm::Aes, 16, Mode::Ecb) => Ok(Cipher::aes_128_ecb()), - (Algorithm::Aes, 24, Mode::Cbc) => Ok(Cipher::aes_192_cbc()), - (Algorithm::Aes, 24, Mode::Ecb) => Ok(Cipher::aes_192_ecb()), - (Algorithm::Aes, 32, Mode::Cbc) => Ok(Cipher::aes_256_cbc()), - (Algorithm::Aes, 32, Mode::Ecb) => Ok(Cipher::aes_256_ecb()), - _ => Err(ExprError::InvalidParam { - name: "mode", - reason: format!( - "invalid algorithm {:?} mode: {:?}", - self.algorithm, self.mode - ) - .into(), - }), - } + fn eval(&self, input: &[u8], stage: CryptographyStage) -> Result> { + let operation = match stage { + CryptographyStage::Encrypt => CipherMode::Encrypt, + CryptographyStage::Decrypt => CipherMode::Decrypt, + }; + self.eval_inner(input, operation).map_err(|reason| { + ExprError::Cryptography(Box::new(CryptographyError { stage, reason })) + }) } - fn enable_padding(&self) -> bool { - match self.padding { + fn eval_inner( + &self, + input: &[u8], + operation: CipherMode, + ) -> std::result::Result, ErrorStack> { + let mut decrypter = Crypter::new(self.cipher, operation, self.crypt_key.as_ref(), None)?; + let enable_padding = match self.padding { Padding::Pkcs => true, Padding::None => false, - } + }; + decrypter.pad(enable_padding); + let mut decrypt = vec![0; input.len() + self.cipher.block_size()]; + let count = decrypter.update(input, &mut decrypt)?; + let rest = decrypter.finalize(&mut decrypt[count..])?; + decrypt.truncate(count + rest); + Ok(decrypt.into()) } } @@ -155,24 +168,7 @@ impl CipherConfig { prebuild = "CipherConfig::parse_cipher_config($1, $2)?" )] pub fn decrypt(data: &[u8], config: &CipherConfig) -> Result> { - let report_error = |e: ErrorStack| { - ExprError::Cryptography(Box::new(CryptographyError { - stage: CryptographyStage::Decrypt, - reason: e, - })) - }; - - let cipher = config.build_cipher()?; - let mut decrypter = Crypter::new(cipher, CipherMode::Decrypt, config.crypt_key.as_ref(), None) - .map_err(report_error)?; - decrypter.pad(config.enable_padding()); - let mut decrypt = vec![0; data.len() + cipher.block_size()]; - let count = decrypter.update(data, &mut decrypt).map_err(report_error)?; - let rest = decrypter - .finalize(&mut decrypt[count..]) - .map_err(report_error)?; - decrypt.truncate(count + rest); - Ok(decrypt.into()) + config.eval(data, CryptographyStage::Decrypt) } #[function( @@ -180,24 +176,7 @@ pub fn decrypt(data: &[u8], config: &CipherConfig) -> Result> { prebuild = "CipherConfig::parse_cipher_config($1, $2)?" )] pub fn encrypt(data: &[u8], config: &CipherConfig) -> Result> { - let report_error = |e: ErrorStack| { - ExprError::Cryptography(Box::new(CryptographyError { - stage: CryptographyStage::Encrypt, - reason: e, - })) - }; - - let cipher = config.build_cipher()?; - let mut encryptor = Crypter::new(cipher, CipherMode::Encrypt, config.crypt_key.as_ref(), None) - .map_err(report_error)?; - encryptor.pad(config.enable_padding()); - let mut encrypt = vec![0; data.len() + cipher.block_size()]; - let count = encryptor.update(data, &mut encrypt).map_err(report_error)?; - let rest = encryptor - .finalize(&mut encrypt[count..]) - .map_err(report_error)?; - encrypt.truncate(count + rest); - Ok(encrypt.into()) + config.eval(data, CryptographyStage::Encrypt) } #[cfg(test)]