Skip to content

Commit

Permalink
feat(udf): add initial support for WASM-based Rust UDF (#14271)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
Co-authored-by: xxchan <[email protected]>
Co-authored-by: wangrunji0408 <[email protected]>
  • Loading branch information
3 people authored Jan 12, 2024
1 parent 89a8297 commit e8f1eb9
Show file tree
Hide file tree
Showing 34 changed files with 1,864 additions and 131 deletions.
1,162 changes: 1,143 additions & 19 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ members = [
"src/utils/workspace-config",
"src/workspace-hack",
]
exclude = ["lints"]
exclude = ["e2e_test/udf/wasm", "lints"]
resolver = "2"

[workspace.package]
Expand Down Expand Up @@ -132,6 +132,7 @@ arrow-flight = "49"
arrow-select = "49"
arrow-ord = "49"
arrow-row = "49"
arrow-udf-wasm = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "f9a9e0d" }
arrow-array-deltalake = { package = "arrow-array", version = "48.0.1" }
arrow-buffer-deltalake = { package = "arrow-buffer", version = "48.0.1" }
arrow-cast-deltalake = { package = "arrow-cast", version = "48.0.1" }
Expand Down
8 changes: 8 additions & 0 deletions ci/scripts/build-other.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ set -euo pipefail
source ci/scripts/common.sh


echo "--- Build Rust UDF"
cd e2e_test/udf/wasm
rustup target add wasm32-wasi
cargo build --release
cd ../../..

echo "--- Build Java packages"
cd java
mvn -B package -Dmaven.test.skip=true
Expand All @@ -26,6 +32,8 @@ tar --zstd -cf java-binding-integration-test.tar.zst bin java/java-binding-integ
echo "--- Upload Java artifacts"
cp java/connector-node/assembly/target/risingwave-connector-1.0.0.tar.gz ./risingwave-connector.tar.gz
cp java/udf-example/target/risingwave-udf-example.jar ./risingwave-udf-example.jar
cp e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm udf.wasm
buildkite-agent artifact upload ./risingwave-connector.tar.gz
buildkite-agent artifact upload ./risingwave-udf-example.jar
buildkite-agent artifact upload ./java-binding-integration-test.tar.zst
buildkite-agent artifact upload ./udf.wasm
5 changes: 5 additions & 0 deletions ci/scripts/run-e2e-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ download_and_prepare_rw "$profile" common
echo "--- Download artifacts"
download-and-decompress-artifact e2e_test_generated ./
download-and-decompress-artifact risingwave_e2e_extended_mode_test-"$profile" target/debug/
mkdir -p e2e_test/udf/wasm/target/wasm32-wasi/release/
buildkite-agent artifact download udf.wasm e2e_test/udf/wasm/target/wasm32-wasi/release/
buildkite-agent artifact download risingwave-udf-example.jar ./
mv target/debug/risingwave_e2e_extended_mode_test-"$profile" target/debug/risingwave_e2e_extended_mode_test

Expand Down Expand Up @@ -97,6 +99,9 @@ sleep 1
sqllogictest -p 4566 -d dev './e2e_test/udf/udf.slt'
pkill java

echo "--- e2e, $mode, wasm udf"
sqllogictest -p 4566 -d dev './e2e_test/udf/wasm_udf.slt'

echo "--- Kill cluster"
cluster_stop

Expand Down
2 changes: 2 additions & 0 deletions e2e_test/udf/wasm/.cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[build]
target = "wasm32-wasi"
2 changes: 2 additions & 0 deletions e2e_test/udf/wasm/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Cargo.lock
target
14 changes: 14 additions & 0 deletions e2e_test/udf/wasm/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "udf"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[lib]
crate-type = ["cdylib"]

[dependencies]
arrow-udf = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "f9a9e0d" }
genawaiter = "0.99"
rust_decimal = "1"
serde_json = "1"
80 changes: 80 additions & 0 deletions e2e_test/udf/wasm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use arrow_udf::function;
use rust_decimal::Decimal;

#[function("int_42() -> int")]
fn int_42() -> i32 {
42
}

#[function("gcd(int, int) -> int")]
fn gcd(mut a: i32, mut b: i32) -> i32 {
while b != 0 {
let t = b;
b = a % b;
a = t;
}
a
}

#[function("gcd(int, int, int) -> int")]
fn gcd3(a: i32, b: i32, c: i32) -> i32 {
gcd(gcd(a, b), c)
}

#[function("sleep(int) -> int")]
fn sleep(second: i32) -> i32 {
std::thread::sleep(std::time::Duration::from_secs(second as u64));
0
}

#[function("segfault() -> int")]
fn segfault() -> i32 {
unsafe { (usize::MAX as *const i32).read_volatile() }
}

#[function("oom() -> int")]
fn oom() -> i32 {
_ = vec![0u8; usize::MAX];
0
}

#[function("create_file() -> int")]
fn create_file() -> i32 {
std::fs::File::create("test").unwrap();
0
}

#[function("length(varchar) -> int")]
#[function("length(bytea) -> int")]
fn length(s: impl AsRef<[u8]>) -> i32 {
s.as_ref().len() as i32
}

#[function("extract_tcp_info(bytea) -> struct<src_addr:varchar,dst_addr:varchar,src_port:smallint,dst_port:smallint>")]
fn extract_tcp_info(tcp_packet: &[u8]) -> (String, String, i16, i16) {
let src_addr = std::net::Ipv4Addr::from(<[u8; 4]>::try_from(&tcp_packet[12..16]).unwrap());
let dst_addr = std::net::Ipv4Addr::from(<[u8; 4]>::try_from(&tcp_packet[16..20]).unwrap());
let src_port = u16::from_be_bytes(<[u8; 2]>::try_from(&tcp_packet[20..22]).unwrap());
let dst_port = u16::from_be_bytes(<[u8; 2]>::try_from(&tcp_packet[22..24]).unwrap());
(
src_addr.to_string(),
dst_addr.to_string(),
src_port as i16,
dst_port as i16,
)
}

#[function("decimal_add(decimal, decimal) -> decimal")]
fn decimal_add(a: Decimal, b: Decimal) -> Decimal {
a + b
}

#[function("jsonb_access(json, int) -> json")]
fn jsonb_access(json: serde_json::Value, index: i32) -> Option<serde_json::Value> {
json.get(index as usize).cloned()
}

#[function("series(int) -> setof int")]
fn series(n: i32) -> impl Iterator<Item = i32> {
0..n
}
93 changes: 93 additions & 0 deletions e2e_test/udf/wasm_udf.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Before running this test:
# cd e2e_test/udf/wasm && cargo build --release

statement ok
create function int_42() returns int
language wasm using link 'fs://e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm';

statement ok
create function gcd(int, int) returns int
language wasm using link 'fs://e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm';

statement ok
create function gcd(int, int, int) returns int
language wasm using link 'fs://e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm';

statement ok
create function extract_tcp_info(bytea) returns struct<src_addr varchar, dst_addr varchar, src_port smallint, dst_port smallint>
language wasm using link 'fs://e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm';

statement ok
create function decimal_add(decimal, decimal) returns decimal
language wasm using link 'fs://e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm';

statement ok
create function jsonb_access(jsonb, int) returns jsonb
language wasm using link 'fs://e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm';

statement ok
create function series(int) returns table (x int)
language wasm using link 'fs://e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm';

query I
select int_42();
----
42

query I
select gcd(25, 15);
----
5

query I
select gcd(25, 15, 3);
----
1

query T
select extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4' :: bytea);
----
(192.168.0.14,192.168.0.1,861,8374)

query R
select decimal_add(1.11, 2.22);
----
3.33

query T
select jsonb_access(a::jsonb, 1) from
(values ('["a", "b", "c"]'), (null), ('[0, false]')) t(a);
----
"b"
NULL
false

query I
select series(5);
----
0
1
2
3
4

statement ok
drop function int_42;

statement ok
drop function gcd(int,int);

statement ok
drop function gcd(int,int,int);

statement ok
drop function extract_tcp_info;

statement ok
drop function decimal_add;

statement ok
drop function jsonb_access;

statement ok
drop function series;
10 changes: 10 additions & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ message ExprNode {
COL_DESCRIPTION = 2401;
PG_GET_VIEWDEF = 2402;
}
// Only use this field for function call. For other types of expression, it should be UNSPECIFIED.
Type function_type = 1;
data.DataType return_type = 3;
oneof rex_node {
Expand Down Expand Up @@ -461,15 +462,24 @@ message WindowFunction {
WindowFrame frame = 5;
}

// Note: due to historic reasons, UserDefinedFunction is a oneof variant parallel to FunctionCall,
// while UserDefinedTableFunction is embedded as a field in TableFunction.

message UserDefinedFunction {
repeated ExprNode children = 1;
string name = 2;
repeated data.DataType arg_types = 3;
string language = 4;
// For external UDF: the link to the external function service.
// For WASM UDF: the link to the wasm binary file.
string link = 5;
// An unique identifier for the function.
// For external UDF, it's the name of the function in the external function service.
// For WASM UDF, it's the name of the function in the wasm binary file.
string identifier = 6;
}

// Additional information for user defined table functions.
message UserDefinedTableFunction {
repeated data.DataType arg_types = 3;
string language = 4;
Expand Down
1 change: 1 addition & 0 deletions proto/meta.proto
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ message SystemParams {
optional uint32 parallel_compact_size_mb = 11;
optional uint32 max_concurrent_creating_streaming_jobs = 12;
optional bool pause_on_next_bootstrap = 13;
optional string wasm_storage_url = 14;
}

message GetSystemParamsRequest {}
Expand Down
4 changes: 4 additions & 0 deletions src/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,9 @@ pub struct SystemConfig {
/// Whether to pause all data sources on next bootstrap.
#[serde(default = "default::system::pause_on_next_bootstrap")]
pub pause_on_next_bootstrap: Option<bool>,

#[serde(default = "default::system::wasm_storage_url")]
pub wasm_storage_url: Option<String>,
}

/// The subsections `[storage.object_store]`.
Expand Down Expand Up @@ -953,6 +956,7 @@ impl SystemConfig {
max_concurrent_creating_streaming_jobs: self.max_concurrent_creating_streaming_jobs,
pause_on_next_bootstrap: self.pause_on_next_bootstrap,
telemetry_enabled: None, // deprecated
wasm_storage_url: self.wasm_storage_url,
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/common/src/system_param/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ macro_rules! for_all_params {
{ backup_storage_directory, String, Some("backup".to_string()), true },
{ max_concurrent_creating_streaming_jobs, u32, Some(1_u32), true },
{ pause_on_next_bootstrap, bool, Some(false), true },
{ wasm_storage_url, String, Some("fs://.risingwave/data".to_string()), false },
}
};
}
Expand Down Expand Up @@ -357,6 +358,7 @@ mod tests {
(BACKUP_STORAGE_DIRECTORY_KEY, "a"),
(MAX_CONCURRENT_CREATING_STREAMING_JOBS_KEY, "1"),
(PAUSE_ON_NEXT_BOOTSTRAP_KEY, "false"),
(WASM_STORAGE_URL_KEY, "a"),
("a_deprecated_param", "foo"),
];

Expand Down
4 changes: 4 additions & 0 deletions src/common/src/system_param/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ impl SystemParamsReader {
self.prost.pause_on_next_bootstrap.unwrap_or(false)
}

pub fn wasm_storage_url(&self) -> &str {
self.prost.wasm_storage_url.as_ref().unwrap()
}

pub fn to_kv(&self) -> Vec<(String, String)> {
system_params_to_kv(&self.prost).unwrap()
}
Expand Down
1 change: 1 addition & 0 deletions src/config/example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,4 @@ backup_storage_url = "memory"
backup_storage_directory = "backup"
max_concurrent_creating_streaming_jobs = 1
pause_on_next_bootstrap = false
wasm_storage_url = "fs://.risingwave/data"
3 changes: 3 additions & 0 deletions src/expr/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ normal = ["workspace-hack", "ctor"]
anyhow = "1"
arrow-array = { workspace = true }
arrow-schema = { workspace = true }
arrow-udf-wasm = { workspace = true }
async-trait = "0.1"
auto_impl = "1"
await-tree = { workspace = true }
Expand All @@ -35,11 +36,13 @@ enum-as-inner = "0.6"
futures-async-stream = { workspace = true }
futures-util = "0.3"
itertools = "0.12"
moka = { version = "0.12", features = ["future"] }
num-traits = "0.2"
parse-display = "0.8"
paste = "1"
risingwave_common = { workspace = true }
risingwave_expr_macro = { path = "../macro" }
risingwave_object_store = { workspace = true }
risingwave_pb = { workspace = true }
risingwave_udf = { workspace = true }
smallvec = "1"
Expand Down
4 changes: 2 additions & 2 deletions src/expr/core/src/expr/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use risingwave_pb::expr::expr_node::{PbType, RexNode};
use risingwave_pb::expr::ExprNode;

use super::expr_some_all::SomeAllExpression;
use super::expr_udf::UdfExpression;
use super::expr_udf::UserDefinedFunction;
use super::strict::Strict;
use super::wrapper::checked::Checked;
use super::wrapper::non_strict::NonStrict;
Expand Down Expand Up @@ -104,7 +104,7 @@ where
match prost.get_rex_node()? {
RexNode::InputRef(_) => InputRefExpression::build_boxed(prost, build_child),
RexNode::Constant(_) => LiteralExpression::build_boxed(prost, build_child),
RexNode::Udf(_) => UdfExpression::build_boxed(prost, build_child),
RexNode::Udf(_) => UserDefinedFunction::build_boxed(prost, build_child),

RexNode::FuncCall(_) => match prost.function_type() {
// Dedicated types
Expand Down
Loading

0 comments on commit e8f1eb9

Please sign in to comment.