Skip to content

Commit

Permalink
feat(udf): introduce #[derive(StructType)] for struct types in Rust…
Browse files Browse the repository at this point in the history
… UDF (#15372)

Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored Mar 14, 2024
1 parent 2993ab6 commit 2a03233
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 272 deletions.
335 changes: 131 additions & 204 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ arrow-select = "50"
arrow-ord = "50"
arrow-row = "50"
arrow-udf-js = "0.1"
arrow-udf-wasm = { version = "0.1.2", features = ["build"] }
arrow-udf-wasm = { version = "0.2", features = ["build"] }
arrow-udf-python = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "6c32f71" }
arrow-array-deltalake = { package = "arrow-array", version = "48.0.1" }
arrow-buffer-deltalake = { package = "arrow-buffer", version = "48.0.1" }
Expand Down
1 change: 1 addition & 0 deletions ci/scripts/run-e2e-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ pkill java

echo "--- e2e, $mode, embedded udf"
sqllogictest -p 4566 -d dev './e2e_test/udf/wasm_udf.slt'
sqllogictest -p 4566 -d dev './e2e_test/udf/rust_udf.slt'
sqllogictest -p 4566 -d dev './e2e_test/udf/js_udf.slt'
sqllogictest -p 4566 -d dev './e2e_test/udf/python_udf.slt'

Expand Down
122 changes: 122 additions & 0 deletions e2e_test/udf/rust_udf.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
statement ok
create function int_42() returns int language rust as $$
fn int_42() -> i32 {
42
}
$$;

query I
select int_42();
----
42

statement ok
drop function int_42;


statement ok
create function gcd(int, int) returns int language rust as $$
fn gcd(mut a: i32, mut b: i32) -> i32 {
while b != 0 {
(a, b) = (b, a % b);
}
a
}
$$;

query I
select gcd(25, 15);
----
5

statement ok
drop function gcd;

statement ok
create function decimal_add(a decimal, b decimal) returns decimal language rust as $$
fn decimal_add(a: Decimal, b: Decimal) -> Decimal {
a + b
}
$$;

query R
select decimal_add(1.11, 2.22);
----
3.33

statement ok
drop function decimal_add;

statement ok
create function datetime(d date, t time) returns timestamp language rust as $$
fn datetime(date: NaiveDate, time: NaiveTime) -> NaiveDateTime {
NaiveDateTime::new(date, time)
}
$$;

query T
select datetime('2020-01-01', '12:34:56');
----
2020-01-01 12:34:56

statement ok
drop function datetime;

statement ok
create function jsonb_access(json jsonb, index int) returns jsonb language rust as $$
fn jsonb_access(json: serde_json::Value, index: i32) -> Option<serde_json::Value> {
json.get(index as usize).cloned()
}
$$;

query T
select jsonb_access(a::jsonb, 1) from
(values ('["a", "b", "c"]'), (null), ('[0, false]')) t(a);
----
"b"
NULL
false

statement ok
drop function jsonb_access;


statement ok
create function key_value(varchar) returns struct<key varchar, value varchar> language rust as $$
#[derive(StructType)]
struct KeyValue<'a> {
key: &'a str,
value: &'a str,
}
#[function("key_value(varchar) -> struct KeyValue")]
fn key_value(kv: &str) -> Option<KeyValue<'_>> {
let (key, value) = kv.split_once('=')?;
Some(KeyValue { key, value })
}
$$;

query T
select key_value('a=1');
----
(a,1)

statement ok
drop function key_value;


statement ok
create function series(n int) returns table (x int) language rust as $$
fn series(n: i32) -> impl Iterator<Item = i32> {
(0..n).into_iter()
}
$$;

query I
select series(3);
----
0
1
2

statement ok
drop function series;
2 changes: 1 addition & 1 deletion e2e_test/udf/wasm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ edition = "2021"
crate-type = ["cdylib"]

[dependencies]
arrow-udf = "0.1"
arrow-udf = "0.2"
genawaiter = "0.99"
rust_decimal = "1"
serde_json = "1"
25 changes: 17 additions & 8 deletions e2e_test/udf/wasm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use arrow_udf::function;
use arrow_udf::types::StructType;
use rust_decimal::Decimal;

#[function("count_char(varchar, varchar) -> int")]
Expand Down Expand Up @@ -55,18 +56,26 @@ fn length(s: impl AsRef<[u8]>) -> i32 {
s.as_ref().len() as i32
}

#[function("extract_tcp_info(bytea) -> struct<src_addr:varchar,dst_addr:varchar,src_port:smallint,dst_port:smallint>")]
fn extract_tcp_info(tcp_packet: &[u8]) -> (String, String, i16, i16) {
#[derive(StructType)]
struct TcpInfo {
src_addr: String,
dst_addr: String,
src_port: i16,
dst_port: i16,
}

#[function("extract_tcp_info(bytea) -> struct TcpInfo")]
fn extract_tcp_info(tcp_packet: &[u8]) -> TcpInfo {
let src_addr = std::net::Ipv4Addr::from(<[u8; 4]>::try_from(&tcp_packet[12..16]).unwrap());
let dst_addr = std::net::Ipv4Addr::from(<[u8; 4]>::try_from(&tcp_packet[16..20]).unwrap());
let src_port = u16::from_be_bytes(<[u8; 2]>::try_from(&tcp_packet[20..22]).unwrap());
let dst_port = u16::from_be_bytes(<[u8; 2]>::try_from(&tcp_packet[22..24]).unwrap());
(
src_addr.to_string(),
dst_addr.to_string(),
src_port as i16,
dst_port as i16,
)
TcpInfo {
src_addr: src_addr.to_string(),
dst_addr: dst_addr.to_string(),
src_port: src_port as i16,
dst_port: dst_port as i16,
}
}

#[function("decimal_add(decimal, decimal) -> decimal")]
Expand Down
21 changes: 0 additions & 21 deletions e2e_test/udf/wasm_udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -91,24 +91,3 @@ drop function jsonb_access;

statement ok
drop function series;

# inlined rust function
statement ok
create function gcd(int, int) returns int language rust as $$
fn gcd(mut a: i32, mut b: i32) -> i32 {
while b != 0 {
let t = b;
b = a % b;
a = t;
}
a
}
$$;

query I
select gcd(25, 15);
----
5

statement ok
drop function gcd;
111 changes: 76 additions & 35 deletions src/frontend/src/handler/create_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,43 +187,52 @@ pub async fn handle_create_function(
);
}
"rust" => {
identifier = wasm_identifier(
&function_name,
&arg_types,
&return_type,
matches!(kind, Kind::Table(_)),
);
if params.using.is_some() {
return Err(ErrorCode::InvalidParameterValue(
"USING is not supported for rust function".to_string(),
)
.into());
}
let function_body = params
let identifier_v1 = wasm_identifier_v1(
&function_name,
&arg_types,
&return_type,
matches!(kind, Kind::Table(_)),
);
// if the function returns a struct, users need to add `#[function]` macro by themselves.
// otherwise, we add it automatically. the code should start with `fn ...`.
let function_macro = if return_type.is_struct() {
String::new()
} else {
format!("#[function(\"{}\")]", identifier_v1)
};
let script = params
.as_
.ok_or_else(|| ErrorCode::InvalidParameterValue("AS must be specified".into()))?
.into_string();
let script = format!("#[arrow_udf::function(\"{identifier}\")]\n{function_body}");
body = Some(function_body.clone());
let script = format!(
"use arrow_udf::{{function, types::*}};\n{}\n{}",
function_macro, script
);
body = Some(script.clone());

let wasm_binary =
tokio::task::spawn_blocking(move || arrow_udf_wasm::build::build("", &script))
.await?
.context("failed to build rust function")?;
let wasm_binary = tokio::task::spawn_blocking(move || {
let mut opts = arrow_udf_wasm::build::BuildOpts::default();
opts.script = script;
// use a fixed tempdir to reuse the build cache
opts.tempdir = Some(std::env::temp_dir().join("risingwave-rust-udf"));

arrow_udf_wasm::build::build_with(&opts)
})
.await?
.context("failed to build rust function")?;

// below is the same as `wasm` language
let runtime = get_or_create_wasm_runtime(&wasm_binary)?;
check_wasm_function(&runtime, &identifier)?;
identifier = find_wasm_identifier_v2(&runtime, &identifier_v1)?;

compressed_binary = Some(zstd::stream::encode_all(wasm_binary.as_slice(), 0)?);
}
"wasm" => {
identifier = wasm_identifier(
&function_name,
&arg_types,
&return_type,
matches!(kind, Kind::Table(_)),
);
let Some(using) = params.using else {
return Err(ErrorCode::InvalidParameterValue(
"USING must be specified".to_string(),
Expand All @@ -242,7 +251,13 @@ pub async fn handle_create_function(
}
};
let runtime = get_or_create_wasm_runtime(&wasm_binary)?;
check_wasm_function(&runtime, &identifier)?;
let identifier_v1 = wasm_identifier_v1(
&function_name,
&arg_types,
&return_type,
matches!(kind, Kind::Table(_)),
);
identifier = find_wasm_identifier_v2(&runtime, &identifier_v1)?;

compressed_binary = Some(zstd::stream::encode_all(wasm_binary.as_ref(), 0)?);
}
Expand Down Expand Up @@ -288,21 +303,47 @@ async fn download_binary_from_link(link: &str) -> Result<Bytes> {
}
}

/// Check if the function exists in the wasm binary.
fn check_wasm_function(runtime: &arrow_udf_wasm::Runtime, identifier: &str) -> Result<()> {
if !runtime.functions().contains(&identifier) {
return Err(ErrorCode::InvalidParameterValue(format!(
"function not found in wasm binary: \"{}\"\nHINT: available functions:\n {}",
identifier,
runtime.functions().join("\n ")
))
.into());
}
Ok(())
/// Convert a v0.1 function identifier to v0.2 format.
///
/// In arrow-udf v0.1 format, struct type is inline in the identifier. e.g.
///
/// ```text
/// keyvalue(varchar,varchar)->struct<key:varchar,value:varchar>
/// ```
///
/// However, since arrow-udf v0.2, struct type is no longer inline.
/// The above identifier is divided into a function and a type.
///
/// ```text
/// keyvalue(varchar,varchar)->struct KeyValue
/// KeyValue=key:varchar,value:varchar
/// ```
///
/// For compatibility, we should call `find_wasm_identifier_v2` to
/// convert v0.1 identifiers to v0.2 format before looking up the function.
fn find_wasm_identifier_v2(
runtime: &arrow_udf_wasm::Runtime,
inlined_signature: &str,
) -> Result<String> {
let identifier = runtime
.find_function_by_inlined_signature(inlined_signature)
.ok_or_else(|| {
ErrorCode::InvalidParameterValue(format!(
"function not found in wasm binary: \"{}\"\nHINT: available functions:\n {}",
inlined_signature,
runtime.functions().join("\n ")
))
})?;
Ok(identifier.into())
}

/// Generate the function identifier in wasm binary.
fn wasm_identifier(name: &str, args: &[DataType], ret: &DataType, table_function: bool) -> String {
/// Generate a function identifier in v0.1 format from the function signature.
fn wasm_identifier_v1(
name: &str,
args: &[DataType],
ret: &DataType,
table_function: bool,
) -> String {
format!(
"{}({}){}{}",
name,
Expand Down
2 changes: 0 additions & 2 deletions src/workspace-hack/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ reqwest = { version = "0.11", features = ["blocking", "json", "rustls-tls"] }
ring = { version = "0.16", features = ["std"] }
rust_decimal = { version = "1", features = ["db-postgres", "maths"] }
rustc-hash = { version = "1" }
rustix = { version = "0.38", features = ["fs", "net"] }
scopeguard = { version = "1" }
sea-orm = { version = "0.12", features = ["runtime-tokio-native-tls", "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite"] }
sea-query = { version = "0.30", default-features = false, features = ["backend-mysql", "backend-postgres", "backend-sqlite", "derive", "hashable-value", "postgres-array", "thread-safe", "with-bigdecimal", "with-chrono", "with-json", "with-rust_decimal", "with-time", "with-uuid"] }
Expand Down Expand Up @@ -202,7 +201,6 @@ regex = { version = "1" }
regex-automata = { version = "0.4", default-features = false, features = ["dfa", "hybrid", "meta", "nfa", "perf", "unicode"] }
regex-syntax = { version = "0.8" }
rustc-hash = { version = "1" }
rustix = { version = "0.38", features = ["fs", "net"] }
scopeguard = { version = "1" }
serde = { version = "1", features = ["alloc", "derive", "rc"] }
serde_json = { version = "1", features = ["alloc", "raw_value"] }
Expand Down

0 comments on commit 2a03233

Please sign in to comment.