Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(udf): add initial support for JavaScript UDF #14513

Merged
merged 15 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 137 additions & 96 deletions Cargo.lock

Large diffs are not rendered by default.

23 changes: 12 additions & 11 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,17 @@ prost = { version = "0.12" }
icelake = { git = "https://github.com/icelake-io/icelake", rev = "3f7b53ba5b563524212c25810345d1314678e7fc", features = [
"prometheus",
] }
arrow-array = "49"
arrow-arith = "49"
arrow-cast = "49"
arrow-schema = "49"
arrow-buffer = "49"
arrow-flight = "49"
arrow-select = "49"
arrow-ord = "49"
arrow-row = "49"
arrow-udf-wasm = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "f9a9e0d" }
arrow-array = "50"
arrow-arith = "50"
arrow-cast = "50"
arrow-schema = "50"
arrow-buffer = "50"
arrow-flight = "50"
arrow-select = "50"
arrow-ord = "50"
arrow-row = "50"
arrow-udf-js = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "70fae28" }
arrow-udf-wasm = "0.1"
arrow-array-deltalake = { package = "arrow-array", version = "48.0.1" }
arrow-buffer-deltalake = { package = "arrow-buffer", version = "48.0.1" }
arrow-cast-deltalake = { package = "arrow-cast", version = "48.0.1" }
Expand All @@ -142,7 +143,7 @@ arrow-schema-deltalake = { package = "arrow-schema", version = "48.0.1" }
deltalake = { git = "https://github.com/risingwavelabs/delta-rs", rev = "5c2dccd4640490202ffe98adbd13b09cef8e007b", features = [
"s3-no-concurrent-write",
] }
parquet = "49"
parquet = "50"
thiserror-ext = "0.0.11"
tikv-jemalloc-ctl = { git = "https://github.com/risingwavelabs/jemallocator.git", rev = "64a2d9" }
tikv-jemallocator = { git = "https://github.com/risingwavelabs/jemallocator.git", features = [
Expand Down
90 changes: 90 additions & 0 deletions e2e_test/udf/js_udf.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
statement ok
create function int_42() returns int language javascript as $$
return 42;
$$;

statement ok
create function gcd(a int, b int) returns int language javascript as $$
// required before we support `RETURNS NULL ON NULL INPUT`
if(a == null || b == null) {
return null;
}
while (b != 0) {
let t = b;
b = a % b;
a = t;
}
return a;
$$;

statement ok
create function to_string(a boolean, b smallint, c int, d bigint, e real, f float, g varchar, h bytea, i jsonb) returns varchar language javascript as $$
return a.toString() + b.toString() + c.toString() + d.toString() + e.toString() + f.toString() + g.toString() + h.toString() + JSON.stringify(i);
$$;

# show data types in javascript
statement ok
create function js_typeof(a boolean, b smallint, c int, d bigint, e real, f float, g varchar, h bytea, i jsonb) returns jsonb language javascript as $$
return {
boolean: typeof a,
smallint: typeof b,
int: typeof c,
bigint: typeof d,
real: typeof e,
float: typeof f,
varchar: typeof g,
bytea: typeof h,
jsonb: typeof i,
};
$$;

statement ok
create function series(n int) returns table (x int) language javascript as $$
for(let i = 0; i < n; i++) {
yield i;
}
$$;

query I
select int_42();
----
42

query I
select gcd(25, 15);
----
5

query T
select to_string(false, 1::smallint, 2, 3, 4.5, 6.7, 'abc', '\x010203', '{"key": 1}');
----
false1234.56.7abc1,2,3{"key":1}

query T
select js_typeof(false, 1::smallint, 2, 3, 4.5, 6.7, 'abc', '\x010203', '{"key": 1}');
----
{"bigint": "number", "boolean": "boolean", "bytea": "object", "float": "number", "int": "number", "jsonb": "object", "real": "number", "smallint": "number", "varchar": "string"}

query I
select series(5);
----
0
1
2
3
4

statement ok
drop function int_42;

statement ok
drop function gcd;

statement ok
drop function to_string;

statement ok
drop function series;

statement ok
drop function js_typeof;
2 changes: 1 addition & 1 deletion e2e_test/udf/wasm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ edition = "2021"
crate-type = ["cdylib"]

[dependencies]
arrow-udf = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "f9a9e0d" }
arrow-udf = "0.1"
genawaiter = "0.99"
rust_decimal = "1"
serde_json = "1"
5 changes: 3 additions & 2 deletions proto/catalog.proto
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,12 @@
uint32 database_id = 3;
string name = 4;
uint32 owner = 9;
repeated string arg_names = 15;
repeated data.DataType arg_types = 5;
data.DataType return_type = 6;
string language = 7;
string link = 8;
string identifier = 10;
optional string link = 8;

Check failure on line 220 in proto/catalog.proto

View workflow job for this annotation

GitHub Actions / Check breaking changes in Protobuf files

Field "8" on message "Function" moved from outside to inside a oneof.
optional string identifier = 10;

Check failure on line 221 in proto/catalog.proto

View workflow job for this annotation

GitHub Actions / Check breaking changes in Protobuf files

Field "10" on message "Function" moved from outside to inside a oneof.
optional string body = 14;

oneof kind {
Expand Down
14 changes: 10 additions & 4 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -468,21 +468,27 @@
message UserDefinedFunction {
repeated ExprNode children = 1;
string name = 2;
repeated string arg_names = 8;
repeated data.DataType arg_types = 3;
string language = 4;
// For external UDF: the link to the external function service.
// For WASM UDF: the link to the wasm binary file.
string link = 5;
optional string link = 5;

Check failure on line 476 in proto/expr.proto

View workflow job for this annotation

GitHub Actions / Check breaking changes in Protobuf files

Field "5" on message "UserDefinedFunction" moved from outside to inside a oneof.
// An unique identifier for the function.
// For external UDF, it's the name of the function in the external function service.
// For WASM UDF, it's the name of the function in the wasm binary file.
string identifier = 6;
// For JavaScript UDF, it's the name of the function.
optional string identifier = 6;

Check failure on line 481 in proto/expr.proto

View workflow job for this annotation

GitHub Actions / Check breaking changes in Protobuf files

Field "6" on message "UserDefinedFunction" moved from outside to inside a oneof.
// For JavaScript UDF, it's the body of the function.
optional string body = 7;
}

// Additional information for user defined table functions.
message UserDefinedTableFunction {
repeated string arg_names = 8;
repeated data.DataType arg_types = 3;
string language = 4;
string link = 5;
string identifier = 6;
optional string link = 5;

Check failure on line 491 in proto/expr.proto

View workflow job for this annotation

GitHub Actions / Check breaking changes in Protobuf files

Field "5" on message "UserDefinedTableFunction" moved from outside to inside a oneof.
optional string identifier = 6;

Check failure on line 492 in proto/expr.proto

View workflow job for this annotation

GitHub Actions / Check breaking changes in Protobuf files

Field "6" on message "UserDefinedTableFunction" moved from outside to inside a oneof.
optional string body = 7;
wangrunji0408 marked this conversation as resolved.
Show resolved Hide resolved
}
1 change: 1 addition & 0 deletions src/expr/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ normal = ["workspace-hack", "ctor"]
anyhow = "1"
arrow-array = { workspace = true }
arrow-schema = { workspace = true }
arrow-udf-js = { workspace = true }
arrow-udf-wasm = { workspace = true }
async-trait = "0.1"
auto_impl = "1"
Expand Down
33 changes: 28 additions & 5 deletions src/expr/core/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::time::Duration;

use anyhow::Context;
use arrow_schema::{Field, Fields, Schema};
use arrow_udf_js::{CallMode, Runtime as JsRuntime};
use arrow_udf_wasm::Runtime as WasmRuntime;
use await_tree::InstrumentAwait;
use cfg_or_panic::cfg_or_panic;
Expand Down Expand Up @@ -61,6 +62,7 @@ const INITIAL_RETRY_COUNT: u8 = 16;
enum UdfImpl {
External(Arc<ArrowFlightUdfClient>),
Wasm(Arc<WasmRuntime>),
JavaScript(JsRuntime),
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -123,6 +125,7 @@ impl UserDefinedFunction {

let output: arrow_array::RecordBatch = match &self.imp {
UdfImpl::Wasm(runtime) => runtime.call(&self.identifier, &input)?,
UdfImpl::JavaScript(runtime) => runtime.call(&self.identifier, &input)?,
UdfImpl::External(client) => {
let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed);
let result = if disable_retry_count != 0 {
Expand Down Expand Up @@ -189,16 +192,36 @@ impl Build for UserDefinedFunction {
let return_type = DataType::from(prost.get_return_type().unwrap());
let udf = prost.get_rex_node().unwrap().as_udf().unwrap();

let identifier = udf.get_identifier()?;
let imp = match udf.language.as_str() {
"wasm" => {
let link = udf.get_link()?;
// Use `block_in_place` as an escape hatch to run async code here in sync context.
// Calling `block_on` directly will panic.
UdfImpl::Wasm(tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(get_or_create_wasm_runtime(&udf.link))
tokio::runtime::Handle::current().block_on(get_or_create_wasm_runtime(link))
})?)
}
_ => UdfImpl::External(get_or_create_flight_client(&udf.link)?),
"javascript" => {
let mut rt = JsRuntime::new()?;
let body = format!(
"export function {}({}) {{ {} }}",
identifier,
udf.arg_names.join(","),
udf.get_body()?
);
rt.add_function(
identifier,
arrow_schema::DataType::try_from(&return_type)?,
CallMode::CalledOnNullInput,
&body,
wangrunji0408 marked this conversation as resolved.
Show resolved Hide resolved
)?;
UdfImpl::JavaScript(rt)
}
_ => {
let link = udf.get_link()?;
UdfImpl::External(get_or_create_flight_client(link)?)
}
};

let arg_schema = Arc::new(Schema::new(
Expand All @@ -222,8 +245,8 @@ impl Build for UserDefinedFunction {
return_type,
arg_schema,
imp,
identifier: udf.identifier.clone(),
span: format!("udf_call({})", udf.identifier).into(),
identifier: identifier.clone(),
span: format!("udf_call({})", identifier).into(),
disable_retry_count: AtomicU8::new(0),
})
}
Expand Down
43 changes: 35 additions & 8 deletions src/expr/core/src/table_function/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::sync::Arc;

use arrow_array::RecordBatch;
use arrow_schema::{Field, Fields, Schema, SchemaRef};
use arrow_udf_js::{CallMode, Runtime as JsRuntime};
use arrow_udf_wasm::Runtime as WasmRuntime;
use cfg_or_panic::cfg_or_panic;
use futures_util::stream;
Expand All @@ -42,6 +43,7 @@ pub struct UserDefinedTableFunction {
enum UdfImpl {
External(Arc<ArrowFlightUdfClient>),
Wasm(Arc<WasmRuntime>),
JavaScript(JsRuntime),
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -70,6 +72,11 @@ impl UdfImpl {
yield res?;
}
}
UdfImpl::JavaScript(runtime) => {
for res in runtime.call_table_function(identifier, &input, 1024)? {
yield res?;
}
}
UdfImpl::Wasm(runtime) => {
for res in runtime.call_table_function(identifier, &input)? {
yield res?;
Expand Down Expand Up @@ -177,28 +184,48 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result<Bo
.try_collect::<_, Fields, _>()?,
));

let identifier = udtf.get_identifier()?;
let return_type = DataType::from(prost.get_return_type()?);

let client = match udtf.language.as_str() {
"wasm" => {
let link = udtf.get_link()?;
// Use `block_in_place` as an escape hatch to run async code here in sync context.
// Calling `block_on` directly will panic.
UdfImpl::Wasm(tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(
crate::expr::expr_udf::get_or_create_wasm_runtime(&udtf.link),
)
tokio::runtime::Handle::current()
.block_on(crate::expr::expr_udf::get_or_create_wasm_runtime(link))
})?)
}
"javascript" => {
let mut rt = JsRuntime::new()?;
let body = format!(
"export function* {}({}) {{ {} }}",
identifier,
udtf.arg_names.join(","),
udtf.get_body()?
);
rt.add_function(
identifier,
arrow_schema::DataType::try_from(&return_type)?,
CallMode::CalledOnNullInput,
&body,
)?;
UdfImpl::JavaScript(rt)
}
// connect to UDF service
_ => UdfImpl::External(crate::expr::expr_udf::get_or_create_flight_client(
&udtf.link,
)?),
_ => {
let link = udtf.get_link()?;
UdfImpl::External(crate::expr::expr_udf::get_or_create_flight_client(link)?)
}
};

Ok(UserDefinedTableFunction {
children: prost.args.iter().map(expr_build_from_prost).try_collect()?,
return_type: prost.return_type.as_ref().expect("no return type").into(),
return_type,
arg_schema,
client,
identifier: udtf.identifier.clone(),
identifier: identifier.clone(),
chunk_size,
}
.boxed())
Expand Down
6 changes: 4 additions & 2 deletions src/frontend/src/catalog/function_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ pub struct FunctionCatalog {
pub name: String,
pub owner: u32,
pub kind: FunctionKind,
pub arg_names: Vec<String>,
pub arg_types: Vec<DataType>,
pub return_type: DataType,
pub language: String,
pub identifier: String,
pub identifier: Option<String>,
pub body: Option<String>,
pub link: String,
pub link: Option<String>,
}

#[derive(Clone, Display, PartialEq, Eq, Hash, Debug)]
Expand Down Expand Up @@ -60,6 +61,7 @@ impl From<&PbFunction> for FunctionCatalog {
name: prost.name.clone(),
owner: prost.owner,
kind: prost.kind.as_ref().unwrap().into(),
arg_names: prost.arg_names.clone(),
arg_types: prost.arg_types.iter().map(|arg| arg.into()).collect(),
return_type: prost.return_type.as_ref().expect("no return type").into(),
language: prost.language.clone(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl SysCatalogReaderImpl {
))),
Some(ScalarImpl::Int32(function.return_type.to_oid())),
Some(ScalarImpl::Utf8(function.language.clone().into())),
Some(ScalarImpl::Utf8(function.link.clone().into())),
function.link.clone().map(|s| ScalarImpl::Utf8(s.into())),
Some(ScalarImpl::Utf8(
get_acl_items(
&Object::FunctionId(function.id.function_id()),
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/src/expr/table_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ impl TableFunction {
.udtf_catalog
.as_ref()
.map(|c| UserDefinedTableFunctionPb {
arg_names: c.arg_names.clone(),
arg_types: c.arg_types.iter().map(|t| t.to_protobuf()).collect(),
language: c.language.clone(),
link: c.link.clone(),
identifier: c.identifier.clone(),
body: c.body.clone(),
}),
}
}
Expand Down
Loading
Loading