Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(udf): introduce #[derive(StructType)] for struct types in Rust UDF #15372

Merged
merged 9 commits into from
Mar 14, 2024
335 changes: 131 additions & 204 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ arrow-select = "50"
arrow-ord = "50"
arrow-row = "50"
arrow-udf-js = "0.1"
arrow-udf-wasm = { version = "0.1.2", features = ["build"] }
arrow-udf-wasm = { version = "0.2", features = ["build"] }
arrow-udf-python = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "6c32f71" }
arrow-array-deltalake = { package = "arrow-array", version = "48.0.1" }
arrow-buffer-deltalake = { package = "arrow-buffer", version = "48.0.1" }
Expand Down
1 change: 1 addition & 0 deletions ci/scripts/run-e2e-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ pkill java

echo "--- e2e, $mode, embedded udf"
sqllogictest -p 4566 -d dev './e2e_test/udf/wasm_udf.slt'
sqllogictest -p 4566 -d dev './e2e_test/udf/rust_udf.slt'
sqllogictest -p 4566 -d dev './e2e_test/udf/js_udf.slt'
sqllogictest -p 4566 -d dev './e2e_test/udf/python_udf.slt'

Expand Down
129 changes: 129 additions & 0 deletions e2e_test/udf/rust_udf.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
statement ok
create function int_42() returns int language rust as $$
#[function("int_42() -> int")]
fn int_42() -> i32 {
42
}
$$;

query I
select int_42();
----
42

statement ok
drop function int_42;


statement ok
create function gcd(int, int) returns int language rust as $$
#[function("gcd(int, int) -> int")]
fn gcd(mut a: i32, mut b: i32) -> i32 {
while b != 0 {
(a, b) = (b, a % b);
}
a
}
$$;

query I
select gcd(25, 15);
----
5

statement ok
drop function gcd;


create function decimal_add(a decimal, b decimal) returns decimal language rust as $$
#[function("decimal_add(decimal, decimal) -> decimal")]
fn decimal_add(a: Decimal, b: Decimal) -> Decimal {
a + b
}
$$;

query R
select decimal_add(1.11, 2.22);
----
3.33

statement ok
drop function decimal_add;

statement ok
create function datetime(d date, t time) returns timestamp language rust as $$
#[function("datetime(date, time) -> timestamp")]
fn datetime(date: NaiveDate, time: NaiveTime) -> NaiveDateTime {
NaiveDateTime::new(date, time)
}
$$;

query T
select datetime('2020-01-01', '12:34:56');
----
2020-01-01 12:34:56

statement ok
drop function datetime;


create function jsonb_access(json jsonb, index int) returns jsonb language rust as $$
#[function("jsonb_access(json, int) -> json")]
fn jsonb_access(json: serde_json::Value, index: i32) -> Option<serde_json::Value> {
json.get(index as usize).cloned()
}
$$;

query T
select jsonb_access(a::jsonb, 1) from
(values ('["a", "b", "c"]'), (null), ('[0, false]')) t(a);
----
"b"
NULL
false

statement ok
drop function jsonb_access;


statement ok
create function key_value(varchar) returns struct<key varchar, value varchar> language rust as $$
#[derive(StructType)]
struct KeyValue<'a> {
key: &'a str,
value: &'a str,
}

#[function("key_value(varchar) -> struct KeyValue")]
fn key_value(kv: &str) -> Option<KeyValue<'_>> {
let (key, value) = kv.split_once('=')?;
Some(KeyValue { key, value })
}
$$;

query T
select key_value('a=1');
----
(a,1)

statement ok
drop function key_value;


statement ok
create function series(n int) returns table (x int) language rust as $$
#[function("series(int) -> setof int")]
fn series(n: i32) -> impl Iterator<Item = i32> {
(0..n).into_iter()
}
$$;

query I
select series(3);
----
0
1
2

statement ok
drop function series;
2 changes: 1 addition & 1 deletion e2e_test/udf/wasm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ edition = "2021"
crate-type = ["cdylib"]

[dependencies]
arrow-udf = "0.1"
arrow-udf = "0.2"
genawaiter = "0.99"
rust_decimal = "1"
serde_json = "1"
25 changes: 17 additions & 8 deletions e2e_test/udf/wasm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use arrow_udf::function;
use arrow_udf::types::StructType;
use rust_decimal::Decimal;

#[function("count_char(varchar, varchar) -> int")]
Expand Down Expand Up @@ -55,18 +56,26 @@ fn length(s: impl AsRef<[u8]>) -> i32 {
s.as_ref().len() as i32
}

#[function("extract_tcp_info(bytea) -> struct<src_addr:varchar,dst_addr:varchar,src_port:smallint,dst_port:smallint>")]
fn extract_tcp_info(tcp_packet: &[u8]) -> (String, String, i16, i16) {
#[derive(StructType)]
struct TcpInfo {
src_addr: String,
dst_addr: String,
src_port: i16,
dst_port: i16,
}

#[function("extract_tcp_info(bytea) -> struct TcpInfo")]
fn extract_tcp_info(tcp_packet: &[u8]) -> TcpInfo {
let src_addr = std::net::Ipv4Addr::from(<[u8; 4]>::try_from(&tcp_packet[12..16]).unwrap());
let dst_addr = std::net::Ipv4Addr::from(<[u8; 4]>::try_from(&tcp_packet[16..20]).unwrap());
let src_port = u16::from_be_bytes(<[u8; 2]>::try_from(&tcp_packet[20..22]).unwrap());
let dst_port = u16::from_be_bytes(<[u8; 2]>::try_from(&tcp_packet[22..24]).unwrap());
(
src_addr.to_string(),
dst_addr.to_string(),
src_port as i16,
dst_port as i16,
)
TcpInfo {
src_addr: src_addr.to_string(),
dst_addr: dst_addr.to_string(),
src_port: src_port as i16,
dst_port: dst_port as i16,
}
}

#[function("decimal_add(decimal, decimal) -> decimal")]
Expand Down
21 changes: 0 additions & 21 deletions e2e_test/udf/wasm_udf.slt
Original file line number Diff line number Diff line change
Expand Up @@ -91,24 +91,3 @@ drop function jsonb_access;

statement ok
drop function series;

# inlined rust function
statement ok
create function gcd(int, int) returns int language rust as $$
fn gcd(mut a: i32, mut b: i32) -> i32 {
while b != 0 {
let t = b;
b = a % b;
a = t;
}
a
}
$$;

query I
select gcd(25, 15);
----
5

statement ok
drop function gcd;
85 changes: 47 additions & 38 deletions src/frontend/src/handler/create_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,43 +187,43 @@ pub async fn handle_create_function(
);
}
"rust" => {
identifier = wasm_identifier(
&function_name,
&arg_types,
&return_type,
matches!(kind, Kind::Table(_)),
);
if params.using.is_some() {
return Err(ErrorCode::InvalidParameterValue(
"USING is not supported for rust function".to_string(),
)
.into());
}
let function_body = params
let script = params
.as_
.ok_or_else(|| ErrorCode::InvalidParameterValue("AS must be specified".into()))?
.into_string();
let script = format!("#[arrow_udf::function(\"{identifier}\")]\n{function_body}");
body = Some(function_body.clone());
let script = format!("use arrow_udf::{{function, types::*}};\n{}", script);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to the introduction of the new struct syntax, users now need to write #[function(...)] explicitly (previously we would prepend this macro for them).

What about checks whether user uses #[function in the script, and decide whether to added this for them? 🤡 In this way, we can keep the simple old syntax.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds a little hacky. But I like this idea. 🤡

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Final rule:

  • if the return type is a struct, they have to write #[function] by themselves.
  • otherwise, we add it for them.

body = Some(script.clone());

let wasm_binary = tokio::task::spawn_blocking(move || {
let mut opts = arrow_udf_wasm::build::BuildOpts::default();
opts.script = script;
// use a fixed tempdir to reuse the build cache
opts.tempdir = Some(std::env::temp_dir().join("risingwave-rust-udf"));

let wasm_binary =
tokio::task::spawn_blocking(move || arrow_udf_wasm::build::build("", &script))
.await?
.context("failed to build rust function")?;
arrow_udf_wasm::build::build_with(&opts)
})
.await?
.context("failed to build rust function")?;

// below is the same as `wasm` language
let runtime = get_or_create_wasm_runtime(&wasm_binary)?;
check_wasm_function(&runtime, &identifier)?;

compressed_binary = Some(zstd::stream::encode_all(wasm_binary.as_slice(), 0)?);
}
"wasm" => {
identifier = wasm_identifier(
identifier = get_wasm_identifier(
&runtime,
&function_name,
&arg_types,
&return_type,
matches!(kind, Kind::Table(_)),
);
)?;

compressed_binary = Some(zstd::stream::encode_all(wasm_binary.as_slice(), 0)?);
}
"wasm" => {
let Some(using) = params.using else {
return Err(ErrorCode::InvalidParameterValue(
"USING must be specified".to_string(),
Expand All @@ -242,7 +242,13 @@ pub async fn handle_create_function(
}
};
let runtime = get_or_create_wasm_runtime(&wasm_binary)?;
check_wasm_function(&runtime, &identifier)?;
identifier = get_wasm_identifier(
&runtime,
&function_name,
&arg_types,
&return_type,
matches!(kind, Kind::Table(_)),
)?;

compressed_binary = Some(zstd::stream::encode_all(wasm_binary.as_ref(), 0)?);
}
Expand Down Expand Up @@ -288,28 +294,31 @@ async fn download_binary_from_link(link: &str) -> Result<Bytes> {
}
}

/// Check if the function exists in the wasm binary.
fn check_wasm_function(runtime: &arrow_udf_wasm::Runtime, identifier: &str) -> Result<()> {
if !runtime.functions().contains(&identifier) {
return Err(ErrorCode::InvalidParameterValue(format!(
"function not found in wasm binary: \"{}\"\nHINT: available functions:\n {}",
identifier,
runtime.functions().join("\n ")
))
.into());
}
Ok(())
}

/// Generate the function identifier in wasm binary.
fn wasm_identifier(name: &str, args: &[DataType], ret: &DataType, table_function: bool) -> String {
format!(
/// Get the function identifier in wasm binary.
fn get_wasm_identifier(
runtime: &arrow_udf_wasm::Runtime,
name: &str,
args: &[DataType],
ret: &DataType,
table_function: bool,
) -> Result<String> {
let inlined_signature = format!(
"{}({}){}{}",
name,
args.iter().map(datatype_name).join(","),
if table_function { "->>" } else { "->" },
datatype_name(ret)
)
);
let identifier = runtime
.find_function_by_inlined_signature(&inlined_signature)
.ok_or_else(|| {
ErrorCode::InvalidParameterValue(format!(
"function not found in wasm binary: \"{}\"\nHINT: available functions:\n {}",
inlined_signature,
runtime.functions().join("\n ")
))
})?;
Ok(identifier.into())
}

/// Convert a data type to string used in identifier.
Expand Down
2 changes: 0 additions & 2 deletions src/workspace-hack/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ reqwest = { version = "0.11", features = ["blocking", "json", "rustls-tls"] }
ring = { version = "0.16", features = ["std"] }
rust_decimal = { version = "1", features = ["db-postgres", "maths"] }
rustc-hash = { version = "1" }
rustix = { version = "0.38", features = ["fs", "net"] }
scopeguard = { version = "1" }
sea-orm = { version = "0.12", features = ["runtime-tokio-native-tls", "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite"] }
sea-query = { version = "0.30", default-features = false, features = ["backend-mysql", "backend-postgres", "backend-sqlite", "derive", "hashable-value", "postgres-array", "thread-safe", "with-bigdecimal", "with-chrono", "with-json", "with-rust_decimal", "with-time", "with-uuid"] }
Expand Down Expand Up @@ -197,7 +196,6 @@ regex = { version = "1" }
regex-automata = { version = "0.4", default-features = false, features = ["dfa", "hybrid", "meta", "nfa", "perf", "unicode"] }
regex-syntax = { version = "0.8" }
rustc-hash = { version = "1" }
rustix = { version = "0.38", features = ["fs", "net"] }
serde = { version = "1", features = ["alloc", "derive", "rc"] }
serde_json = { version = "1", features = ["alloc", "raw_value"] }
sha2 = { version = "0.10", features = ["oid"] }
Expand Down
Loading