Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(udf): add initial support for WASM-based Rust UDF (#14271) #14714

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,164 changes: 1,144 additions & 20 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ members = [
"src/utils/workspace-config",
"src/workspace-hack",
]
exclude = ["lints"]
exclude = ["e2e_test/udf/wasm", "lints"]
resolver = "2"

[workspace.package]
Expand Down Expand Up @@ -131,6 +131,7 @@ arrow-flight = "49"
arrow-select = "49"
arrow-ord = "49"
arrow-row = "49"
arrow-udf-wasm = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "f9a9e0d" }
arrow-array-deltalake = { package = "arrow-array", version = "48.0.1" }
arrow-buffer-deltalake = { package = "arrow-buffer", version = "48.0.1" }
arrow-cast-deltalake = { package = "arrow-cast", version = "48.0.1" }
Expand Down
8 changes: 8 additions & 0 deletions ci/scripts/build-other.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ set -euo pipefail
source ci/scripts/common.sh


echo "--- Build Rust UDF"
cd e2e_test/udf/wasm
rustup target add wasm32-wasi
cargo build --release
cd ../../..

echo "--- Build Java packages"
cd java
mvn -B package -Dmaven.test.skip=true
Expand All @@ -26,6 +32,8 @@ tar --zstd -cf java-binding-integration-test.tar.zst bin java/java-binding-integ
echo "--- Upload Java artifacts"
cp java/connector-node/assembly/target/risingwave-connector-1.0.0.tar.gz ./risingwave-connector.tar.gz
cp java/udf-example/target/risingwave-udf-example.jar ./risingwave-udf-example.jar
cp e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm udf.wasm
buildkite-agent artifact upload ./risingwave-connector.tar.gz
buildkite-agent artifact upload ./risingwave-udf-example.jar
buildkite-agent artifact upload ./java-binding-integration-test.tar.zst
buildkite-agent artifact upload ./udf.wasm
5 changes: 5 additions & 0 deletions ci/scripts/run-e2e-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ download_and_prepare_rw "$profile" common
echo "--- Download artifacts"
download-and-decompress-artifact e2e_test_generated ./
download-and-decompress-artifact risingwave_e2e_extended_mode_test-"$profile" target/debug/
mkdir -p e2e_test/udf/wasm/target/wasm32-wasi/release/
buildkite-agent artifact download udf.wasm e2e_test/udf/wasm/target/wasm32-wasi/release/
buildkite-agent artifact download risingwave-udf-example.jar ./
mv target/debug/risingwave_e2e_extended_mode_test-"$profile" target/debug/risingwave_e2e_extended_mode_test

Expand Down Expand Up @@ -97,6 +99,9 @@ sleep 1
sqllogictest -p 4566 -d dev './e2e_test/udf/udf.slt'
pkill java

echo "--- e2e, $mode, wasm udf"
sqllogictest -p 4566 -d dev './e2e_test/udf/wasm_udf.slt'

echo "--- Kill cluster"
cluster_stop

Expand Down
2 changes: 2 additions & 0 deletions e2e_test/udf/wasm/.cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[build]
target = "wasm32-wasi"
2 changes: 2 additions & 0 deletions e2e_test/udf/wasm/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Cargo.lock
target
14 changes: 14 additions & 0 deletions e2e_test/udf/wasm/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "udf"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[lib]
crate-type = ["cdylib"]

[dependencies]
arrow-udf = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "f9a9e0d" }
genawaiter = "0.99"
rust_decimal = "1"
serde_json = "1"
80 changes: 80 additions & 0 deletions e2e_test/udf/wasm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use arrow_udf::function;
use rust_decimal::Decimal;

#[function("int_42() -> int")]
fn int_42() -> i32 {
42
}

#[function("gcd(int, int) -> int")]
fn gcd(mut a: i32, mut b: i32) -> i32 {
while b != 0 {
let t = b;
b = a % b;
a = t;
}
a
}

#[function("gcd(int, int, int) -> int")]
fn gcd3(a: i32, b: i32, c: i32) -> i32 {
gcd(gcd(a, b), c)
}

#[function("sleep(int) -> int")]
fn sleep(second: i32) -> i32 {
std::thread::sleep(std::time::Duration::from_secs(second as u64));
0
}

#[function("segfault() -> int")]
fn segfault() -> i32 {
unsafe { (usize::MAX as *const i32).read_volatile() }
}

#[function("oom() -> int")]
fn oom() -> i32 {
_ = vec![0u8; usize::MAX];
0
}

#[function("create_file() -> int")]
fn create_file() -> i32 {
std::fs::File::create("test").unwrap();
0
}

#[function("length(varchar) -> int")]
#[function("length(bytea) -> int")]
fn length(s: impl AsRef<[u8]>) -> i32 {
s.as_ref().len() as i32
}

#[function("extract_tcp_info(bytea) -> struct<src_addr:varchar,dst_addr:varchar,src_port:smallint,dst_port:smallint>")]
fn extract_tcp_info(tcp_packet: &[u8]) -> (String, String, i16, i16) {
let src_addr = std::net::Ipv4Addr::from(<[u8; 4]>::try_from(&tcp_packet[12..16]).unwrap());
let dst_addr = std::net::Ipv4Addr::from(<[u8; 4]>::try_from(&tcp_packet[16..20]).unwrap());
let src_port = u16::from_be_bytes(<[u8; 2]>::try_from(&tcp_packet[20..22]).unwrap());
let dst_port = u16::from_be_bytes(<[u8; 2]>::try_from(&tcp_packet[22..24]).unwrap());
(
src_addr.to_string(),
dst_addr.to_string(),
src_port as i16,
dst_port as i16,
)
}

#[function("decimal_add(decimal, decimal) -> decimal")]
fn decimal_add(a: Decimal, b: Decimal) -> Decimal {
a + b
}

#[function("jsonb_access(json, int) -> json")]
fn jsonb_access(json: serde_json::Value, index: i32) -> Option<serde_json::Value> {
json.get(index as usize).cloned()
}

#[function("series(int) -> setof int")]
fn series(n: i32) -> impl Iterator<Item = i32> {
0..n
}
93 changes: 93 additions & 0 deletions e2e_test/udf/wasm_udf.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Before running this test:
# cd e2e_test/udf/wasm && cargo build --release

statement ok
create function int_42() returns int
language wasm using link 'fs://e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm';

statement ok
create function gcd(int, int) returns int
language wasm using link 'fs://e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm';

statement ok
create function gcd(int, int, int) returns int
language wasm using link 'fs://e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm';

statement ok
create function extract_tcp_info(bytea) returns struct<src_addr varchar, dst_addr varchar, src_port smallint, dst_port smallint>
language wasm using link 'fs://e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm';

statement ok
create function decimal_add(decimal, decimal) returns decimal
language wasm using link 'fs://e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm';

statement ok
create function jsonb_access(jsonb, int) returns jsonb
language wasm using link 'fs://e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm';

statement ok
create function series(int) returns table (x int)
language wasm using link 'fs://e2e_test/udf/wasm/target/wasm32-wasi/release/udf.wasm';

query I
select int_42();
----
42

query I
select gcd(25, 15);
----
5

query I
select gcd(25, 15, 3);
----
1

query T
select extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4' :: bytea);
----
(192.168.0.14,192.168.0.1,861,8374)

query R
select decimal_add(1.11, 2.22);
----
3.33

query T
select jsonb_access(a::jsonb, 1) from
(values ('["a", "b", "c"]'), (null), ('[0, false]')) t(a);
----
"b"
NULL
false

query I
select series(5);
----
0
1
2
3
4

statement ok
drop function int_42;

statement ok
drop function gcd(int,int);

statement ok
drop function gcd(int,int,int);

statement ok
drop function extract_tcp_info;

statement ok
drop function decimal_add;

statement ok
drop function jsonb_access;

statement ok
drop function series;
10 changes: 10 additions & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ message ExprNode {
PG_GET_INDEXDEF = 2400;
COL_DESCRIPTION = 2401;
}
// Only use this field for function call. For other types of expression, it should be UNSPECIFIED.
Type function_type = 1;
data.DataType return_type = 3;
oneof rex_node {
Expand Down Expand Up @@ -460,15 +461,24 @@ message WindowFunction {
WindowFrame frame = 5;
}

// Note: due to historic reasons, UserDefinedFunction is a oneof variant parallel to FunctionCall,
// while UserDefinedTableFunction is embedded as a field in TableFunction.

message UserDefinedFunction {
repeated ExprNode children = 1;
string name = 2;
repeated data.DataType arg_types = 3;
string language = 4;
// For external UDF: the link to the external function service.
// For WASM UDF: the link to the wasm binary file.
string link = 5;
// An unique identifier for the function.
// For external UDF, it's the name of the function in the external function service.
// For WASM UDF, it's the name of the function in the wasm binary file.
string identifier = 6;
}

// Additional information for user defined table functions.
message UserDefinedTableFunction {
repeated data.DataType arg_types = 3;
string language = 4;
Expand Down
1 change: 1 addition & 0 deletions proto/meta.proto
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ message SystemParams {
optional uint32 parallel_compact_size_mb = 11;
optional uint32 max_concurrent_creating_streaming_jobs = 12;
optional bool pause_on_next_bootstrap = 13;
optional string wasm_storage_url = 14;
}

message GetSystemParamsRequest {}
Expand Down
4 changes: 4 additions & 0 deletions src/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,9 @@ pub struct SystemConfig {
/// Whether to pause all data sources on next bootstrap.
#[serde(default = "default::system::pause_on_next_bootstrap")]
pub pause_on_next_bootstrap: Option<bool>,

#[serde(default = "default::system::wasm_storage_url")]
pub wasm_storage_url: Option<String>,
}

/// The subsections `[storage.object_store]`.
Expand Down Expand Up @@ -927,6 +930,7 @@ impl SystemConfig {
max_concurrent_creating_streaming_jobs: self.max_concurrent_creating_streaming_jobs,
pause_on_next_bootstrap: self.pause_on_next_bootstrap,
telemetry_enabled: None, // deprecated
wasm_storage_url: self.wasm_storage_url,
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/common/src/system_param/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ macro_rules! for_all_params {
{ backup_storage_directory, String, Some("backup".to_string()), true },
{ max_concurrent_creating_streaming_jobs, u32, Some(1_u32), true },
{ pause_on_next_bootstrap, bool, Some(false), true },
{ wasm_storage_url, String, Some("fs://.risingwave/data".to_string()), false },
}
};
}
Expand Down Expand Up @@ -357,6 +358,7 @@ mod tests {
(BACKUP_STORAGE_DIRECTORY_KEY, "a"),
(MAX_CONCURRENT_CREATING_STREAMING_JOBS_KEY, "1"),
(PAUSE_ON_NEXT_BOOTSTRAP_KEY, "false"),
(WASM_STORAGE_URL_KEY, "a"),
("a_deprecated_param", "foo"),
];

Expand Down
4 changes: 4 additions & 0 deletions src/common/src/system_param/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ impl SystemParamsReader {
self.prost.pause_on_next_bootstrap.unwrap_or(false)
}

pub fn wasm_storage_url(&self) -> &str {
self.prost.wasm_storage_url.as_ref().unwrap()
}

pub fn to_kv(&self) -> Vec<(String, String)> {
system_params_to_kv(&self.prost).unwrap()
}
Expand Down
1 change: 1 addition & 0 deletions src/config/example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,4 @@ backup_storage_url = "memory"
backup_storage_directory = "backup"
max_concurrent_creating_streaming_jobs = 1
pause_on_next_bootstrap = false
wasm_storage_url = "fs://.risingwave/data"
3 changes: 3 additions & 0 deletions src/expr/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ normal = ["workspace-hack", "ctor"]
anyhow = "1"
arrow-array = { workspace = true }
arrow-schema = { workspace = true }
arrow-udf-wasm = { workspace = true }
async-trait = "0.1"
auto_impl = "1"
await-tree = { workspace = true }
Expand All @@ -35,11 +36,13 @@ enum-as-inner = "0.6"
futures-async-stream = { workspace = true }
futures-util = "0.3"
itertools = "0.12"
moka = { version = "0.12", features = ["future"] }
num-traits = "0.2"
parse-display = "0.8"
paste = "1"
risingwave_common = { workspace = true }
risingwave_expr_macro = { path = "../macro" }
risingwave_object_store = { workspace = true }
risingwave_pb = { workspace = true }
risingwave_udf = { workspace = true }
smallvec = "1"
Expand Down
4 changes: 2 additions & 2 deletions src/expr/core/src/expr/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use risingwave_pb::expr::expr_node::{PbType, RexNode};
use risingwave_pb::expr::ExprNode;

use super::expr_some_all::SomeAllExpression;
use super::expr_udf::UdfExpression;
use super::expr_udf::UserDefinedFunction;
use super::strict::Strict;
use super::wrapper::checked::Checked;
use super::wrapper::non_strict::NonStrict;
Expand Down Expand Up @@ -104,7 +104,7 @@ where
match prost.get_rex_node()? {
RexNode::InputRef(_) => InputRefExpression::build_boxed(prost, build_child),
RexNode::Constant(_) => LiteralExpression::build_boxed(prost, build_child),
RexNode::Udf(_) => UdfExpression::build_boxed(prost, build_child),
RexNode::Udf(_) => UserDefinedFunction::build_boxed(prost, build_child),

RexNode::FuncCall(_) => match prost.function_type() {
// Dedicated types
Expand Down
Loading