Skip to content

Commit

Permalink
feat: support sql func encrypt/decrypt (#14717)
Browse files Browse the repository at this point in the history
Co-authored-by: xiangjinwu <[email protected]>
  • Loading branch information
tabVersion and xiangjinwu authored Feb 1, 2024
1 parent d224230 commit 344cf99
Show file tree
Hide file tree
Showing 15 changed files with 500 additions and 2 deletions.
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions ci/scripts/regress-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ dpkg-reconfigure --frontend=noninteractive locales
# All the above is required because otherwise psql would throw some warning
# that goes into the output file and thus diverges from the expected output file.
export PGPASSWORD='postgres';

# Load extensions. This shall only be done once per database, so not part of test runner.
psql -h db -p 5432 -d postgres -U postgres \
-c 'create extension pgcrypto;' \
-c 'create extension hstore;' \
-c 'create extension tablefunc;'

RUST_BACKTRACE=1 target/debug/risingwave_regress_test --host db \
-p 5432 \
-u postgres \
Expand Down
2 changes: 2 additions & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ message ExprNode {
PGWIRE_RECV = 321;
CONVERT_FROM = 322;
CONVERT_TO = 323;
DECRYPT = 324;
ENCRYPT = 325;

// Unary operators
NEG = 401;
Expand Down
1 change: 1 addition & 0 deletions src/expr/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ itertools = "0.12"
linkme = { version = "0.3", features = ["used_linker"] }
moka = { version = "0.12", features = ["future"] }
num-traits = "0.2"
openssl = { version = "0.10", features = ["vendored"] }
parse-display = "0.8"
paste = "1"
risingwave_common = { workspace = true }
Expand Down
19 changes: 18 additions & 1 deletion src/expr/core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt::Display;
use std::fmt::{Debug, Display};

use risingwave_common::array::{ArrayError, ArrayRef};
use risingwave_common::error::{ErrorCode, RwError};
Expand Down Expand Up @@ -117,6 +117,23 @@ pub enum ExprError {

#[error("invalid state: {0}")]
InvalidState(String),

#[error("error in cryptography: {0}")]
Cryptography(Box<CryptographyError>),
}

#[derive(Debug)]
pub enum CryptographyStage {
Encrypt,
Decrypt,
}

#[derive(Debug, Error)]
#[error("{stage:?} stage, reason: {reason}")]
pub struct CryptographyError {
pub stage: CryptographyStage,
#[source]
pub reason: openssl::error::ErrorStack,
}

static_assertions::const_assert_eq!(std::mem::size_of::<ExprError>(), 40);
Expand Down
2 changes: 1 addition & 1 deletion src/expr/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ pub mod sig;
pub mod table_function;
pub mod window_function;

pub use error::{ContextUnavailable, ExprError, Result};
pub use error::{ContextUnavailable, CryptographyError, CryptographyStage, ExprError, Result};
pub use risingwave_common::{bail, ensure};
pub use risingwave_expr_macro::*;
1 change: 1 addition & 0 deletions src/expr/impl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jsonbb = "0.1.2"
linkme = { version = "0.3", features = ["used_linker"] }
md5 = "0.7"
num-traits = "0.2"
openssl = { version = "0.10", features = ["vendored"] }
regex = "1"
risingwave_common = { workspace = true }
risingwave_expr = { workspace = true }
Expand Down
249 changes: 249 additions & 0 deletions src/expr/impl/src/scalar/encrypt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt::Debug;
use std::sync::LazyLock;

use openssl::error::ErrorStack;
use openssl::symm::{Cipher, Crypter, Mode as CipherMode};
use regex::Regex;
use risingwave_expr::{function, CryptographyError, CryptographyStage, ExprError, Result};

#[derive(Debug, Clone, PartialEq)]
enum Algorithm {
Aes,
}

#[derive(Debug, Clone, PartialEq)]
enum Mode {
Cbc,
Ecb,
}
#[derive(Debug, Clone, PartialEq)]
enum Padding {
Pkcs,
None,
}

#[derive(Clone)]
pub struct CipherConfig {
algorithm: Algorithm,
mode: Mode,
cipher: Cipher,
padding: Padding,
crypt_key: Vec<u8>,
}

/// Because `Cipher` is not `Debug`, we include algorithm, key length and mode manually.
impl std::fmt::Debug for CipherConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CipherConfig")
.field("algorithm", &self.algorithm)
.field("key_len", &self.crypt_key.len())
.field("mode", &self.mode)
.field("padding", &self.padding)
.finish()
}
}

static CIPHER_CONFIG_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"^(aes)(?:-(cbc|ecb))?(?:/pad:(pkcs|none))?$").unwrap());

impl CipherConfig {
fn parse_cipher_config(key: &[u8], input: &str) -> Result<CipherConfig> {
let Some(caps) = CIPHER_CONFIG_RE.captures(input) else {
return Err(ExprError::InvalidParam {
name: "mode",
reason: format!(
"invalid mode: {}, expect pattern algorithm[-mode][/pad:padding]",
input
)
.into(),
});
};

let algorithm = match caps.get(1).map(|s| s.as_str()) {
Some("aes") => Algorithm::Aes,
algo => {
return Err(ExprError::InvalidParam {
name: "mode",
reason: format!("expect aes for algorithm, but got: {:?}", algo).into(),
})
}
};

let mode = match caps.get(2).map(|m| m.as_str()) {
Some("cbc") | None => Mode::Cbc, // Default to Cbc if not specified
Some("ecb") => Mode::Ecb,
Some(mode) => {
return Err(ExprError::InvalidParam {
name: "mode",
reason: format!("expect cbc or ecb for mode, but got: {}", mode).into(),
})
}
};

let padding = match caps.get(3).map(|m| m.as_str()) {
Some("pkcs") | None => Padding::Pkcs, // Default to Pkcs if not specified
Some("none") => Padding::None,
Some(padding) => {
return Err(ExprError::InvalidParam {
name: "mode",
reason: format!("expect pkcs or none for padding, but got: {}", padding).into(),
})
}
};

let cipher = match (&algorithm, key.len(), &mode) {
(Algorithm::Aes, 16, Mode::Cbc) => Cipher::aes_128_cbc(),
(Algorithm::Aes, 16, Mode::Ecb) => Cipher::aes_128_ecb(),
(Algorithm::Aes, 24, Mode::Cbc) => Cipher::aes_192_cbc(),
(Algorithm::Aes, 24, Mode::Ecb) => Cipher::aes_192_ecb(),
(Algorithm::Aes, 32, Mode::Cbc) => Cipher::aes_256_cbc(),
(Algorithm::Aes, 32, Mode::Ecb) => Cipher::aes_256_ecb(),
(Algorithm::Aes, n, Mode::Cbc | Mode::Ecb) => {
return Err(ExprError::InvalidParam {
name: "key",
reason: format!("invalid key length: {}, expect 16, 24 or 32", n).into(),
})
}
};

Ok(CipherConfig {
algorithm,
mode,
cipher,
padding,
crypt_key: key.to_vec(),
})
}

fn eval(&self, input: &[u8], stage: CryptographyStage) -> Result<Box<[u8]>> {
let operation = match stage {
CryptographyStage::Encrypt => CipherMode::Encrypt,
CryptographyStage::Decrypt => CipherMode::Decrypt,
};
self.eval_inner(input, operation).map_err(|reason| {
ExprError::Cryptography(Box::new(CryptographyError { stage, reason }))
})
}

fn eval_inner(
&self,
input: &[u8],
operation: CipherMode,
) -> std::result::Result<Box<[u8]>, ErrorStack> {
let mut decrypter = Crypter::new(self.cipher, operation, self.crypt_key.as_ref(), None)?;
let enable_padding = match self.padding {
Padding::Pkcs => true,
Padding::None => false,
};
decrypter.pad(enable_padding);
let mut decrypt = vec![0; input.len() + self.cipher.block_size()];
let count = decrypter.update(input, &mut decrypt)?;
let rest = decrypter.finalize(&mut decrypt[count..])?;
decrypt.truncate(count + rest);
Ok(decrypt.into())
}
}

/// from [pg doc](https://www.postgresql.org/docs/current/pgcrypto.html#PGCRYPTO-RAW-ENC-FUNCS)
#[function(
"decrypt(bytea, bytea, varchar) -> bytea",
prebuild = "CipherConfig::parse_cipher_config($1, $2)?"
)]
pub fn decrypt(data: &[u8], config: &CipherConfig) -> Result<Box<[u8]>> {
config.eval(data, CryptographyStage::Decrypt)
}

#[function(
"encrypt(bytea, bytea, varchar) -> bytea",
prebuild = "CipherConfig::parse_cipher_config($1, $2)?"
)]
pub fn encrypt(data: &[u8], config: &CipherConfig) -> Result<Box<[u8]>> {
config.eval(data, CryptographyStage::Encrypt)
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_decrypt() {
let data = b"hello world";
let mode = "aes";

let config = CipherConfig::parse_cipher_config(
b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F" as &[u8],
mode,
)
.unwrap();
let encrypted = encrypt(data, &config).unwrap();

let decrypted = decrypt(&encrypted, &config).unwrap();
assert_eq!(decrypted, (*data).into());
}

#[test]
fn encrypt_testcase() {
let encrypt_wrapper = |data: &[u8], key: &[u8], mode: &str| -> Result<Box<[u8]>> {
let config = CipherConfig::parse_cipher_config(key, mode)?;
encrypt(data, &config)
};
let decrypt_wrapper = |data: &[u8], key: &[u8], mode: &str| -> Result<Box<[u8]>> {
let config = CipherConfig::parse_cipher_config(key, mode)?;
decrypt(data, &config)
};
let key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f";

let encrypted = encrypt_wrapper(
b"\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff",
key,
"aes-ecb/pad:none",
)
.unwrap();

let decrypted = decrypt_wrapper(&encrypted, key, "aes-ecb/pad:none").unwrap();
assert_eq!(
decrypted,
(*b"\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff").into()
)
}

#[test]
fn test_parse_cipher_config() {
let key = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f";

let mode_1 = "aes-ecb/pad:none";
let config = CipherConfig::parse_cipher_config(key, mode_1).unwrap();
assert_eq!(config.algorithm, Algorithm::Aes);
assert_eq!(config.mode, Mode::Ecb);
assert_eq!(config.padding, Padding::None);

let mode_2 = "aes-cbc/pad:pkcs";
let config = CipherConfig::parse_cipher_config(key, mode_2).unwrap();
assert_eq!(config.algorithm, Algorithm::Aes);
assert_eq!(config.mode, Mode::Cbc);
assert_eq!(config.padding, Padding::Pkcs);

let mode_3 = "aes";
let config = CipherConfig::parse_cipher_config(key, mode_3).unwrap();
assert_eq!(config.algorithm, Algorithm::Aes);
assert_eq!(config.mode, Mode::Cbc);
assert_eq!(config.padding, Padding::Pkcs);

let mode_4 = "cbc";
assert!(CipherConfig::parse_cipher_config(key, mode_4).is_err());
}
}
1 change: 1 addition & 0 deletions src/expr/impl/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ mod to_char;
mod to_jsonb;
mod vnode;
pub use to_jsonb::*;
mod encrypt;
mod external;
mod to_timestamp;
mod translate;
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,8 @@ impl Binder {
("sha256", raw_call(ExprType::Sha256)),
("sha384", raw_call(ExprType::Sha384)),
("sha512", raw_call(ExprType::Sha512)),
("encrypt", raw_call(ExprType::Encrypt)),
("decrypt", raw_call(ExprType::Decrypt)),
("left", raw_call(ExprType::Left)),
("right", raw_call(ExprType::Right)),
("int8send", raw_call(ExprType::PgwireSend)),
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/src/expr/pure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ impl ExprVisitor for ImpureAnalyzer {
| expr_node::Type::Sha256
| expr_node::Type::Sha384
| expr_node::Type::Sha512
| expr_node::Type::Decrypt
| expr_node::Type::Encrypt
| expr_node::Type::Tand
| expr_node::Type::ArrayPositions
| expr_node::Type::StringToArray
Expand Down
Loading

0 comments on commit 344cf99

Please sign in to comment.