Skip to content

Commit

Permalink
refactor(expr): DRY encrypt/decrypt implementation (#14869)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangjinwu authored Jan 30, 2024
1 parent 1dbce01 commit e559641
Showing 1 changed file with 95 additions and 116 deletions.
211 changes: 95 additions & 116 deletions src/expr/impl/src/scalar/encrypt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use std::fmt::Debug;
use std::sync::LazyLock;

use openssl::error::ErrorStack;
use openssl::symm::{Cipher, Crypter, Mode as CipherMode};
Expand All @@ -36,116 +37,128 @@ enum Padding {
None,
}

#[derive(Debug, Clone)]
#[derive(Clone)]
pub struct CipherConfig {
algorithm: Algorithm,
mode: Mode,
cipher: Cipher,
padding: Padding,
crypt_key: Vec<u8>,
}

/// Because `Cipher` is not `Debug`, we include algorithm, key length and mode manually.
impl std::fmt::Debug for CipherConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CipherConfig")
.field("algorithm", &self.algorithm)
.field("key_len", &self.crypt_key.len())
.field("mode", &self.mode)
.field("padding", &self.padding)
.finish()
}
}

static CIPHER_CONFIG_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"^(aes|bf)(?:-(cbc|ecb))?(?:/pad:(pkcs|none))?$").unwrap());

impl CipherConfig {
fn parse_cipher_config(key: &[u8], input: &str) -> Result<CipherConfig> {
let re = Regex::new(r"^(aes|bf)(?:-(cbc|ecb))?(?:/pad:(pkcs|none))?$").unwrap();
let (algorithm, mode, padding) = {
if let Some(caps) = re.captures(input) {
let algorithm = match caps.get(1).map(|s| s.as_str()) {
Some("bf") => Algorithm::Blowfish,
Some("aes") => Algorithm::Aes,
algo => {
return Err(ExprError::InvalidParam {
name: "mode",
reason: format!("expect bf or aes for algorithm, but got: {:?}", algo)
.into(),
})
}
};
let Some(caps) = CIPHER_CONFIG_RE.captures(input) else {
return Err(ExprError::InvalidParam {
name: "mode",
reason: format!(
"invalid mode: {}, expect pattern algorithm[-mode][/pad:padding]",
input
)
.into(),
});
};

let mode = match caps.get(2).map(|m| m.as_str()) {
Some("cbc") | None => Mode::Cbc,
Some("ecb") => Mode::Ecb, // Default to Ecb if not specified
mode => {
return Err(ExprError::InvalidParam {
name: "mode",
reason: format!(
"expect cbc or ecb for mode, but got: {}",
mode.unwrap()
)
.into(),
})
}
};
let algorithm = match caps.get(1).map(|s| s.as_str()) {
Some("bf") => Algorithm::Blowfish,
Some("aes") => Algorithm::Aes,
algo => {
return Err(ExprError::InvalidParam {
name: "mode",
reason: format!("expect bf or aes for algorithm, but got: {:?}", algo).into(),
})
}
};

let padding = match caps.get(3).map(|m| m.as_str()) {
Some("pkcs") | None => Padding::Pkcs, // Default to Pkcs if not specified
Some("none") => Padding::None,
padding => {
return Err(ExprError::InvalidParam {
name: "mode",
reason: format!(
"expect cbc or ecb for padding, but got: {}",
padding.unwrap()
)
.into(),
})
}
};
let mode = match caps.get(2).map(|m| m.as_str()) {
Some("cbc") | None => Mode::Cbc, // Default to Cbc if not specified
Some("ecb") => Mode::Ecb,
Some(mode) => {
return Err(ExprError::InvalidParam {
name: "mode",
reason: format!("expect cbc or ecb for mode, but got: {}", mode).into(),
})
}
};

(algorithm, mode, padding)
} else {
let padding = match caps.get(3).map(|m| m.as_str()) {
Some("pkcs") | None => Padding::Pkcs, // Default to Pkcs if not specified
Some("none") => Padding::None,
Some(padding) => {
return Err(ExprError::InvalidParam {
name: "mode",
reason: format!(
"invalid mode: {}, expect pattern algorithm[-mode][/pad:padding]",
input
)
.into(),
});
reason: format!("expect pkcs or none for padding, but got: {}", padding).into(),
})
}
};

if algorithm == Algorithm::Aes && key.len() != 16 && key.len() != 24 && key.len() != 32 {
return Err(ExprError::InvalidParam {
name: "key",
reason: format!("invalid key length: {}, expect 16, 24 or 32", key.len()).into(),
});
}
let cipher = match (&algorithm, key.len(), &mode) {
(Algorithm::Blowfish, _, Mode::Cbc) => Cipher::bf_cbc(),
(Algorithm::Blowfish, _, Mode::Ecb) => Cipher::bf_ecb(),
(Algorithm::Aes, 16, Mode::Cbc) => Cipher::aes_128_cbc(),
(Algorithm::Aes, 16, Mode::Ecb) => Cipher::aes_128_ecb(),
(Algorithm::Aes, 24, Mode::Cbc) => Cipher::aes_192_cbc(),
(Algorithm::Aes, 24, Mode::Ecb) => Cipher::aes_192_ecb(),
(Algorithm::Aes, 32, Mode::Cbc) => Cipher::aes_256_cbc(),
(Algorithm::Aes, 32, Mode::Ecb) => Cipher::aes_256_ecb(),
(Algorithm::Aes, n, Mode::Cbc | Mode::Ecb) => {
return Err(ExprError::InvalidParam {
name: "key",
reason: format!("invalid key length: {}, expect 16, 24 or 32", n).into(),
})
}
};

Ok(CipherConfig {
algorithm,
mode,
cipher,
padding,
crypt_key: key.to_vec(),
})
}

fn build_cipher(&self) -> Result<Cipher> {
// match config's algorithm, mode, padding to openssl's cipher
match (&self.algorithm, self.crypt_key.len(), &self.mode) {
(Algorithm::Blowfish, _, Mode::Cbc) => Ok(Cipher::bf_cbc()),
(Algorithm::Blowfish, _, Mode::Ecb) => Ok(Cipher::bf_ecb()),
(Algorithm::Aes, 16, Mode::Cbc) => Ok(Cipher::aes_128_cbc()),
(Algorithm::Aes, 16, Mode::Ecb) => Ok(Cipher::aes_128_ecb()),
(Algorithm::Aes, 24, Mode::Cbc) => Ok(Cipher::aes_192_cbc()),
(Algorithm::Aes, 24, Mode::Ecb) => Ok(Cipher::aes_192_ecb()),
(Algorithm::Aes, 32, Mode::Cbc) => Ok(Cipher::aes_256_cbc()),
(Algorithm::Aes, 32, Mode::Ecb) => Ok(Cipher::aes_256_ecb()),
_ => Err(ExprError::InvalidParam {
name: "mode",
reason: format!(
"invalid algorithm {:?} mode: {:?}",
self.algorithm, self.mode
)
.into(),
}),
}
fn eval(&self, input: &[u8], stage: CryptographyStage) -> Result<Box<[u8]>> {
let operation = match stage {
CryptographyStage::Encrypt => CipherMode::Encrypt,
CryptographyStage::Decrypt => CipherMode::Decrypt,
};
self.eval_inner(input, operation).map_err(|reason| {
ExprError::Cryptography(Box::new(CryptographyError { stage, reason }))
})
}

fn enable_padding(&self) -> bool {
match self.padding {
fn eval_inner(
&self,
input: &[u8],
operation: CipherMode,
) -> std::result::Result<Box<[u8]>, ErrorStack> {
let mut decrypter = Crypter::new(self.cipher, operation, self.crypt_key.as_ref(), None)?;
let enable_padding = match self.padding {
Padding::Pkcs => true,
Padding::None => false,
}
};
decrypter.pad(enable_padding);
let mut decrypt = vec![0; input.len() + self.cipher.block_size()];
let count = decrypter.update(input, &mut decrypt)?;
let rest = decrypter.finalize(&mut decrypt[count..])?;
decrypt.truncate(count + rest);
Ok(decrypt.into())
}
}

Expand All @@ -155,49 +168,15 @@ impl CipherConfig {
prebuild = "CipherConfig::parse_cipher_config($1, $2)?"
)]
pub fn decrypt(data: &[u8], config: &CipherConfig) -> Result<Box<[u8]>> {
let report_error = |e: ErrorStack| {
ExprError::Cryptography(Box::new(CryptographyError {
stage: CryptographyStage::Decrypt,
reason: e,
}))
};

let cipher = config.build_cipher()?;
let mut decrypter = Crypter::new(cipher, CipherMode::Decrypt, config.crypt_key.as_ref(), None)
.map_err(report_error)?;
decrypter.pad(config.enable_padding());
let mut decrypt = vec![0; data.len() + cipher.block_size()];
let count = decrypter.update(data, &mut decrypt).map_err(report_error)?;
let rest = decrypter
.finalize(&mut decrypt[count..])
.map_err(report_error)?;
decrypt.truncate(count + rest);
Ok(decrypt.into())
config.eval(data, CryptographyStage::Decrypt)
}

#[function(
"encrypt(bytea, bytea, varchar) -> bytea",
prebuild = "CipherConfig::parse_cipher_config($1, $2)?"
)]
pub fn encrypt(data: &[u8], config: &CipherConfig) -> Result<Box<[u8]>> {
let report_error = |e: ErrorStack| {
ExprError::Cryptography(Box::new(CryptographyError {
stage: CryptographyStage::Encrypt,
reason: e,
}))
};

let cipher = config.build_cipher()?;
let mut encryptor = Crypter::new(cipher, CipherMode::Encrypt, config.crypt_key.as_ref(), None)
.map_err(report_error)?;
encryptor.pad(config.enable_padding());
let mut encrypt = vec![0; data.len() + cipher.block_size()];
let count = encryptor.update(data, &mut encrypt).map_err(report_error)?;
let rest = encryptor
.finalize(&mut encrypt[count..])
.map_err(report_error)?;
encrypt.truncate(count + rest);
Ok(encrypt.into())
config.eval(data, CryptographyStage::Encrypt)
}

#[cfg(test)]
Expand Down

0 comments on commit e559641

Please sign in to comment.