Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(udf): add initial support for JavaScript UDF #14513

Merged
merged 15 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ arrow-flight = "49"
arrow-select = "49"
arrow-ord = "49"
arrow-row = "49"
arrow-udf-wasm = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "f9a9e0d" }
arrow-udf-js = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "51d941d" }
arrow-udf-wasm = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "51d941d" }
arrow-array-deltalake = { package = "arrow-array", version = "48.0.1" }
arrow-buffer-deltalake = { package = "arrow-buffer", version = "48.0.1" }
arrow-cast-deltalake = { package = "arrow-cast", version = "48.0.1" }
Expand Down
52 changes: 52 additions & 0 deletions e2e_test/udf/js_udf.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
statement ok
create function int_42() returns int language javascript as $$
export function int_42() {
return 42;
}
$$;

statement ok
create function gcd(int, int) returns int language javascript as $$
export function gcd(a, b) {
if(a == null || b == null) {
return null;
}
while (b != 0) {
let t = b;
b = a % b;
a = t;
}
return a;
}
$$;

statement ok
create function to_string(boolean, smallint, int, bigint, real, float, varchar) returns varchar language javascript as $$
export function to_string(a, b, c, d, e, f, g) {
return a.toString() + b.toString() + c.toString() + d.toString() + e.toString() + f.toString() + g.toString();
}
$$;

query I
select int_42();
----
42

query I
select gcd(25, 15);
----
5

query T
select to_string(false, 1::smallint, 2, 3, 4.5, 6.7, 'abc');
----
false1234.56.7abc

statement ok
drop function int_42;

statement ok
drop function gcd;

statement ok
drop function to_string;
2 changes: 1 addition & 1 deletion e2e_test/udf/wasm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ edition = "2021"
crate-type = ["cdylib"]

[dependencies]
arrow-udf = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "f9a9e0d" }
arrow-udf = { git = "https://github.com/risingwavelabs/arrow-udf.git", rev = "51d941d" }
genawaiter = "0.99"
rust_decimal = "1"
serde_json = "1"
4 changes: 2 additions & 2 deletions proto/catalog.proto
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@
repeated data.DataType arg_types = 5;
data.DataType return_type = 6;
string language = 7;
string link = 8;
string identifier = 10;
optional string link = 8;

Check failure on line 219 in proto/catalog.proto

View workflow job for this annotation

GitHub Actions / Check breaking changes in Protobuf files

Field "8" on message "Function" moved from outside to inside a oneof.
optional string identifier = 10;

Check failure on line 220 in proto/catalog.proto

View workflow job for this annotation

GitHub Actions / Check breaking changes in Protobuf files

Field "10" on message "Function" moved from outside to inside a oneof.
optional string body = 14;

oneof kind {
Expand Down
12 changes: 8 additions & 4 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -472,17 +472,21 @@
string language = 4;
// For external UDF: the link to the external function service.
// For WASM UDF: the link to the wasm binary file.
string link = 5;
optional string link = 5;

Check failure on line 475 in proto/expr.proto

View workflow job for this annotation

GitHub Actions / Check breaking changes in Protobuf files

Field "5" on message "UserDefinedFunction" moved from outside to inside a oneof.
// An unique identifier for the function.
// For external UDF, it's the name of the function in the external function service.
// For WASM UDF, it's the name of the function in the wasm binary file.
string identifier = 6;
// For JavaScript UDF, it's the name of the function.
optional string identifier = 6;

Check failure on line 480 in proto/expr.proto

View workflow job for this annotation

GitHub Actions / Check breaking changes in Protobuf files

Field "6" on message "UserDefinedFunction" moved from outside to inside a oneof.
// For JavaScript UDF, it's the body of the function.
optional string body = 7;
}

// Additional information for user defined table functions.
message UserDefinedTableFunction {
repeated data.DataType arg_types = 3;
string language = 4;
string link = 5;
string identifier = 6;
optional string link = 5;

Check failure on line 489 in proto/expr.proto

View workflow job for this annotation

GitHub Actions / Check breaking changes in Protobuf files

Field "5" on message "UserDefinedTableFunction" moved from outside to inside a oneof.
optional string identifier = 6;

Check failure on line 490 in proto/expr.proto

View workflow job for this annotation

GitHub Actions / Check breaking changes in Protobuf files

Field "6" on message "UserDefinedTableFunction" moved from outside to inside a oneof.
optional string body = 7;
wangrunji0408 marked this conversation as resolved.
Show resolved Hide resolved
}
1 change: 1 addition & 0 deletions src/expr/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ normal = ["workspace-hack", "ctor"]
anyhow = "1"
arrow-array = { workspace = true }
arrow-schema = { workspace = true }
arrow-udf-js = { workspace = true }
arrow-udf-wasm = { workspace = true }
async-trait = "0.1"
auto_impl = "1"
Expand Down
27 changes: 22 additions & 5 deletions src/expr/core/src/expr/expr_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::time::Duration;

use anyhow::Context;
use arrow_schema::{Field, Fields, Schema};
use arrow_udf_js::{CallMode, Runtime as JsRuntime};
use arrow_udf_wasm::Runtime as WasmRuntime;
use await_tree::InstrumentAwait;
use cfg_or_panic::cfg_or_panic;
Expand Down Expand Up @@ -61,6 +62,7 @@ const INITIAL_RETRY_COUNT: u8 = 16;
enum UdfImpl {
External(Arc<ArrowFlightUdfClient>),
Wasm(Arc<WasmRuntime>),
JavaScript(JsRuntime),
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -123,6 +125,7 @@ impl UserDefinedFunction {

let output: arrow_array::RecordBatch = match &self.imp {
UdfImpl::Wasm(runtime) => runtime.call(&self.identifier, &input)?,
UdfImpl::JavaScript(runtime) => runtime.call(&self.identifier, &input)?,
UdfImpl::External(client) => {
let disable_retry_count = self.disable_retry_count.load(Ordering::Relaxed);
let result = if disable_retry_count != 0 {
Expand Down Expand Up @@ -189,16 +192,30 @@ impl Build for UserDefinedFunction {
let return_type = DataType::from(prost.get_return_type().unwrap());
let udf = prost.get_rex_node().unwrap().as_udf().unwrap();

let identifier = udf.get_identifier()?;
let imp = match udf.language.as_str() {
"wasm" => {
let link = udf.get_link()?;
// Use `block_in_place` as an escape hatch to run async code here in sync context.
// Calling `block_on` directly will panic.
UdfImpl::Wasm(tokio::task::block_in_place(|| {
tokio::runtime::Handle::current()
.block_on(get_or_create_wasm_runtime(&udf.link))
tokio::runtime::Handle::current().block_on(get_or_create_wasm_runtime(link))
})?)
}
_ => UdfImpl::External(get_or_create_flight_client(&udf.link)?),
"javascript" => {
let mut rt = JsRuntime::new()?;
rt.add_function(
identifier,
arrow_schema::DataType::try_from(&return_type)?,
CallMode::CalledOnNullInput,
udf.get_body()?,
)?;
UdfImpl::JavaScript(rt)
}
_ => {
let link = udf.get_link()?;
UdfImpl::External(get_or_create_flight_client(link)?)
}
wangrunji0408 marked this conversation as resolved.
Show resolved Hide resolved
};

let arg_schema = Arc::new(Schema::new(
Expand All @@ -222,8 +239,8 @@ impl Build for UserDefinedFunction {
return_type,
arg_schema,
imp,
identifier: udf.identifier.clone(),
span: format!("udf_call({})", udf.identifier).into(),
identifier: identifier.clone(),
span: format!("udf_call({})", identifier).into(),
disable_retry_count: AtomicU8::new(0),
})
}
Expand Down
12 changes: 5 additions & 7 deletions src/expr/core/src/table_function/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,28 +177,26 @@ pub fn new_user_defined(prost: &PbTableFunction, chunk_size: usize) -> Result<Bo
.try_collect::<_, Fields, _>()?,
));

let link = udtf.get_link()?;
let client = match udtf.language.as_str() {
"wasm" => {
// Use `block_in_place` as an escape hatch to run async code here in sync context.
// Calling `block_on` directly will panic.
UdfImpl::Wasm(tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(
crate::expr::expr_udf::get_or_create_wasm_runtime(&udtf.link),
)
tokio::runtime::Handle::current()
.block_on(crate::expr::expr_udf::get_or_create_wasm_runtime(link))
})?)
}
// connect to UDF service
_ => UdfImpl::External(crate::expr::expr_udf::get_or_create_flight_client(
&udtf.link,
)?),
_ => UdfImpl::External(crate::expr::expr_udf::get_or_create_flight_client(link)?),
};

Ok(UserDefinedTableFunction {
children: prost.args.iter().map(expr_build_from_prost).try_collect()?,
return_type: prost.return_type.as_ref().expect("no return type").into(),
arg_schema,
client,
identifier: udtf.identifier.clone(),
identifier: udtf.get_identifier()?.clone(),
chunk_size,
}
.boxed())
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/catalog/function_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ pub struct FunctionCatalog {
pub arg_types: Vec<DataType>,
pub return_type: DataType,
pub language: String,
pub identifier: String,
pub identifier: Option<String>,
pub body: Option<String>,
pub link: String,
pub link: Option<String>,
}

#[derive(Clone, Display, PartialEq, Eq, Hash, Debug)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl SysCatalogReaderImpl {
))),
Some(ScalarImpl::Int32(function.return_type.to_oid())),
Some(ScalarImpl::Utf8(function.language.clone().into())),
Some(ScalarImpl::Utf8(function.link.clone().into())),
function.link.clone().map(|s| ScalarImpl::Utf8(s.into())),
Some(ScalarImpl::Utf8(
get_acl_items(
&Object::FunctionId(function.id.function_id()),
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/expr/table_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ impl TableFunction {
language: c.language.clone(),
link: c.link.clone(),
identifier: c.identifier.clone(),
body: c.body.clone(),
}),
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/frontend/src/expr/user_defined_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@ impl UserDefinedFunction {
arg_types,
return_type,
language: udf.get_language().clone(),
identifier: udf.get_identifier().clone(),
// TODO: Ensure if we need `body` here
body: None,
link: udf.get_link().clone(),
identifier: udf.identifier.clone(),
body: udf.body.clone(),
link: udf.link.clone(),
};

Ok(Self {
Expand Down Expand Up @@ -90,6 +89,7 @@ impl Expr for UserDefinedFunction {
language: self.catalog.language.clone(),
identifier: self.catalog.identifier.clone(),
link: self.catalog.link.clone(),
body: self.catalog.body.clone(),
})),
}
}
Expand Down
Loading
Loading