Skip to content

Commit

Permalink
feat(udf): add initial support for JavaScript UDF (#14513)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
Co-authored-by: wangrunji0408 <[email protected]>
  • Loading branch information
wangrunji0408 and wangrunji0408 authored Jan 22, 2024
1 parent 3b8c942 commit 705be19
Show file tree
Hide file tree
Showing 21 changed files with 538 additions and 158 deletions.
236 changes: 139 additions & 97 deletions Cargo.lock

Large diffs are not rendered by default.

23 changes: 12 additions & 11 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,17 @@ prost = { version = "0.12" }
icelake = { git = "https://github.com/icelake-io/icelake", rev = "32c0bbf242f5c47b1e743f10577012fe7436c770", features = [
"prometheus",
] }
arrow-array = "49"
arrow-arith = "49"
arrow-cast = "49"
arrow-schema = "49"
arrow-buffer = "49"
arrow-flight = "49"
arrow-select = "49"
arrow-ord = "49"
arrow-row = "49"
arrow-udf-wasm = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "f9a9e0d" }
arrow-array = "50"
arrow-arith = "50"
arrow-cast = "50"
arrow-schema = "50"
arrow-buffer = "50"
arrow-flight = "50"
arrow-select = "50"
arrow-ord = "50"
arrow-row = "50"
arrow-udf-js = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "7ba1c22" }
arrow-udf-wasm = "0.1"
arrow-array-deltalake = { package = "arrow-array", version = "48.0.1" }
arrow-buffer-deltalake = { package = "arrow-buffer", version = "48.0.1" }
arrow-cast-deltalake = { package = "arrow-cast", version = "48.0.1" }
Expand All @@ -143,7 +144,7 @@ arrow-schema-deltalake = { package = "arrow-schema", version = "48.0.1" }
deltalake = { git = "https://github.com/risingwavelabs/delta-rs", rev = "5c2dccd4640490202ffe98adbd13b09cef8e007b", features = [
"s3-no-concurrent-write",
] }
parquet = "49"
parquet = "50"
thiserror-ext = "0.0.11"
tikv-jemalloc-ctl = { git = "https://github.com/risingwavelabs/jemallocator.git", rev = "64a2d9" }
tikv-jemallocator = { git = "https://github.com/risingwavelabs/jemallocator.git", features = [
Expand Down
154 changes: 154 additions & 0 deletions e2e_test/udf/js_udf.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
statement ok
create function int_42() returns int language javascript as $$
return 42;
$$;

query I
select int_42();
----
42

statement ok
drop function int_42;


statement ok
create function gcd(a int, b int) returns int language javascript as $$
// required before we support `RETURNS NULL ON NULL INPUT`
if(a == null || b == null) {
return null;
}
while (b != 0) {
let t = b;
b = a % b;
a = t;
}
return a;
$$;

query I
select gcd(25, 15);
----
5

statement ok
drop function gcd;


statement ok
create function decimal_add(a decimal, b decimal) returns decimal language javascript as $$
return a + b;
$$;

query R
select decimal_add(1.11, 2.22);
----
3.33

statement ok
drop function decimal_add;


statement ok
create function to_string(a boolean, b smallint, c int, d bigint, e real, f float, g decimal, h varchar, i bytea, j jsonb) returns varchar language javascript as $$
return a.toString() + b.toString() + c.toString() + d.toString() + e.toString() + f.toString() + g.toString() + h.toString() + i.toString() + JSON.stringify(j);
$$;

query T
select to_string(false, 1::smallint, 2, 3, 4.5, 6.7, 8.9, 'abc', '\x010203', '{"key": 1}');
----
false1234.56.78.9abc1,2,3{"key":1}

statement ok
drop function to_string;


# show data types in javascript
statement ok
create function js_typeof(a boolean, b smallint, c int, d bigint, e real, f float, g decimal, h varchar, i bytea, j jsonb) returns jsonb language javascript as $$
return {
boolean: typeof a,
smallint: typeof b,
int: typeof c,
bigint: typeof d,
real: typeof e,
float: typeof f,
decimal: typeof g,
varchar: typeof h,
bytea: typeof i,
jsonb: typeof j,
};
$$;

query T
select js_typeof(false, 1::smallint, 2, 3, 4.5, 6.7, 8.9, 'abc', '\x010203', '{"key": 1}');
----
{"bigint": "number", "boolean": "boolean", "bytea": "object", "decimal": "bigdecimal", "float": "number", "int": "number", "jsonb": "object", "real": "number", "smallint": "number", "varchar": "string"}

statement ok
drop function js_typeof;


statement ok
create function return_all(a boolean, b smallint, c int, d bigint, e real, f float, g decimal, h varchar, i bytea, j jsonb, s struct<f1 int, f2 int>)
returns struct<a boolean, b smallint, c int, d bigint, e real, f float, g decimal, h varchar, i bytea, j jsonb, s struct<f1 int, f2 int>>
language javascript as $$
return {a,b,c,d,e,f,g,h,i,j,s};
$$;

query T
select (return_all(
true,
1 ::smallint,
1,
1,
1,
1,
12345678901234567890.12345678,
'string',
'bytes',
'{"key":1}',
row(1, 2)::struct<f1 int, f2 int>
)).*;
----
t 1 1 1 1 1 12345678901234567890.12345678 string \x6279746573 {"key": 1} (1,2)

statement ok
drop function return_all;


statement ok
create function series(n int) returns table (x int) language javascript as $$
for(let i = 0; i < n; i++) {
yield i;
}
$$;

query I
select series(5);
----
0
1
2
3
4

statement ok
drop function series;


statement ok
create function split(s varchar) returns table (word varchar, length int) language javascript as $$
for(let word of s.split(' ')) {
yield { word: word, length: word.length };
}
$$;

query IT
select * from split('rising wave');
----
rising 6
wave 4

statement ok
drop function split;
2 changes: 1 addition & 1 deletion e2e_test/udf/wasm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ edition = "2021"
crate-type = ["cdylib"]

[dependencies]
arrow-udf = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "f9a9e0d" }
arrow-udf = "0.1"
genawaiter = "0.99"
rust_decimal = "1"
serde_json = "1"
5 changes: 3 additions & 2 deletions proto/catalog.proto
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,12 @@ message Function {
uint32 database_id = 3;
string name = 4;
uint32 owner = 9;
repeated string arg_names = 15;
repeated data.DataType arg_types = 5;
data.DataType return_type = 6;
string language = 7;
string link = 8;
string identifier = 10;
optional string link = 8;
optional string identifier = 10;
optional string body = 14;

oneof kind {
Expand Down
14 changes: 10 additions & 4 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -471,21 +471,27 @@ message WindowFunction {
message UserDefinedFunction {
repeated ExprNode children = 1;
string name = 2;
repeated string arg_names = 8;
repeated data.DataType arg_types = 3;
string language = 4;
// For external UDF: the link to the external function service.
// For WASM UDF: the link to the wasm binary file.
string link = 5;
optional string link = 5;
// An unique identifier for the function.
// For external UDF, it's the name of the function in the external function service.
// For WASM UDF, it's the name of the function in the wasm binary file.
string identifier = 6;
// For JavaScript UDF, it's the name of the function.
optional string identifier = 6;
// For JavaScript UDF, it's the body of the function.
optional string body = 7;
}

// Additional information for user defined table functions.
message UserDefinedTableFunction {
repeated string arg_names = 8;
repeated data.DataType arg_types = 3;
string language = 4;
string link = 5;
string identifier = 6;
optional string link = 5;
optional string identifier = 6;
optional string body = 7;
}
1 change: 1 addition & 0 deletions src/expr/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ normal = ["workspace-hack", "ctor"]
anyhow = "1"
arrow-array = { workspace = true }
arrow-schema = { workspace = true }
arrow-udf-js = { workspace = true }
arrow-udf-wasm = { workspace = true }
async-trait = "0.1"
auto_impl = "1"
Expand Down
33 changes: 28 additions & 5 deletions src/expr/core/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::time::Duration;

use anyhow::Context;
use arrow_schema::{Field, Fields, Schema};
use arrow_udf_js::{CallMode, Runtime as JsRuntime};
use arrow_udf_wasm::Runtime as WasmRuntime;
use await_tree::InstrumentAwait;
use cfg_or_panic::cfg_or_panic;
Expand Down Expand Up @@ -61,6 +62,7 @@ const INITIAL_RETRY_COUNT: u8 = 16;
enum UdfImpl {
External(Arc<ArrowFlightUdfClient>),
Wasm(Arc<WasmRuntime>),
JavaScript(JsRuntime),
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -123,6 +125,7 @@ impl UserDefinedFunction {

let output: arrow_array::RecordBatch = match &self.imp {
UdfImpl::Wasm(runtime) => runtime.call(&self.identifier, &input)?,
UdfImpl::JavaScript(runtime) => runtime.call(&self.identifier, &input)?,
UdfImpl::External(client) => {
let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed);
let result = if disable_retry_count != 0 {
Expand Down Expand Up @@ -189,16 +192,36 @@ impl Build for UserDefinedFunction {
let return_type = DataType::from(prost.get_return_type().unwrap());
let udf = prost.get_rex_node().unwrap().as_udf().unwrap();

let identifier = udf.get_identifier()?;
let imp = match udf.language.as_str() {
"wasm" => {
let link = udf.get_link()?;
// Use `block_in_place` as an escape hatch to run async code here in sync context.
// Calling `block_on` directly will panic.
UdfImpl::Wasm(tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(get_or_create_wasm_runtime(&udf.link))
tokio::runtime::Handle::current().block_on(get_or_create_wasm_runtime(link))
})?)
}
_ => UdfImpl::External(get_or_create_flight_client(&udf.link)?),
"javascript" => {
let mut rt = JsRuntime::new()?;
let body = format!(
"export function {}({}) {{ {} }}",
identifier,
udf.arg_names.join(","),
udf.get_body()?
);
rt.add_function(
identifier,
arrow_schema::DataType::try_from(&return_type)?,
CallMode::CalledOnNullInput,
&body,
)?;
UdfImpl::JavaScript(rt)
}
_ => {
let link = udf.get_link()?;
UdfImpl::External(get_or_create_flight_client(link)?)
}
};

let arg_schema = Arc::new(Schema::new(
Expand All @@ -222,8 +245,8 @@ impl Build for UserDefinedFunction {
return_type,
arg_schema,
imp,
identifier: udf.identifier.clone(),
span: format!("udf_call({})", udf.identifier).into(),
identifier: identifier.clone(),
span: format!("udf_call({})", identifier).into(),
disable_retry_count: AtomicU8::new(0),
})
}
Expand Down
Loading

0 comments on commit 705be19

Please sign in to comment.