From adb1d83bdaa169d449a26ba3f5df824195e66320 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Fri, 29 Dec 2023 19:51:55 +0800 Subject: [PATCH] support json and decimal Signed-off-by: Runji Wang --- Cargo.lock | 2 +- Cargo.toml | 2 +- e2e_test/udf/wasm/Cargo.toml | 4 ++- e2e_test/udf/wasm/src/lib.rs | 11 +++++++++ e2e_test/udf/wasm_udf.slt | 27 +++++++++++++++++++++ src/expr/udf/README.md | 2 +- src/frontend/src/handler/create_function.rs | 2 +- 7 files changed, 45 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3d49587e11d2..a9dfdfd17840 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -634,7 +634,7 @@ dependencies = [ [[package]] name = "arrow-udf-wasm" version = "0.1.0" -source = "git+https://github.com/risingwavelabs/arrow-udf.git?rev=34e1c9d#34e1c9d56bb8e1d11bc03faf20d0ae8cce6883d3" +source = "git+https://github.com/risingwavelabs/arrow-udf.git?rev=3ac8371#3ac8371d799d955cc9e383cbcfb000027df9d2cf" dependencies = [ "anyhow", "arrow-array 49.0.0", diff --git a/Cargo.toml b/Cargo.toml index c5d4e4290ff4..a9c69b70d9a3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -131,7 +131,7 @@ arrow-flight = "49" arrow-select = "49" arrow-ord = "49" arrow-row = "49" -arrow-udf-wasm = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "34e1c9d" } +arrow-udf-wasm = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "3ac8371" } arrow-array-deltalake = { package = "arrow-array", version = "48.0.1" } arrow-buffer-deltalake = { package = "arrow-buffer", version = "48.0.1" } arrow-cast-deltalake = { package = "arrow-cast", version = "48.0.1" } diff --git a/e2e_test/udf/wasm/Cargo.toml b/e2e_test/udf/wasm/Cargo.toml index 405b133e088e..39f3d6e92130 100644 --- a/e2e_test/udf/wasm/Cargo.toml +++ b/e2e_test/udf/wasm/Cargo.toml @@ -8,4 +8,6 @@ edition = "2021" crate-type = ["cdylib"] [dependencies] -arrow-udf = { git = "https://github.com/risingwavelabs/arrow-udf-wasm.git", rev = "34e1c9d" } +arrow-udf = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "3ac8371" } +rust_decimal = "1" +serde_json = "1" diff --git a/e2e_test/udf/wasm/src/lib.rs b/e2e_test/udf/wasm/src/lib.rs index fb1b51147627..a7149dea8cfb 100644 --- a/e2e_test/udf/wasm/src/lib.rs +++ b/e2e_test/udf/wasm/src/lib.rs @@ -1,4 +1,5 @@ use arrow_udf::function; +use rust_decimal::Decimal; #[function("int_42() -> int")] fn int_42() -> i32 { @@ -62,3 +63,13 @@ fn extract_tcp_info(tcp_packet: &[u8]) -> (String, String, i16, i16) { dst_port as i16, ) } + +#[function("decimal_add(decimal, decimal) -> decimal")] +fn decimal_add(a: Decimal, b: Decimal) -> Decimal { + a + b +} + +#[function("jsonb_access(json, int) -> json")] +fn jsonb_access(json: serde_json::Value, index: i32) -> Option { + json.get(index as usize).cloned() +} diff --git a/e2e_test/udf/wasm_udf.slt b/e2e_test/udf/wasm_udf.slt index efe78e0ba9a7..54896ca107ae 100644 --- a/e2e_test/udf/wasm_udf.slt +++ b/e2e_test/udf/wasm_udf.slt @@ -17,6 +17,14 @@ statement ok create function extract_tcp_info(bytea) returns struct language wasm using link 'fs://e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm'; +statement ok +create function decimal_add(decimal, decimal) returns decimal +language wasm using link 'fs://e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm'; + +statement ok +create function jsonb_access(jsonb, int) returns jsonb +language wasm using link 'fs://e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm'; + query I select int_42(); ---- @@ -37,6 +45,19 @@ select extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d97 ---- (192.168.0.14,192.168.0.1,861,8374) +query R +select decimal_add(1.11, 2.22); +---- +3.33 + +query T +select jsonb_access(a::jsonb, 1) from +(values ('["a", "b", "c"]'), (null), ('[0, false]')) t(a); +---- +"b" +NULL +false + statement ok drop function int_42; @@ -48,3 +69,9 @@ drop function gcd(int,int,int); statement ok drop function extract_tcp_info; + +statement ok +drop function decimal_add; + +statement ok +drop function jsonb_access; diff --git a/src/expr/udf/README.md b/src/expr/udf/README.md index 97df6d622f03..cc9b665b2901 100644 --- a/src/expr/udf/README.md +++ b/src/expr/udf/README.md @@ -50,7 +50,7 @@ fn gcd(mut x: i32, mut y: i32) -> i32 { You can find more usages in the [documentation](https://docs.rs/arrow_udf/0.1.0/arrow_udf/attr.function.html) and more examples in the [tests](https://github.com/risingwavelabs/arrow-udf/blob/main/arrow-udf/tests/tests.rs). Currently we only support scalar functions with a limited set of data types. -`decimal`, `timestamptz`, `jsonb` and complex array types are not supported yet. +`timestamptz` and complex array types are not supported yet. ## 3. Build the project diff --git a/src/frontend/src/handler/create_function.rs b/src/frontend/src/handler/create_function.rs index 3c8300d3e979..8a662e8f80e2 100644 --- a/src/frontend/src/handler/create_function.rs +++ b/src/frontend/src/handler/create_function.rs @@ -284,7 +284,7 @@ fn datatype_name(ty: &DataType) -> String { DataType::Timestamptz => "timestamptz".to_string(), DataType::Interval => "interval".to_string(), DataType::Decimal => "decimal".to_string(), - DataType::Jsonb => "jsonb".to_string(), + DataType::Jsonb => "json".to_string(), DataType::Serial => "serial".to_string(), DataType::Int256 => "int256".to_string(), DataType::Bytea => "bytea".to_string(),